diff --git a/atom/atom.py b/atom/atom.py
index 619202bfd..81d957a2e 100644
--- a/atom/atom.py
+++ b/atom/atom.py
@@ -1699,14 +1699,15 @@ def impute(
Impute or remove missing values according to the selected
strategy. Also removes rows and columns with too many missing
- values. Use the `missing` attribute to customize what are
- considered "missing values".
+ values.
See the [Imputer][] class for a description of the parameters.
!!! tip
- Use the [nans][self-nans] attribute to check the amount of
- missing values per column.
+ - Use the [nans][self-nans] attribute to check the amount of
+ missing values per column.
+ - Use the [`missing`][self-missing] attribute to customize
+ what are considered "missing values".
"""
columns = kwargs.pop("columns", None)
diff --git a/atom/basemodel.py b/atom/basemodel.py
index 71dd22a63..0e246a025 100644
--- a/atom/basemodel.py
+++ b/atom/basemodel.py
@@ -309,7 +309,7 @@ def fullname(self) -> str:
"""Return the model's class name."""
return self.__class__.__name__
- @property
+ @cached_property
def _est_class(self) -> type[Predictor]:
"""Return the estimator's class (not instance).
@@ -698,8 +698,8 @@ def _get_pred(
method=method_caller,
)
- except ValueError as ex:
- # Fails for models that don't allow in-sample predictions
+ except (ValueError, NotImplementedError) as ex:
+ # Can fail for models that don't allow in-sample predictions
self._log(
f"Failed to get predictions for model {self.name} "
f"on rows {rows}. Returning NaN. Exception: {ex}.", 3
diff --git a/atom/basetransformer.py b/atom/basetransformer.py
index bc63f9c64..befc53a3b 100644
--- a/atom/basetransformer.py
+++ b/atom/basetransformer.py
@@ -32,6 +32,7 @@
from joblib.memory import Memory
from pandas._typing import Axes
from ray.util.joblib import register_ray
+from sklearn.base import OneToOneFeatureMixin
from sklearn.utils.validation import check_memory
from atom.utils.types import (
@@ -40,7 +41,9 @@
Pandas, Sequence, Severity, Verbose, Warnings, XSelector, YSelector,
bool_t, dataframe_t, int_t, sequence_t,
)
-from atom.utils.utils import crash, flt, lst, n_cols, to_df, to_pandas
+from atom.utils.utils import (
+ crash, flt, lst, n_cols, to_df, to_pandas, wrap_fit,
+)
T_Estimator = TypeVar("T_Estimator", bound=Estimator)
@@ -373,6 +376,11 @@ def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator
to that of this instance. If `obj` is a meta-estimator, it
also adjusts the parameters of the base estimator.
+ Additionally, the `fit` method of non-sklearn objects is wrapped
+ to always add the `n_features_in_` and `feature_names_in_`
+ attributes, and the `get-feature_names_out` method is added to
+ transformers that don't have it already.
+
Parameters
----------
obj: Estimator
@@ -398,6 +406,12 @@ def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator
else:
obj.set_params(**{p: lst(self._config.sp.sp)[0]})
+ if hasattr(obj, "fit") and "sklearn" not in obj.__module__:
+ obj.__class__.fit = wrap_fit(obj.__class__.fit) # type: ignore[method-assign]
+ if hasattr(obj, "transform") and not hasattr(obj, "get_feature_names_out"):
+ # We assume here that the transformer does not create nor remove columns
+ obj.__class__.get_feature_names_out = OneToOneFeatureMixin.get_feature_names_out
+
return obj
def _get_est_class(self, name: str, module: str) -> type[Estimator]:
diff --git a/atom/data_cleaning.py b/atom/data_cleaning.py
index be9d92eb2..a95f46055 100644
--- a/atom/data_cleaning.py
+++ b/atom/data_cleaning.py
@@ -46,6 +46,7 @@
from sktime.transformations.series.detrend import (
ConditionalDeseasonalizer, Deseasonalizer, Detrender,
)
+from sktime.transformations.series.impute import Imputer as sktimeImputer
from typing_extensions import Self
from atom.basetransformer import BaseTransformer
@@ -53,11 +54,11 @@
from atom.utils.patches import wrap_method_output
from atom.utils.types import (
Bins, Bool, CategoricalStrats, DataFrame, DiscretizerStrats, Engine,
- Estimator, FloatLargerZero, IntLargerEqualZero, IntLargerTwo,
- IntLargerZero, NJobs, NormalizerStrats, NumericalStrats, Pandas, Predictor,
- PrunerStrats, Scalar, ScalerStrats, SeasonalityModels, Sequence, Series,
- Transformer, Verbose, XSelector, YSelector, dataframe_t, sequence_t,
- series_t,
+ EngineTuple, Estimator, FloatLargerZero, Int, IntLargerEqualZero,
+ IntLargerTwo, IntLargerZero, NJobs, NormalizerStrats, NumericalStrats,
+ Pandas, Predictor, PrunerStrats, Scalar, ScalerStrats, SeasonalityModels,
+ Sequence, Series, Transformer, Verbose, XSelector, YSelector, dataframe_t,
+ sequence_t, series_t,
)
from atom.utils.utils import (
Goal, bk, check_is_fitted, composed, crash, get_col_order, get_cols, it,
@@ -92,6 +93,17 @@ def __init_subclass__(cls, **kwargs):
with patch("sklearn.utils._set_output._wrap_method_output", wrap_method_output):
super().__init_subclass__(**kwargs)
+ def __repr__(self, N_CHAR_MAX: Int = 700) -> str:
+ """Drop named tuples if default parameters from string representation."""
+ out = super().__repr__(N_CHAR_MAX)
+
+ # Remove default engine for cleaner representation
+ if hasattr(self, "engine") and self.engine == EngineTuple():
+ out = re.sub(r"engine=EngineTuple\(data='numpy', estimator='sklearn'\)", "", out)
+ out = re.sub(r"((?<=\(),\s|,\s(?=\))|,\s(?=,\s))", "", out) # Drop comma-spaces
+
+ return out
+
def __sklearn_clone__(self: T) -> T:
"""Wrap cloning method to attach internal attributes."""
cloned = _clone_parametrized(self)
@@ -1521,10 +1533,6 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]:
return labels
- Xt, yt = self._check_input(X, y)
- self._check_feature_names(Xt, reset=True)
- self._check_n_features(Xt, reset=True)
-
self._estimators: dict[str, Estimator] = {}
self._labels: dict[str, Sequence[str]] = {}
@@ -1548,7 +1556,7 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]:
raise ValueError(
"Invalid value for the bins parameter. The length of the "
"bins does not match the length of the columns, got len"
- f"(bins)={len(bins_c)} and len(columns)={Xt.shape[1]}."
+ f"(bins)={len(bins_c)} and len(columns)={X.shape[1]}."
) from None
else:
bins_x = bins_c
@@ -1566,7 +1574,7 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]:
encode="ordinal",
strategy=self.strategy,
**kwargs,
- ).fit(Xt[[col]])
+ ).fit(X[[col]])
# Save labels for transform method
self._labels[col] = get_labels(
@@ -1592,7 +1600,7 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]:
self._estimators[col] = FunctionTransformer(
func=bk.cut,
kw_args={"bins": bins_c, "labels": get_labels(col, bins_c)},
- ).fit(Xt[[col]])
+ ).fit(X[[col]])
return self
@@ -2021,7 +2029,7 @@ class Imputer(TransformerMixin, _SetOutputMixin):
Impute or remove missing values according to the selected strategy.
Also removes rows and columns with too many missing values. Use
- the `missing` attribute to customize what are considered "missing
+ the `missing_` attribute to customize what are considered "missing
values".
This class can be accessed from atom through the [impute]
@@ -2036,9 +2044,18 @@ class Imputer(TransformerMixin, _SetOutputMixin):
- "drop": Drop rows containing missing values.
- "mean": Impute with mean of column.
- "median": Impute with median of column.
+ - "most_frequent": Impute with the most frequent value.
- "knn": Impute using a K-Nearest Neighbors approach.
- "iterative": Impute using a multivariate imputer.
- - "most_frequent": Impute with the most frequent value.
+ - "drift": Impute values using a [PolynomialTrend][] model.
+ - "linear": Impute using linear interpolation.
+ - "nearest": Impute with nearest value.
+ - "bfill": Impute by using the next valid observation to fill
+ the gap.
+ - "ffill": Impute by propagating the last valid observation
+ to next valid.
+ - "random": Impute with random values between the min and max
+ of column.
- int or float: Impute with provided numerical value.
strat_cat: str, default="drop"
@@ -2263,6 +2280,15 @@ def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:
num_imputer = IterativeImputer(random_state=self.random_state)
elif self.strat_num == "drop":
num_imputer = "passthrough"
+ else:
+ # Inherit sklearn's attributes and methods
+ num_imputer = self._inherit(
+ sktimeImputer(
+ method=self.strat_num,
+ missing_values=[pd.NA],
+ random_state=self.random_state,
+ )
+ )
else:
num_imputer = SimpleImputer(
missing_values=pd.NA,
@@ -2401,8 +2427,7 @@ def transform(
if name not in self._estimator.feature_names_in_:
self._log(
f" --> Dropping feature {name}. Contains {nans} "
- f"({nans * 100 // len(X)}%) missing values.",
- 2,
+ f"({nans * 100 // len(X)}%) missing values.", 2,
)
X = X.drop(columns=name)
continue
@@ -2411,34 +2436,34 @@ def transform(
if not isinstance(self.strat_num, str):
self._log(
f" --> Imputing {nans} missing values with "
- f"number '{self.strat_num}' in feature {name}.",
- 2,
+ f"number '{self.strat_num}' in column {name}.", 2,
)
elif self.strat_num in ("knn", "iterative"):
self._log(
f" --> Imputing {nans} missing values using "
- f"the {self.strat_num} imputer in feature {name}.",
- 2,
+ f"the {self.strat_num} imputer in column {name}.", 2,
+ )
+ elif self.strat_num in ("mean", "median", "most_frequent"):
+ self._log(
+ f" --> Imputing {nans} missing values with {self.strat_num} "
+ f"({np.round(get_stat(num_imputer, name), 2)}) in column "
+ f"{name}.", 2,
)
- elif self.strat_num != "drop": # mean, median or most_frequent
+ else:
self._log(
f" --> Imputing {nans} missing values with {self.strat_num} "
- f"({np.round(get_stat(num_imputer, name), 2)}) in feature "
- f"{name}.",
- 2,
+ f"in column {name}.", 2,
)
elif self.strat_cat != "drop" and name in cat_imputer.feature_names_in_:
if self.strat_cat == "most_frequent":
self._log(
f" --> Imputing {nans} missing values with most_frequent "
- f"({get_stat(cat_imputer, name)}) in feature {name}.",
- 2,
+ f"({get_stat(cat_imputer, name)}) in column {name}.", 2,
)
elif self.strat_cat != "drop":
self._log(
f" --> Imputing {nans} missing values with value "
- f"'{self.strat_cat}' in feature {name}.",
- 2,
+ f"'{self.strat_cat}' in column {name}.", 2,
)
Xt = self._estimator.transform(X)
@@ -2969,8 +2994,7 @@ def transform(
cond = np.abs(z_scores) > self.max_sigma
objective = objective.mask(cond, self.method)
self._log(
- f" --> Replacing {cond.sum()} outlier values with {self.method}.",
- 2,
+ f" --> Replacing {cond.sum()} outlier values with {self.method}.", 2,
)
elif self.method.lower() == "minmax":
@@ -2992,8 +3016,7 @@ def transform(
self._log(
f" --> Replacing {counts} outlier values "
- "with the min or max of the column.",
- 2,
+ "with the min or max of the column.", 2,
)
elif self.method.lower() == "drop":
@@ -3002,8 +3025,7 @@ def transform(
if len(lst(self.strategy)) > 1:
self._log(
f" --> The zscore strategy detected "
- f"{len(mask) - sum(mask)} outliers.",
- 2,
+ f"{len(mask) - sum(mask)} outliers.", 2,
)
else:
@@ -3013,8 +3035,7 @@ def transform(
if len(lst(self.strategy)) > 1:
self._log(
f" --> The {estimator.__class__.__name__} "
- f"detected {len(mask) - sum(mask)} outliers.",
- 2,
+ f"detected {len(mask) - sum(mask)} outliers.", 2,
)
# Add the estimator as attribute to the instance
diff --git a/atom/feature_engineering.py b/atom/feature_engineering.py
index 1921f97d8..fe67480f5 100644
--- a/atom/feature_engineering.py
+++ b/atom/feature_engineering.py
@@ -1540,8 +1540,7 @@ def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:
self._log(
f" --> Dropping feature {column} "
f"(score: {self.univariate_.scores_[n]:.2f} "
- f"p-value: {self.univariate_.pvalues_[n]:.2f}).",
- 2,
+ f"p-value: {self.univariate_.pvalues_[n]:.2f}).", 2,
)
X = X.drop(columns=column)
@@ -1570,8 +1569,7 @@ def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:
if hasattr(self._estimator, "ranking_"):
self._log(
f" --> Dropping feature {column} "
- f"(rank {self._estimator.ranking_[n]}).",
- 2,
+ f"(rank {self._estimator.ranking_[n]}).", 2,
)
else:
self._log(f" --> Dropping feature {column}.", 2)
diff --git a/atom/plots/dataplot.py b/atom/plots/dataplot.py
index 98453e2dc..df5e90c49 100644
--- a/atom/plots/dataplot.py
+++ b/atom/plots/dataplot.py
@@ -54,8 +54,9 @@ class DataPlot(BasePlot, metaclass=ABCMeta):
def plot_acf(
self,
columns: ColumnSelector | None = None,
- nlags: IntLargerZero | None = None,
*,
+ nlags: IntLargerZero | None = None,
+ plot_interval: Bool = True,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "upper right",
figsize: tuple[IntLargerZero, IntLargerZero] | None = None,
@@ -83,6 +84,9 @@ def plot_acf(
returned value includes lag 0 (i.e., 1), so the size of the
vector is `(nlags + 1,)`.
+ plot_interval: bool, default=True
+ Whether to plot the 95% confidence interval.
+
title: str, dict or None, default=None
Title for the plot.
@@ -174,35 +178,35 @@ def plot_acf(
yaxis=yaxis,
)
- # Add error bands
- fig.add_traces(
- [
- go.Scatter(
- x=x,
- y=np.subtract(conf[:, 1], corr),
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(col)},
- hovertemplate="%{y}upper bound",
- legendgroup=col,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- go.Scatter(
- x=x,
- y=np.subtract(conf[:, 0], corr),
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(col)},
- fill="tonexty",
- fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)",
- hovertemplate="%{y}lower bound",
- legendgroup=col,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- ]
- )
+ if plot_interval:
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=x,
+ y=conf[:, 1] - corr,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(col)},
+ hovertemplate="%{y}upper bound",
+ legendgroup=col,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ go.Scatter(
+ x=x,
+ y=conf[:, 0] - corr,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(col)},
+ fill="tonexty",
+ fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)",
+ hovertemplate="%{y}lower bound",
+ legendgroup=col,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ ]
+ )
fig.update_yaxes(zerolinecolor="black")
fig.update_layout({"hovermode": "x unified"})
@@ -227,8 +231,9 @@ def plot_ccf(
self,
columns: ColumnSelector = 0,
target: TargetSelector = 0,
- nlags: IntLargerZero | None = None,
*,
+ nlags: IntLargerZero | None = None,
+ plot_interval: Bool = False,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "upper right",
figsize: tuple[IntLargerZero, IntLargerZero] = (900, 600),
@@ -261,6 +266,9 @@ def plot_ccf(
returned value includes lag 0 (i.e., 1), so the size of the
vector is `(nlags + 1,)`.
+ plot_interval: bool, default=False
+ Whether to plot the 95% confidence interval.
+
title: str, dict or None, default=None
Title for the plot.
@@ -360,35 +368,35 @@ def plot_ccf(
yaxis=yaxis,
)
- # Add error bands
- fig.add_traces(
- [
- go.Scatter(
- x=x,
- y=conf[:, 1],
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(col)},
- hovertemplate="%{y}upper bound",
- legendgroup=col,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- go.Scatter(
- x=x,
- y=conf[:, 0],
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(col)},
- fill="tonexty",
- fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)",
- hovertemplate="%{y}lower bound",
- legendgroup=col,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- ]
- )
+ if plot_interval:
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=x,
+ y=conf[:, 1] - corr,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(col)},
+ hovertemplate="%{y}upper bound",
+ legendgroup=col,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ go.Scatter(
+ x=x,
+ y=conf[:, 0] - corr,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(col)},
+ fill="tonexty",
+ fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)",
+ hovertemplate="%{y}lower bound",
+ legendgroup=col,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ ]
+ )
fig.update_yaxes(zerolinecolor="black")
fig.update_layout({"hovermode": "x unified"})
@@ -1292,9 +1300,10 @@ def get_text(column: Series) -> Series:
def plot_pacf(
self,
columns: ColumnSelector | None = None,
+ *,
nlags: IntLargerZero | None = None,
method: PACFMethods = "ywadjusted",
- *,
+ plot_interval: Bool = True,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "upper right",
figsize: tuple[IntLargerZero, IntLargerZero] | None = None,
@@ -1306,9 +1315,8 @@ def plot_pacf(
The partial autocorrelation function (PACF) measures the
correlation between a time series and lagged versions of
itself. It's useful, for example, to identify the order of
- an autoregressive model. The transparent band represents
- the 95% confidence interval.This plot is only available
- for [forecast][time-series] tasks.
+ an autoregressive model. This plot is only available for
+ [forecast][time-series] tasks.
Parameters
----------
@@ -1327,8 +1335,8 @@ def plot_pacf(
- "yw" or "ywadjusted": Yule-Walker with sample-size
adjustment in denominator for acovf.
- - "ywm" or "ywmle": Yule-Walker without adjustment.
- - "ols" : Regression of time series on lags of it and on
+ - "ywm" or "ywmle": Yule-Walker without an adjustment.
+ - "ols": Regression of time series on lags of it and on
constant.
- "ols-inefficient": Regression of time series on lags using
a single common sample to estimate all pacf coefficients.
@@ -1338,7 +1346,10 @@ def plot_pacf(
correction.
- "ldb" or "ldbiased": Levinson-Durbin recursion without bias
correction.
- - "burg": Burg"s partial autocorrelation estimator.
+ - "burg": Burg"s partial autocorrelation estimator.
+
+ plot_interval: bool, default=True
+ Whether to plot the 95% confidence interval.
title: str, dict or None, default=None
Title for the plot.
@@ -1431,35 +1442,35 @@ def plot_pacf(
yaxis=yaxis,
)
- # Add error bands
- fig.add_traces(
- [
- go.Scatter(
- x=x,
- y=np.subtract(conf[:, 1], corr),
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(col)},
- hovertemplate="%{y}upper bound",
- legendgroup=col,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- go.Scatter(
- x=x,
- y=np.subtract(conf[:, 0], corr),
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(col)},
- fill="tonexty",
- fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)",
- hovertemplate="%{y}lower bound",
- legendgroup=col,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- ]
- )
+ if plot_interval:
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=x,
+ y=conf[:, 1] - corr,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(col)},
+ hovertemplate="%{y}upper bound",
+ legendgroup=col,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ go.Scatter(
+ x=x,
+ y=conf[:, 0] - corr,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(col)},
+ fill="tonexty",
+ fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)",
+ hovertemplate="%{y}lower bound",
+ legendgroup=col,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ ]
+ )
fig.update_yaxes(zerolinecolor="black")
fig.update_layout({"hovermode": "x unified"})
@@ -2005,6 +2016,7 @@ def plot_relationships(
def plot_rfecv(
self,
*,
+ plot_interval: Bool = True,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "upper right",
figsize: tuple[IntLargerZero, IntLargerZero] = (900, 600),
@@ -2019,6 +2031,9 @@ def plot_rfecv(
Parameters
----------
+ plot_interval: bool, default=True
+ Whether to plot the 1-sigma confidence interval.
+
title: str, dict or None, default=None
Title for the plot.
@@ -2112,35 +2127,35 @@ def plot_rfecv(
yaxis=yaxis,
)
- # Add error bands
- fig.add_traces(
- [
- go.Scatter(
- x=x,
- y=np.add(mean, std),
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)},
- hovertemplate="%{y}upper bound",
- legendgroup=ylabel,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- go.Scatter(
- x=x,
- y=np.subtract(mean, std),
- mode="lines",
- line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)},
- fill="tonexty",
- fillcolor=f"rgba{BasePlot._fig.get_elem(ylabel)[3:-1]}, 0.2)",
- hovertemplate="%{y}lower bound",
- legendgroup=ylabel,
- showlegend=False,
- xaxis=xaxis,
- yaxis=yaxis,
- ),
- ]
- )
+ if plot_interval:
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=x,
+ y=mean + std,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)},
+ hovertemplate="%{y}upper bound",
+ legendgroup=ylabel,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ go.Scatter(
+ x=x,
+ y=mean - std,
+ mode="lines",
+ line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)},
+ fill="tonexty",
+ fillcolor=f"rgba{BasePlot._fig.get_elem(ylabel)[3:-1]}, 0.2)",
+ hovertemplate="%{y}lower bound",
+ legendgroup=ylabel,
+ showlegend=False,
+ xaxis=xaxis,
+ yaxis=yaxis,
+ ),
+ ]
+ )
fig.update_layout({"hovermode": "x unified"})
diff --git a/atom/utils/types.py b/atom/utils/types.py
index d3aaa7f31..43bc2ff6d 100644
--- a/atom/utils/types.py
+++ b/atom/utils/types.py
@@ -253,7 +253,20 @@ def predict(self, *args, **kwargs) -> Pandas: ...
Verbose: TypeAlias = Literal[0, 1, 2]
# Data cleaning parameters
-NumericalStrats: TypeAlias = Literal["drop", "mean", "median", "knn", "iterative", "most_frequent"]
+NumericalStrats: TypeAlias = Literal[
+ "drop",
+ "mean",
+ "median",
+ "knn",
+ "iterative",
+ "most_frequent",
+ "drift",
+ "linear",
+ "nearest",
+ "bfill",
+ "ffill",
+ "random",
+]
CategoricalStrats: TypeAlias = Literal["drop", "most_frequent"]
DiscretizerStrats: TypeAlias = Literal["uniform", "quantile", "kmeans", "custom"]
Bins: TypeAlias = IntLargerOne | Sequence[Scalar] | dict[str, IntLargerOne | Sequence[Scalar]]
diff --git a/atom/utils/utils.py b/atom/utils/utils.py
index ac2ed7689..e794b55e1 100644
--- a/atom/utils/utils.py
+++ b/atom/utils/utils.py
@@ -43,6 +43,7 @@
from pandas._typing import Axes, Dtype, DtypeArg
from pandas.api.types import is_numeric_dtype
from shap import Explainer, Explanation
+from sklearn.base import BaseEstimator
from sklearn.metrics import (
confusion_matrix, get_scorer, get_scorer_names, make_scorer,
matthews_corrcoef,
@@ -2524,6 +2525,8 @@ def prepare_df(out: TReturn, og: DataFrame) -> DataFrame:
name=flt(getattr(transformer, "target_names_in_", None)),
)
+ use_y = True
+
kwargs: dict[str, Any] = {}
inc = list(getattr(transformer, "_cols", getattr(Xt, "columns", [])))
if "X" in (params := sign(getattr(transformer, method))):
@@ -2534,13 +2537,15 @@ def prepare_df(out: TReturn, og: DataFrame) -> DataFrame:
if len(kwargs) == 0:
if yt is not None and hasattr(transformer, "_cols"):
kwargs["X"] = to_df(yt)[inc]
+ use_y = False
elif params["X"].default != Parameter.empty:
kwargs["X"] = params["X"].default # Fill X with default
else:
return Xt, yt # If X is needed, skip the transformer
if "y" in params:
- if yt is not None:
+ # We skip `y` when already added to `X`
+ if yt is not None and use_y:
kwargs["y"] = yt
elif "X" not in params:
return Xt, yt # If y is None and no X in transformer, skip the transformer
@@ -2763,6 +2768,29 @@ def wrapper(*args, **kwargs) -> Any:
return wrapper
+def wrap_fit(f: Callable) -> Callable:
+ """Wrap the fit method of estimators to add custom attributes.
+
+ Wrapper to add the `feature_names_in_` and `n_features_in_`
+ attributes to an arbitrary estimator during `fit`. Used for
+ classes that don't have these attributes (e.g., from cuml or
+ sktime).
+
+ """
+
+ @wraps(f)
+ def wrapped(self, X, *args, **kwargs):
+ out = f(self, X, *args, **kwargs)
+
+ # We add the attributes after running the function
+ # to avoid deleting them with .reset() calls
+ BaseEstimator._check_feature_names(self, X, reset=True)
+ BaseEstimator._check_n_features(self, X, reset=True)
+ return out
+
+ return wrapped
+
+
def wrap_transformer_methods(f: Callable) -> Callable:
"""Wrap transformer methods with shared code.
diff --git a/docs_sources/changelog/v6.x.x.md b/docs_sources/changelog/v6.x.x.md
index a0b32c8a7..2b475d0cf 100644
--- a/docs_sources/changelog/v6.x.x.md
+++ b/docs_sources/changelog/v6.x.x.md
@@ -32,6 +32,8 @@
**:rocket: Enhancements**
* Transformations only on `y` are now accepted, e.g., `atom.scale(columns=-1)`.
+* The [Imputer][] class has many more strategies for numerical columns designed
+ for time series.
* Full support for [pandas nullable dtypes](https://pandas.pydata.org/docs/user_guide/integer_na.html).
* The dataset can now be provided as callable.
* The [FeatureExtractor][] class can extract features from the dataset's index.
diff --git a/docs_sources/examples/accelerating_sklearnex.ipynb b/docs_sources/examples/accelerating_sklearnex.ipynb
index 57c70bf29..b1a5abf4c 100644
--- a/docs_sources/examples/accelerating_sklearnex.ipynb
+++ b/docs_sources/examples/accelerating_sklearnex.ipynb
@@ -300,26 +300,26 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Dropping 637 samples due to missing values in feature MinTemp.\n",
- " --> Dropping 322 samples due to missing values in feature MaxTemp.\n",
- " --> Dropping 1406 samples due to missing values in feature Rainfall.\n",
- " --> Dropping 60843 samples due to missing values in feature Evaporation.\n",
- " --> Dropping 67816 samples due to missing values in feature Sunshine.\n",
- " --> Dropping 9330 samples due to missing values in feature WindGustDir.\n",
- " --> Dropping 9270 samples due to missing values in feature WindGustSpeed.\n",
- " --> Dropping 10013 samples due to missing values in feature WindDir9am.\n",
- " --> Dropping 3778 samples due to missing values in feature WindDir3pm.\n",
- " --> Dropping 1348 samples due to missing values in feature WindSpeed9am.\n",
- " --> Dropping 2630 samples due to missing values in feature WindSpeed3pm.\n",
- " --> Dropping 1774 samples due to missing values in feature Humidity9am.\n",
- " --> Dropping 3610 samples due to missing values in feature Humidity3pm.\n",
- " --> Dropping 14014 samples due to missing values in feature Pressure9am.\n",
- " --> Dropping 13981 samples due to missing values in feature Pressure3pm.\n",
- " --> Dropping 53657 samples due to missing values in feature Cloud9am.\n",
- " --> Dropping 57094 samples due to missing values in feature Cloud3pm.\n",
- " --> Dropping 904 samples due to missing values in feature Temp9am.\n",
- " --> Dropping 2726 samples due to missing values in feature Temp3pm.\n",
- " --> Dropping 1406 samples due to missing values in feature RainToday.\n",
+ " --> Dropping 637 samples due to missing values in column MinTemp.\n",
+ " --> Dropping 322 samples due to missing values in column MaxTemp.\n",
+ " --> Dropping 1406 samples due to missing values in column Rainfall.\n",
+ " --> Dropping 60843 samples due to missing values in column Evaporation.\n",
+ " --> Dropping 67816 samples due to missing values in column Sunshine.\n",
+ " --> Dropping 9330 samples due to missing values in column WindGustDir.\n",
+ " --> Dropping 9270 samples due to missing values in column WindGustSpeed.\n",
+ " --> Dropping 10013 samples due to missing values in column WindDir9am.\n",
+ " --> Dropping 3778 samples due to missing values in column WindDir3pm.\n",
+ " --> Dropping 1348 samples due to missing values in column WindSpeed9am.\n",
+ " --> Dropping 2630 samples due to missing values in column WindSpeed3pm.\n",
+ " --> Dropping 1774 samples due to missing values in column Humidity9am.\n",
+ " --> Dropping 3610 samples due to missing values in column Humidity3pm.\n",
+ " --> Dropping 14014 samples due to missing values in column Pressure9am.\n",
+ " --> Dropping 13981 samples due to missing values in column Pressure3pm.\n",
+ " --> Dropping 53657 samples due to missing values in column Cloud9am.\n",
+ " --> Dropping 57094 samples due to missing values in column Cloud3pm.\n",
+ " --> Dropping 904 samples due to missing values in column Temp9am.\n",
+ " --> Dropping 2726 samples due to missing values in column Temp3pm.\n",
+ " --> Dropping 1406 samples due to missing values in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 26 classes.\n",
diff --git a/docs_sources/examples/binary_classification.ipynb b/docs_sources/examples/binary_classification.ipynb
index 61e92c70f..6b260e7cc 100644
--- a/docs_sources/examples/binary_classification.ipynb
+++ b/docs_sources/examples/binary_classification.ipynb
@@ -306,26 +306,26 @@
"Fitting Imputer...\n",
"Imputing missing values...\n",
" --> Dropping 7 samples for containing more than 16 missing values.\n",
- " --> Imputing 23 missing values with median (11.9) in feature MinTemp.\n",
- " --> Imputing 10 missing values with median (22.6) in feature MaxTemp.\n",
- " --> Imputing 72 missing values with median (0.0) in feature Rainfall.\n",
- " --> Imputing 3059 missing values with median (4.6) in feature Evaporation.\n",
- " --> Imputing 3382 missing values with median (8.5) in feature Sunshine.\n",
- " --> Dropping 467 samples due to missing values in feature WindGustDir.\n",
- " --> Imputing 466 missing values with median (39.0) in feature WindGustSpeed.\n",
- " --> Dropping 479 samples due to missing values in feature WindDir9am.\n",
- " --> Dropping 165 samples due to missing values in feature WindDir3pm.\n",
- " --> Imputing 53 missing values with median (13.0) in feature WindSpeed9am.\n",
- " --> Imputing 115 missing values with median (17.0) in feature WindSpeed3pm.\n",
- " --> Imputing 72 missing values with median (70.0) in feature Humidity9am.\n",
- " --> Imputing 164 missing values with median (52.0) in feature Humidity3pm.\n",
- " --> Imputing 699 missing values with median (1017.7) in feature Pressure9am.\n",
- " --> Imputing 699 missing values with median (1015.4) in feature Pressure3pm.\n",
- " --> Imputing 2698 missing values with median (5.0) in feature Cloud9am.\n",
- " --> Imputing 2903 missing values with median (5.0) in feature Cloud3pm.\n",
- " --> Imputing 32 missing values with median (16.7) in feature Temp9am.\n",
- " --> Imputing 116 missing values with median (21.1) in feature Temp3pm.\n",
- " --> Dropping 72 samples due to missing values in feature RainToday.\n"
+ " --> Imputing 23 missing values with median (11.9) in column MinTemp.\n",
+ " --> Imputing 10 missing values with median (22.6) in column MaxTemp.\n",
+ " --> Imputing 72 missing values with median (0.0) in column Rainfall.\n",
+ " --> Imputing 3059 missing values with median (4.6) in column Evaporation.\n",
+ " --> Imputing 3382 missing values with median (8.5) in column Sunshine.\n",
+ " --> Dropping 467 samples due to missing values in column WindGustDir.\n",
+ " --> Imputing 466 missing values with median (39.0) in column WindGustSpeed.\n",
+ " --> Dropping 479 samples due to missing values in column WindDir9am.\n",
+ " --> Dropping 165 samples due to missing values in column WindDir3pm.\n",
+ " --> Imputing 53 missing values with median (13.0) in column WindSpeed9am.\n",
+ " --> Imputing 115 missing values with median (17.0) in column WindSpeed3pm.\n",
+ " --> Imputing 72 missing values with median (70.0) in column Humidity9am.\n",
+ " --> Imputing 164 missing values with median (52.0) in column Humidity3pm.\n",
+ " --> Imputing 699 missing values with median (1017.7) in column Pressure9am.\n",
+ " --> Imputing 699 missing values with median (1015.4) in column Pressure3pm.\n",
+ " --> Imputing 2698 missing values with median (5.0) in column Cloud9am.\n",
+ " --> Imputing 2903 missing values with median (5.0) in column Cloud3pm.\n",
+ " --> Imputing 32 missing values with median (16.7) in column Temp9am.\n",
+ " --> Imputing 116 missing values with median (21.1) in column Temp3pm.\n",
+ " --> Dropping 72 samples due to missing values in column RainToday.\n"
]
}
],
diff --git a/docs_sources/examples/feature_engineering.ipynb b/docs_sources/examples/feature_engineering.ipynb
index ff309fd2d..5bd9e658d 100644
--- a/docs_sources/examples/feature_engineering.ipynb
+++ b/docs_sources/examples/feature_engineering.ipynb
@@ -397,7 +397,7 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Imputing 12 missing values using the KNN imputer in feature NATURAL_LOGARITHM(Temp3pm).\n"
+ " --> Imputing 12 missing values using the KNN imputer in column NATURAL_LOGARITHM(Temp3pm).\n"
]
}
],
diff --git a/docs_sources/examples/getting_started.ipynb b/docs_sources/examples/getting_started.ipynb
index 0ffe3fe8b..ee2ac20a5 100644
--- a/docs_sources/examples/getting_started.ipynb
+++ b/docs_sources/examples/getting_started.ipynb
@@ -71,26 +71,26 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Imputing 8 missing values with median (11.6) in feature MinTemp.\n",
- " --> Imputing 2 missing values with median (22.3) in feature MaxTemp.\n",
- " --> Imputing 12 missing values with median (0.0) in feature Rainfall.\n",
- " --> Imputing 425 missing values with median (4.8) in feature Evaporation.\n",
- " --> Imputing 480 missing values with median (8.55) in feature Sunshine.\n",
- " --> Imputing 59 missing values with most_frequent (N) in feature WindGustDir.\n",
- " --> Imputing 59 missing values with median (37.0) in feature WindGustSpeed.\n",
- " --> Imputing 90 missing values with most_frequent (N) in feature WindDir9am.\n",
- " --> Imputing 28 missing values with most_frequent (SW) in feature WindDir3pm.\n",
- " --> Imputing 10 missing values with median (13.0) in feature WindSpeed9am.\n",
- " --> Imputing 19 missing values with median (17.0) in feature WindSpeed3pm.\n",
- " --> Imputing 17 missing values with median (70.0) in feature Humidity9am.\n",
- " --> Imputing 31 missing values with median (51.0) in feature Humidity3pm.\n",
- " --> Imputing 89 missing values with median (1017.8) in feature Pressure9am.\n",
- " --> Imputing 87 missing values with median (1015.2) in feature Pressure3pm.\n",
- " --> Imputing 383 missing values with median (5.0) in feature Cloud9am.\n",
- " --> Imputing 412 missing values with median (5.0) in feature Cloud3pm.\n",
- " --> Imputing 11 missing values with median (16.5) in feature Temp9am.\n",
- " --> Imputing 26 missing values with median (20.7) in feature Temp3pm.\n",
- " --> Imputing 12 missing values with most_frequent (No) in feature RainToday.\n",
+ " --> Imputing 8 missing values with median (11.6) in column MinTemp.\n",
+ " --> Imputing 2 missing values with median (22.3) in column MaxTemp.\n",
+ " --> Imputing 12 missing values with median (0.0) in column Rainfall.\n",
+ " --> Imputing 425 missing values with median (4.8) in column Evaporation.\n",
+ " --> Imputing 480 missing values with median (8.55) in column Sunshine.\n",
+ " --> Imputing 59 missing values with most_frequent (N) in column WindGustDir.\n",
+ " --> Imputing 59 missing values with median (37.0) in column WindGustSpeed.\n",
+ " --> Imputing 90 missing values with most_frequent (N) in column WindDir9am.\n",
+ " --> Imputing 28 missing values with most_frequent (SW) in column WindDir3pm.\n",
+ " --> Imputing 10 missing values with median (13.0) in column WindSpeed9am.\n",
+ " --> Imputing 19 missing values with median (17.0) in column WindSpeed3pm.\n",
+ " --> Imputing 17 missing values with median (70.0) in column Humidity9am.\n",
+ " --> Imputing 31 missing values with median (51.0) in column Humidity3pm.\n",
+ " --> Imputing 89 missing values with median (1017.8) in column Pressure9am.\n",
+ " --> Imputing 87 missing values with median (1015.2) in column Pressure3pm.\n",
+ " --> Imputing 383 missing values with median (5.0) in column Cloud9am.\n",
+ " --> Imputing 412 missing values with median (5.0) in column Cloud3pm.\n",
+ " --> Imputing 11 missing values with median (16.5) in column Temp9am.\n",
+ " --> Imputing 26 missing values with median (20.7) in column Temp3pm.\n",
+ " --> Imputing 12 missing values with most_frequent (No) in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 49 classes.\n",
diff --git a/docs_sources/examples/holdout_set.ipynb b/docs_sources/examples/holdout_set.ipynb
index ffe21b751..68065a975 100644
--- a/docs_sources/examples/holdout_set.ipynb
+++ b/docs_sources/examples/holdout_set.ipynb
@@ -323,26 +323,26 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Dropping 258 samples due to missing values in feature MinTemp.\n",
- " --> Dropping 127 samples due to missing values in feature MaxTemp.\n",
- " --> Dropping 553 samples due to missing values in feature Rainfall.\n",
- " --> Dropping 24308 samples due to missing values in feature Evaporation.\n",
- " --> Dropping 27187 samples due to missing values in feature Sunshine.\n",
- " --> Dropping 3739 samples due to missing values in feature WindGustDir.\n",
- " --> Dropping 3712 samples due to missing values in feature WindGustSpeed.\n",
- " --> Dropping 3995 samples due to missing values in feature WindDir9am.\n",
- " --> Dropping 1508 samples due to missing values in feature WindDir3pm.\n",
- " --> Dropping 539 samples due to missing values in feature WindSpeed9am.\n",
- " --> Dropping 1077 samples due to missing values in feature WindSpeed3pm.\n",
- " --> Dropping 706 samples due to missing values in feature Humidity9am.\n",
- " --> Dropping 1447 samples due to missing values in feature Humidity3pm.\n",
- " --> Dropping 5610 samples due to missing values in feature Pressure9am.\n",
- " --> Dropping 5591 samples due to missing values in feature Pressure3pm.\n",
- " --> Dropping 21520 samples due to missing values in feature Cloud9am.\n",
- " --> Dropping 22921 samples due to missing values in feature Cloud3pm.\n",
- " --> Dropping 365 samples due to missing values in feature Temp9am.\n",
- " --> Dropping 1106 samples due to missing values in feature Temp3pm.\n",
- " --> Dropping 553 samples due to missing values in feature RainToday.\n",
+ " --> Dropping 258 samples due to missing values in column MinTemp.\n",
+ " --> Dropping 127 samples due to missing values in column MaxTemp.\n",
+ " --> Dropping 553 samples due to missing values in column Rainfall.\n",
+ " --> Dropping 24308 samples due to missing values in column Evaporation.\n",
+ " --> Dropping 27187 samples due to missing values in column Sunshine.\n",
+ " --> Dropping 3739 samples due to missing values in column WindGustDir.\n",
+ " --> Dropping 3712 samples due to missing values in column WindGustSpeed.\n",
+ " --> Dropping 3995 samples due to missing values in column WindDir9am.\n",
+ " --> Dropping 1508 samples due to missing values in column WindDir3pm.\n",
+ " --> Dropping 539 samples due to missing values in column WindSpeed9am.\n",
+ " --> Dropping 1077 samples due to missing values in column WindSpeed3pm.\n",
+ " --> Dropping 706 samples due to missing values in column Humidity9am.\n",
+ " --> Dropping 1447 samples due to missing values in column Humidity3pm.\n",
+ " --> Dropping 5610 samples due to missing values in column Pressure9am.\n",
+ " --> Dropping 5591 samples due to missing values in column Pressure3pm.\n",
+ " --> Dropping 21520 samples due to missing values in column Cloud9am.\n",
+ " --> Dropping 22921 samples due to missing values in column Cloud3pm.\n",
+ " --> Dropping 365 samples due to missing values in column Temp9am.\n",
+ " --> Dropping 1106 samples due to missing values in column Temp3pm.\n",
+ " --> Dropping 553 samples due to missing values in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 26 classes.\n",
diff --git a/docs_sources/examples/train_sizing.ipynb b/docs_sources/examples/train_sizing.ipynb
index ec75bff9d..1d106f059 100644
--- a/docs_sources/examples/train_sizing.ipynb
+++ b/docs_sources/examples/train_sizing.ipynb
@@ -287,26 +287,26 @@
"Fitting Imputer...\n",
"Imputing missing values...\n",
" --> Dropping 161 samples for containing more than 16 missing values.\n",
- " --> Imputing 481 missing values with median (12.0) in feature MinTemp.\n",
- " --> Imputing 265 missing values with median (22.6) in feature MaxTemp.\n",
- " --> Imputing 1354 missing values with median (0.0) in feature Rainfall.\n",
- " --> Imputing 60682 missing values with median (4.8) in feature Evaporation.\n",
- " --> Imputing 67659 missing values with median (8.4) in feature Sunshine.\n",
- " --> Imputing 9187 missing values with most_frequent (W) in feature WindGustDir.\n",
- " --> Imputing 9127 missing values with median (39.0) in feature WindGustSpeed.\n",
- " --> Imputing 9852 missing values with most_frequent (N) in feature WindDir9am.\n",
- " --> Imputing 3617 missing values with most_frequent (SE) in feature WindDir3pm.\n",
- " --> Imputing 1187 missing values with median (13.0) in feature WindSpeed9am.\n",
- " --> Imputing 2469 missing values with median (19.0) in feature WindSpeed3pm.\n",
- " --> Imputing 1613 missing values with median (70.0) in feature Humidity9am.\n",
- " --> Imputing 3449 missing values with median (52.0) in feature Humidity3pm.\n",
- " --> Imputing 13863 missing values with median (1017.6) in feature Pressure9am.\n",
- " --> Imputing 13830 missing values with median (1015.2) in feature Pressure3pm.\n",
- " --> Imputing 53496 missing values with median (5.0) in feature Cloud9am.\n",
- " --> Imputing 56933 missing values with median (5.0) in feature Cloud3pm.\n",
- " --> Imputing 743 missing values with median (16.7) in feature Temp9am.\n",
- " --> Imputing 2565 missing values with median (21.1) in feature Temp3pm.\n",
- " --> Imputing 1354 missing values with most_frequent (No) in feature RainToday.\n",
+ " --> Imputing 481 missing values with median (12.0) in column MinTemp.\n",
+ " --> Imputing 265 missing values with median (22.6) in column MaxTemp.\n",
+ " --> Imputing 1354 missing values with median (0.0) in column Rainfall.\n",
+ " --> Imputing 60682 missing values with median (4.8) in column Evaporation.\n",
+ " --> Imputing 67659 missing values with median (8.4) in column Sunshine.\n",
+ " --> Imputing 9187 missing values with most_frequent (W) in column WindGustDir.\n",
+ " --> Imputing 9127 missing values with median (39.0) in column WindGustSpeed.\n",
+ " --> Imputing 9852 missing values with most_frequent (N) in column WindDir9am.\n",
+ " --> Imputing 3617 missing values with most_frequent (SE) in column WindDir3pm.\n",
+ " --> Imputing 1187 missing values with median (13.0) in column WindSpeed9am.\n",
+ " --> Imputing 2469 missing values with median (19.0) in column WindSpeed3pm.\n",
+ " --> Imputing 1613 missing values with median (70.0) in column Humidity9am.\n",
+ " --> Imputing 3449 missing values with median (52.0) in column Humidity3pm.\n",
+ " --> Imputing 13863 missing values with median (1017.6) in column Pressure9am.\n",
+ " --> Imputing 13830 missing values with median (1015.2) in column Pressure3pm.\n",
+ " --> Imputing 53496 missing values with median (5.0) in column Cloud9am.\n",
+ " --> Imputing 56933 missing values with median (5.0) in column Cloud3pm.\n",
+ " --> Imputing 743 missing values with median (16.7) in column Temp9am.\n",
+ " --> Imputing 2565 missing values with median (21.1) in column Temp3pm.\n",
+ " --> Imputing 1354 missing values with most_frequent (No) in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 49 classes.\n",
diff --git a/examples/accelerating_sklearnex.ipynb b/examples/accelerating_sklearnex.ipynb
index 57c70bf29..b1a5abf4c 100644
--- a/examples/accelerating_sklearnex.ipynb
+++ b/examples/accelerating_sklearnex.ipynb
@@ -300,26 +300,26 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Dropping 637 samples due to missing values in feature MinTemp.\n",
- " --> Dropping 322 samples due to missing values in feature MaxTemp.\n",
- " --> Dropping 1406 samples due to missing values in feature Rainfall.\n",
- " --> Dropping 60843 samples due to missing values in feature Evaporation.\n",
- " --> Dropping 67816 samples due to missing values in feature Sunshine.\n",
- " --> Dropping 9330 samples due to missing values in feature WindGustDir.\n",
- " --> Dropping 9270 samples due to missing values in feature WindGustSpeed.\n",
- " --> Dropping 10013 samples due to missing values in feature WindDir9am.\n",
- " --> Dropping 3778 samples due to missing values in feature WindDir3pm.\n",
- " --> Dropping 1348 samples due to missing values in feature WindSpeed9am.\n",
- " --> Dropping 2630 samples due to missing values in feature WindSpeed3pm.\n",
- " --> Dropping 1774 samples due to missing values in feature Humidity9am.\n",
- " --> Dropping 3610 samples due to missing values in feature Humidity3pm.\n",
- " --> Dropping 14014 samples due to missing values in feature Pressure9am.\n",
- " --> Dropping 13981 samples due to missing values in feature Pressure3pm.\n",
- " --> Dropping 53657 samples due to missing values in feature Cloud9am.\n",
- " --> Dropping 57094 samples due to missing values in feature Cloud3pm.\n",
- " --> Dropping 904 samples due to missing values in feature Temp9am.\n",
- " --> Dropping 2726 samples due to missing values in feature Temp3pm.\n",
- " --> Dropping 1406 samples due to missing values in feature RainToday.\n",
+ " --> Dropping 637 samples due to missing values in column MinTemp.\n",
+ " --> Dropping 322 samples due to missing values in column MaxTemp.\n",
+ " --> Dropping 1406 samples due to missing values in column Rainfall.\n",
+ " --> Dropping 60843 samples due to missing values in column Evaporation.\n",
+ " --> Dropping 67816 samples due to missing values in column Sunshine.\n",
+ " --> Dropping 9330 samples due to missing values in column WindGustDir.\n",
+ " --> Dropping 9270 samples due to missing values in column WindGustSpeed.\n",
+ " --> Dropping 10013 samples due to missing values in column WindDir9am.\n",
+ " --> Dropping 3778 samples due to missing values in column WindDir3pm.\n",
+ " --> Dropping 1348 samples due to missing values in column WindSpeed9am.\n",
+ " --> Dropping 2630 samples due to missing values in column WindSpeed3pm.\n",
+ " --> Dropping 1774 samples due to missing values in column Humidity9am.\n",
+ " --> Dropping 3610 samples due to missing values in column Humidity3pm.\n",
+ " --> Dropping 14014 samples due to missing values in column Pressure9am.\n",
+ " --> Dropping 13981 samples due to missing values in column Pressure3pm.\n",
+ " --> Dropping 53657 samples due to missing values in column Cloud9am.\n",
+ " --> Dropping 57094 samples due to missing values in column Cloud3pm.\n",
+ " --> Dropping 904 samples due to missing values in column Temp9am.\n",
+ " --> Dropping 2726 samples due to missing values in column Temp3pm.\n",
+ " --> Dropping 1406 samples due to missing values in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 26 classes.\n",
diff --git a/examples/binary_classification.ipynb b/examples/binary_classification.ipynb
index 61e92c70f..6b260e7cc 100644
--- a/examples/binary_classification.ipynb
+++ b/examples/binary_classification.ipynb
@@ -306,26 +306,26 @@
"Fitting Imputer...\n",
"Imputing missing values...\n",
" --> Dropping 7 samples for containing more than 16 missing values.\n",
- " --> Imputing 23 missing values with median (11.9) in feature MinTemp.\n",
- " --> Imputing 10 missing values with median (22.6) in feature MaxTemp.\n",
- " --> Imputing 72 missing values with median (0.0) in feature Rainfall.\n",
- " --> Imputing 3059 missing values with median (4.6) in feature Evaporation.\n",
- " --> Imputing 3382 missing values with median (8.5) in feature Sunshine.\n",
- " --> Dropping 467 samples due to missing values in feature WindGustDir.\n",
- " --> Imputing 466 missing values with median (39.0) in feature WindGustSpeed.\n",
- " --> Dropping 479 samples due to missing values in feature WindDir9am.\n",
- " --> Dropping 165 samples due to missing values in feature WindDir3pm.\n",
- " --> Imputing 53 missing values with median (13.0) in feature WindSpeed9am.\n",
- " --> Imputing 115 missing values with median (17.0) in feature WindSpeed3pm.\n",
- " --> Imputing 72 missing values with median (70.0) in feature Humidity9am.\n",
- " --> Imputing 164 missing values with median (52.0) in feature Humidity3pm.\n",
- " --> Imputing 699 missing values with median (1017.7) in feature Pressure9am.\n",
- " --> Imputing 699 missing values with median (1015.4) in feature Pressure3pm.\n",
- " --> Imputing 2698 missing values with median (5.0) in feature Cloud9am.\n",
- " --> Imputing 2903 missing values with median (5.0) in feature Cloud3pm.\n",
- " --> Imputing 32 missing values with median (16.7) in feature Temp9am.\n",
- " --> Imputing 116 missing values with median (21.1) in feature Temp3pm.\n",
- " --> Dropping 72 samples due to missing values in feature RainToday.\n"
+ " --> Imputing 23 missing values with median (11.9) in column MinTemp.\n",
+ " --> Imputing 10 missing values with median (22.6) in column MaxTemp.\n",
+ " --> Imputing 72 missing values with median (0.0) in column Rainfall.\n",
+ " --> Imputing 3059 missing values with median (4.6) in column Evaporation.\n",
+ " --> Imputing 3382 missing values with median (8.5) in column Sunshine.\n",
+ " --> Dropping 467 samples due to missing values in column WindGustDir.\n",
+ " --> Imputing 466 missing values with median (39.0) in column WindGustSpeed.\n",
+ " --> Dropping 479 samples due to missing values in column WindDir9am.\n",
+ " --> Dropping 165 samples due to missing values in column WindDir3pm.\n",
+ " --> Imputing 53 missing values with median (13.0) in column WindSpeed9am.\n",
+ " --> Imputing 115 missing values with median (17.0) in column WindSpeed3pm.\n",
+ " --> Imputing 72 missing values with median (70.0) in column Humidity9am.\n",
+ " --> Imputing 164 missing values with median (52.0) in column Humidity3pm.\n",
+ " --> Imputing 699 missing values with median (1017.7) in column Pressure9am.\n",
+ " --> Imputing 699 missing values with median (1015.4) in column Pressure3pm.\n",
+ " --> Imputing 2698 missing values with median (5.0) in column Cloud9am.\n",
+ " --> Imputing 2903 missing values with median (5.0) in column Cloud3pm.\n",
+ " --> Imputing 32 missing values with median (16.7) in column Temp9am.\n",
+ " --> Imputing 116 missing values with median (21.1) in column Temp3pm.\n",
+ " --> Dropping 72 samples due to missing values in column RainToday.\n"
]
}
],
diff --git a/examples/feature_engineering.ipynb b/examples/feature_engineering.ipynb
index ff309fd2d..5bd9e658d 100644
--- a/examples/feature_engineering.ipynb
+++ b/examples/feature_engineering.ipynb
@@ -397,7 +397,7 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Imputing 12 missing values using the KNN imputer in feature NATURAL_LOGARITHM(Temp3pm).\n"
+ " --> Imputing 12 missing values using the KNN imputer in column NATURAL_LOGARITHM(Temp3pm).\n"
]
}
],
diff --git a/examples/getting_started.ipynb b/examples/getting_started.ipynb
index 0ffe3fe8b..ee2ac20a5 100644
--- a/examples/getting_started.ipynb
+++ b/examples/getting_started.ipynb
@@ -71,26 +71,26 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Imputing 8 missing values with median (11.6) in feature MinTemp.\n",
- " --> Imputing 2 missing values with median (22.3) in feature MaxTemp.\n",
- " --> Imputing 12 missing values with median (0.0) in feature Rainfall.\n",
- " --> Imputing 425 missing values with median (4.8) in feature Evaporation.\n",
- " --> Imputing 480 missing values with median (8.55) in feature Sunshine.\n",
- " --> Imputing 59 missing values with most_frequent (N) in feature WindGustDir.\n",
- " --> Imputing 59 missing values with median (37.0) in feature WindGustSpeed.\n",
- " --> Imputing 90 missing values with most_frequent (N) in feature WindDir9am.\n",
- " --> Imputing 28 missing values with most_frequent (SW) in feature WindDir3pm.\n",
- " --> Imputing 10 missing values with median (13.0) in feature WindSpeed9am.\n",
- " --> Imputing 19 missing values with median (17.0) in feature WindSpeed3pm.\n",
- " --> Imputing 17 missing values with median (70.0) in feature Humidity9am.\n",
- " --> Imputing 31 missing values with median (51.0) in feature Humidity3pm.\n",
- " --> Imputing 89 missing values with median (1017.8) in feature Pressure9am.\n",
- " --> Imputing 87 missing values with median (1015.2) in feature Pressure3pm.\n",
- " --> Imputing 383 missing values with median (5.0) in feature Cloud9am.\n",
- " --> Imputing 412 missing values with median (5.0) in feature Cloud3pm.\n",
- " --> Imputing 11 missing values with median (16.5) in feature Temp9am.\n",
- " --> Imputing 26 missing values with median (20.7) in feature Temp3pm.\n",
- " --> Imputing 12 missing values with most_frequent (No) in feature RainToday.\n",
+ " --> Imputing 8 missing values with median (11.6) in column MinTemp.\n",
+ " --> Imputing 2 missing values with median (22.3) in column MaxTemp.\n",
+ " --> Imputing 12 missing values with median (0.0) in column Rainfall.\n",
+ " --> Imputing 425 missing values with median (4.8) in column Evaporation.\n",
+ " --> Imputing 480 missing values with median (8.55) in column Sunshine.\n",
+ " --> Imputing 59 missing values with most_frequent (N) in column WindGustDir.\n",
+ " --> Imputing 59 missing values with median (37.0) in column WindGustSpeed.\n",
+ " --> Imputing 90 missing values with most_frequent (N) in column WindDir9am.\n",
+ " --> Imputing 28 missing values with most_frequent (SW) in column WindDir3pm.\n",
+ " --> Imputing 10 missing values with median (13.0) in column WindSpeed9am.\n",
+ " --> Imputing 19 missing values with median (17.0) in column WindSpeed3pm.\n",
+ " --> Imputing 17 missing values with median (70.0) in column Humidity9am.\n",
+ " --> Imputing 31 missing values with median (51.0) in column Humidity3pm.\n",
+ " --> Imputing 89 missing values with median (1017.8) in column Pressure9am.\n",
+ " --> Imputing 87 missing values with median (1015.2) in column Pressure3pm.\n",
+ " --> Imputing 383 missing values with median (5.0) in column Cloud9am.\n",
+ " --> Imputing 412 missing values with median (5.0) in column Cloud3pm.\n",
+ " --> Imputing 11 missing values with median (16.5) in column Temp9am.\n",
+ " --> Imputing 26 missing values with median (20.7) in column Temp3pm.\n",
+ " --> Imputing 12 missing values with most_frequent (No) in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 49 classes.\n",
diff --git a/examples/holdout_set.ipynb b/examples/holdout_set.ipynb
index ffe21b751..68065a975 100644
--- a/examples/holdout_set.ipynb
+++ b/examples/holdout_set.ipynb
@@ -323,26 +323,26 @@
"text": [
"Fitting Imputer...\n",
"Imputing missing values...\n",
- " --> Dropping 258 samples due to missing values in feature MinTemp.\n",
- " --> Dropping 127 samples due to missing values in feature MaxTemp.\n",
- " --> Dropping 553 samples due to missing values in feature Rainfall.\n",
- " --> Dropping 24308 samples due to missing values in feature Evaporation.\n",
- " --> Dropping 27187 samples due to missing values in feature Sunshine.\n",
- " --> Dropping 3739 samples due to missing values in feature WindGustDir.\n",
- " --> Dropping 3712 samples due to missing values in feature WindGustSpeed.\n",
- " --> Dropping 3995 samples due to missing values in feature WindDir9am.\n",
- " --> Dropping 1508 samples due to missing values in feature WindDir3pm.\n",
- " --> Dropping 539 samples due to missing values in feature WindSpeed9am.\n",
- " --> Dropping 1077 samples due to missing values in feature WindSpeed3pm.\n",
- " --> Dropping 706 samples due to missing values in feature Humidity9am.\n",
- " --> Dropping 1447 samples due to missing values in feature Humidity3pm.\n",
- " --> Dropping 5610 samples due to missing values in feature Pressure9am.\n",
- " --> Dropping 5591 samples due to missing values in feature Pressure3pm.\n",
- " --> Dropping 21520 samples due to missing values in feature Cloud9am.\n",
- " --> Dropping 22921 samples due to missing values in feature Cloud3pm.\n",
- " --> Dropping 365 samples due to missing values in feature Temp9am.\n",
- " --> Dropping 1106 samples due to missing values in feature Temp3pm.\n",
- " --> Dropping 553 samples due to missing values in feature RainToday.\n",
+ " --> Dropping 258 samples due to missing values in column MinTemp.\n",
+ " --> Dropping 127 samples due to missing values in column MaxTemp.\n",
+ " --> Dropping 553 samples due to missing values in column Rainfall.\n",
+ " --> Dropping 24308 samples due to missing values in column Evaporation.\n",
+ " --> Dropping 27187 samples due to missing values in column Sunshine.\n",
+ " --> Dropping 3739 samples due to missing values in column WindGustDir.\n",
+ " --> Dropping 3712 samples due to missing values in column WindGustSpeed.\n",
+ " --> Dropping 3995 samples due to missing values in column WindDir9am.\n",
+ " --> Dropping 1508 samples due to missing values in column WindDir3pm.\n",
+ " --> Dropping 539 samples due to missing values in column WindSpeed9am.\n",
+ " --> Dropping 1077 samples due to missing values in column WindSpeed3pm.\n",
+ " --> Dropping 706 samples due to missing values in column Humidity9am.\n",
+ " --> Dropping 1447 samples due to missing values in column Humidity3pm.\n",
+ " --> Dropping 5610 samples due to missing values in column Pressure9am.\n",
+ " --> Dropping 5591 samples due to missing values in column Pressure3pm.\n",
+ " --> Dropping 21520 samples due to missing values in column Cloud9am.\n",
+ " --> Dropping 22921 samples due to missing values in column Cloud3pm.\n",
+ " --> Dropping 365 samples due to missing values in column Temp9am.\n",
+ " --> Dropping 1106 samples due to missing values in column Temp3pm.\n",
+ " --> Dropping 553 samples due to missing values in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 26 classes.\n",
diff --git a/examples/train_sizing.ipynb b/examples/train_sizing.ipynb
index ec75bff9d..1d106f059 100644
--- a/examples/train_sizing.ipynb
+++ b/examples/train_sizing.ipynb
@@ -287,26 +287,26 @@
"Fitting Imputer...\n",
"Imputing missing values...\n",
" --> Dropping 161 samples for containing more than 16 missing values.\n",
- " --> Imputing 481 missing values with median (12.0) in feature MinTemp.\n",
- " --> Imputing 265 missing values with median (22.6) in feature MaxTemp.\n",
- " --> Imputing 1354 missing values with median (0.0) in feature Rainfall.\n",
- " --> Imputing 60682 missing values with median (4.8) in feature Evaporation.\n",
- " --> Imputing 67659 missing values with median (8.4) in feature Sunshine.\n",
- " --> Imputing 9187 missing values with most_frequent (W) in feature WindGustDir.\n",
- " --> Imputing 9127 missing values with median (39.0) in feature WindGustSpeed.\n",
- " --> Imputing 9852 missing values with most_frequent (N) in feature WindDir9am.\n",
- " --> Imputing 3617 missing values with most_frequent (SE) in feature WindDir3pm.\n",
- " --> Imputing 1187 missing values with median (13.0) in feature WindSpeed9am.\n",
- " --> Imputing 2469 missing values with median (19.0) in feature WindSpeed3pm.\n",
- " --> Imputing 1613 missing values with median (70.0) in feature Humidity9am.\n",
- " --> Imputing 3449 missing values with median (52.0) in feature Humidity3pm.\n",
- " --> Imputing 13863 missing values with median (1017.6) in feature Pressure9am.\n",
- " --> Imputing 13830 missing values with median (1015.2) in feature Pressure3pm.\n",
- " --> Imputing 53496 missing values with median (5.0) in feature Cloud9am.\n",
- " --> Imputing 56933 missing values with median (5.0) in feature Cloud3pm.\n",
- " --> Imputing 743 missing values with median (16.7) in feature Temp9am.\n",
- " --> Imputing 2565 missing values with median (21.1) in feature Temp3pm.\n",
- " --> Imputing 1354 missing values with most_frequent (No) in feature RainToday.\n",
+ " --> Imputing 481 missing values with median (12.0) in column MinTemp.\n",
+ " --> Imputing 265 missing values with median (22.6) in column MaxTemp.\n",
+ " --> Imputing 1354 missing values with median (0.0) in column Rainfall.\n",
+ " --> Imputing 60682 missing values with median (4.8) in column Evaporation.\n",
+ " --> Imputing 67659 missing values with median (8.4) in column Sunshine.\n",
+ " --> Imputing 9187 missing values with most_frequent (W) in column WindGustDir.\n",
+ " --> Imputing 9127 missing values with median (39.0) in column WindGustSpeed.\n",
+ " --> Imputing 9852 missing values with most_frequent (N) in column WindDir9am.\n",
+ " --> Imputing 3617 missing values with most_frequent (SE) in column WindDir3pm.\n",
+ " --> Imputing 1187 missing values with median (13.0) in column WindSpeed9am.\n",
+ " --> Imputing 2469 missing values with median (19.0) in column WindSpeed3pm.\n",
+ " --> Imputing 1613 missing values with median (70.0) in column Humidity9am.\n",
+ " --> Imputing 3449 missing values with median (52.0) in column Humidity3pm.\n",
+ " --> Imputing 13863 missing values with median (1017.6) in column Pressure9am.\n",
+ " --> Imputing 13830 missing values with median (1015.2) in column Pressure3pm.\n",
+ " --> Imputing 53496 missing values with median (5.0) in column Cloud9am.\n",
+ " --> Imputing 56933 missing values with median (5.0) in column Cloud3pm.\n",
+ " --> Imputing 743 missing values with median (16.7) in column Temp9am.\n",
+ " --> Imputing 2565 missing values with median (21.1) in column Temp3pm.\n",
+ " --> Imputing 1354 missing values with most_frequent (No) in column RainToday.\n",
"Fitting Encoder...\n",
"Encoding categorical columns...\n",
" --> Target-encoding feature Location. Contains 49 classes.\n",
diff --git a/tests/test_basetransformer.py b/tests/test_basetransformer.py
index 7bc3a179c..9ddea11af 100644
--- a/tests/test_basetransformer.py
+++ b/tests/test_basetransformer.py
@@ -21,6 +21,7 @@
from sklearn.naive_bayes import GaussianNB
from sklearnex import get_config
from sklearnex.svm import SVC
+from sktime.transformations.series.impute import Imputer
from atom import ATOMClassifier, ATOMForecaster
from atom.basetransformer import BaseTransformer
@@ -214,6 +215,20 @@ def test_inherit_sp():
assert atom.tbats.estimator.get_params()["sp"] == [12, 24] # Multiple seasonality
+def test_inherit_attributes_and_methods():
+ """Assert that sklearn attributes and methods are added to the estimator."""
+ imputer = Imputer().fit(y_fc)
+ assert not hasattr(imputer, "feature_names_in_")
+ assert not hasattr(imputer, "n_features_in_")
+ assert not hasattr(imputer, "get_feature_names_out")
+
+ imputer = BaseTransformer(random_state=1)._inherit(imputer)
+ imputer.fit(pd.DataFrame(y_fc))
+ assert hasattr(imputer, "feature_names_in_")
+ assert hasattr(imputer, "n_features_in_")
+ assert hasattr(imputer, "get_feature_names_out")
+
+
# Test _get_est_class ============================================== >>
@pytest.mark.skipif(machine() not in ("x86_64", "AMD64"), reason="Only x86 support")
diff --git a/tests/test_data_cleaning.py b/tests/test_data_cleaning.py
index c0252b45d..61e4240dd 100644
--- a/tests/test_data_cleaning.py
+++ b/tests/test_data_cleaning.py
@@ -19,6 +19,7 @@
Balancer, Cleaner, Decomposer, Discretizer, Encoder, Imputer, Normalizer,
Pruner, Scaler,
)
+from atom.utils.types import NumericalStrats
from atom.utils.utils import NotFittedError, check_scaling, to_df
from .conftest import (
@@ -29,6 +30,15 @@
# Test TransformerMixin ============================================ >>
+def test_repr():
+ """Assert that __repr__ hides the default engine."""
+ assert str(Cleaner(engine="pyarrow")).startswith("Cleaner(engine=EngineTuple")
+ assert str(Cleaner()) == "Cleaner()"
+ assert str(Cleaner(device="gpu")) == "Cleaner(device='gpu')"
+ assert str(Cleaner(verbose=2)) == "Cleaner(verbose=2)"
+ assert str(Cleaner(device="gpu", verbose=2)) == "Cleaner(device='gpu', verbose=2)"
+
+
def test_clone():
"""Assert that cloning the transformer keeps internal attributes."""
pruner = Pruner().fit(X_bin)
@@ -474,7 +484,7 @@ def test_imputing_all_missing_values_categorical(missing):
X = [[missing, "a", "a"], ["b", "c", missing], ["b", "a", "c"], ["c", "a", "a"]]
y = [1, 1, 0, 0]
imputer = Imputer(strat_cat="most_frequent")
- X, y = imputer.fit_transform(X, y)
+ X, _ = imputer.fit_transform(X, y)
assert X.isna().sum().sum() == 0
@@ -491,7 +501,7 @@ def test_rows_too_many_nans(max_nan_rows, random):
max_nan_rows=max_nan_rows,
)
X, y = imputer.fit_transform(X, y)
- assert len(X) == 569 # Original size
+ assert len(X) == len(y) == 569 # Original size
assert X.isna().sum().sum() == 0
@@ -515,7 +525,7 @@ def test_cols_too_many_nans(max_nan_cols):
def test_imputing_numeric_drop():
"""Assert that imputing drop for numerical values works."""
imputer = Imputer(strat_num="drop")
- X, y = imputer.fit_transform(X10_nan, y10)
+ X, _ = imputer.fit_transform(X10_nan, y10)
assert len(X) == 8
assert X.isna().sum().sum() == 0
@@ -523,55 +533,23 @@ def test_imputing_numeric_drop():
def test_imputing_numeric_number():
"""Assert that imputing a number for numerical values works."""
imputer = Imputer(strat_num=3.2)
- X, y = imputer.fit_transform(X10_nan, y10)
+ X, _ = imputer.fit_transform(X10_nan, y10)
assert X.iloc[0, 0] == 3.2
assert X.isna().sum().sum() == 0
-def test_imputing_numeric_mean():
- """Assert that imputing the mean for numerical values works."""
- imputer = Imputer(strat_num="mean")
- X, y = imputer.fit_transform(X10_nan, y10)
- assert X.iloc[0, 0] == pytest.approx(2.577778, rel=1e-6, abs=1e-12)
- assert X.isna().sum().sum() == 0
-
-
-def test_imputing_numeric_median():
- """Assert that imputing the median for numerical values works."""
- imputer = Imputer(strat_num="median")
- X, y = imputer.fit_transform(X10_nan, y10)
- assert X.iloc[0, 0] == 3
- assert X.isna().sum().sum() == 0
-
-
-def test_imputing_numeric_knn():
- """Assert that imputing numerical values with KNNImputer works."""
- imputer = Imputer(strat_num="knn", random_state=1)
- X, y = imputer.fit_transform(X10_nan, y10)
- assert X.iloc[0, 0] == 3.04
- assert X.isna().sum().sum() == 0
-
-
-def test_imputing_numeric_iterative():
- """Assert that imputing numerical values with IterativeImputer works."""
- imputer = Imputer(strat_num="iterative")
- X, y = imputer.fit_transform(X10_nan, y10)
- assert X.iloc[0, 0] == pytest.approx(2.577836, rel=1e-6, abs=1e-12)
- assert X.isna().sum().sum() == 0
-
-
-def test_imputing_numeric_most_frequent():
- """Assert that imputing the most_frequent for numerical values works."""
- imputer = Imputer(strat_num="most_frequent")
- X, y = imputer.fit_transform(X10_nan, y10)
- assert X.iloc[0, 0] == 3
+@pytest.mark.parametrize("strat_num", NumericalStrats.__args__)
+def test_imputing_numeric(strat_num):
+ """Assert that imputing numerical columns works."""
+ imputer = Imputer(strat_num=strat_num)
+ X, _ = imputer.fit_transform(X10_nan, y10)
assert X.isna().sum().sum() == 0
def test_imputing_non_numeric_string():
"""Assert that imputing a string for non-numerical values works."""
imputer = Imputer(strat_cat="missing")
- X, y = imputer.fit_transform(X10_sn, y10)
+ X, _ = imputer.fit_transform(X10_sn, y10)
assert X.iloc[0, 2] == "missing"
assert X.isna().sum().sum() == 0
@@ -579,7 +557,7 @@ def test_imputing_non_numeric_string():
def test_imputing_non_numeric_drop():
"""Assert that the drop strategy for non-numerical works."""
imputer = Imputer(strat_cat="drop")
- X, y = imputer.fit_transform(X10_sn, y10)
+ X, _ = imputer.fit_transform(X10_sn, y10)
assert len(X) == 9
assert X.isna().sum().sum() == 0
@@ -587,7 +565,7 @@ def test_imputing_non_numeric_drop():
def test_imputing_non_numeric_most_frequent():
"""Assert that the most_frequent strategy for non-numerical works."""
imputer = Imputer(strat_cat="most_frequent")
- X, y = imputer.fit_transform(X10_sn, y10)
+ X, _ = imputer.fit_transform(X10_sn, y10)
assert X.iloc[0, 2] == "d"
assert X.isna().sum().sum() == 0
diff --git a/tests/test_plots.py b/tests/test_plots.py
index c4191b647..57e3cc066 100644
--- a/tests/test_plots.py
+++ b/tests/test_plots.py
@@ -23,7 +23,7 @@
from .conftest import (
X10, X10_str, X_bin, X_class, X_ex, X_label, X_reg, X_sparse, X_text, y10,
- y_bin, y_class, y_fc, y_label, y_multiclass, y_reg,
+ y_bin, y_class, y_ex, y_fc, y_label, y_multiclass, y_reg,
)
@@ -282,6 +282,16 @@ def test_plot_acf(columns):
atom.plot_acf(columns=columns, display=False)
+def test_plot_ccf():
+ """Assert that the plot_ccf method works."""
+ atom = ATOMForecaster(y_fc, random_state=1)
+ with pytest.raises(ValueError, match=".*requires at least two columns.*"):
+ atom.plot_ccf(display=False)
+
+ atom = ATOMForecaster(X_ex, y=y_ex, random_state=1)
+ atom.plot_ccf(plot_interval=True, display=False)
+
+
@pytest.mark.parametrize("show", [10, None])
def test_plot_components(show):
"""Assert that the plot_components method works."""
@@ -316,6 +326,13 @@ def test_plot_distribution():
atom.plot_distribution(columns=[0, 1], distributions="pearson3", display=False)
+@pytest.mark.parametrize("columns", [None, -1])
+def test_plot_fft(columns):
+ """Assert that the plot_fft method works."""
+ atom = ATOMForecaster(y_fc, random_state=1)
+ atom.plot_fft(columns=columns, display=False)
+
+
@pytest.mark.parametrize("ngram", [1, 2, 3, 4])
def test_plot_ngrams(ngram):
"""Assert that the plot_ngrams method works."""
@@ -345,6 +362,13 @@ def test_plot_pca(X):
atom.plot_pca(display=False)
+@pytest.mark.parametrize("columns", [None, -1])
+def test_plot_periodogram(columns):
+ """Assert that the plot_periodogram method works."""
+ atom = ATOMForecaster(y_fc, random_state=1)
+ atom.plot_periodogram(columns=columns, display=False)
+
+
def test_plot_qq():
"""Assert that the plot_qq method works."""
atom = ATOMClassifier(X_bin, y_bin, random_state=1)