Skip to content

Commit

Permalink
towards metadata routing
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Jan 21, 2024
1 parent 4757ddb commit d2277d1
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 63 deletions.
146 changes: 98 additions & 48 deletions atom/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from sklearn.base import clone
from sklearn.pipeline import Pipeline as SkPipeline
from sklearn.pipeline import _final_estimator_has

Check notice on line 17 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _final_estimator_has of a class
from sklearn.utils import _print_elapsed_time
from sklearn.utils import Bunch, _print_elapsed_time

Check notice on line 18 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _print_elapsed_time of a class
from sklearn.utils.metadata_routing import _raise_for_params, process_routing

Check notice on line 19 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _raise_for_params of a class
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_memory
from sktime.proba.normal import Normal
Expand All @@ -26,8 +27,8 @@
Verbose, XConstructor, YConstructor,
)
from atom.utils.utils import (
NotFittedError, adjust_verbosity, check_is_fitted, fit_one,
fit_transform_one, sign, transform_one, variable_return,
NotFittedError, adjust_verbosity, check_is_fitted, fit_transform_one, sign,
transform_one, variable_return,
)


Expand Down Expand Up @@ -110,7 +111,7 @@ class Pipeline(SkPipeline):
Examples
--------
```pycon
``pycon
from atom import ATOMClassifier
from sklearn.datasets import load_breast_cancer
Expand All @@ -130,7 +131,7 @@ class Pipeline(SkPipeline):
# Get the pipeline and make predictions
pl = atom.lr.export_pipeline()
print(pl.predict(X))
```
``
"""

Expand All @@ -141,7 +142,8 @@ def __init__(
memory: str | Memory | None = None,
verbose: Verbose | None = 0,
):
super().__init__(steps, memory=memory, verbose=verbose)
super().__init__(steps, memory=memory, verbose=False)
self._verbose = verbose

def __bool__(self):
"""Whether the pipeline has at least one estimator."""
Expand Down Expand Up @@ -263,7 +265,7 @@ def _fit(
self,
X: XConstructor | None = None,

Check notice on line 266 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YConstructor | None = None,
**fit_params_steps,
routed_params: dict[str, Bunch] | None = None,
) -> tuple[DataFrame | None, Pandas | None]:
"""Get data transformed through the pipeline.
Expand All @@ -276,8 +278,8 @@ def _fit(
y: dict, sequence, dataframe or None, default=None
Target column corresponding to `X`.
**fit_params
Additional keyword arguments for the fit method.
routed_params: dict or None, default=None
Metadata parameters routed for the fit method.
Returns
-------
Expand Down Expand Up @@ -309,14 +311,16 @@ def _fit(
if hasattr(transformer, attr):
setattr(cloned, attr, getattr(transformer, attr))

with adjust_verbosity(cloned, self.verbose):
with adjust_verbosity(cloned, self._verbose):
# Fit or load the current estimator from cache
# Type ignore because routed_params is never None but
# the signature of _fit needs to comply with sklearn's
X, y, fitted_transformer = self._mem_fit(

Check notice on line 318 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
transformer=cloned,
X=X,
y=y,
message=self._log_message(step),
**fit_params_steps.get(name, {}),
routed_params=routed_params[name], # type: ignore[index]
)

# Replace the estimator of the step with the fitted
Expand All @@ -329,10 +333,14 @@ def fit(
self,
X: XConstructor | None = None,

Check notice on line 334 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YConstructor | None = None,
**fit_params,
**params,
) -> Self:
"""Fit the pipeline.
Fit all the transformers one after the other and sequentially
transform the data. Finally, fit the transformed data using the
final estimator.
Parameters
----------
X: dataframe-like or None, default=None
Expand All @@ -342,23 +350,31 @@ def fit(
y: dict, sequence, dataframe or None, default=None
Target column corresponding to `X`.
**fit_params
Additional keyword arguments for the fit method.
**params
- If `enable_metadata_routing=False` (default):
Parameters passed to the `fit` method of each step,
where each parameter name is prefixed such that
parameter `p` for step `s` has key `s__p`.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step
must have requested certain metadata for these parameters
to be forwarded to them.
Returns
-------
self
Estimator instance.
Pipeline with fitted steps.
"""
fit_params_steps = self._check_fit_params(**fit_params)
X, y = self._fit(X, y, **fit_params_steps)
routed_params = self._check_method_params(method="fit", props=params)
X, y = self._fit(X, y, routed_params)

Check notice on line 373 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
last_step = self._final_estimator
if last_step is not None and last_step != "passthrough":
with adjust_verbosity(last_step, self.verbose):
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
fit_one(last_step, X, y, **fit_params_last_step)
if self._final_estimator is not None and self._final_estimator != "passthrough":
with adjust_verbosity(self._final_estimator, self._verbose):
self._final_estimator.fit(X, y, **routed_params[self.steps[-1][0]]["fit"])

return self

Expand All @@ -367,7 +383,7 @@ def fit_transform(
self,
X: XConstructor | None = None,

Check notice on line 384 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YConstructor | None = None,
**fit_params,
**params,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Fit the pipeline and transform the data.
Expand All @@ -388,8 +404,18 @@ def fit_transform(
y: dict, sequence, dataframe or None, default=None
Target column corresponding to `X`.
**fit_params
Additional keyword arguments for the fit method.
**params
- If `enable_metadata_routing=False` (default):
Parameters passed to the `fit` method of each step,
where each parameter name is prefixed such that
parameter `p` for step `s` has key `s__p`.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step
must have requested certain metadata for these parameters
to be forwarded to them.
Returns
-------
Expand All @@ -400,8 +426,8 @@ def fit_transform(
Transformed target column. Only returned if provided.
"""
fit_params_steps = self._check_fit_params(**fit_params)
X, y = self._fit(X, y, **fit_params_steps)
routed_params = self._check_method_params(method="fit_transform", props=params)
X, y = self._fit(X, y, routed_params)

Check notice on line 430 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Don't clone when caching is disabled to preserve backward compatibility
if self.memory.location is None:
Expand All @@ -413,9 +439,8 @@ def fit_transform(
if last_step is None or last_step == "passthrough":
return variable_return(X, y)

with adjust_verbosity(last_step, self.verbose):
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
X, y, _ = self._mem_fit(last_step, X, y, **fit_params_last_step)
with adjust_verbosity(last_step, self._verbose):
X, y, _ = self._mem_fit(last_step, X, y, routed_params[self.steps[-1][0]])

Check notice on line 443 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return variable_return(X, y)

Expand All @@ -424,7 +449,7 @@ def transform(
self,
X: XConstructor | None = None,

Check notice on line 450 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YConstructor | None = None,
**kwargs,
**params,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Transform the data.
Expand All @@ -444,8 +469,10 @@ def transform(
y: dict, sequence, dataframe or None, default=None
Target column corresponding to `X`.
**kwargs
Additional keyword arguments for the `_iter` inner method.
**params
Parameters requested and accepted by steps. Each step must
have requested certain metadata for these parameters to be
forwarded to them.
Returns
-------
Expand All @@ -459,9 +486,17 @@ def transform(
if X is None and y is None:
raise ValueError("X and y cannot be both None.")

for _, _, transformer in self._iter(**kwargs):
with adjust_verbosity(transformer, self.verbose):
X, y = self._mem_transform(transformer, X, y)
_raise_for_params(params, self, "transform")

routed_params = process_routing(self, "transform", **params)
for _, name, transformer in self._iter():
with adjust_verbosity(transformer, self._verbose):
X, y = self._mem_transform(

Check notice on line 494 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
transformer=transformer,
X=X,
y=y,
**routed_params[name].transform,
)

return variable_return(X, y)

Expand All @@ -470,6 +505,7 @@ def inverse_transform(
self,
X: XConstructor | None = None,

Check notice on line 506 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YConstructor | None = None,
**params,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Inverse transform for each step in a reverse order.
Expand All @@ -485,6 +521,11 @@ def inverse_transform(
y: dict, sequence, dataframe or None, default=None
Target column corresponding to `X`.
**params
Parameters requested and accepted by steps. Each step must
have requested certain metadata for these parameters to be
forwarded to them.
Returns
-------
dataframe
Expand All @@ -497,9 +538,18 @@ def inverse_transform(
if X is None and y is None:
raise ValueError("X and y cannot be both None.")

for _, _, transformer in reversed(list(self._iter())):
with adjust_verbosity(transformer, self.verbose):
X, y = self._mem_transform(transformer, X, y, method="inverse_transform")
_raise_for_params(params, self, "inverse_transform")

routed_params = process_routing(self, "inverse_transform", **params)
for _, name, transformer in reversed(list(self._iter())):
with adjust_verbosity(transformer, self._verbose):
X, y = self._mem_transform(

Check notice on line 546 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
transformer=transformer,
X=X,
y=y,
method="inverse_transform",
**routed_params[name].inverse_transform,
)

return variable_return(X, y)

Expand All @@ -522,7 +572,7 @@ def decision_function(self, X: XConstructor) -> np.ndarray:
"""
for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, _ = self._mem_transform(transformer, X)

Check notice on line 576 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return self.steps[-1][1].decision_function(X)
Expand Down Expand Up @@ -564,7 +614,7 @@ def predict(
raise ValueError("X and fh cannot be both None.")

for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, _ = self._mem_transform(transformer, X)

Check notice on line 618 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if "fh" in sign(self.steps[-1][1].predict):
Expand Down Expand Up @@ -604,7 +654,7 @@ def predict_interval(
"""
for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, y = self._mem_transform(transformer, X)

Check notice on line 658 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return self.steps[-1][1].predict_interval(fh=fh, X=X, coverage=coverage)
Expand All @@ -626,7 +676,7 @@ def predict_log_proba(self, X: XConstructor) -> np.ndarray:
"""
for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, _ = self._mem_transform(transformer, X)

Check notice on line 680 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return self.steps[-1][1].predict_log_proba(X)
Expand Down Expand Up @@ -670,7 +720,7 @@ def predict_proba(
raise ValueError("X and fh cannot be both None.")

for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, _ = self._mem_transform(transformer, X)

Check notice on line 724 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if "fh" in sign(self.steps[-1][1].predict_proba):
Expand Down Expand Up @@ -711,7 +761,7 @@ def predict_quantiles(
"""
for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, y = self._mem_transform(transformer, X)

Check notice on line 765 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return self.steps[-1][1].predict_quantiles(fh=fh, X=X, alpha=alpha)
Expand Down Expand Up @@ -740,7 +790,7 @@ def predict_residuals(
"""
for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, y = self._mem_transform(transformer, X, y)

Check notice on line 794 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return self.steps[-1][1].predict_residuals(y=y, X=X)
Expand Down Expand Up @@ -775,7 +825,7 @@ def predict_var(
"""
for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, _ = self._mem_transform(transformer, X)

Check notice on line 829 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return self.steps[-1][1].predict_var(fh=fh, X=X, cov=cov)
Expand Down Expand Up @@ -817,7 +867,7 @@ def score(
raise ValueError("X and y cannot be both None.")

for _, _, transformer in self._iter(with_final=False):
with adjust_verbosity(transformer, self.verbose):
with adjust_verbosity(transformer, self._verbose):
X, y = self._mem_transform(transformer, X, y)

Check notice on line 871 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if "fh" in sign(self.steps[-1][1].score):
Expand Down
2 changes: 1 addition & 1 deletion atom/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class Estimator(Protocol):
def __init__(self, *args, **kwargs): ...
def get_params(self, *args, **kwargs): ...
def set_params(self, *args, **kwargs): ...
def fit(self, *args, **kwargs): ...


@runtime_checkable
Expand All @@ -135,7 +136,6 @@ def transform(self, *args, **kwargs): ...
class Predictor(Estimator, Protocol):
"""Protocol for sklearn-like predictors."""

def fit(self, *args, **kwargs): ...
def predict(self, *args, **kwargs): ...


Expand Down
Loading

0 comments on commit d2277d1

Please sign in to comment.