@@ -61,12 +61,19 @@ class TimeCNNClassifier(BaseDeepClassifier):
61
61
The number of samples per gradient update.
62
62
verbose : boolean, default = False
63
63
Whether to output extra information.
64
- loss : string, default = "mean_squared_error"
65
- Fit parameter for the keras model.
66
- optimizer : keras.optimizer, default = keras.optimizers.Adam()
67
- metrics : list of strings, default = ["accuracy"]
68
- callbacks : keras.callbacks, default = model_checkpoint
69
- To save best model on training loss.
64
+ loss : str, default = "mean_squared_error"
65
+ The name of the keras training loss.
66
+ optimizer : keras.optimizer, default = tf.keras.optimizers.Adam()
67
+ The keras optimizer used for training.
68
+ metrics : str or list[str], default="accuracy"
69
+ The evaluation metrics to use during training. If
70
+ a single string metric is provided, it will be
71
+ used as the only metric. If a list of metrics are
72
+ provided, all will be used for evaluation.
73
+ callbacks : keras callback or list of callbacks,
74
+ default = None
75
+ The default list of callbacks are set to
76
+ ModelCheckpoint.
70
77
file_path : file_path for the best model
71
78
Only used if checkpoint is used as callback.
72
79
save_best_model : bool, default = False
@@ -131,7 +138,7 @@ def __init__(
131
138
init_file_name = "init_model" ,
132
139
verbose = False ,
133
140
loss = "mean_squared_error" ,
134
- metrics = None ,
141
+ metrics = "accuracy" ,
135
142
random_state = None ,
136
143
use_bias = True ,
137
144
optimizer = None ,
@@ -201,18 +208,13 @@ def build_model(self, input_shape, n_classes, **kwargs):
201
208
import numpy as np
202
209
import tensorflow as tf
203
210
204
- if self .metrics is None :
205
- metrics = ["accuracy" ]
206
- else :
207
- metrics = self .metrics
208
-
209
211
rng = check_random_state (self .random_state )
210
212
self .random_state_ = rng .randint (0 , np .iinfo (np .int32 ).max )
211
213
tf .keras .utils .set_random_seed (self .random_state_ )
212
214
input_layer , output_layer = self ._network .build_network (input_shape , ** kwargs )
213
215
214
216
output_layer = tf .keras .layers .Dense (
215
- units = n_classes , activation = self .activation , use_bias = self . use_bias
217
+ units = n_classes , activation = self .activation
216
218
)(output_layer )
217
219
218
220
self .optimizer_ = (
@@ -223,7 +225,7 @@ def build_model(self, input_shape, n_classes, **kwargs):
223
225
model .compile (
224
226
loss = self .loss ,
225
227
optimizer = self .optimizer_ ,
226
- metrics = metrics ,
228
+ metrics = self . _metrics ,
227
229
)
228
230
229
231
return model
@@ -249,6 +251,11 @@ def _fit(self, X, y):
249
251
# Transpose to conform to Keras input style.
250
252
X = X .transpose (0 , 2 , 1 )
251
253
254
+ if isinstance (self .metrics , list ):
255
+ self ._metrics = self .metrics
256
+ elif isinstance (self .metrics , str ):
257
+ self ._metrics = [self .metrics ]
258
+
252
259
self .input_shape = X .shape [1 :]
253
260
self .training_model_ = self .build_model (self .input_shape , self .n_classes_ )
254
261
0 commit comments