@@ -6,29 +6,26 @@ import {
66 LayersModel ,
77 model ,
88 node ,
9- onesLike ,
109 Optimizer ,
1110 Scalar ,
1211 SymbolicTensor ,
1312 Tensor ,
14- TensorContainer ,
15- tidy ,
16- where ,
17- zerosLike
13+ TensorContainer
1814} from '@tensorflow/tfjs-node' ;
1915import { green , red } from 'chalk' ;
2016import { Table } from 'console-table-printer' ;
2117import { FeatureExtractor } from '../feature-engineering/feature-extractor' ;
2218import { prepareDatasetsForBinaryClassification } from '../feature-engineering/prepare-datasets-for-binary-classification' ;
2319import { ConfusionMatrix } from '../testing/confusion-matrix' ;
2420import { Metrics } from '../testing/metrics' ;
21+ import { binarize } from '../utils/binarize' ;
2522
2623export type BinaryClassificationTrainerOptions = {
27- batchSize : number ;
28- epochs : number ;
29- patience : number ;
30- inputFeatureExtractors : Array < FeatureExtractor < any , any > > ;
31- outputFeatureExtractor : FeatureExtractor < any , any > ;
24+ batchSize ? : number ;
25+ epochs ? : number ;
26+ patience ? : number ;
27+ inputFeatureExtractors ? : Array < FeatureExtractor < any , any > > ;
28+ outputFeatureExtractor ? : FeatureExtractor < any , any > ;
3229 model ?: LayersModel ;
3330 hiddenLayers ?: Array < layers . Layer > ;
3431 optimizer ?: string | Optimizer ;
@@ -40,60 +37,37 @@ export class BinaryClassificationTrainer {
4037 protected epochs : number ;
4138 protected patience : number ;
4239 protected tensorBoardLogsDirectory ?: string ;
43- protected inputFeatureExtractors : Array < FeatureExtractor < any , any > > ;
44- protected outputFeatureExtractor : FeatureExtractor < any , any > ;
40+ protected inputFeatureExtractors ? : Array < FeatureExtractor < any , any > > ;
41+ protected outputFeatureExtractor ? : FeatureExtractor < any , any > ;
4542 protected model ! : LayersModel ;
4643
44+ protected static DEFAULT_BATCH_SIZE : number = 32 ;
45+ protected static DEFAULT_EPOCHS : number = 1000 ;
46+ protected static DEFAULT_PATIENCE : number = 20 ;
47+
4748 constructor ( options : BinaryClassificationTrainerOptions ) {
48- this . batchSize = options . batchSize ;
49- this . epochs = options . epochs ;
50- this . patience = options . patience ;
49+ this . batchSize = options . batchSize ?? BinaryClassificationTrainer . DEFAULT_BATCH_SIZE ;
50+ this . epochs = options . epochs ?? BinaryClassificationTrainer . DEFAULT_EPOCHS ;
51+ this . patience = options . patience ?? BinaryClassificationTrainer . DEFAULT_PATIENCE ;
5152 this . tensorBoardLogsDirectory = options . tensorBoardLogsDirectory ;
5253 this . inputFeatureExtractors = options . inputFeatureExtractors ;
5354 this . outputFeatureExtractor = options . outputFeatureExtractor ;
5455
55- if ( options . model !== undefined ) {
56- this . model = options . model ;
57- } else {
58- if ( options . hiddenLayers !== undefined && options . inputFeatureExtractors !== undefined ) {
59- const inputLayer = input ( { shape : [ options . inputFeatureExtractors . length ] } ) ;
60- let symbolicTensor = inputLayer ;
61-
62- options . hiddenLayers . forEach ( ( layer ) => {
63- symbolicTensor = layer . apply ( symbolicTensor ) as SymbolicTensor ;
64- } ) ;
65-
66- const outputLayer = layers
67- . dense ( { units : 1 , activation : 'sigmoid' } )
68- . apply ( symbolicTensor ) as SymbolicTensor ;
69-
70- this . model = model ( {
71- inputs : inputLayer ,
72- outputs : outputLayer
73- } ) ;
74- } else {
75- throw new Error ( 'hiddenLayers and inputFeaturesCount options are required when the model is not provided!' ) ;
76- }
77- }
78-
79- this . model . compile ( {
80- optimizer : options . optimizer ?? 'adam' ,
81- loss : 'binaryCrossentropy'
82- } ) ;
56+ this . initializeModel ( options ) ;
8357 }
8458
8559 public async trainAndTest ( {
8660 data,
8761 trainingDataset,
8862 validationDataset,
8963 testingDataset,
90- printResults
64+ printTestingResults
9165 } : {
9266 data ?: Array < any > ,
9367 trainingDataset ?: data . Dataset < TensorContainer > ;
9468 validationDataset ?: data . Dataset < TensorContainer > ;
9569 testingDataset ?: data . Dataset < TensorContainer > ;
96- printResults ?: boolean ;
70+ printTestingResults ?: boolean ;
9771 } ) : Promise < {
9872 loss : number ;
9973 confusionMatrix : ConfusionMatrix ;
@@ -111,7 +85,15 @@ export class BinaryClassificationTrainer {
11185 callbacks . push ( node . tensorBoard ( this . tensorBoardLogsDirectory ) ) ;
11286 }
11387
114- if ( trainingDataset === undefined || validationDataset === undefined || testingDataset === undefined ) {
88+ if (
89+ trainingDataset === undefined ||
90+ validationDataset === undefined ||
91+ testingDataset === undefined
92+ ) {
93+ if ( this . inputFeatureExtractors === undefined || this . outputFeatureExtractor === undefined ) {
94+ throw new Error ( 'trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!' ) ;
95+ }
96+
11597 const datasets = await prepareDatasetsForBinaryClassification ( {
11698 data : data as Array < any > ,
11799 inputFeatureExtractors : this . inputFeatureExtractors ,
@@ -130,19 +112,50 @@ export class BinaryClassificationTrainer {
130112 callbacks
131113 } ) ;
132114
133- return await this . test ( { testingDataset, printResults } ) ;
115+ return await this . test ( { testingDataset, printTestingResults } ) ;
134116 }
135117
136118 public async save ( path : string ) : Promise < void > {
137119 await this . model . save ( `file://${ path } ` ) ;
138120 }
139121
122+ private initializeModel ( options : BinaryClassificationTrainerOptions ) : void {
123+ if ( options . model !== undefined ) {
124+ this . model = options . model ;
125+ } else {
126+ if ( options . hiddenLayers !== undefined && options . inputFeatureExtractors !== undefined ) {
127+ const inputLayer = input ( { shape : [ options . inputFeatureExtractors . length ] } ) ;
128+ let symbolicTensor = inputLayer ;
129+
130+ for ( const layer of options . hiddenLayers ) {
131+ symbolicTensor = layer . apply ( symbolicTensor ) as SymbolicTensor ;
132+ }
133+
134+ const outputLayer = layers
135+ . dense ( { units : 1 , activation : 'sigmoid' } )
136+ . apply ( symbolicTensor ) as SymbolicTensor ;
137+
138+ this . model = model ( {
139+ inputs : inputLayer ,
140+ outputs : outputLayer
141+ } ) ;
142+ } else {
143+ throw new Error ( 'hiddenLayers and inputFeatureExtractors options are required when the model is not provided!' ) ;
144+ }
145+ }
146+
147+ this . model . compile ( {
148+ optimizer : options . optimizer ?? 'adam' ,
149+ loss : 'binaryCrossentropy'
150+ } ) ;
151+ }
152+
140153 private async test ( {
141154 testingDataset,
142- printResults
155+ printTestingResults
143156 } : {
144157 testingDataset : data . Dataset < TensorContainer > ;
145- printResults ?: boolean ;
158+ printTestingResults ?: boolean ;
146159 } ) : Promise < {
147160 loss : number ;
148161 confusionMatrix : ConfusionMatrix ;
@@ -151,23 +164,24 @@ export class BinaryClassificationTrainer {
151164 const lossTensor = ( await this . model . evaluateDataset ( testingDataset as data . Dataset < any > , { } ) ) as Scalar ;
152165 const [ loss ] = await lossTensor . data ( ) ;
153166
154- const testingData = ( await testingDataset . toArray ( ) ) as Array < {
167+ const [ testingData ] = ( await testingDataset . toArray ( ) ) as Array < {
155168 xs : Tensor ;
156169 ys : Tensor ;
157170 } > ;
158- const testXs = testingData [ 0 ] . xs ;
159- const testYs = testingData [ 0 ] . ys ;
171+
172+ const testXs = testingData . xs ;
173+ const testYs = testingData . ys ;
160174
161175 const predictions = this . model . predict ( testXs ) as Tensor ;
162- const binarizedPredictions = this . binarize ( predictions ) ;
176+ const binarizedPredictions = binarize ( predictions ) ;
163177
164- const trueValues = ( await testYs . data ( ) ) as Float32Array ;
165- const predictedValues = ( await binarizedPredictions . data ( ) ) as Float32Array ;
178+ const trueValues = await testYs . data < 'float32' > ( ) ;
179+ const predictedValues = await binarizedPredictions . data < 'float32' > ( ) ;
166180
167181 const confusionMatrix = this . calculateConfusionMatrix ( trueValues , predictedValues ) ;
168182 const metrics = this . calculateMetrics ( confusionMatrix ) ;
169183
170- if ( printResults ) {
184+ if ( printTestingResults ) {
171185 this . printTestResults ( loss , confusionMatrix , metrics ) ;
172186 }
173187
@@ -301,12 +315,4 @@ export class BinaryClassificationTrainer {
301315
302316 metricsTable . printTable ( ) ;
303317 }
304-
305- private binarize ( tensor : Tensor , threshold = 0.5 ) : Tensor {
306- return tidy ( ( ) => {
307- const condition = tensor . greater ( threshold ) ;
308-
309- return where ( condition , onesLike ( tensor ) , zerosLike ( tensor ) ) ;
310- } ) ;
311- }
312318}
0 commit comments