diff --git a/aeon/transformations/collection/_impute.py b/aeon/transformations/collection/_impute.py index a145317820..26792b6dc8 100644 --- a/aeon/transformations/collection/_impute.py +++ b/aeon/transformations/collection/_impute.py @@ -118,7 +118,10 @@ def _transform( x, ) else: # if strategy is a callable function - x = np.where(np.isnan(x), self.strategy(x), x) + n_channels = x.shape[0] + for i in range(n_channels): + nan_mask = np.isnan(x[i]) + x[i] = np.where(nan_mask, self.strategy(x[i][nan_mask]), x[i]) Xt.append(x) return Xt diff --git a/aeon/transformations/collection/tests/test_simple_imputer.py b/aeon/transformations/collection/tests/test_simple_imputer.py index d1bb168969..f1e4fb62a2 100644 --- a/aeon/transformations/collection/tests/test_simple_imputer.py +++ b/aeon/transformations/collection/tests/test_simple_imputer.py @@ -116,3 +116,23 @@ def test_valid_parameters(): with pytest.raises(ValueError): imputer.fit_transform(X) + + +def test_callable(): + """Test SimpleImputer with callable strategy.""" + X, _ = make_example_3d_numpy( + n_cases=10, n_channels=2, n_timepoints=50, random_state=42 + ) + X[2, 1, 10] = np.nan + X[5, 0, 20] = np.nan + + def dummy_strategy(x): + return 0 + + imputer = SimpleImputer(strategy=dummy_strategy) + Xt = imputer.fit_transform(X) + + assert not np.isnan(Xt).any() + assert Xt.shape == X.shape + assert np.allclose(Xt[2, 1, 10], 0) + assert np.allclose(Xt[5, 0, 20], 0)