Skip to content

Commit cedc870

Browse files
author
Michael Ivonin
committed
feat(#29): add GPU usage
1 parent 4760df5 commit cedc870

File tree

5 files changed

+74
-1
lines changed

5 files changed

+74
-1
lines changed

package-lock.json

Lines changed: 56 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"private": true,
77
"dependencies": {
88
"@tensorflow/tfjs-node": "^4.1.0",
9+
"@tensorflow/tfjs-node-gpu": "^4.1.0",
910
"tslib": "^2.4.1"
1011
},
1112
"devDependencies": {
@@ -30,4 +31,3 @@
3031
"typescript": "4.8.4"
3132
}
3233
}
33-

packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import { Metric } from '../testing/metric';
2222
import { MetricCalculator } from '../testing/metric-calculator';
2323
import { binarize } from '../utils/binarize';
2424
import { FeatureNormalizer } from '../feature-engineering/feature-normalizer';
25+
import { switchHardwareUsage } from '../utils/switch-hardware-usage';
2526

2627
export type BinaryClassificationTrainerOptions = {
28+
shouldUseGPU?: boolean;
2729
batchSize?: number;
2830
epochs?: number;
2931
patience?: number;
@@ -53,6 +55,8 @@ export class BinaryClassificationTrainer {
5355
protected static DEFAULT_PATIENCE: number = 20;
5456

5557
constructor(options: BinaryClassificationTrainerOptions) {
58+
switchHardwareUsage(options.shouldUseGPU);
59+
5660
this.batchSize = options.batchSize ?? BinaryClassificationTrainer.DEFAULT_BATCH_SIZE;
5761
this.epochs = options.epochs ?? BinaryClassificationTrainer.DEFAULT_EPOCHS;
5862
this.patience = options.patience ?? BinaryClassificationTrainer.DEFAULT_PATIENCE;

packages/tfjs-node-helpers/src/classification/binary-classifier.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import { LayersModel, loadLayersModel, tensor, Tensor } from '@tensorflow/tfjs-node';
2+
import { switchHardwareUsage } from '../utils/switch-hardware-usage';
23

34
export type BinaryClassifierOptions = {
45
model: LayersModel;
6+
shouldUseGPU?: boolean;
57
};
68

79
export class BinaryClassifier {
810
protected model?: LayersModel;
911

1012
constructor(options?: BinaryClassifierOptions) {
13+
switchHardwareUsage(options?.shouldUseGPU);
14+
1115
this.model = options?.model;
1216
}
1317

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export const switchHardwareUsage = (shouldUseGPU?: boolean): void => {
2+
if (shouldUseGPU) {
3+
console.log('Attempting to use GPU.');
4+
require('@tensorflow/tfjs-node-gpu');
5+
} else {
6+
console.log('Using CPU.');
7+
require('@tensorflow/tfjs-node');
8+
}
9+
};

0 commit comments

Comments
 (0)