From e7e0139e16e4a4b9101f41af1ba7fff11c9fbc97 Mon Sep 17 00:00:00 2001 From: Ali El Hadi ISMAIL FAWAZ <54309336+hadifawaz1999@users.noreply.github.com> Date: Sat, 27 Jul 2024 21:58:40 +0200 Subject: [PATCH] [ENH,MNT,DOC] Tidying up deep learning modules clasification/regression (#1826) * initial init * add save load test * fix bug * add test to regression as well * maintainer * fix bug tmp * add file path bug fix * Update _cnn.py * Update _cnn.py * Update _cnn.py * Update _cnn.py * Update _cnn.py * Update _cnn.py * Update _cnn.py --- aeon/classification/deep_learning/__init__.py | 3 +- aeon/classification/deep_learning/_cnn.py | 348 ++++++++++++++++- aeon/classification/deep_learning/_encoder.py | 12 + aeon/classification/deep_learning/_fcn.py | 17 +- .../deep_learning/_inception_time.py | 39 +- .../deep_learning/_lite_time.py | 35 +- aeon/classification/deep_learning/_mlp.py | 16 +- aeon/classification/deep_learning/_resnet.py | 16 +- aeon/classification/deep_learning/_tapnet.py | 4 +- aeon/classification/deep_learning/base.py | 2 +- .../tests/test_deep_classifier_base.py | 2 +- .../tests/test_random_state_deep_learning.py | 42 ++- .../test_saving_loading_deep_learning_cls.py | 83 ++++ aeon/clustering/deep_learning/_ae_fcn.py | 2 +- aeon/clustering/deep_learning/_ae_resnet.py | 2 +- .../tests/test_deep_clusterer_base.py | 2 +- aeon/networks/__init__.py | 3 +- aeon/networks/_ae_bgru.py | 4 +- aeon/networks/_ae_fcn.py | 2 +- aeon/networks/_ae_resnet.py | 5 +- aeon/networks/_cnn.py | 177 ++++++++- aeon/networks/_encoder.py | 2 +- aeon/networks/_fcn.py | 5 +- aeon/networks/_inception.py | 4 +- aeon/networks/_lite.py | 5 +- aeon/networks/_mlp.py | 5 +- aeon/networks/_resnet.py | 5 +- aeon/networks/_tapnet.py | 4 +- aeon/regression/deep_learning/__init__.py | 5 +- aeon/regression/deep_learning/_cnn.py | 356 +++++++++++++++++- aeon/regression/deep_learning/_encoder.py | 14 +- aeon/regression/deep_learning/_fcn.py | 16 +- .../deep_learning/_inception_time.py | 27 +- aeon/regression/deep_learning/_lite_time.py | 27 +- aeon/regression/deep_learning/_mlp.py | 16 +- aeon/regression/deep_learning/_resnet.py | 16 +- aeon/regression/deep_learning/_tapnet.py | 4 +- .../tests/test_deep_regressor_base.py | 2 +- .../test_saving_loading_deep_learning_cls.py | 79 ++++ docs/api_reference/classification.rst | 2 + docs/api_reference/clustering.rst | 1 + docs/api_reference/distances.rst | 3 + docs/api_reference/networks.rst | 4 + docs/api_reference/regression.rst | 4 +- docs/api_reference/visualisation.rst | 1 + 45 files changed, 1332 insertions(+), 91 deletions(-) create mode 100644 aeon/classification/deep_learning/tests/test_saving_loading_deep_learning_cls.py create mode 100644 aeon/regression/deep_learning/tests/test_saving_loading_deep_learning_cls.py diff --git a/aeon/classification/deep_learning/__init__.py b/aeon/classification/deep_learning/__init__.py index 26ff063855..ba36690cfe 100644 --- a/aeon/classification/deep_learning/__init__.py +++ b/aeon/classification/deep_learning/__init__.py @@ -3,6 +3,7 @@ __all__ = [ "BaseDeepClassifier", "CNNClassifier", + "TimeCNNClassifier", "EncoderClassifier", "FCNClassifier", "InceptionTimeClassifier", @@ -13,7 +14,7 @@ "LITETimeClassifier", "IndividualLITEClassifier", ] -from aeon.classification.deep_learning._cnn import CNNClassifier +from aeon.classification.deep_learning._cnn import CNNClassifier, TimeCNNClassifier from aeon.classification.deep_learning._encoder import EncoderClassifier from aeon.classification.deep_learning._fcn import FCNClassifier from aeon.classification.deep_learning._inception_time import ( diff --git a/aeon/classification/deep_learning/_cnn.py b/aeon/classification/deep_learning/_cnn.py index 15fa496f3a..6f8a13a5ae 100644 --- a/aeon/classification/deep_learning/_cnn.py +++ b/aeon/classification/deep_learning/_cnn.py @@ -1,19 +1,27 @@ -"""Time Convolutional Neural Network (CNN) for classification.""" +"""Time Convolutional Neural Network (CNN) classifier.""" -__maintainer__ = [] -__all__ = ["CNNClassifier"] +__maintainer__ = ["hadifawaz1999"] +__all__ = ["CNNClassifier", "TimeCNNClassifier"] import gc import os import time from copy import deepcopy +from deprecated.sphinx import deprecated from sklearn.utils import check_random_state from aeon.classification.deep_learning.base import BaseDeepClassifier -from aeon.networks import CNNNetwork +from aeon.networks import CNNNetwork, TimeCNNNetwork +# TODO: remove v0.12.0 +@deprecated( + version="0.10.0", + reason="CNNClassifier has been renamed to TimeCNNClassifier" + "and will be removed in 0.12.0.", + category=FutureWarning, +) class CNNClassifier(BaseDeepClassifier): """ Time Convolutional Neural Network (CNN). @@ -76,12 +84,17 @@ class CNNClassifier(BaseDeepClassifier): save_last_model : bool, default = False Whether to save the last model, last epoch trained, using the base class method save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded. last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. Notes ----- @@ -120,8 +133,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", verbose=False, loss="mean_squared_error", metrics=None, @@ -144,7 +159,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.verbose = verbose self.loss = loss self.metrics = metrics @@ -243,6 +260,329 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + + if self.verbose: + self.training_model_.summary() + + self.file_name_ = ( + self.best_file_name if self.save_best_model else str(time.time_ns()) + ) + + if self.callbacks is None: + self.callbacks_ = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=self.file_path + self.file_name_ + ".keras", + monitor="loss", + save_best_only=True, + ), + ] + else: + self.callbacks_ = self._get_model_checkpoint_callback( + callbacks=self.callbacks, + file_path=self.file_path, + file_name=self.file_name_, + ) + + self.history = self.training_model_.fit( + X, + y_onehot, + batch_size=self.batch_size, + epochs=self.n_epochs, + verbose=self.verbose, + callbacks=self.callbacks_, + ) + + try: + self.model_ = tf.keras.models.load_model( + self.file_path + self.file_name_ + ".keras", compile=False + ) + if not self.save_best_model: + os.remove(self.file_path + self.file_name_ + ".keras") + except FileNotFoundError: + self.model_ = deepcopy(self.training_model_) + + if self.save_last_model: + self.save_last_model_to_file(file_path=self.file_path) + + gc.collect() + return self + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default = "default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return "default" set. + For classifiers, a "default" set of parameters should be provided for + general testing, and a "results_comparison" set for comparing against + previously recorded results if the general set does not produce suitable + probabilities to compare against. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class. + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params`. + """ + param1 = { + "n_epochs": 10, + "batch_size": 4, + "avg_pool_size": 4, + } + + test_params = [param1] + + return test_params + + +class TimeCNNClassifier(BaseDeepClassifier): + """ + Time Convolutional Neural Network (CNN). + + Adapted from the implementation used in [1]_. + + Parameters + ---------- + n_layers : int, default = 2 + The number of convolution layers in the network. + kernel_size : int or list of int, default = 7 + Kernel size of convolution layers, if not a list, the same kernel size + is used for all layer, len(list) should be n_layers. + n_filters : int or list of int, default = [6, 12] + Number of filters for each convolution layer, if not a list, the same n_filters + is used in all layers. + avg_pool_size : int or list of int, default = 3 + The size of the average pooling layer, if not a list, the same + max pooling size is used for all convolution layer. + activation : str or list of str, default = "sigmoid" + Keras activation function used in the model for each layer, if not a list, + the same activation is used for all layers. + padding : str or list of str, default = 'valid' + The method of padding in convolution layers, if not a list, the same padding + used for all convolution layers. + strides : int or list of int, default = 1 + The strides of kernels in the convolution and max pooling layers, if not a + list, the same strides are used for all layers. + dilation_rate : int or list of int, default = 1 + The dilation rate of the convolution layers, if not a list, the same dilation + rate is used all over the network. + use_bias : bool or list of bool, default = True + Condition on whether to use bias values for convolution layers, + if not a list, the same condition is used for all layers. + random_state : int, RandomState instance or None, default=None + If `int`, random_state is the seed used by the random number generator; + If `RandomState` instance, random_state is the random number generator; + If `None`, the random number generator is the `RandomState` instance used + by `np.random`. + Seeded random number generation can only be guaranteed on CPU processing, + GPU processing will be non-deterministic. + n_epochs : int, default = 2000 + The number of epochs to train the model. + batch_size : int, default = 16 + The number of samples per gradient update. + verbose : boolean, default = False + Whether to output extra information. + loss : string, default = "mean_squared_error" + Fit parameter for the keras model. + optimizer : keras.optimizer, default = keras.optimizers.Adam() + metrics : list of strings, default = ["accuracy"] + callbacks : keras.callbacks, default = model_checkpoint + To save best model on training loss. + file_path : file_path for the best model + Only used if checkpoint is used as callback. + save_best_model : bool, default = False + Whether to save the best model, if the modelcheckpoint callback is used by + default, this condition, if True, will prevent the automatic deletion of the + best saved model from file and the user can choose the file name. + save_last_model : bool, default = False + Whether to save the last model, last epoch trained, using the base class method + save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. + best_file_name : str, default = "best_model" + The name of the file of the best model, if save_best_model is set to False, + this parameter is discarded. + last_file_name : str, default = "last_model" + The name of the file of the last model, if save_last_model is set to False, + this parameter is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. + + Notes + ----- + Adapted from the implementation from Fawaz et. al + https://github.com/hfawaz/dl-4-tsc/blob/master/classifiers/cnn.py + + References + ---------- + .. [1] Zhao et. al, Convolutional neural networks for time series classification, + Journal of Systems Engineering and Electronics, 28(1):2017. + + Examples + -------- + >>> from aeon.classification.deep_learning import TimeCNNClassifier + >>> from aeon.datasets import load_unit_test + >>> X_train, y_train = load_unit_test(split="train") + >>> X_test, y_test = load_unit_test(split="test") + >>> cnn = TimeCNNClassifier(n_epochs=20, batch_size=4) # doctest: +SKIP + >>> cnn.fit(X_train, y_train) # doctest: +SKIP + TimeCNNClassifier(...) + """ + + def __init__( + self, + n_layers=2, + kernel_size=7, + n_filters=None, + avg_pool_size=3, + activation="sigmoid", + padding="valid", + strides=1, + dilation_rate=1, + n_epochs=2000, + batch_size=16, + callbacks=None, + file_path="./", + save_best_model=False, + save_last_model=False, + save_init_model=False, + best_file_name="best_model", + last_file_name="last_model", + init_file_name="init_model", + verbose=False, + loss="mean_squared_error", + metrics=None, + random_state=None, + use_bias=True, + optimizer=None, + ): + self.n_layers = n_layers + self.kernel_size = kernel_size + self.n_filters = n_filters + self.padding = padding + self.strides = strides + self.dilation_rate = dilation_rate + self.avg_pool_size = avg_pool_size + self.activation = activation + self.use_bias = use_bias + + self.n_epochs = n_epochs + self.callbacks = callbacks + self.file_path = file_path + self.save_best_model = save_best_model + self.save_last_model = save_last_model + self.save_init_model = save_init_model + self.best_file_name = best_file_name + self.init_file_name = init_file_name + self.verbose = verbose + self.loss = loss + self.metrics = metrics + self.optimizer = optimizer + + self.history = None + + super().__init__( + batch_size=batch_size, + random_state=random_state, + last_file_name=last_file_name, + ) + + self._network = TimeCNNNetwork( + n_layers=self.n_layers, + kernel_size=self.kernel_size, + n_filters=self.n_filters, + avg_pool_size=self.avg_pool_size, + activation=self.activation, + padding=self.padding, + strides=self.strides, + dilation_rate=self.dilation_rate, + use_bias=self.use_bias, + ) + + def build_model(self, input_shape, n_classes, **kwargs): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d, m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m, d). This method also assumes (m, d). Transpose should + happen in fit. + + Parameters + ---------- + input_shape : tuple + The shape of the data fed into the input layer, should be (m, d) + n_classes : int + The number of classes, which becomes the size of the output layer + + Returns + ------- + output : a compiled Keras Model + """ + import numpy as np + import tensorflow as tf + + if self.metrics is None: + metrics = ["accuracy"] + else: + metrics = self.metrics + + rng = check_random_state(self.random_state) + self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) + tf.keras.utils.set_random_seed(self.random_state_) + input_layer, output_layer = self._network.build_network(input_shape, **kwargs) + + output_layer = tf.keras.layers.Dense( + units=n_classes, activation=self.activation, use_bias=self.use_bias + )(output_layer) + + self.optimizer_ = ( + tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer + ) + + model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) + model.compile( + loss=self.loss, + optimizer=self.optimizer_, + metrics=metrics, + ) + + return model + + def _fit(self, X, y): + """Fit the classifier on the training set (X, y). + + Parameters + ---------- + X : np.ndarray + The training input samples of shape (n_cases, n_channels, n_timepoints) + y : np.ndarray + The training data class labels of shape (n_cases,). + + + Returns + ------- + self : object + """ + import tensorflow as tf + + y_onehot = self.convert_y_to_keras(y) + # Transpose to conform to Keras input style. + X = X.transpose(0, 2, 1) + + self.input_shape = X.shape[1:] + self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_encoder.py b/aeon/classification/deep_learning/_encoder.py index 7f839456f2..2765c4cbbe 100644 --- a/aeon/classification/deep_learning/_encoder.py +++ b/aeon/classification/deep_learning/_encoder.py @@ -52,6 +52,8 @@ class EncoderClassifier(BaseDeepClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -60,6 +62,9 @@ class EncoderClassifier(BaseDeepClassifier): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -102,8 +107,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", verbose=False, loss="categorical_crossentropy", metrics=None, @@ -124,7 +131,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.n_epochs = n_epochs self.verbose = verbose self.loss = loss @@ -225,6 +234,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_fcn.py b/aeon/classification/deep_learning/_fcn.py index e77457646b..ea14036b87 100644 --- a/aeon/classification/deep_learning/_fcn.py +++ b/aeon/classification/deep_learning/_fcn.py @@ -1,6 +1,6 @@ -"""Fully Convolutional Network (FCN) for classification.""" +"""Fully Convolutional Network (FCN) classifier.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["FCNClassifier"] import gc @@ -69,6 +69,8 @@ class FCNClassifier(BaseDeepClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -77,6 +79,10 @@ class FCNClassifier(BaseDeepClassifier): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if + save_init_model is set to False, + this parameter is discarded. callbacks : keras.callbacks, default = None Notes @@ -112,8 +118,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", n_epochs=2000, batch_size=16, use_mini_batch_size=False, @@ -145,7 +153,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.history = None @@ -238,6 +248,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_inception_time.py b/aeon/classification/deep_learning/_inception_time.py index d5cc3cec56..6b377d565a 100644 --- a/aeon/classification/deep_learning/_inception_time.py +++ b/aeon/classification/deep_learning/_inception_time.py @@ -1,6 +1,6 @@ -"""InceptionTime classifier.""" +"""InceptionTime and Inception classifiers.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["InceptionTimeClassifier"] import gc @@ -103,6 +103,8 @@ class InceptionTimeClassifier(BaseClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -111,6 +113,9 @@ class InceptionTimeClassifier(BaseClassifier): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -181,8 +186,10 @@ def __init__( file_path="./", save_last_model=False, save_best_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -218,8 +225,10 @@ def __init__( self.save_last_model = save_last_model self.save_best_model = save_best_model + self.save_init_model = save_init_model self.best_file_name = best_file_name self.last_file_name = last_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.random_state = random_state @@ -229,7 +238,7 @@ def __init__( self.metrics = metrics self.optimizer = optimizer - self.classifers_ = [] + self.classifiers_ = [] super().__init__() @@ -247,7 +256,7 @@ def _fit(self, X, y): ------- self : object """ - self.classifers_ = [] + self.classifiers_ = [] rng = check_random_state(self.random_state) for n in range(0, self.n_classifiers): @@ -269,8 +278,10 @@ def _fit(self, X, y): file_path=self.file_path, save_best_model=self.save_best_model, save_last_model=self.save_last_model, + save_init_model=self.save_init_model, best_file_name=self.best_file_name + str(n), last_file_name=self.last_file_name + str(n), + init_file_name=self.init_file_name + str(n), batch_size=self.batch_size, use_mini_batch_size=self.use_mini_batch_size, n_epochs=self.n_epochs, @@ -282,7 +293,7 @@ def _fit(self, X, y): verbose=self.verbose, ) cls.fit(X, y) - self.classifers_.append(cls) + self.classifiers_.append(cls) gc.collect() return self @@ -323,7 +334,7 @@ def _predict_proba(self, X) -> np.ndarray: """ probs = np.zeros((X.shape[0], self.n_classes_)) - for cls in self.classifers_: + for cls in self.classifiers_: probs += cls._predict_proba(X) probs = probs / self.n_classifiers @@ -437,14 +448,19 @@ class IndividualInceptionClassifier(BaseDeepClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter - is discarded + is discarded. last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter - is discarded + is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -504,8 +520,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -538,7 +556,9 @@ def __init__( self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.verbose = verbose @@ -652,6 +672,9 @@ def _fit(self, X, y): mini_batch_size = self.batch_size self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_lite_time.py b/aeon/classification/deep_learning/_lite_time.py index ef2479a949..f16e136d71 100644 --- a/aeon/classification/deep_learning/_lite_time.py +++ b/aeon/classification/deep_learning/_lite_time.py @@ -1,6 +1,6 @@ -"""LITETime classifier.""" +"""LITETime and LITE classifiers.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["LITETimeClassifier"] import gc @@ -60,6 +60,8 @@ class LITETimeClassifier(BaseClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -68,6 +70,9 @@ class LITETimeClassifier(BaseClassifier): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -120,8 +125,10 @@ def __init__( file_path="./", save_last_model=False, save_best_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -146,8 +153,10 @@ def __init__( self.save_last_model = save_last_model self.save_best_model = save_best_model + self.save_init_model = save_init_model self.best_file_name = best_file_name self.last_file_name = last_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.random_state = random_state @@ -157,7 +166,7 @@ def __init__( self.metrics = metrics self.optimizer = optimizer - self.classifers_ = [] + self.classifiers_ = [] super().__init__() @@ -175,7 +184,7 @@ def _fit(self, X, y): ------- self : object """ - self.classifers_ = [] + self.classifiers_ = [] rng = check_random_state(self.random_state) for n in range(0, self.n_classifiers): @@ -185,8 +194,10 @@ def _fit(self, X, y): file_path=self.file_path, save_best_model=self.save_best_model, save_last_model=self.save_last_model, + save_init_model=self.save_init_model, best_file_name=self.best_file_name + str(n), last_file_name=self.last_file_name + str(n), + init_file_name=self.init_file_name + str(n), batch_size=self.batch_size, use_mini_batch_size=self.use_mini_batch_size, n_epochs=self.n_epochs, @@ -198,7 +209,7 @@ def _fit(self, X, y): verbose=self.verbose, ) cls.fit(X, y) - self.classifers_.append(cls) + self.classifiers_.append(cls) gc.collect() return self @@ -239,7 +250,7 @@ def _predict_proba(self, X) -> np.ndarray: """ probs = np.zeros((X.shape[0], self.n_classes_)) - for cls in self.classifers_: + for cls in self.classifiers_: probs += cls._predict_proba(X) probs = probs / self.n_classifiers @@ -318,6 +329,8 @@ class IndividualLITEClassifier(BaseDeepClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -326,6 +339,9 @@ class IndividualLITEClassifier(BaseDeepClassifier): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -369,8 +385,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -393,7 +411,9 @@ def __init__( self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.verbose = verbose @@ -495,6 +515,9 @@ def _fit(self, X, y): mini_batch_size = self.batch_size self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_mlp.py b/aeon/classification/deep_learning/_mlp.py index 1624aa9904..48eb8f711e 100644 --- a/aeon/classification/deep_learning/_mlp.py +++ b/aeon/classification/deep_learning/_mlp.py @@ -1,6 +1,6 @@ -"""Multi Layer Perceptron Network (MLP) for classification.""" +"""Multi Layer Perceptron Network (MLP) classifier.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["MLPClassifier"] import gc @@ -51,6 +51,8 @@ class MLPClassifier(BaseDeepClassifier): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -59,6 +61,9 @@ class MLPClassifier(BaseDeepClassifier): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. optimizer : keras.optimizer, default=keras.optimizers.Adadelta(), metrics : list of strings, default=["accuracy"], activation : string or a tf callable, default="sigmoid" @@ -101,8 +106,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", random_state=None, activation="sigmoid", use_bias=True, @@ -119,7 +126,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.optimizer = optimizer self.history = None @@ -204,6 +213,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_resnet.py b/aeon/classification/deep_learning/_resnet.py index ccd6d11f0d..963faec26b 100644 --- a/aeon/classification/deep_learning/_resnet.py +++ b/aeon/classification/deep_learning/_resnet.py @@ -1,6 +1,6 @@ -"""Residual Network (ResNet) for classification.""" +"""Residual Network (ResNet) classifier.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["ResNetClassifier"] import gc @@ -75,12 +75,17 @@ class ResNetClassifier(BaseDeepClassifier): save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded. last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. verbose : boolean, default = False whether to output extra information loss : string, default = "mean_squared_error" @@ -131,8 +136,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", optimizer=None, ): self.n_residual_blocks = n_residual_blocks @@ -153,7 +160,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.optimizer = optimizer self.history = None @@ -250,6 +259,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/classification/deep_learning/_tapnet.py b/aeon/classification/deep_learning/_tapnet.py index 4bfe51e4c2..2ad048827b 100644 --- a/aeon/classification/deep_learning/_tapnet.py +++ b/aeon/classification/deep_learning/_tapnet.py @@ -1,6 +1,6 @@ -"""Time Convolutional Neural Network (CNN) for classification.""" +"""Time series Attentional Prototype Network (TapNet) Classifier.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = [ "TapNetClassifier", ] diff --git a/aeon/classification/deep_learning/base.py b/aeon/classification/deep_learning/base.py index 45974298a0..837945eead 100644 --- a/aeon/classification/deep_learning/base.py +++ b/aeon/classification/deep_learning/base.py @@ -5,7 +5,7 @@ because we can generalise tags, _predict and _predict_proba """ -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["BaseDeepClassifier"] from abc import ABC, abstractmethod diff --git a/aeon/classification/deep_learning/tests/test_deep_classifier_base.py b/aeon/classification/deep_learning/tests/test_deep_classifier_base.py index 203c256eef..e0c4157c7a 100644 --- a/aeon/classification/deep_learning/tests/test_deep_classifier_base.py +++ b/aeon/classification/deep_learning/tests/test_deep_classifier_base.py @@ -10,7 +10,7 @@ from aeon.testing.data_generation import make_example_2d_numpy_collection from aeon.utils.validation._dependencies import _check_soft_dependencies -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] class _DummyDeepClassifier(BaseDeepClassifier): diff --git a/aeon/classification/deep_learning/tests/test_random_state_deep_learning.py b/aeon/classification/deep_learning/tests/test_random_state_deep_learning.py index 4647de8b00..b426169cce 100644 --- a/aeon/classification/deep_learning/tests/test_random_state_deep_learning.py +++ b/aeon/classification/deep_learning/tests/test_random_state_deep_learning.py @@ -12,35 +12,37 @@ __maintainer__ = ["hadifawaz1999"] +_deep_cls_classes = [ + member[1] for member in inspect.getmembers(deep_learning, inspect.isclass) +] + + @pytest.mark.skipif( not _check_soft_dependencies(["tensorflow"], severity="none"), reason="skip test if required soft dependency not available", ) -def test_random_state_deep_learning_cls(): +@pytest.mark.parametrize("deep_cls", _deep_cls_classes) +def test_random_state_deep_learning_cls(deep_cls): """Test Deep Classifier seeding.""" - random_state = 42 - - X, y = make_example_3d_numpy(random_state=random_state) - - deep_cls_classes = [ - member[1] for member in inspect.getmembers(deep_learning, inspect.isclass) - ] - - for i in range(len(deep_cls_classes)): - if ( - "BaseDeepClassifier" in str(deep_cls_classes[i]) - or "InceptionTimeClassifier" in str(deep_cls_classes[i]) - or "LITETimeClassifier" in str(deep_cls_classes[i]) - or "TapNetClassifier" in str(deep_cls_classes[i]) - ): - continue - - deep_cls1 = deep_cls_classes[i](random_state=random_state, n_epochs=4) + if not ( + deep_cls.__name__ + in [ + "BaseDeepClassifier", + "InceptionTimeClassifier", + "LITETimeClassifier", + "TapNetClassifier", + ] + ): + random_state = 42 + + X, y = make_example_3d_numpy(random_state=random_state) + + deep_cls1 = deep_cls(random_state=random_state, n_epochs=4) deep_cls1.fit(X, y) layers1 = deep_cls1.training_model_.layers[1:] - deep_cls2 = deep_cls_classes[i](random_state=random_state, n_epochs=4) + deep_cls2 = deep_cls(random_state=random_state, n_epochs=4) deep_cls2.fit(X, y) layers2 = deep_cls2.training_model_.layers[1:] diff --git a/aeon/classification/deep_learning/tests/test_saving_loading_deep_learning_cls.py b/aeon/classification/deep_learning/tests/test_saving_loading_deep_learning_cls.py new file mode 100644 index 0000000000..d90393a369 --- /dev/null +++ b/aeon/classification/deep_learning/tests/test_saving_loading_deep_learning_cls.py @@ -0,0 +1,83 @@ +"""Unit tests for classifiers deep learners save/load functionalities.""" + +import inspect +import os +import tempfile +import time + +import numpy as np +import pytest + +from aeon.classification import deep_learning +from aeon.testing.data_generation import make_example_3d_numpy +from aeon.utils.validation._dependencies import _check_soft_dependencies + +__maintainer__ = ["hadifawaz1999"] + + +_deep_cls_classes = [ + member[1] for member in inspect.getmembers(deep_learning, inspect.isclass) +] + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="skip test if required soft dependency not available", +) +@pytest.mark.parametrize("deep_cls", _deep_cls_classes) +def test_saving_loading_deep_learning_cls(deep_cls): + """Test Deep Classifier saving.""" + with tempfile.TemporaryDirectory() as tmp: + if not ( + deep_cls.__name__ + in [ + "BaseDeepClassifier", + "InceptionTimeClassifier", + "LITETimeClassifier", + "TapNetClassifier", + ] + ): + if tmp[-1] != "/": + tmp = tmp + "/" + curr_time = str(time.time_ns()) + last_file_name = curr_time + "last" + best_file_name = curr_time + "best" + init_file_name = curr_time + "init" + + X, y = make_example_3d_numpy() + + deep_cls_train = deep_cls( + n_epochs=2, + save_best_model=True, + save_last_model=True, + save_init_model=True, + best_file_name=best_file_name, + last_file_name=last_file_name, + init_file_name=init_file_name, + file_path=tmp, + ) + deep_cls_train.fit(X, y) + + deep_cls_best = deep_cls() + deep_cls_best.load_model( + model_path=os.path.join(tmp, best_file_name + ".keras"), + classes=np.unique(y), + ) + ypred_best = deep_cls_best.predict(X) + assert len(ypred_best) == len(y) + + deep_cls_last = deep_cls() + deep_cls_last.load_model( + model_path=os.path.join(tmp, last_file_name + ".keras"), + classes=np.unique(y), + ) + ypred_last = deep_cls_last.predict(X) + assert len(ypred_last) == len(y) + + deep_cls_init = deep_cls() + deep_cls_init.load_model( + model_path=os.path.join(tmp, init_file_name + ".keras"), + classes=np.unique(y), + ) + ypred_init = deep_cls_init.predict(X) + assert len(ypred_init) == len(y) diff --git a/aeon/clustering/deep_learning/_ae_fcn.py b/aeon/clustering/deep_learning/_ae_fcn.py index 322a614a01..a9f33751ce 100644 --- a/aeon/clustering/deep_learning/_ae_fcn.py +++ b/aeon/clustering/deep_learning/_ae_fcn.py @@ -1,6 +1,6 @@ """Deep Learning Auto-Encoder using FCN Network.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["AEFCNClusterer"] import gc diff --git a/aeon/clustering/deep_learning/_ae_resnet.py b/aeon/clustering/deep_learning/_ae_resnet.py index 19ab56549c..2d1ccf13e0 100644 --- a/aeon/clustering/deep_learning/_ae_resnet.py +++ b/aeon/clustering/deep_learning/_ae_resnet.py @@ -1,6 +1,6 @@ """Residual Network (ResNet) for clustering.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["AEResNetClusterer"] import gc diff --git a/aeon/clustering/deep_learning/tests/test_deep_clusterer_base.py b/aeon/clustering/deep_learning/tests/test_deep_clusterer_base.py index cc8e952e85..421aa69a85 100644 --- a/aeon/clustering/deep_learning/tests/test_deep_clusterer_base.py +++ b/aeon/clustering/deep_learning/tests/test_deep_clusterer_base.py @@ -9,7 +9,7 @@ from aeon.testing.mock_estimators import MockDeepClusterer from aeon.utils.validation._dependencies import _check_soft_dependencies -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] @pytest.mark.skipif( diff --git a/aeon/networks/__init__.py b/aeon/networks/__init__.py index b7ec3ca75a..df56152dea 100644 --- a/aeon/networks/__init__.py +++ b/aeon/networks/__init__.py @@ -4,6 +4,7 @@ "BaseDeepNetwork", "BaseDeepLearningNetwork", "CNNNetwork", + "TimeCNNNetwork", "EncoderNetwork", "FCNNetwork", "InceptionNetwork", @@ -18,7 +19,7 @@ from aeon.networks._ae_bgru import AEBiGRUNetwork from aeon.networks._ae_fcn import AEFCNNetwork from aeon.networks._ae_resnet import AEResNetNetwork -from aeon.networks._cnn import CNNNetwork +from aeon.networks._cnn import CNNNetwork, TimeCNNNetwork from aeon.networks._encoder import EncoderNetwork from aeon.networks._fcn import FCNNetwork from aeon.networks._inception import InceptionNetwork diff --git a/aeon/networks/_ae_bgru.py b/aeon/networks/_ae_bgru.py index 7fc2616f5f..5e2e78a71e 100644 --- a/aeon/networks/_ae_bgru.py +++ b/aeon/networks/_ae_bgru.py @@ -1,4 +1,6 @@ -"""Implement Auto-Encoder based on Bidirectional GRUs.""" +"""Auto-Encoder using Bidirectional GRU Network (AEBiGRUNetwork).""" + +__maintainer__ = ["aadya940", "hadifawaz1999"] from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_ae_fcn.py b/aeon/networks/_ae_fcn.py index 37bbdf1aa5..1cef7d5c31 100644 --- a/aeon/networks/_ae_fcn.py +++ b/aeon/networks/_ae_fcn.py @@ -1,6 +1,6 @@ """Auto-Encoder using Fully Convolutional Network (FCN).""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] import numpy as np diff --git a/aeon/networks/_ae_resnet.py b/aeon/networks/_ae_resnet.py index 3b817efe8e..4578dd8092 100644 --- a/aeon/networks/_ae_resnet.py +++ b/aeon/networks/_ae_resnet.py @@ -1,4 +1,7 @@ -"""Residual Network (ResNet) (minus the final output layer).""" +"""Auto-Encoder using Residual Network (AEResNetNetwork).""" + +__maintainer__ = ["hadifawaz1999"] + import numpy as np diff --git a/aeon/networks/_cnn.py b/aeon/networks/_cnn.py index cb40b8f373..aafb9875c4 100644 --- a/aeon/networks/_cnn.py +++ b/aeon/networks/_cnn.py @@ -1,10 +1,19 @@ -"""Time Convolutional Neural Network (CNN) (minus the final output layer).""" +"""Time Convolutional Neural Network (TimeCNNNetwork).""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] + +from deprecated.sphinx import deprecated from aeon.networks.base import BaseDeepLearningNetwork +# TODO: remove v0.12.0 +@deprecated( + version="0.10.0", + reason="CNNNetwork has been renamed to TimeCNNNetwork" + "and will be removed in 0.12.0.", + category=FutureWarning, +) class CNNNetwork(BaseDeepLearningNetwork): """Establish the network structure for a CNN. @@ -167,3 +176,167 @@ def build_network(self, input_shape, **kwargs): flatten_layer = tf.keras.layers.Flatten()(conv) return input_layer, flatten_layer + + +class TimeCNNNetwork(BaseDeepLearningNetwork): + """Establish the network structure for a CNN. + + Adapted from the implementation used in [1]_. + + Parameters + ---------- + n_layers : int, default = 2 + The number of convolution layers in the network. + kernel_size : int or list of int, default = 7 + Kernel size of convolution layers, if not a list, the same kernel size is + used for all layer, len(list) should be n_layers. + n_filters : int or list of int, default = [6, 12] + Number of filters for each convolution layer, if not a list, the same + `n_filters` is used in all layers. + avg_pool_size : int or list of int, default = 3 + The size of the average pooling layer, if not a list, the same max pooling + size is used for all convolution layer. + activation : str or list of str, default = "sigmoid" + Keras activation function used in the model for each layer, if not a list, + the same activation is used for all layers. + padding : str or list of str, default = "valid" + The method of padding in convolution layers, if not a list, the same padding + used for all convolution layers. + strides : int or list of int, default = 1 + The strides of kernels in the convolution and max pooling layers, if not a list, + the same strides are used for all layers. + dilation_rate : int or list of int, default = 1 + The dilation rate of the convolution layers, if not a list, the same dilation + rate is used all over the network. + use_bias : bool or list of bool, default = True + Condition on whether or not to use bias values for convolution layers, if not + a list, the same condition is used for all layers. + + Notes + ----- + Adapted from source code + https://github.com/hfawaz/dl-4-tsc/blob/master/classifiers/cnn.py + + References + ---------- + .. [1] Zhao et al. Convolutional neural networks for time series classification, + Journal of Systems Engineering and Electronics 28(1), 162--169, 2017 + """ + + def __init__( + self, + n_layers=2, + kernel_size=7, + n_filters=None, + avg_pool_size=3, + activation="sigmoid", + padding="valid", + strides=1, + dilation_rate=1, + use_bias=True, + ): + self.n_layers = n_layers + self.n_filters = n_filters + self.kernel_size = kernel_size + self.avg_pool_size = avg_pool_size + self.activation = activation + self.padding = padding + self.strides = strides + self.dilation_rate = dilation_rate + self.use_bias = use_bias + + super().__init__() + + def build_network(self, input_shape, **kwargs): + """ + Construct a network and return its input and output layers. + + Parameters + ---------- + input_shape : tuple + The shape of the data fed into the input layer. + + Returns + ------- + input_layer : a keras layer + output_layer : a keras layer + """ + import tensorflow as tf + + self._n_filters_ = [6, 12] if self.n_filters is None else self.n_filters + + if isinstance(self.kernel_size, list): + assert len(self.kernel_size) == self.n_layers + self._kernel_size = self.kernel_size + else: + self._kernel_size = [self.kernel_size] * self.n_layers + + if isinstance(self._n_filters_, list): + assert len(self._n_filters_) == self.n_layers + self._n_filters = self._n_filters_ + else: + self._n_filters = [self._n_filters_] * self.n_layers + + if isinstance(self.avg_pool_size, list): + assert len(self.avg_pool_size) == self.n_layers + self._avg_pool_size = self.avg_pool_size + else: + self._avg_pool_size = [self.avg_pool_size] * self.n_layers + + if isinstance(self.activation, list): + assert len(self.activation) == self.n_layers + self._activation = self.activation + else: + self._activation = [self.activation] * self.n_layers + + if isinstance(self.padding, list): + assert len(self.padding) == self.n_layers + self._padding = self.padding + else: + self._padding = [self.padding] * self.n_layers + + if isinstance(self.strides, list): + assert len(self.strides) == self.n_layers + self._strides = self.strides + else: + self._strides = [self.strides] * self.n_layers + + if isinstance(self.dilation_rate, list): + assert len(self.dilation_rate) == self.n_layers + self._dilation_rate = self.dilation_rate + else: + self._dilation_rate = [self.dilation_rate] * self.n_layers + + if isinstance(self.use_bias, list): + assert len(self.use_bias) == self.n_layers + self._use_bias = self.use_bias + else: + self._use_bias = [self.use_bias] * self.n_layers + + input_layer = tf.keras.layers.Input(input_shape) + + if input_shape[0] < 60: + self._padding = ["same"] * self.n_layers + + x = input_layer + + for i in range(self.n_layers): + conv = tf.keras.layers.Conv1D( + filters=self._n_filters[i], + kernel_size=self._kernel_size[i], + strides=self._strides[i], + padding=self._padding[i], + dilation_rate=self._dilation_rate[i], + activation=self._activation[i], + use_bias=self._use_bias[i], + )(x) + + conv = tf.keras.layers.AveragePooling1D(pool_size=self._avg_pool_size[i])( + conv + ) + + x = conv + + flatten_layer = tf.keras.layers.Flatten()(conv) + + return input_layer, flatten_layer diff --git a/aeon/networks/_encoder.py b/aeon/networks/_encoder.py index cc600aa932..f01c6e99cd 100644 --- a/aeon/networks/_encoder.py +++ b/aeon/networks/_encoder.py @@ -1,4 +1,4 @@ -"""Encoder Classifier.""" +"""Encoder Network (EncoderNetwork).""" __maintainer__ = ["hadifawaz1999"] diff --git a/aeon/networks/_fcn.py b/aeon/networks/_fcn.py index 9b6c10de96..fa4e3f763a 100644 --- a/aeon/networks/_fcn.py +++ b/aeon/networks/_fcn.py @@ -1,6 +1,7 @@ -"""Fully Convolutional Network (FCN) (minus the final output layer).""" +"""Fully Convolutional Network (FCNNetwork).""" + +__maintainer__ = ["hadifawaz1999"] -__maintainer__ = [] from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_inception.py b/aeon/networks/_inception.py index 1b209dd8dc..4c8abaa449 100644 --- a/aeon/networks/_inception.py +++ b/aeon/networks/_inception.py @@ -1,6 +1,6 @@ -"""Inception Network.""" +"""Inception Network (InceptionNetwork).""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_lite.py b/aeon/networks/_lite.py index 64a5961ab4..df19fba0d0 100644 --- a/aeon/networks/_lite.py +++ b/aeon/networks/_lite.py @@ -1,6 +1,7 @@ -"""LITE Network.""" +"""LITE Network (LITENetwork).""" + +__maintainer__ = ["hadifawaz1999"] -__maintainer__ = [] from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_mlp.py b/aeon/networks/_mlp.py index 537fadce0b..84b1570ff7 100644 --- a/aeon/networks/_mlp.py +++ b/aeon/networks/_mlp.py @@ -1,6 +1,7 @@ -"""Multi Layer Perceptron (MLP) (minus the final output layer).""" +"""Multi Layer Perceptron Network (MLPNetwork).""" + +__maintainer__ = ["hadifawaz1999"] -__maintainer__ = [] from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_resnet.py b/aeon/networks/_resnet.py index 91aaf5b154..d1ab38883a 100644 --- a/aeon/networks/_resnet.py +++ b/aeon/networks/_resnet.py @@ -1,6 +1,7 @@ -"""Residual Network (ResNet) (minus the final output layer).""" +"""Residual Network (ResNetNetwork).""" + +__maintainer__ = ["hadifawaz1999"] -__maintainer__ = [] from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_tapnet.py b/aeon/networks/_tapnet.py index 4e06d942c3..720925f2d2 100644 --- a/aeon/networks/_tapnet.py +++ b/aeon/networks/_tapnet.py @@ -1,6 +1,6 @@ -"""Time Convolutional Neural Network (CNN) (minus the final output layer).""" +"""Time series Attentional Prototype Network (TapNetNetwork).""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] import math diff --git a/aeon/regression/deep_learning/__init__.py b/aeon/regression/deep_learning/__init__.py index 13077e96e8..eb69f1e0f0 100644 --- a/aeon/regression/deep_learning/__init__.py +++ b/aeon/regression/deep_learning/__init__.py @@ -1,7 +1,9 @@ """Deep learning based regressors.""" __all__ = [ + "BaseDeepRegressor", "CNNRegressor", + "TimeCNNRegressor", "FCNRegressor", "InceptionTimeRegressor", "IndividualInceptionRegressor", @@ -13,7 +15,7 @@ "MLPRegressor", ] -from aeon.regression.deep_learning._cnn import CNNRegressor +from aeon.regression.deep_learning._cnn import CNNRegressor, TimeCNNRegressor from aeon.regression.deep_learning._encoder import EncoderRegressor from aeon.regression.deep_learning._fcn import FCNRegressor from aeon.regression.deep_learning._inception_time import ( @@ -27,3 +29,4 @@ from aeon.regression.deep_learning._mlp import MLPRegressor from aeon.regression.deep_learning._resnet import ResNetRegressor from aeon.regression.deep_learning._tapnet import TapNetRegressor +from aeon.regression.deep_learning.base import BaseDeepRegressor diff --git a/aeon/regression/deep_learning/_cnn.py b/aeon/regression/deep_learning/_cnn.py index c5854b2438..988e823c63 100644 --- a/aeon/regression/deep_learning/_cnn.py +++ b/aeon/regression/deep_learning/_cnn.py @@ -1,19 +1,27 @@ -"""Time Convolutional Neural Network (CNN) for regression.""" +"""Time Convolutional Neural Network (TimeCNN) regressor.""" -__maintainer__ = [] -__all__ = ["CNNRegressor"] +__maintainer__ = ["hadifawaz1999"] +__all__ = ["CNNRegressor", "TimeCNNRegressor"] import gc import os import time from copy import deepcopy +from deprecated.sphinx import deprecated from sklearn.utils import check_random_state -from aeon.networks import CNNNetwork +from aeon.networks import CNNNetwork, TimeCNNNetwork from aeon.regression.deep_learning.base import BaseDeepRegressor +# TODO: remove v0.12.0 +@deprecated( + version="0.10.0", + reason="CNNRegressor has been renamed to TimeCNNRegressor" + "and will be removed in 0.12.0.", + category=FutureWarning, +) class CNNRegressor(BaseDeepRegressor): """Time Series Convolutional Neural Network (CNN). @@ -86,6 +94,8 @@ class CNNRegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -94,6 +104,9 @@ class CNNRegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. Notes ----- @@ -133,8 +146,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", verbose=False, loss="mse", output_activation="linear", @@ -151,7 +166,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.strides = strides self.dilation_rate = dilation_rate self.callbacks = callbacks @@ -253,6 +270,337 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + + if self.verbose: + self.training_model_.summary() + + self.file_name_ = ( + self.best_file_name if self.save_best_model else str(time.time_ns()) + ) + + if self.callbacks is None: + self.callbacks_ = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=self.file_path + self.file_name_ + ".keras", + monitor="loss", + save_best_only=True, + ), + ] + else: + self.callbacks_ = self._get_model_checkpoint_callback( + callbacks=self.callbacks, + file_path=self.file_path, + file_name=self.file_name_, + ) + + self.history = self.training_model_.fit( + X, + y, + batch_size=self.batch_size, + epochs=self.n_epochs, + verbose=self.verbose, + callbacks=self.callbacks_, + ) + + try: + self.model_ = tf.keras.models.load_model( + self.file_path + self.file_name_ + ".keras", compile=False + ) + if not self.save_best_model: + os.remove(self.file_path + self.file_name_ + ".keras") + except FileNotFoundError: + self.model_ = deepcopy(self.training_model_) + + if self.save_last_model: + self.save_last_model_to_file(file_path=self.file_path) + + gc.collect() + return self + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + For regressors, a "default" set of parameters should be provided for + general testing, and a "results_comparison" set for comparing against + previously recorded results if the general set does not produce suitable + probabilities to compare against. + + Returns + ------- + params : dict or list of dict, default={} + Parameters to create testing instances of the class. + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params`. + """ + param = { + "n_epochs": 10, + "batch_size": 4, + "avg_pool_size": 4, + } + + return [param] + + +class TimeCNNRegressor(BaseDeepRegressor): + """Time Series Convolutional Neural Network (CNN). + + Adapted from the implementation used in [1]_. + + Parameters + ---------- + n_layers : int, default = 2, + the number of convolution layers in the network + kernel_size : int or list of int, default = 7, + kernel size of convolution layers, if not a list, the same kernel size + is used for all layer, len(list) should be n_layers + n_filters : int or list of int, default = [6, 12], + number of filters for each convolution layer, if not a list, the same n_filters + is used in all layers. + avg_pool_size : int or list of int, default = 3, + the size of the average pooling layer, if not a list, the same + max pooling size is used + for all convolution layer + output_activation : str, default = "linear", + the output activation for the regressor + activation : str or list of str, default = "sigmoid", + keras activation function used in the model for each layer, + if not a list, the same + activation is used for all layers + padding : str or list of str, default = 'valid', + the method of padding in convolution layers, if not a list, + the same padding used + for all convolution layers + strides : int or list of int, default = 1, + the strides of kernels in the convolution and max pooling layers, + if not a list, the same strides are used for all layers + dilation_rate : int or list of int, default = 1, + the dilation rate of the convolution layers, if not a list, + the same dilation rate is used all over the network + use_bias : bool or list of bool, default = True, + condition on whether or not to use bias values for convolution layers, + if not a list, the same condition is used for all layers + random_state : int, RandomState instance or None, default=None + If `int`, random_state is the seed used by the random number generator; + If `RandomState` instance, random_state is the random number generator; + If `None`, the random number generator is the `RandomState` instance used + by `np.random`. + Seeded random number generation can only be guaranteed on CPU processing, + GPU processing will be non-deterministic. + n_epochs : int, default = 2000 + the number of epochs to train the model + batch_size : int, default = 16 + the number of samples per gradient update. + verbose : boolean, default = False + whether to output extra information + loss : string, default="mean_squared_error" + fit parameter for the keras model + optimizer : keras.optimizer, default=keras.optimizers.Adam(), + metrics : str or list of str, default="mean_squared_error" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + callbacks : keras.callbacks, default=model_checkpoint to save best + model on training loss + file_path : file_path for the best model (if checkpoint is used as callback) + save_best_model : bool, default = False + Whether or not to save the best model, if the + modelcheckpoint callback is used by default, + this condition, if True, will prevent the + automatic deletion of the best saved model from + file and the user can choose the file name + save_last_model : bool, default = False + Whether or not to save the last model, last + epoch trained, using the base class method + save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. + best_file_name : str, default = "best_model" + The name of the file of the best model, if + save_best_model is set to False, this parameter + is discarded + last_file_name : str, default = "last_model" + The name of the file of the last model, if + save_last_model is set to False, this parameter + is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. + + Notes + ----- + Adapted from the implementation from Fawaz et. al + https://github.com/hfawaz/dl-4-tsc/blob/master/classifiers/cnn.py + + References + ---------- + .. [1] Zhao et. al, Convolutional neural networks for time series classification, + Journal of Systems Engineering and Electronics, 28(1):2017. + + Examples + -------- + >>> from aeon.regression.deep_learning import TimeCNNRegressor + >>> from aeon.testing.data_generation import make_example_3d_numpy + >>> X, y = make_example_3d_numpy(n_cases=10, n_channels=1, n_timepoints=12, + ... return_y=True, regression_target=True, + ... random_state=0) + >>> rgs = TimeCNNRegressor(n_epochs=20, bacth_size=4) # doctest: +SKIP + >>> rgs.fit(X, y) # doctest: +SKIP + TimeCNNRegressor(...) + """ + + def __init__( + self, + n_layers=2, + kernel_size=7, + n_filters=None, + avg_pool_size=3, + activation="sigmoid", + padding="valid", + strides=1, + dilation_rate=1, + n_epochs=2000, + batch_size=16, + callbacks=None, + file_path="./", + save_best_model=False, + save_last_model=False, + save_init_model=False, + best_file_name="best_model", + last_file_name="last_model", + init_file_name="init_model", + verbose=False, + loss="mse", + output_activation="linear", + metrics="mean_squared_error", + random_state=None, + use_bias=True, + optimizer=None, + ): + self.n_layers = n_layers + self.avg_pool_size = avg_pool_size + self.padding = padding + self.n_filters = n_filters + self.kernel_size = kernel_size + self.file_path = file_path + self.save_best_model = save_best_model + self.save_last_model = save_last_model + self.save_init_model = save_init_model + self.best_file_name = best_file_name + self.init_file_name = init_file_name + self.strides = strides + self.dilation_rate = dilation_rate + self.callbacks = callbacks + self.n_epochs = n_epochs + self.verbose = verbose + self.loss = loss + self.output_activation = output_activation + self.metrics = metrics + self.random_state = random_state + self.activation = activation + self.use_bias = use_bias + self.optimizer = optimizer + + self.history = None + + super().__init__( + batch_size=batch_size, + last_file_name=last_file_name, + ) + + self._network = TimeCNNNetwork( + n_layers=self.n_layers, + kernel_size=self.kernel_size, + n_filters=self.n_filters, + avg_pool_size=self.avg_pool_size, + activation=self.activation, + padding=self.padding, + strides=self.strides, + dilation_rate=self.dilation_rate, + use_bias=self.use_bias, + ) + + def build_model(self, input_shape, **kwargs): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d,m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m,d). This method also assumes (m,d). Transpose should + happen in fit. + + Parameters + ---------- + input_shape : tuple + The shape of the data fed into the input layer, should be (m,d) + + Returns + ------- + output : a compiled Keras Model + """ + import numpy as np + import tensorflow as tf + from tensorflow import keras + + rng = check_random_state(self.random_state) + self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) + tf.keras.utils.set_random_seed(self.random_state_) + input_layer, output_layer = self._network.build_network(input_shape, **kwargs) + + output_layer = keras.layers.Dense(units=1, activation=self.output_activation)( + output_layer + ) + + self.optimizer_ = ( + keras.optimizers.Adam() if self.optimizer is None else self.optimizer + ) + + model = keras.models.Model(inputs=input_layer, outputs=output_layer) + + model.compile( + loss=self.loss, + optimizer=self.optimizer_, + metrics=self._metrics, + ) + return model + + def _fit(self, X, y): + """Fit the regressor on the training set (X, y). + + Parameters + ---------- + X : np.ndarray + The training input samples of shape (n_cases, n_channels, n_timepoints). + y : np.ndarray + The training data target values of shape (n_cases,). + + Returns + ------- + self : object + """ + import tensorflow as tf + + # Transpose to conform to Keras input style. + X = X.transpose(0, 2, 1) + + if isinstance(self.metrics, str): + self._metrics = [self.metrics] + else: + self._metrics = self.metrics + self.input_shape = X.shape[1:] + self.training_model_ = self.build_model(self.input_shape) + + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_encoder.py b/aeon/regression/deep_learning/_encoder.py index be65733d0f..4b73047c05 100644 --- a/aeon/regression/deep_learning/_encoder.py +++ b/aeon/regression/deep_learning/_encoder.py @@ -1,6 +1,6 @@ """Encoder Regressor.""" -__author__ = ["AnonymousCodes911"] +__author__ = ["AnonymousCodes911", "hadifawaz1999"] __all__ = ["EncoderRegressor"] import gc @@ -54,6 +54,8 @@ class EncoderRegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -62,6 +64,9 @@ class EncoderRegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. n_epochs: The number of times the entire training dataset will be passed forward and backward @@ -121,8 +126,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", verbose=False, loss="mean_squared_error", metrics="mean_squared_error", @@ -144,7 +151,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.n_epochs = n_epochs self.verbose = verbose self.loss = loss @@ -239,6 +248,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_fcn.py b/aeon/regression/deep_learning/_fcn.py index 7c136cae34..374feb5b93 100644 --- a/aeon/regression/deep_learning/_fcn.py +++ b/aeon/regression/deep_learning/_fcn.py @@ -1,6 +1,6 @@ -"""Fully Convolutional Network (FCN) for regression.""" +"""Fully Convolutional Network (FCN) regressor.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["FCNRegressor"] import gc @@ -75,6 +75,8 @@ class FCNRegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -83,6 +85,9 @@ class FCNRegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. callbacks : keras.callbacks, default = None Notes @@ -119,8 +124,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", n_epochs=2000, batch_size=16, use_mini_batch_size=False, @@ -153,7 +160,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.history = None @@ -239,6 +248,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_inception_time.py b/aeon/regression/deep_learning/_inception_time.py index dc14749d01..a5bd459545 100644 --- a/aeon/regression/deep_learning/_inception_time.py +++ b/aeon/regression/deep_learning/_inception_time.py @@ -1,6 +1,6 @@ -"""InceptionTime regressor.""" +"""InceptionTime and Inception regressors.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["InceptionTimeRegressor"] import gc @@ -107,6 +107,8 @@ class InceptionTimeRegressor(BaseRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -115,6 +117,9 @@ class InceptionTimeRegressor(BaseRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -187,8 +192,10 @@ def __init__( file_path="./", save_last_model=False, save_best_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -223,8 +230,10 @@ def __init__( self.save_last_model = save_last_model self.save_best_model = save_best_model + self.save_init_model = save_init_model self.best_file_name = best_file_name self.last_file_name = last_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.random_state = random_state @@ -275,8 +284,10 @@ def _fit(self, X, y): file_path=self.file_path, save_best_model=self.save_best_model, save_last_model=self.save_last_model, + save_init_model=self.save_init_model, best_file_name=self.best_file_name + str(n), last_file_name=self.last_file_name + str(n), + init_file_name=self.init_file_name + str(n), batch_size=self.batch_size, use_mini_batch_size=self.use_mini_batch_size, n_epochs=self.n_epochs, @@ -424,6 +435,8 @@ class IndividualInceptionRegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -432,6 +445,9 @@ class IndividualInceptionRegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -492,8 +508,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -526,7 +544,9 @@ def __init__( self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.random_state = random_state @@ -626,6 +646,9 @@ def _fit(self, X, y): mini_batch_size = self.batch_size self.training_model_ = self.build_model(self.input_shape_) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_lite_time.py b/aeon/regression/deep_learning/_lite_time.py index 8b2e32f200..5a2079df94 100644 --- a/aeon/regression/deep_learning/_lite_time.py +++ b/aeon/regression/deep_learning/_lite_time.py @@ -1,6 +1,6 @@ -"""LITETime Regressor.""" +"""LITETime and LITE regressors.""" -__author__ = ["aadya940"] +__author__ = ["aadya940", "hadifawaz1999"] __all__ = ["IndividualLITERegressor", "LITETimeRegressor"] import gc @@ -61,6 +61,8 @@ class LITETimeRegressor(BaseRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -69,6 +71,9 @@ class LITETimeRegressor(BaseRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -122,8 +127,10 @@ def __init__( file_path="./", save_last_model=False, save_best_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -149,8 +156,10 @@ def __init__( self.save_last_model = save_last_model self.save_best_model = save_best_model + self.save_init_model = save_init_model self.best_file_name = best_file_name self.last_file_name = last_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.random_state = random_state @@ -188,8 +197,10 @@ def _fit(self, X, y): file_path=self.file_path, save_best_model=self.save_best_model, save_last_model=self.save_last_model, + save_init_model=self.save_init_model, best_file_name=self.best_file_name + str(n), last_file_name=self.last_file_name + str(n), + init_file_name=self.init_file_name + str(n), batch_size=self.batch_size, use_mini_batch_size=self.use_mini_batch_size, n_epochs=self.n_epochs, @@ -302,6 +313,8 @@ class IndividualLITERegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -310,6 +323,9 @@ class IndividualLITERegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -354,8 +370,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", batch_size=64, use_mini_batch_size=False, n_epochs=1500, @@ -379,7 +397,9 @@ def __init__( self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.callbacks = callbacks self.random_state = random_state @@ -476,6 +496,9 @@ def _fit(self, X, y): mini_batch_size = self.batch_size self.training_model_ = self.build_model(self.input_shape) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_mlp.py b/aeon/regression/deep_learning/_mlp.py index a8883f3f12..cb9907fe7c 100644 --- a/aeon/regression/deep_learning/_mlp.py +++ b/aeon/regression/deep_learning/_mlp.py @@ -1,6 +1,6 @@ -"""Multi Layer Perceptron Network (MLP) for Regression.""" +"""Multi Layer Perceptron Network (MLP) regressor.""" -__author__ = ["Aadya-Chinubhai"] +__author__ = ["Aadya-Chinubhai", "hadifawaz1999"] __all__ = ["MLPRegressor"] import gc @@ -47,6 +47,8 @@ class MLPRegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -55,6 +57,9 @@ class MLPRegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -100,8 +105,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", random_state=None, activation="relu", output_activation="linear", @@ -118,7 +125,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.optimizer = optimizer self.random_state = random_state self.output_activation = output_activation @@ -202,6 +211,9 @@ def _fit(self, X, y): self.training_model_ = self.build_model(self.input_shape) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_resnet.py b/aeon/regression/deep_learning/_resnet.py index b9a411891c..48f2d3c5f8 100644 --- a/aeon/regression/deep_learning/_resnet.py +++ b/aeon/regression/deep_learning/_resnet.py @@ -1,6 +1,6 @@ -"""Residual Network (ResNet) for regression.""" +"""Residual Network (ResNet) regressor.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = ["ResNetRegressor"] import gc @@ -82,6 +82,8 @@ class ResNetRegressor(BaseDeepRegressor): Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file + save_init_model : bool, default = False + Whether to save the initialization of the model. best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter @@ -90,6 +92,9 @@ class ResNetRegressor(BaseDeepRegressor): The name of the file of the last model, if save_last_model is set to False, this parameter is discarded + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. verbose : boolean, default = False whether to output extra information loss : string, default="mean_squared_error" @@ -147,8 +152,10 @@ def __init__( file_path="./", save_best_model=False, save_last_model=False, + save_init_model=False, best_file_name="best_model", last_file_name="last_model", + init_file_name="init_model", optimizer=None, ): self.n_residual_blocks = n_residual_blocks @@ -171,7 +178,9 @@ def __init__( self.file_path = file_path self.save_best_model = save_best_model self.save_last_model = save_last_model + self.save_init_model = save_init_model self.best_file_name = best_file_name + self.init_file_name = init_file_name self.optimizer = optimizer self.history = None @@ -261,6 +270,9 @@ def _fit(self, X, y): self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + if self.verbose: self.training_model_.summary() diff --git a/aeon/regression/deep_learning/_tapnet.py b/aeon/regression/deep_learning/_tapnet.py index 380d2ee533..4a03f02c66 100644 --- a/aeon/regression/deep_learning/_tapnet.py +++ b/aeon/regression/deep_learning/_tapnet.py @@ -1,6 +1,6 @@ -"""Time Convolutional Neural Network (CNN) for classification.""" +"""Time series Attentional Prototype Network (TapNet) regressor.""" -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] __all__ = [ "TapNetRegressor", ] diff --git a/aeon/regression/deep_learning/tests/test_deep_regressor_base.py b/aeon/regression/deep_learning/tests/test_deep_regressor_base.py index 000b39173d..9dad88f0b2 100644 --- a/aeon/regression/deep_learning/tests/test_deep_regressor_base.py +++ b/aeon/regression/deep_learning/tests/test_deep_regressor_base.py @@ -10,7 +10,7 @@ from aeon.testing.data_generation import make_example_2d_numpy_collection from aeon.utils.validation._dependencies import _check_soft_dependencies -__maintainer__ = [] +__maintainer__ = ["hadifawaz1999"] class _DummyDeepRegressor(BaseDeepRegressor): diff --git a/aeon/regression/deep_learning/tests/test_saving_loading_deep_learning_cls.py b/aeon/regression/deep_learning/tests/test_saving_loading_deep_learning_cls.py new file mode 100644 index 0000000000..736d99baf3 --- /dev/null +++ b/aeon/regression/deep_learning/tests/test_saving_loading_deep_learning_cls.py @@ -0,0 +1,79 @@ +"""Unit tests for regressors deep learners save/load functionalities.""" + +import inspect +import os +import tempfile +import time + +import pytest + +from aeon.regression import deep_learning +from aeon.testing.data_generation import make_example_3d_numpy +from aeon.utils.validation._dependencies import _check_soft_dependencies + +__maintainer__ = ["hadifawaz1999"] + + +_deep_rgs_classes = [ + member[1] for member in inspect.getmembers(deep_learning, inspect.isclass) +] + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="skip test if required soft dependency not available", +) +@pytest.mark.parametrize("deep_rgs", _deep_rgs_classes) +def test_saving_loading_deep_learning_rgs(deep_rgs): + """Test Deep Regressor saving.""" + with tempfile.TemporaryDirectory() as tmp: + if not ( + deep_rgs.__name__ + in [ + "BaseDeepRegressor", + "InceptionTimeRegressor", + "LITETimeRegressor", + "TapNetRegressor", + ] + ): + if tmp[-1] != "/": + tmp = tmp + "/" + curr_time = str(time.time_ns()) + last_file_name = curr_time + "last" + best_file_name = curr_time + "best" + init_file_name = curr_time + "init" + + X, y = make_example_3d_numpy() + + deep_rgs_train = deep_rgs( + n_epochs=2, + save_best_model=True, + save_last_model=True, + save_init_model=True, + best_file_name=best_file_name, + last_file_name=last_file_name, + init_file_name=init_file_name, + file_path=tmp, + ) + deep_rgs_train.fit(X, y) + + deep_rgs_best = deep_rgs() + deep_rgs_best.load_model( + model_path=os.path.join(tmp, best_file_name + ".keras"), + ) + ypred_best = deep_rgs_best.predict(X) + assert len(ypred_best) == len(y) + + deep_rgs_last = deep_rgs() + deep_rgs_last.load_model( + model_path=os.path.join(tmp, last_file_name + ".keras"), + ) + ypred_last = deep_rgs_last.predict(X) + assert len(ypred_last) == len(y) + + deep_rgs_init = deep_rgs() + deep_rgs_init.load_model( + model_path=os.path.join(tmp, init_file_name + ".keras"), + ) + ypred_init = deep_rgs_init.predict(X) + assert len(ypred_init) == len(y) diff --git a/docs/api_reference/classification.rst b/docs/api_reference/classification.rst index c87b62fa6b..89a6261c3b 100644 --- a/docs/api_reference/classification.rst +++ b/docs/api_reference/classification.rst @@ -32,7 +32,9 @@ Deep learning :toctree: auto_generated/ :template: class.rst + BaseDeepClassifier CNNClassifier + TimeCNNClassifier EncoderClassifier FCNClassifier InceptionTimeClassifier diff --git a/docs/api_reference/clustering.rst b/docs/api_reference/clustering.rst index 1361f8b1a4..737a585631 100644 --- a/docs/api_reference/clustering.rst +++ b/docs/api_reference/clustering.rst @@ -20,6 +20,7 @@ Deep learning BaseDeepClusterer AEFCNClusterer + AEResNetClusterer Clustering Algorithms --------------------- diff --git a/docs/api_reference/distances.rst b/docs/api_reference/distances.rst index ef79912320..c11ca70151 100644 --- a/docs/api_reference/distances.rst +++ b/docs/api_reference/distances.rst @@ -153,6 +153,9 @@ Shape Dynamic Time Warping (Shape DTW) shape_dtw_pairwise_distance shape_dtw_cost_matrix shape_dtw_alignment_path + _pad_ts_collection_edges + _pad_ts_edges + _transform_subsequences Squared ------- diff --git a/docs/api_reference/networks.rst b/docs/api_reference/networks.rst index 0a33613289..65b8aefa14 100644 --- a/docs/api_reference/networks.rst +++ b/docs/api_reference/networks.rst @@ -15,6 +15,7 @@ Deep learning networks BaseDeepLearningNetwork CNNNetwork + TimeCNNNetwork EncoderNetwork FCNNetwork InceptionNetwork @@ -22,3 +23,6 @@ Deep learning networks ResNetNetwork TapNetNetwork AEFCNNetwork + AEResNetNetwork + LITENetwork + AEBiGRUNetwork diff --git a/docs/api_reference/regression.rst b/docs/api_reference/regression.rst index 2001e5e4b4..bb0ff961c3 100644 --- a/docs/api_reference/regression.rst +++ b/docs/api_reference/regression.rst @@ -51,16 +51,18 @@ Deep learning :toctree: auto_generated/ :template: class.rst + BaseDeepRegressor CNNRegressor + TimeCNNRegressor EncoderRegressor FCNRegressor InceptionTimeRegressor IndividualLITERegressor IndividualInceptionRegressor LITETimeRegressor - LITETimeRegressor ResNetRegressor TapNetRegressor + MLPRegressor Distance-based -------------- diff --git a/docs/api_reference/visualisation.rst b/docs/api_reference/visualisation.rst index 1555d6bbaf..4d7be0feb8 100644 --- a/docs/api_reference/visualisation.rst +++ b/docs/api_reference/visualisation.rst @@ -31,3 +31,4 @@ Visualisation plot_time_series_with_change_points plot_time_series_with_profiles plot_cluster_algorithm + plot_network