Skip to content

Commit cf10ce9

Browse files
[ENH] Add DisjointCNN classifier and Regressor (#2316)
* add network * add to init * test input list for kernel init * fix bug test network * fix bug test * adding deep classifier * update api * fix test * add regressor and refactor * no test for mpl * bug copying cls to rgs * add rs as self
1 parent c267ab8 commit cf10ce9

26 files changed

+1665
-355
lines changed

aeon/classification/deep_learning/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
"TapNetClassifier",
1313
"LITETimeClassifier",
1414
"IndividualLITEClassifier",
15+
"DisjointCNNClassifier",
1516
]
1617
from aeon.classification.deep_learning._cnn import TimeCNNClassifier
18+
from aeon.classification.deep_learning._disjoint_cnn import DisjointCNNClassifier
1719
from aeon.classification.deep_learning._encoder import EncoderClassifier
1820
from aeon.classification.deep_learning._fcn import FCNClassifier
1921
from aeon.classification.deep_learning._inception_time import (

aeon/classification/deep_learning/_cnn.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,19 @@ class TimeCNNClassifier(BaseDeepClassifier):
6161
The number of samples per gradient update.
6262
verbose : boolean, default = False
6363
Whether to output extra information.
64-
loss : string, default = "mean_squared_error"
65-
Fit parameter for the keras model.
66-
optimizer : keras.optimizer, default = keras.optimizers.Adam()
67-
metrics : list of strings, default = ["accuracy"]
68-
callbacks : keras.callbacks, default = model_checkpoint
69-
To save best model on training loss.
64+
loss : str, default = "mean_squared_error"
65+
The name of the keras training loss.
66+
optimizer : keras.optimizer, default = tf.keras.optimizers.Adam()
67+
The keras optimizer used for training.
68+
metrics : str or list[str], default="accuracy"
69+
The evaluation metrics to use during training. If
70+
a single string metric is provided, it will be
71+
used as the only metric. If a list of metrics are
72+
provided, all will be used for evaluation.
73+
callbacks : keras callback or list of callbacks,
74+
default = None
75+
The default list of callbacks are set to
76+
ModelCheckpoint.
7077
file_path : file_path for the best model
7178
Only used if checkpoint is used as callback.
7279
save_best_model : bool, default = False
@@ -131,7 +138,7 @@ def __init__(
131138
init_file_name="init_model",
132139
verbose=False,
133140
loss="mean_squared_error",
134-
metrics=None,
141+
metrics="accuracy",
135142
random_state=None,
136143
use_bias=True,
137144
optimizer=None,
@@ -201,18 +208,13 @@ def build_model(self, input_shape, n_classes, **kwargs):
201208
import numpy as np
202209
import tensorflow as tf
203210

204-
if self.metrics is None:
205-
metrics = ["accuracy"]
206-
else:
207-
metrics = self.metrics
208-
209211
rng = check_random_state(self.random_state)
210212
self.random_state_ = rng.randint(0, np.iinfo(np.int32).max)
211213
tf.keras.utils.set_random_seed(self.random_state_)
212214
input_layer, output_layer = self._network.build_network(input_shape, **kwargs)
213215

214216
output_layer = tf.keras.layers.Dense(
215-
units=n_classes, activation=self.activation, use_bias=self.use_bias
217+
units=n_classes, activation=self.activation
216218
)(output_layer)
217219

218220
self.optimizer_ = (
@@ -223,7 +225,7 @@ def build_model(self, input_shape, n_classes, **kwargs):
223225
model.compile(
224226
loss=self.loss,
225227
optimizer=self.optimizer_,
226-
metrics=metrics,
228+
metrics=self._metrics,
227229
)
228230

229231
return model
@@ -249,6 +251,11 @@ def _fit(self, X, y):
249251
# Transpose to conform to Keras input style.
250252
X = X.transpose(0, 2, 1)
251253

254+
if isinstance(self.metrics, list):
255+
self._metrics = self.metrics
256+
elif isinstance(self.metrics, str):
257+
self._metrics = [self.metrics]
258+
252259
self.input_shape = X.shape[1:]
253260
self.training_model_ = self.build_model(self.input_shape, self.n_classes_)
254261

0 commit comments

Comments
 (0)