diff --git a/aeon/classification/deep_learning/_mlp.py b/aeon/classification/deep_learning/_mlp.py index 6768ba48bd..37b4f80444 100644 --- a/aeon/classification/deep_learning/_mlp.py +++ b/aeon/classification/deep_learning/_mlp.py @@ -21,6 +21,17 @@ class MLPClassifier(BaseDeepClassifier): Parameters ---------- + n_layers : int, optional (default=3) + The number of dense layers in the MLP. + n_units : Union[int, List[int]], optional (default=200) + Number of units in each dense layer. + activation : Union[str, List[str]], optional (default='relu') + Activation function(s) for each dense layer. + dropout_rate : Union[float, List[Union[int, float]]], optional (default=None) + Dropout rate(s) for each dense layer. If None, a default rate of 0.2 is used. + Dropout rate(s) are typically a number in the interval [0, 1]. + dropout_last : float, default = 0.3 + The dropout rate of the last layer. use_bias : bool, default = True Condition on whether or not to use bias values for dense layers. n_epochs : int, default = 2000 @@ -76,10 +87,6 @@ class MLPClassifier(BaseDeepClassifier): a single string metric is provided, it will be used as the only metric. If a list of metrics are provided, all will be used for evaluation. - activation : string or a tf callable, default="sigmoid" - Activation function used in the output linear layer. - List of available activation functions: - https://keras.io/api/layers/activations/ Notes ----- @@ -104,6 +111,11 @@ class MLPClassifier(BaseDeepClassifier): def __init__( self, + n_layers=3, + n_units=200, + activation="relu", + dropout_rate=None, + dropout_last=None, use_bias=True, n_epochs=2000, batch_size=16, @@ -120,16 +132,19 @@ def __init__( last_file_name="last_model", init_file_name="init_model", random_state=None, - activation="sigmoid", optimizer=None, ): + self.n_layers = n_layers + self.n_units = n_units + self.activation = activation + self.dropout_rate = dropout_rate + self.dropout_last = dropout_last self.callbacks = callbacks self.n_epochs = n_epochs self.verbose = verbose self.loss = loss self.metrics = metrics self.use_mini_batch_size = use_mini_batch_size - self.activation = activation self.use_bias = use_bias self.file_path = file_path self.save_best_model = save_best_model @@ -147,7 +162,14 @@ def __init__( last_file_name=last_file_name, ) - self._network = MLPNetwork(use_bias=self.use_bias) + self._network = MLPNetwork( + n_layers=self.n_layers, + n_units=self.n_units, + activation=self.activation, + dropout_rate=self.dropout_rate, + dropout_last=self.dropout_last, + use_bias=self.use_bias, + ) def build_model(self, input_shape, n_classes, **kwargs): """Construct a compiled, un-trained, keras model that is ready for training. diff --git a/aeon/networks/_mlp.py b/aeon/networks/_mlp.py index 4cf81b5e1f..cba7b89019 100644 --- a/aeon/networks/_mlp.py +++ b/aeon/networks/_mlp.py @@ -3,6 +3,10 @@ __maintainer__ = ["hadifawaz1999"] +import typing + +import numpy as np + from aeon.networks.base import BaseDeepLearningNetwork @@ -13,6 +17,17 @@ class MLPNetwork(BaseDeepLearningNetwork): Parameters ---------- + n_layers : int, optional (default=3) + The number of dense layers in the MLP. + n_units : Union[int, List[int]], optional (default=200) + Number of units in each dense layer. + activation : Union[str, List[str]], optional (default='relu') + Activation function(s) for each dense layer. + dropout_rate : Union[float, List[Union[int, float]]], optional (default=None) + Dropout rate(s) for each dense layer. If None, a default rate of 0.2 is used. + Dropout rate(s) are typically a number in the interval [0, 1]. + dropout_last : float, default = 0.3 + The dropout rate of the last layer. use_bias : bool, default = True Condition on whether or not to use bias values for dense layers. @@ -35,12 +50,22 @@ class MLPNetwork(BaseDeepLearningNetwork): def __init__( self, - use_bias=True, + n_layers: int = 3, + n_units: typing.Union[int, list[int]] = 200, + activation: typing.Union[str, list[str]] = "relu", + dropout_rate: typing.Union[float, list[float]] = None, + dropout_last: float = None, + use_bias: bool = True, ): - self.use_bias = use_bias - super().__init__() + self.n_layers = n_layers + self.n_units = n_units + self.activation = activation + self.dropout_rate = dropout_rate + self.dropout_last = dropout_last + self.use_bias = use_bias + def build_network(self, input_shape, **kwargs): """Construct a network and return its input and output layers. @@ -54,27 +79,80 @@ def build_network(self, input_shape, **kwargs): input_layer : a keras layer output_layer : a keras layer """ + if isinstance(self.activation, str): + self._activation = [self.activation] * self.n_layers + elif isinstance(self.activation, list): + assert ( + len(self.activation) == self.n_layers + ), "There should be an `activation` function associated with each layer." + assert all( + isinstance(a, str) for a in self.activation + ), "Activation must be a list of strings." + self._activation = self.activation + + if self.dropout_rate is None: + self._dropout_rate = [0.1] + self._dropout_rate.extend([0.2] * (self.n_layers - 1)) + assert np.all( + np.array(self._dropout_rate) - 1 <= 0 + ), "Dropout rate(s) should be in the interval [0, 1]." + elif isinstance(self.dropout_rate, (int, float)): + self._dropout_rate = [float(self.dropout_rate)] * self.n_layers + assert np.all( + np.array(self._dropout_rate) - 1 <= 0 + ), "Dropout rate(s) should be in the interval [0, 1]." + elif isinstance(self.dropout_rate, list): + assert ( + len(self.dropout_rate) == self.n_layers + ), "There should be a `dropout_rate` associated with each layer." + assert all( + isinstance(d, (int, float)) for d in self.dropout_rate + ), "Dropout rates must be int or float." + assert ( + len(self.dropout_rate) == self.n_layers + ), "Dropout list length must match number of layers." + self._dropout_rate = [float(d) for d in self.dropout_rate] + assert np.all( + np.array(self._dropout_rate) - 1 <= 0 + ), "Dropout rate(s) should be in the interval [0, 1]." + + if isinstance(self.n_units, int): + self._n_units = [self.n_units] * self.n_layers + elif isinstance(self.n_units, list): + assert all( + isinstance(u, int) for u in self.n_units + ), "`n_units` must be int for all layers." + assert ( + len(self.n_units) == self.n_layers + ), "`n_units` length must match number of layers." + self._n_units = self.n_units + + if self.dropout_last is None: + self._dropout_last = 0.3 + else: + assert isinstance(self.dropout_last, float) or ( + int(self.dropout_last // 1) in [0, 1] + ), "a float is expected in the `dropout_last` argument." + assert ( + self.dropout_last - 1 <= 0 + ), "`dropout_last` argument must be a number in the interval [0, 1]" + self._dropout_last = self.dropout_last + from tensorflow import keras - # flattened because multivariate should be on same axis input_layer = keras.layers.Input(input_shape) input_layer_flattened = keras.layers.Flatten()(input_layer) - layer_1 = keras.layers.Dropout(0.1)(input_layer_flattened) - layer_1 = keras.layers.Dense(500, activation="relu", use_bias=self.use_bias)( - layer_1 - ) - - layer_2 = keras.layers.Dropout(0.2)(layer_1) - layer_2 = keras.layers.Dense(500, activation="relu", use_bias=self.use_bias)( - layer_2 - ) + x = input_layer_flattened - layer_3 = keras.layers.Dropout(0.2)(layer_2) - layer_3 = keras.layers.Dense(500, activation="relu", use_bias=self.use_bias)( - layer_3 - ) + for idx in range(0, self.n_layers): + x = keras.layers.Dropout(self._dropout_rate[idx])(x) + x = keras.layers.Dense( + self._n_units[idx], + activation=self._activation[idx], + use_bias=self.use_bias, + )(x) - output_layer = keras.layers.Dropout(0.3)(layer_3) + output_layer = keras.layers.Dropout(self._dropout_last)(x) return input_layer, output_layer diff --git a/aeon/regression/deep_learning/_mlp.py b/aeon/regression/deep_learning/_mlp.py index 9a616238c6..073da4db62 100644 --- a/aeon/regression/deep_learning/_mlp.py +++ b/aeon/regression/deep_learning/_mlp.py @@ -21,6 +21,17 @@ class MLPRegressor(BaseDeepRegressor): Parameters ---------- + n_layers : int, optional (default=3) + The number of dense layers in the MLP. + n_units : Union[int, List[int]], optional (default=200) + Number of units in each dense layer. + activation : Union[str, List[str]], optional (default='relu') + Activation function(s) for each dense layer. + dropout_rate : Union[float, List[Union[int, float]]], optional (default=None) + Dropout rate(s) for each dense layer. If None, a default rate of 0.2 is used. + Dropout rate(s) are typically a number in the interval [0, 1]. + dropout_last : float, default = 0.3 + The dropout rate of the last layer. use_bias : bool, default = True Condition on whether or not to use bias values for dense layers. n_epochs : int, default = 2000 @@ -72,10 +83,6 @@ class MLPRegressor(BaseDeepRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - activation : string or a tf callable, default="relu" - Activation function used in the output linear layer. - List of available activation functions: - https://keras.io/api/layers/activations/ output_activation : str = "linear" Activation for the last layer in a Regressor. optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() @@ -100,6 +107,11 @@ class MLPRegressor(BaseDeepRegressor): def __init__( self, + n_layers=3, + n_units=200, + activation="relu", + dropout_rate=None, + dropout_last=None, use_bias=True, n_epochs=2000, batch_size=16, @@ -115,16 +127,19 @@ def __init__( last_file_name="last_model", init_file_name="init_model", random_state=None, - activation="relu", output_activation="linear", optimizer=None, ): + self.n_layers = n_layers + self.n_units = n_units + self.activation = activation + self.dropout_rate = dropout_rate + self.dropout_last = dropout_last self.callbacks = callbacks self.n_epochs = n_epochs self.verbose = verbose self.loss = loss self.metrics = metrics - self.activation = activation self.use_bias = use_bias self.file_path = file_path self.save_best_model = save_best_model @@ -143,7 +158,14 @@ def __init__( last_file_name=last_file_name, ) - self._network = MLPNetwork(use_bias=self.use_bias) + self._network = MLPNetwork( + n_layers=self.n_layers, + n_units=self.n_units, + activation=self.activation, + dropout_rate=self.dropout_rate, + dropout_last=self.dropout_last, + use_bias=self.use_bias, + ) def build_model(self, input_shape, **kwargs): """Construct a compiled, un-trained, keras model that is ready for training.