Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_feature_names_out for EstimatorTransformer #539

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
22da776
Pass along arbitrary parameters to fit `EstimatorTransformer`
CarloLepelaars Sep 12, 2022
955a0ee
Remove *args option from `EstimatorTransformer.fit()`
CarloLepelaars Sep 12, 2022
8d93361
Setup test for passing additional arguments in `EstimatorTransformer.…
CarloLepelaars Sep 12, 2022
082ce5f
Test if `EstimatorTransformer` fit+transform is the same with sample_…
CarloLepelaars Sep 12, 2022
237edb4
`EstimatorTransformer` test_kwargs comments
CarloLepelaars Sep 12, 2022
3eff767
Use array to test passing of `sample_weight` in `EstimatorTransformer`
CarloLepelaars Sep 14, 2022
8ef9f70
Use more simple `LinearRegression` in `test_kwargs`
CarloLepelaars Sep 14, 2022
ae7b061
Update tests/test_meta/test_estimatortransformer.py
CarloLepelaars Sep 19, 2022
f06a7af
Update tests/test_meta/test_estimatortransformer.py
CarloLepelaars Sep 19, 2022
b0ca1a0
Use unittest.Mock to check if fit method works with added kwargs
CarloLepelaars Sep 19, 2022
15e25fc
Merge branch 'main' into main
CarloLepelaars Sep 26, 2022
c471311
Working solution to test `EstimatorTransformer.fit` with added kwargs
CarloLepelaars Sep 26, 2022
9af3a6d
Fix Python3.7 issue with `Mock().call_args` for non-keyword args.
CarloLepelaars Sep 26, 2022
50b4b06
Simplify `test_kwargs` so passing of `kwargs` is tested.
CarloLepelaars Sep 26, 2022
2102583
Remove redundant whitespace at bottom of tests file
CarloLepelaars Sep 26, 2022
c7df2aa
Fix Python3.7 issue for `Mock().call_args`
CarloLepelaars Sep 27, 2022
aa526aa
Merge branch 'koaning:main' into main
CarloLepelaars Sep 27, 2022
999c197
PoC for `get_feature_names_out` for `EstimatorTransformer`
CarloLepelaars Sep 27, 2022
78016af
Refine `get_feature_names_out` for `EstimatorTransformer`. Tests for …
CarloLepelaars Sep 27, 2022
ea9627a
Custom `check_is_fitted` requirements.
CarloLepelaars Sep 27, 2022
4325348
Remove redundant imports
CarloLepelaars Sep 27, 2022
7404a7f
Remove redundant check in `__sklearn_.is_fitted`
CarloLepelaars Sep 27, 2022
7978e1c
Clean up tests for `EstimatorTransformer`
CarloLepelaars Sep 28, 2022
b9706fd
Merge branch 'main' into feature/meta-feature-names-out
CarloLepelaars Oct 6, 2022
e8f1d19
New lines in docstrings
CarloLepelaars Oct 6, 2022
eb177d3
Merge branch 'main' into feature/meta-feature-names-out
CarloLepelaars Nov 1, 2022
03f0f75
Check for `EstimatorTransformer.estimator_` attribute in `__sklearn_i…
CarloLepelaars Nov 4, 2022
ccad977
Merge branch 'main' into feature/meta-feature-names-out
CarloLepelaars Nov 4, 2022
601594a
Comment to clarify `__sklearn_is_fitted`
CarloLepelaars Nov 4, 2022
875c11c
Add link to clarify `__sklearn_is_fitted`
CarloLepelaars Nov 6, 2022
4eb3597
Add link in comment to clarify `__sklearn_is_fitted`
CarloLepelaars Nov 6, 2022
fa1e9b2
Merge branch 'main' into feature/meta-feature-names-out
CarloLepelaars Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion sklego/meta/estimator_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def fit(self, X, y, **kwargs):
"""Fits the estimator"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES, multi_output=True)
self.multi_output_ = len(y.shape) > 1
self.output_len_ = y.shape[1] if self.multi_output_ else 1
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X, y, **kwargs)
return self
Expand All @@ -38,6 +39,34 @@ def transform(self, X):
Returns array of shape `(X.shape[0], )` if estimator is not multi output.
For multi output estimators an array of shape `(X.shape[0], y.shape[1])` is returned.
"""
check_is_fitted(self, "estimator_")
CarloLepelaars marked this conversation as resolved.
Show resolved Hide resolved
check_is_fitted(self)
output = getattr(self.estimator_, self.predict_func)(X)
return output if self.multi_output_ else output.reshape(-1, 1)

