-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Implemented SimpleImputer (#2347)
* Implemented SimpleImputer * fix: using np.str_ instead of np.unicode_ * fix: removed np.unicode_ * update: changed documentation * - changed "fit_is_empty" tag - calculating fill values on cases instead of whole channel - removed string support * documentation update
- Loading branch information
1 parent
0c528af
commit a09ea8f
Showing
4 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
"""Time Series imputer.""" | ||
|
||
__maintainer__ = [] | ||
__all__ = ["SimpleImputer"] | ||
|
||
from typing import Callable, Optional, Union | ||
|
||
import numpy as np | ||
from scipy.stats import mode | ||
|
||
from aeon.transformations.collection.base import BaseCollectionTransformer | ||
|
||
|
||
class SimpleImputer(BaseCollectionTransformer): | ||
"""Time series imputer. | ||
Transformer that imputes missing values in time series. Fill values are calculated | ||
across series. | ||
Parameters | ||
---------- | ||
strategy : str or Callable, default="mean" | ||
The imputation strategy. | ||
- if "mean", replace missing values using the mean. | ||
- if "median", replace missing values using the median. | ||
- if "constant", replace missing values with the fill_value. | ||
- if "most frequent", replace missing values with the most frequent value. | ||
- if Callable, a function that returns the value to replace | ||
missing values with on each 1D array containing all | ||
non-missing values of each series. | ||
fill_value : float or None, default=None | ||
The value to replace missing values with. Only used when strategy is "constant". | ||
""" | ||
|
||
_tags = { | ||
"X_inner_type": ["np-list", "numpy3D"], | ||
"fit_is_empty": True, | ||
"capability:multivariate": True, | ||
"capability:unequal_length": True, | ||
"capability:missing_values": True, | ||
"removes_missing_values": True, | ||
} | ||
|
||
def __init__( | ||
self, | ||
strategy: Union[str, Callable] = "mean", | ||
fill_value: Optional[float] = None, | ||
): | ||
self.strategy = strategy | ||
self.fill_value = fill_value | ||
super().__init__() | ||
|
||
def _transform( | ||
self, X: Union[np.ndarray, list[np.ndarray]], y=None | ||
) -> Union[np.ndarray, list[np.ndarray]]: | ||
""" | ||
Transform method to apply the SimpleImputer. | ||
Parameters | ||
---------- | ||
X: np.ndarray or list | ||
Collection to transform. Either a list of 2D arrays with shape | ||
``(n_channels, n_timepoints_i)`` or a single 3D array of shape | ||
``(n_cases, n_channels, n_timepoints)``. | ||
y: None | ||
Ignored. | ||
Returns | ||
------- | ||
np.ndarray or list | ||
""" | ||
self._validate_parameters() | ||
|
||
if isinstance(X, np.ndarray): # if X is a 3D array | ||
|
||
if self.strategy == "mean": | ||
X = np.where(np.isnan(X), np.nanmean(X, axis=-1, keepdims=True), X) | ||
|
||
elif self.strategy == "median": | ||
X = np.where(np.isnan(X), np.nanmedian(X, axis=-1, keepdims=True), X) | ||
|
||
elif self.strategy == "constant": | ||
X = np.where(np.isnan(X), self.fill_value, X) | ||
|
||
elif self.strategy == "most frequent": | ||
X = np.where( | ||
np.isnan(X), | ||
mode(X, axis=-1, nan_policy="omit", keepdims=True).mode, | ||
X, | ||
) | ||
|
||
else: # if strategy is a callable function | ||
for i in range(X.shape[0]): | ||
for j in range(X.shape[1]): | ||
nan_mask = np.isnan(X[i, j]) | ||
X[i, j] = np.where( | ||
nan_mask, self.strategy(X[i, j][nan_mask]), X[i, j] | ||
) # applying callable function to each case without nan values | ||
return X | ||
|
||
else: # if X is a list of 2D arrays | ||
Xt = [] | ||
|
||
for x in X: | ||
if self.strategy == "mean": | ||
x = np.where(np.isnan(x), np.nanmean(x, axis=-1, keepdims=True), x) | ||
elif self.strategy == "median": | ||
x = np.where( | ||
np.isnan(x), np.nanmedian(x, axis=-1, keepdims=True), x | ||
) | ||
elif self.strategy == "constant": | ||
x = np.where(np.isnan(x), self.fill_value, x) | ||
elif self.strategy == "most frequent": | ||
x = np.where( | ||
np.isnan(x), | ||
mode(x, axis=-1, nan_policy="omit", keepdims=True).mode, | ||
x, | ||
) | ||
else: # if strategy is a callable function | ||
x = np.where(np.isnan(x), self.strategy(x), x) | ||
Xt.append(x) | ||
|
||
return Xt | ||
|
||
def _validate_parameters(self): | ||
"""Validate the parameters.""" | ||
if self.strategy not in [ | ||
"mean", | ||
"median", | ||
"constant", | ||
"most frequent", | ||
] and not callable(self.strategy): | ||
raise ValueError( | ||
"strategy must be 'mean', 'median', 'constant', 'most frequent'," | ||
f" or a callable. Got {self.strategy}." | ||
) | ||
|
||
if self.strategy == "constant" and self.fill_value is None: | ||
raise ValueError("fill_value must be provided when strategy is 'constant'.") | ||
|
||
@classmethod | ||
def _get_test_params(cls, parameter_set="default"): | ||
"""Return testing parameter settings for the estimator. | ||
Returns | ||
------- | ||
params : dict or list of dict, default={} | ||
Parameters to create testing instances of the class. | ||
Each dict are parameters to construct an "interesting" test instance, i.e., | ||
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. | ||
""" | ||
return {"strategy": "mean"} |
118 changes: 118 additions & 0 deletions
118
aeon/transformations/collection/tests/test_simple_imputer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Test for SimpleImputer.""" | ||
|
||
__maintainer__ = [] | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from aeon.testing.data_generation import ( | ||
make_example_3d_numpy, | ||
make_example_3d_numpy_list, | ||
) | ||
from aeon.transformations.collection import SimpleImputer | ||
|
||
|
||
def test_3d_numpy(): | ||
"""Test SimpleImputer with 3D numpy array.""" | ||
X, _ = make_example_3d_numpy( | ||
n_cases=10, n_channels=2, n_timepoints=50, random_state=42 | ||
) | ||
X[2, 1, 10] = np.nan | ||
X[5, 0, 20] = np.nan | ||
|
||
imputer = SimpleImputer(strategy="mean") | ||
Xt = imputer.fit_transform(X) | ||
|
||
assert not np.isnan(Xt).any() | ||
assert Xt.shape == X.shape | ||
assert np.allclose(Xt[2, 1, 10], np.nanmean(X[2, 1, :])) | ||
|
||
|
||
def test_2d_list(): | ||
"""Test SimpleImputer with 2D list.""" | ||
X, _ = make_example_3d_numpy_list( | ||
n_cases=5, | ||
n_channels=2, | ||
min_n_timepoints=50, | ||
max_n_timepoints=70, | ||
random_state=42, | ||
) | ||
X[2][1, 10] = np.nan | ||
X[4][0, 20] = np.nan | ||
|
||
imputer = SimpleImputer(strategy="mean") | ||
Xt = imputer.fit_transform(X) | ||
|
||
assert all(not np.isnan(x).any() for x in Xt) # no NaNs in any of the arrays | ||
assert Xt[2][1, 10] == np.nanmean(X[2][1, :]) | ||
assert Xt[4][0, 20] == np.nanmean(X[4][0, :]) | ||
|
||
|
||
def test_median(): | ||
"""Test SimpleImputer with median strategy.""" | ||
X, _ = make_example_3d_numpy( | ||
n_cases=10, n_channels=2, n_timepoints=50, random_state=42 | ||
) | ||
X[2, 1, 10] = np.nan | ||
X[5, 0, 20] = np.nan | ||
|
||
imputer = SimpleImputer(strategy="median") | ||
Xt = imputer.fit_transform(X) | ||
|
||
assert not np.isnan(Xt).any() | ||
assert Xt.shape == X.shape | ||
assert np.allclose(Xt[2, 1, 10], np.nanmedian(X[2, 1, :])) | ||
assert np.allclose(Xt[5, 0, 20], np.nanmedian(X[5, 0, :])) | ||
|
||
|
||
def test_most_frequent(): | ||
"""Test SimpleImputer with most frequent strategy.""" | ||
from scipy.stats import mode | ||
|
||
X, _ = make_example_3d_numpy( | ||
n_cases=10, n_channels=2, n_timepoints=50, random_state=42 | ||
) | ||
X[2, 1, 10] = np.nan | ||
X[5, 0, 20] = np.nan | ||
|
||
imputer = SimpleImputer(strategy="most frequent") | ||
Xt = imputer.fit_transform(X) | ||
|
||
assert not np.isnan(Xt).any() | ||
assert Xt.shape == X.shape | ||
assert np.allclose(Xt[2, 1, 10], mode(X[2, 1, :], nan_policy="omit").mode) | ||
assert np.allclose(Xt[5, 0, 20], mode(X[5, 0, :], nan_policy="omit").mode) | ||
|
||
|
||
def test_constant(): | ||
"""Test SimpleImputer with constant strategy.""" | ||
X, _ = make_example_3d_numpy( | ||
n_cases=10, n_channels=2, n_timepoints=50, random_state=42 | ||
) | ||
X[2, 1, 10] = np.nan | ||
X[5, 0, 20] = np.nan | ||
|
||
imputer = SimpleImputer(strategy="constant", fill_value=-1) | ||
Xt = imputer.fit_transform(X) | ||
|
||
assert not np.isnan(Xt).any() | ||
assert Xt.shape == X.shape | ||
assert np.allclose(Xt[2, 1, 10], -1) | ||
assert np.allclose(Xt[5, 0, 20], -1) | ||
|
||
|
||
def test_valid_parameters(): | ||
"""Test SimpleImputer with valid parameters.""" | ||
X, _ = make_example_3d_numpy( | ||
n_cases=10, n_channels=2, n_timepoints=50, random_state=42 | ||
) | ||
|
||
imputer = SimpleImputer(strategy="constant") # no fill_value | ||
|
||
with pytest.raises(ValueError): | ||
imputer.fit_transform(X) | ||
|
||
imputer = SimpleImputer(strategy="mode") # invalid strategy | ||
|
||
with pytest.raises(ValueError): | ||
imputer.fit_transform(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters