Skip to content

Commit

Permalink
add get_feature_names_out
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Jan 28, 2024
1 parent b14a15a commit 22bb8c6
Show file tree
Hide file tree
Showing 121 changed files with 365 additions and 236 deletions.
12 changes: 8 additions & 4 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BaseRunner(BaseTracker, metaclass=ABCMeta):
"""

def __getstate__(self) -> dict[str, Any]:
"""Require to store an extra attribute with the package versions."""
"""Require storing an extra attribute with the package versions."""
return self.__dict__ | {"_versions": get_versions(self._models)}

def __setstate__(self, state: dict[str, Any]):
Expand Down Expand Up @@ -149,6 +149,10 @@ def __getitem__(self, item: Int | str | list) -> Any:
else:
return self.dataset[item] # Get subset of dataset

def __sklearn_is_fitted__(self) -> bool:
"""Return fitted when there are trained models."""
return bool(self._models)

# Utility properties =========================================== >>

@cached_property
Expand Down Expand Up @@ -1022,7 +1026,7 @@ def evaluate(
Scores of the models.
"""
check_is_fitted(self, attributes="_models")
check_is_fitted(self)

return pd.DataFrame([m.evaluate(metric, rows, threshold=threshold) for m in self._models])

Expand Down Expand Up @@ -1380,7 +1384,7 @@ def stacking(
parameter, e.g., `atom.stacking(final_estimator="LR")`.
"""
check_is_fitted(self, attributes="_models")
check_is_fitted(self)
models_c = self._get_models(models, ensembles=False, branch=self.branch)

if len(models_c) < 2:
Expand Down Expand Up @@ -1465,7 +1469,7 @@ def voting(
- For forecast tasks: [EnsembleForecaster][].
"""
check_is_fitted(self, attributes="_models")
check_is_fitted(self)
models_c = self._get_models(models, ensembles=False, branch=self.branch)

if len(models_c) < 2:
Expand Down
4 changes: 2 additions & 2 deletions atom/basetrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(
self.parallel = parallel
self.errors = errors

self._models = lst(models) if models is not None else []
self._metric = lst(metric) if metric is not None else []
self._models = lst(models) if models is not None else ClassMap()
self._metric = lst(metric) if metric is not None else ClassMap()

self._config = DataConfig()
self._branches = BranchManager(memory=self.memory)
Expand Down
144 changes: 111 additions & 33 deletions atom/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from collections.abc import Hashable
from logging import Logger
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, TypeVar
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand All @@ -34,15 +35,20 @@
TomekLinks,
)
from scipy.stats import zscore
from sklearn.base import BaseEstimator, _clone_parametrized
from sklearn.base import (
BaseEstimator, OneToOneFeatureMixin, _clone_parametrized,

Check notice on line 39 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _clone_parametrized of a class
)
from sklearn.compose import ColumnTransformer
from sklearn.experimental import enable_iterative_imputer # noqa: F401
from sklearn.impute import IterativeImputer, KNNImputer
from sklearn.utils._set_output import _SetOutputMixin

Check notice on line 44 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _set_output of a class

Check notice on line 44 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _SetOutputMixin of a class
from sklearn.utils.validation import _check_feature_names_in

Check notice on line 45 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _check_feature_names_in of a class
from sktime.transformations.series.detrend import Deseasonalizer, Detrender
from typing_extensions import Self

from atom.basetransformer import BaseTransformer
from atom.utils.constants import CAT_TYPES, DEFAULT_MISSING
from atom.utils.patches import wrap_method_output
from atom.utils.types import (
Bins, Bool, CategoricalStrats, DataFrame, DiscretizerStrats, Engine,
Estimator, FloatLargerZero, IntLargerEqualZero, IntLargerTwo,
Expand All @@ -52,12 +58,15 @@
series_t,
)
from atom.utils.utils import (
Goal, bk, composed, crash, get_col_order, get_cols, it, lst, merge,
method_to_log, n_cols, replace_missing, sign, to_df, to_series,
variable_return, wrap_methods,
Goal, bk, check_is_fitted, composed, crash, get_col_order, get_cols, it,
lst, merge, method_to_log, n_cols, replace_missing, sign, to_df, to_series,
variable_return, wrap_transformer_methods,
)


T = TypeVar("T", bound=Transformer)


@beartype
class TransformerMixin(BaseEstimator, BaseTransformer):
"""Mixin class for all transformers in ATOM.
Expand All @@ -74,12 +83,14 @@ class TransformerMixin(BaseEstimator, BaseTransformer):

def __init_subclass__(cls, **kwargs):
"""Wrap transformer methods to apply data and fit check."""
super().__init_subclass__(**kwargs)

for k in ("fit", "transform", "inverse_transform"):
setattr(cls, k, wrap_methods(getattr(cls, k)))
setattr(cls, k, wrap_transformer_methods(getattr(cls, k)))

# Patch to avoid errors for transformers that allow passing only y
with patch("sklearn.utils._set_output._wrap_method_output", wrap_method_output):
super().__init_subclass__(**kwargs)

def __sklearn_clone__(self):
def __sklearn_clone__(self: T) -> T:
"""Wrap cloning method to attach internal attributes."""
cloned = _clone_parametrized(self)

Expand All @@ -89,6 +100,7 @@ def __sklearn_clone__(self):

return cloned

@composed(crash, method_to_log)
def fit(
self,
X: DataFrame | None = None,

Check notice on line 106 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
Expand Down Expand Up @@ -216,7 +228,7 @@ def inverse_transform(


@beartype
class Balancer(TransformerMixin):
class Balancer(TransformerMixin, OneToOneFeatureMixin, _SetOutputMixin):
"""Balance the number of samples per class in the target column.
When oversampling, the newly created samples have an increasing
Expand Down Expand Up @@ -551,7 +563,7 @@ def log_changes(y):


@beartype
class Cleaner(TransformerMixin):
class Cleaner(TransformerMixin, _SetOutputMixin):
"""Applies standard data cleaning steps on a dataset.
Use the parameters to choose which transformations to perform.
Expand Down Expand Up @@ -750,13 +762,17 @@ def fit(self, X: DataFrame | None = None, y: Pandas | None = None) -> Self:
"""
self.mapping_: dict[str, Any] = {}

Check notice on line 764 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute mapping_ defined outside __init__
self._drop_cols = []
self._estimators = {}

Check notice on line 766 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimators defined outside __init__

if not hasattr(self, "missing_"):
self.missing_ = DEFAULT_MISSING

Check notice on line 769 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute missing_ defined outside __init__

self._log("Fitting Cleaner...", 1)

if X is not None and self.drop_dtypes is not None:
self._drop_cols = list(X.select_dtypes(include=list(self.drop_dtypes)).columns)

Check notice on line 774 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _drop_cols defined outside __init__

if y is not None:
if isinstance(y, series_t):
self.target_names_in_ = np.array([y.name])
Expand Down Expand Up @@ -787,6 +803,32 @@ def fit(self, X: DataFrame | None = None, y: Pandas | None = None) -> Self:

return self

def get_feature_names_out(self, input_features: Sequence[str] | None = None) -> list[str]:
"""Get output feature names for transformation.
Parameters
----------
input_features: sequence or None, default=None
Only used to validate feature names with the names seen in
`fit`.
Returns
-------
np.ndarray
Transformed feature names.
"""
check_is_fitted(self, attributes="feature_names_in_")
_check_feature_names_in(self, input_features)

columns = [col for col in self.feature_names_in_ if col not in self._drop_cols]

if self.drop_chars:
# Drop prohibited chars from column names
columns = [re.sub(self.drop_chars, "", str(c)) for c in columns]

return columns

@composed(crash, method_to_log)
def transform(
self,
Expand Down Expand Up @@ -829,18 +871,15 @@ def transform(
X = replace_missing(X, self.missing_)

Check notice on line 871 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

for name, column in X.items():
dtype = column.dtype.name

# Drop features with an invalid data type
if dtype in lst(self.drop_dtypes):
if name in self._drop_cols:
self._log(
f" --> Dropping feature {name} for having a prohibited type: {dtype}.",
2,
f" --> Dropping feature {name} for "
f"having type: {column.dtype.name}.", 2
)
X = X.drop(columns=name)

Check notice on line 880 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
continue

elif dtype in CAT_TYPES:
elif column.dtype.name in CAT_TYPES:
if self.strip_categorical:
# Strip strings from blank spaces
X[name] = column.apply(
Expand Down Expand Up @@ -977,7 +1016,7 @@ def inverse_transform(


@beartype
class Decomposer(TransformerMixin):
class Decomposer(TransformerMixin, OneToOneFeatureMixin, _SetOutputMixin):
"""Detrend and deseasonalize the time series.
This class does two things:
Expand Down Expand Up @@ -1216,7 +1255,7 @@ def inverse_transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:


@beartype
class Discretizer(TransformerMixin):
class Discretizer(TransformerMixin, OneToOneFeatureMixin, _SetOutputMixin):
"""Bin continuous data into intervals.
For each feature, the bin edges are computed during fit and,
Expand Down Expand Up @@ -1559,7 +1598,7 @@ def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:


@beartype
class Encoder(TransformerMixin):
class Encoder(TransformerMixin, _SetOutputMixin):
"""Perform encoding of categorical features.
The encoding type depends on the number of classes in the column:
Expand Down Expand Up @@ -1869,6 +1908,29 @@ def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

return self

def get_feature_names_out(self, input_features: Sequence[str] | None = None) -> list[str]:
"""Get output feature names for transformation.
Parameters
----------
input_features: sequence or None, default=None
Only used to validate feature names with the names seen in
`fit`.
Returns
-------
np.ndarray
Transformed feature names.
"""
check_is_fitted(self, attributes="feature_names_in_")
_check_feature_names_in(self, input_features)

# Drop _nan columns (since missing values are propagated)
cols = [c for c in self._estimator.get_feature_names_out() if not c.endswith("_nan")]

return get_col_order(cols, self.feature_names_in_, self._estimator.feature_names_in_)

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 1935 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 1935 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'y' value is not used
"""Encode the data.
Expand Down Expand Up @@ -1916,14 +1978,11 @@ def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Xt = self._estimator.transform(X)

Check notice on line 1979 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Drop _nan columns (since missing values are propagated)
Xt = Xt.loc[:, ~Xt.columns.str.endswith("_nan")]

return Xt[get_col_order(Xt, X.columns.tolist(), self._estimator.feature_names_in_)]
return Xt[self.get_feature_names_out()]


@beartype
class Imputer(TransformerMixin):
class Imputer(TransformerMixin, _SetOutputMixin):
"""Handle missing values in the data.
Impute or remove missing values according to the selected strategy.
Expand Down Expand Up @@ -2203,6 +2262,25 @@ def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

return self

def get_feature_names_out(self, input_features: Sequence[str] | None = None) -> list[str]:
"""Get output feature names for transformation.
Parameters
----------
input_features: sequence or None, default=None
Only used to validate feature names with the names seen in
`fit`.
Returns
-------
np.ndarray
Transformed feature names.
"""
check_is_fitted(self, attributes="feature_names_in_")
_check_feature_names_in(self, input_features)
return [c for c in self.feature_names_in_ if c in self._estimator.get_feature_names_out()]

@composed(crash, method_to_log)
def transform(
self,
Expand Down Expand Up @@ -2329,20 +2407,20 @@ def transform(
2,
)

X = self._estimator.transform(X)
Xt = self._estimator.transform(X)

Check notice on line 2410 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Make y consistent with X
if y is not None:
y = y[y.index.isin(X.index)]
y = y[y.index.isin(Xt.index)]

# Reorder columns to original order
X = X[[col for col in self.feature_names_in_ if col in X.columns]]
Xt = Xt[self.get_feature_names_out()]

Check notice on line 2417 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return variable_return(X, y)
return variable_return(Xt, y)


@beartype
class Normalizer(TransformerMixin):
class Normalizer(TransformerMixin, OneToOneFeatureMixin, _SetOutputMixin):
"""Transform the data to follow a Normal/Gaussian distribution.
This transformation is useful for modeling issues related to
Expand Down Expand Up @@ -2604,7 +2682,7 @@ def inverse_transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:


@beartype
class Pruner(TransformerMixin):
class Pruner(TransformerMixin, OneToOneFeatureMixin, _SetOutputMixin):
"""Prune outliers from the data.
Replace or remove outliers. The definition of outlier depends
Expand Down Expand Up @@ -2930,7 +3008,7 @@ def transform(


@beartype
class Scaler(TransformerMixin):
class Scaler(TransformerMixin, OneToOneFeatureMixin, _SetOutputMixin):
"""Scale the data.
Apply one of sklearn's scaling strategies. Categorical columns
Expand Down
Loading

0 comments on commit 22bb8c6

Please sign in to comment.