def get_feature_names_out(self, feature_names_out=None) -> list:
"""
Defines descriptive names for each output of the (fitted) estimator.
:param feature_names_out: Redundant parameter for which the contents are ignored in this function.
CarloLepelaars marked this conversation as resolved.
Show resolved Hide resolved
feature_names_out is defined here because EstimatorTransformer can be part of a larger complex pipeline.
Some components may depend on defined feature_names_out and some not, but it is passed to all components
in the pipeline if `Pipeline.get_feature_names_out` is called. feature_names_out is therefore necessary
to define here to avoid `TypeError`s when used within a scikit-learn `Pipeline` object.
:return: List of descriptive names for each output variable from the fitted estimator.
"""
check_is_fitted(self)
estimator_name_lower = self.estimator_.__class__.__name__.lower()
if self.multi_output_:
feature_names = [f"{estimator_name_lower}_{i}" for i in range(self.output_len_)]
else:
feature_names = [estimator_name_lower]
return feature_names

def __sklearn_is_fitted(self) -> bool:
"""
Custom additional requirements that need to be satisfied to pass check_is_fitted.
:return: Boolean indicating if the additional requirements
for determining check_is_fitted are satisfied.
"""
has_fit_attr = all(hasattr(self, attr) for attr in ["multi_output_", "output_len_"])
CarloLepelaars marked this conversation as resolved.
Show resolved Hide resolved
check_is_fitted(self.estimator_)
return has_fit_attr
56 changes: 55 additions & 1 deletion tests/test_meta/test_estimatortransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.utils import check_X_y

from sklearn.exceptions import NotFittedError
from sklego.common import flatten
from sklego.meta import EstimatorTransformer
from tests.conftest import transformer_checks, general_checks
Expand Down Expand Up @@ -111,3 +111,57 @@ def test_kwargs(patched_clone, random_xy_dataset_clf):
np.testing.assert_array_equal(
sample_weights, estimator.fit.call_args[1]['sample_weight']
)


def test_get_feature_names_out(random_xy_dataset_clf):
X, y = random_xy_dataset_clf
pipeline = EstimatorTransformer(LinearRegression())

# We shouldn't be able to call get_feature_names_out before estimator is fitted.
with pytest.raises(NotFittedError):
pipeline.get_feature_names_out()

pipeline.fit(X, y)
feature_names = pipeline.get_feature_names_out()

estimator_name_lower = pipeline.estimator.__class__.__name__.lower()
expected_feature_names = [estimator_name_lower]
np.testing.assert_array_equal(feature_names, expected_feature_names)


def test_get_feature_names_out_multitarget(random_xy_dataset_multitarget):
X, y = random_xy_dataset_multitarget
pipeline = EstimatorTransformer(LinearRegression())

pipeline.fit(X, y)
feature_names = pipeline.get_feature_names_out()

estimator_name_lower = pipeline.estimator.__class__.__name__.lower()
expected_feature_names = [f"{estimator_name_lower}_{i}" for i in range(pipeline.output_len_)]
np.testing.assert_array_equal(feature_names, expected_feature_names)


def test_get_feature_names_out_featureunion(random_xy_dataset_clf):
X, y = random_xy_dataset_clf
pipeline = Pipeline(
[
(
"ml_features",
FeatureUnion(
[
("model_1", EstimatorTransformer(LinearRegression())),
("model_2", EstimatorTransformer(Ridge())),
]
),
)
]
)

# We shouldn't be able to call get_feature_names_out before estimator is fitted.
with pytest.raises(NotFittedError):
pipeline.get_feature_names_out()

pipeline.fit(X, y)
feature_names = pipeline.get_feature_names_out()
expected_feature_names = ["model_1__linearregression", "model_2__ridge"]
np.testing.assert_array_equal(feature_names, expected_feature_names)