diff --git a/aeon/transformations/collection/__init__.py b/aeon/transformations/collection/__init__.py index 949260da1d..2f5cd4c4c5 100644 --- a/aeon/transformations/collection/__init__.py +++ b/aeon/transformations/collection/__init__.py @@ -19,6 +19,7 @@ "PeriodogramTransformer", "Resizer", "SlopeTransformer", + "SimpleImputer", "Truncator", "Tabularizer", ] @@ -28,6 +29,7 @@ from aeon.transformations.collection._downsample import DownsampleTransformer from aeon.transformations.collection._dwt import DWTTransformer from aeon.transformations.collection._hog1d import HOG1DTransformer +from aeon.transformations.collection._impute import SimpleImputer from aeon.transformations.collection._matrix_profile import MatrixProfile from aeon.transformations.collection._pad import Padder from aeon.transformations.collection._periodogram import PeriodogramTransformer diff --git a/aeon/transformations/collection/_impute.py b/aeon/transformations/collection/_impute.py new file mode 100644 index 0000000000..a145317820 --- /dev/null +++ b/aeon/transformations/collection/_impute.py @@ -0,0 +1,153 @@ +"""Time Series imputer.""" + +__maintainer__ = [] +__all__ = ["SimpleImputer"] + +from typing import Callable, Optional, Union + +import numpy as np +from scipy.stats import mode + +from aeon.transformations.collection.base import BaseCollectionTransformer + + +class SimpleImputer(BaseCollectionTransformer): + """Time series imputer. + + Transformer that imputes missing values in time series. Fill values are calculated + across series. + + Parameters + ---------- + strategy : str or Callable, default="mean" + The imputation strategy. + - if "mean", replace missing values using the mean. + - if "median", replace missing values using the median. + - if "constant", replace missing values with the fill_value. + - if "most frequent", replace missing values with the most frequent value. + - if Callable, a function that returns the value to replace + missing values with on each 1D array containing all + non-missing values of each series. + + fill_value : float or None, default=None + The value to replace missing values with. Only used when strategy is "constant". + """ + + _tags = { + "X_inner_type": ["np-list", "numpy3D"], + "fit_is_empty": True, + "capability:multivariate": True, + "capability:unequal_length": True, + "capability:missing_values": True, + "removes_missing_values": True, + } + + def __init__( + self, + strategy: Union[str, Callable] = "mean", + fill_value: Optional[float] = None, + ): + self.strategy = strategy + self.fill_value = fill_value + super().__init__() + + def _transform( + self, X: Union[np.ndarray, list[np.ndarray]], y=None + ) -> Union[np.ndarray, list[np.ndarray]]: + """ + Transform method to apply the SimpleImputer. + + Parameters + ---------- + X: np.ndarray or list + Collection to transform. Either a list of 2D arrays with shape + ``(n_channels, n_timepoints_i)`` or a single 3D array of shape + ``(n_cases, n_channels, n_timepoints)``. + y: None + Ignored. + + Returns + ------- + np.ndarray or list + """ + self._validate_parameters() + + if isinstance(X, np.ndarray): # if X is a 3D array + + if self.strategy == "mean": + X = np.where(np.isnan(X), np.nanmean(X, axis=-1, keepdims=True), X) + + elif self.strategy == "median": + X = np.where(np.isnan(X), np.nanmedian(X, axis=-1, keepdims=True), X) + + elif self.strategy == "constant": + X = np.where(np.isnan(X), self.fill_value, X) + + elif self.strategy == "most frequent": + X = np.where( + np.isnan(X), + mode(X, axis=-1, nan_policy="omit", keepdims=True).mode, + X, + ) + + else: # if strategy is a callable function + for i in range(X.shape[0]): + for j in range(X.shape[1]): + nan_mask = np.isnan(X[i, j]) + X[i, j] = np.where( + nan_mask, self.strategy(X[i, j][nan_mask]), X[i, j] + ) # applying callable function to each case without nan values + return X + + else: # if X is a list of 2D arrays + Xt = [] + + for x in X: + if self.strategy == "mean": + x = np.where(np.isnan(x), np.nanmean(x, axis=-1, keepdims=True), x) + elif self.strategy == "median": + x = np.where( + np.isnan(x), np.nanmedian(x, axis=-1, keepdims=True), x + ) + elif self.strategy == "constant": + x = np.where(np.isnan(x), self.fill_value, x) + elif self.strategy == "most frequent": + x = np.where( + np.isnan(x), + mode(x, axis=-1, nan_policy="omit", keepdims=True).mode, + x, + ) + else: # if strategy is a callable function + x = np.where(np.isnan(x), self.strategy(x), x) + Xt.append(x) + + return Xt + + def _validate_parameters(self): + """Validate the parameters.""" + if self.strategy not in [ + "mean", + "median", + "constant", + "most frequent", + ] and not callable(self.strategy): + raise ValueError( + "strategy must be 'mean', 'median', 'constant', 'most frequent'," + f" or a callable. Got {self.strategy}." + ) + + if self.strategy == "constant" and self.fill_value is None: + raise ValueError("fill_value must be provided when strategy is 'constant'.") + + @classmethod + def _get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Returns + ------- + params : dict or list of dict, default={} + Parameters to create testing instances of the class. + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + """ + return {"strategy": "mean"} diff --git a/aeon/transformations/collection/tests/test_simple_imputer.py b/aeon/transformations/collection/tests/test_simple_imputer.py new file mode 100644 index 0000000000..d1bb168969 --- /dev/null +++ b/aeon/transformations/collection/tests/test_simple_imputer.py @@ -0,0 +1,118 @@ +"""Test for SimpleImputer.""" + +__maintainer__ = [] + +import numpy as np +import pytest + +from aeon.testing.data_generation import ( + make_example_3d_numpy, + make_example_3d_numpy_list, +) +from aeon.transformations.collection import SimpleImputer + + +def test_3d_numpy(): + """Test SimpleImputer with 3D numpy array.""" + X, _ = make_example_3d_numpy( + n_cases=10, n_channels=2, n_timepoints=50, random_state=42 + ) + X[2, 1, 10] = np.nan + X[5, 0, 20] = np.nan + + imputer = SimpleImputer(strategy="mean") + Xt = imputer.fit_transform(X) + + assert not np.isnan(Xt).any() + assert Xt.shape == X.shape + assert np.allclose(Xt[2, 1, 10], np.nanmean(X[2, 1, :])) + + +def test_2d_list(): + """Test SimpleImputer with 2D list.""" + X, _ = make_example_3d_numpy_list( + n_cases=5, + n_channels=2, + min_n_timepoints=50, + max_n_timepoints=70, + random_state=42, + ) + X[2][1, 10] = np.nan + X[4][0, 20] = np.nan + + imputer = SimpleImputer(strategy="mean") + Xt = imputer.fit_transform(X) + + assert all(not np.isnan(x).any() for x in Xt) # no NaNs in any of the arrays + assert Xt[2][1, 10] == np.nanmean(X[2][1, :]) + assert Xt[4][0, 20] == np.nanmean(X[4][0, :]) + + +def test_median(): + """Test SimpleImputer with median strategy.""" + X, _ = make_example_3d_numpy( + n_cases=10, n_channels=2, n_timepoints=50, random_state=42 + ) + X[2, 1, 10] = np.nan + X[5, 0, 20] = np.nan + + imputer = SimpleImputer(strategy="median") + Xt = imputer.fit_transform(X) + + assert not np.isnan(Xt).any() + assert Xt.shape == X.shape + assert np.allclose(Xt[2, 1, 10], np.nanmedian(X[2, 1, :])) + assert np.allclose(Xt[5, 0, 20], np.nanmedian(X[5, 0, :])) + + +def test_most_frequent(): + """Test SimpleImputer with most frequent strategy.""" + from scipy.stats import mode + + X, _ = make_example_3d_numpy( + n_cases=10, n_channels=2, n_timepoints=50, random_state=42 + ) + X[2, 1, 10] = np.nan + X[5, 0, 20] = np.nan + + imputer = SimpleImputer(strategy="most frequent") + Xt = imputer.fit_transform(X) + + assert not np.isnan(Xt).any() + assert Xt.shape == X.shape + assert np.allclose(Xt[2, 1, 10], mode(X[2, 1, :], nan_policy="omit").mode) + assert np.allclose(Xt[5, 0, 20], mode(X[5, 0, :], nan_policy="omit").mode) + + +def test_constant(): + """Test SimpleImputer with constant strategy.""" + X, _ = make_example_3d_numpy( + n_cases=10, n_channels=2, n_timepoints=50, random_state=42 + ) + X[2, 1, 10] = np.nan + X[5, 0, 20] = np.nan + + imputer = SimpleImputer(strategy="constant", fill_value=-1) + Xt = imputer.fit_transform(X) + + assert not np.isnan(Xt).any() + assert Xt.shape == X.shape + assert np.allclose(Xt[2, 1, 10], -1) + assert np.allclose(Xt[5, 0, 20], -1) + + +def test_valid_parameters(): + """Test SimpleImputer with valid parameters.""" + X, _ = make_example_3d_numpy( + n_cases=10, n_channels=2, n_timepoints=50, random_state=42 + ) + + imputer = SimpleImputer(strategy="constant") # no fill_value + + with pytest.raises(ValueError): + imputer.fit_transform(X) + + imputer = SimpleImputer(strategy="mode") # invalid strategy + + with pytest.raises(ValueError): + imputer.fit_transform(X) diff --git a/docs/api_reference/transformations.rst b/docs/api_reference/transformations.rst index 4d8908d00b..02ea16d5c8 100644 --- a/docs/api_reference/transformations.rst +++ b/docs/api_reference/transformations.rst @@ -41,6 +41,7 @@ Collection transformers PeriodogramTransformer Tabularizer Resizer + SimpleImputer SlopeTransformer Standardizer Truncator