Skip to content

Commit

Permalink
Bug fix in SimpleImputer (#2363)
Browse files Browse the repository at this point in the history
  • Loading branch information
notaryanramani authored Nov 17, 2024
1 parent a09ea8f commit cd3bc33
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aeon/transformations/collection/_impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def _transform(
x,
)
else: # if strategy is a callable function
x = np.where(np.isnan(x), self.strategy(x), x)
n_channels = x.shape[0]
for i in range(n_channels):
nan_mask = np.isnan(x[i])
x[i] = np.where(nan_mask, self.strategy(x[i][nan_mask]), x[i])
Xt.append(x)

return Xt
Expand Down
20 changes: 20 additions & 0 deletions aeon/transformations/collection/tests/test_simple_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,23 @@ def test_valid_parameters():

with pytest.raises(ValueError):
imputer.fit_transform(X)


def test_callable():
"""Test SimpleImputer with callable strategy."""
X, _ = make_example_3d_numpy(
n_cases=10, n_channels=2, n_timepoints=50, random_state=42
)
X[2, 1, 10] = np.nan
X[5, 0, 20] = np.nan

def dummy_strategy(x):
return 0

imputer = SimpleImputer(strategy=dummy_strategy)
Xt = imputer.fit_transform(X)

assert not np.isnan(Xt).any()
assert Xt.shape == X.shape
assert np.allclose(Xt[2, 1, 10], 0)
assert np.allclose(Xt[5, 0, 20], 0)

0 comments on commit cd3bc33

Please sign in to comment.