diff --git a/atom/atom.py b/atom/atom.py index 619202bfd..81d957a2e 100644 --- a/atom/atom.py +++ b/atom/atom.py @@ -1699,14 +1699,15 @@ def impute( Impute or remove missing values according to the selected strategy. Also removes rows and columns with too many missing - values. Use the `missing` attribute to customize what are - considered "missing values". + values. See the [Imputer][] class for a description of the parameters. !!! tip - Use the [nans][self-nans] attribute to check the amount of - missing values per column. + - Use the [nans][self-nans] attribute to check the amount of + missing values per column. + - Use the [`missing`][self-missing] attribute to customize + what are considered "missing values". """ columns = kwargs.pop("columns", None) diff --git a/atom/basemodel.py b/atom/basemodel.py index 71dd22a63..0e246a025 100644 --- a/atom/basemodel.py +++ b/atom/basemodel.py @@ -309,7 +309,7 @@ def fullname(self) -> str: """Return the model's class name.""" return self.__class__.__name__ - @property + @cached_property def _est_class(self) -> type[Predictor]: """Return the estimator's class (not instance). @@ -698,8 +698,8 @@ def _get_pred( method=method_caller, ) - except ValueError as ex: - # Fails for models that don't allow in-sample predictions + except (ValueError, NotImplementedError) as ex: + # Can fail for models that don't allow in-sample predictions self._log( f"Failed to get predictions for model {self.name} " f"on rows {rows}. Returning NaN. Exception: {ex}.", 3 diff --git a/atom/basetransformer.py b/atom/basetransformer.py index bc63f9c64..befc53a3b 100644 --- a/atom/basetransformer.py +++ b/atom/basetransformer.py @@ -32,6 +32,7 @@ from joblib.memory import Memory from pandas._typing import Axes from ray.util.joblib import register_ray +from sklearn.base import OneToOneFeatureMixin from sklearn.utils.validation import check_memory from atom.utils.types import ( @@ -40,7 +41,9 @@ Pandas, Sequence, Severity, Verbose, Warnings, XSelector, YSelector, bool_t, dataframe_t, int_t, sequence_t, ) -from atom.utils.utils import crash, flt, lst, n_cols, to_df, to_pandas +from atom.utils.utils import ( + crash, flt, lst, n_cols, to_df, to_pandas, wrap_fit, +) T_Estimator = TypeVar("T_Estimator", bound=Estimator) @@ -373,6 +376,11 @@ def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator to that of this instance. If `obj` is a meta-estimator, it also adjusts the parameters of the base estimator. + Additionally, the `fit` method of non-sklearn objects is wrapped + to always add the `n_features_in_` and `feature_names_in_` + attributes, and the `get-feature_names_out` method is added to + transformers that don't have it already. + Parameters ---------- obj: Estimator @@ -398,6 +406,12 @@ def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator else: obj.set_params(**{p: lst(self._config.sp.sp)[0]}) + if hasattr(obj, "fit") and "sklearn" not in obj.__module__: + obj.__class__.fit = wrap_fit(obj.__class__.fit) # type: ignore[method-assign] + if hasattr(obj, "transform") and not hasattr(obj, "get_feature_names_out"): + # We assume here that the transformer does not create nor remove columns + obj.__class__.get_feature_names_out = OneToOneFeatureMixin.get_feature_names_out + return obj def _get_est_class(self, name: str, module: str) -> type[Estimator]: diff --git a/atom/data_cleaning.py b/atom/data_cleaning.py index be9d92eb2..a95f46055 100644 --- a/atom/data_cleaning.py +++ b/atom/data_cleaning.py @@ -46,6 +46,7 @@ from sktime.transformations.series.detrend import ( ConditionalDeseasonalizer, Deseasonalizer, Detrender, ) +from sktime.transformations.series.impute import Imputer as sktimeImputer from typing_extensions import Self from atom.basetransformer import BaseTransformer @@ -53,11 +54,11 @@ from atom.utils.patches import wrap_method_output from atom.utils.types import ( Bins, Bool, CategoricalStrats, DataFrame, DiscretizerStrats, Engine, - Estimator, FloatLargerZero, IntLargerEqualZero, IntLargerTwo, - IntLargerZero, NJobs, NormalizerStrats, NumericalStrats, Pandas, Predictor, - PrunerStrats, Scalar, ScalerStrats, SeasonalityModels, Sequence, Series, - Transformer, Verbose, XSelector, YSelector, dataframe_t, sequence_t, - series_t, + EngineTuple, Estimator, FloatLargerZero, Int, IntLargerEqualZero, + IntLargerTwo, IntLargerZero, NJobs, NormalizerStrats, NumericalStrats, + Pandas, Predictor, PrunerStrats, Scalar, ScalerStrats, SeasonalityModels, + Sequence, Series, Transformer, Verbose, XSelector, YSelector, dataframe_t, + sequence_t, series_t, ) from atom.utils.utils import ( Goal, bk, check_is_fitted, composed, crash, get_col_order, get_cols, it, @@ -92,6 +93,17 @@ def __init_subclass__(cls, **kwargs): with patch("sklearn.utils._set_output._wrap_method_output", wrap_method_output): super().__init_subclass__(**kwargs) + def __repr__(self, N_CHAR_MAX: Int = 700) -> str: + """Drop named tuples if default parameters from string representation.""" + out = super().__repr__(N_CHAR_MAX) + + # Remove default engine for cleaner representation + if hasattr(self, "engine") and self.engine == EngineTuple(): + out = re.sub(r"engine=EngineTuple\(data='numpy', estimator='sklearn'\)", "", out) + out = re.sub(r"((?<=\(),\s|,\s(?=\))|,\s(?=,\s))", "", out) # Drop comma-spaces + + return out + def __sklearn_clone__(self: T) -> T: """Wrap cloning method to attach internal attributes.""" cloned = _clone_parametrized(self) @@ -1521,10 +1533,6 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]: return labels - Xt, yt = self._check_input(X, y) - self._check_feature_names(Xt, reset=True) - self._check_n_features(Xt, reset=True) - self._estimators: dict[str, Estimator] = {} self._labels: dict[str, Sequence[str]] = {} @@ -1548,7 +1556,7 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]: raise ValueError( "Invalid value for the bins parameter. The length of the " "bins does not match the length of the columns, got len" - f"(bins)={len(bins_c)} and len(columns)={Xt.shape[1]}." + f"(bins)={len(bins_c)} and len(columns)={X.shape[1]}." ) from None else: bins_x = bins_c @@ -1566,7 +1574,7 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]: encode="ordinal", strategy=self.strategy, **kwargs, - ).fit(Xt[[col]]) + ).fit(X[[col]]) # Save labels for transform method self._labels[col] = get_labels( @@ -1592,7 +1600,7 @@ def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]: self._estimators[col] = FunctionTransformer( func=bk.cut, kw_args={"bins": bins_c, "labels": get_labels(col, bins_c)}, - ).fit(Xt[[col]]) + ).fit(X[[col]]) return self @@ -2021,7 +2029,7 @@ class Imputer(TransformerMixin, _SetOutputMixin): Impute or remove missing values according to the selected strategy. Also removes rows and columns with too many missing values. Use - the `missing` attribute to customize what are considered "missing + the `missing_` attribute to customize what are considered "missing values". This class can be accessed from atom through the [impute] @@ -2036,9 +2044,18 @@ class Imputer(TransformerMixin, _SetOutputMixin): - "drop": Drop rows containing missing values. - "mean": Impute with mean of column. - "median": Impute with median of column. + - "most_frequent": Impute with the most frequent value. - "knn": Impute using a K-Nearest Neighbors approach. - "iterative": Impute using a multivariate imputer. - - "most_frequent": Impute with the most frequent value. + - "drift": Impute values using a [PolynomialTrend][] model. + - "linear": Impute using linear interpolation. + - "nearest": Impute with nearest value. + - "bfill": Impute by using the next valid observation to fill + the gap. + - "ffill": Impute by propagating the last valid observation + to next valid. + - "random": Impute with random values between the min and max + of column. - int or float: Impute with provided numerical value. strat_cat: str, default="drop" @@ -2263,6 +2280,15 @@ def fit(self, X: DataFrame, y: Pandas | None = None) -> Self: num_imputer = IterativeImputer(random_state=self.random_state) elif self.strat_num == "drop": num_imputer = "passthrough" + else: + # Inherit sklearn's attributes and methods + num_imputer = self._inherit( + sktimeImputer( + method=self.strat_num, + missing_values=[pd.NA], + random_state=self.random_state, + ) + ) else: num_imputer = SimpleImputer( missing_values=pd.NA, @@ -2401,8 +2427,7 @@ def transform( if name not in self._estimator.feature_names_in_: self._log( f" --> Dropping feature {name}. Contains {nans} " - f"({nans * 100 // len(X)}%) missing values.", - 2, + f"({nans * 100 // len(X)}%) missing values.", 2, ) X = X.drop(columns=name) continue @@ -2411,34 +2436,34 @@ def transform( if not isinstance(self.strat_num, str): self._log( f" --> Imputing {nans} missing values with " - f"number '{self.strat_num}' in feature {name}.", - 2, + f"number '{self.strat_num}' in column {name}.", 2, ) elif self.strat_num in ("knn", "iterative"): self._log( f" --> Imputing {nans} missing values using " - f"the {self.strat_num} imputer in feature {name}.", - 2, + f"the {self.strat_num} imputer in column {name}.", 2, + ) + elif self.strat_num in ("mean", "median", "most_frequent"): + self._log( + f" --> Imputing {nans} missing values with {self.strat_num} " + f"({np.round(get_stat(num_imputer, name), 2)}) in column " + f"{name}.", 2, ) - elif self.strat_num != "drop": # mean, median or most_frequent + else: self._log( f" --> Imputing {nans} missing values with {self.strat_num} " - f"({np.round(get_stat(num_imputer, name), 2)}) in feature " - f"{name}.", - 2, + f"in column {name}.", 2, ) elif self.strat_cat != "drop" and name in cat_imputer.feature_names_in_: if self.strat_cat == "most_frequent": self._log( f" --> Imputing {nans} missing values with most_frequent " - f"({get_stat(cat_imputer, name)}) in feature {name}.", - 2, + f"({get_stat(cat_imputer, name)}) in column {name}.", 2, ) elif self.strat_cat != "drop": self._log( f" --> Imputing {nans} missing values with value " - f"'{self.strat_cat}' in feature {name}.", - 2, + f"'{self.strat_cat}' in column {name}.", 2, ) Xt = self._estimator.transform(X) @@ -2969,8 +2994,7 @@ def transform( cond = np.abs(z_scores) > self.max_sigma objective = objective.mask(cond, self.method) self._log( - f" --> Replacing {cond.sum()} outlier values with {self.method}.", - 2, + f" --> Replacing {cond.sum()} outlier values with {self.method}.", 2, ) elif self.method.lower() == "minmax": @@ -2992,8 +3016,7 @@ def transform( self._log( f" --> Replacing {counts} outlier values " - "with the min or max of the column.", - 2, + "with the min or max of the column.", 2, ) elif self.method.lower() == "drop": @@ -3002,8 +3025,7 @@ def transform( if len(lst(self.strategy)) > 1: self._log( f" --> The zscore strategy detected " - f"{len(mask) - sum(mask)} outliers.", - 2, + f"{len(mask) - sum(mask)} outliers.", 2, ) else: @@ -3013,8 +3035,7 @@ def transform( if len(lst(self.strategy)) > 1: self._log( f" --> The {estimator.__class__.__name__} " - f"detected {len(mask) - sum(mask)} outliers.", - 2, + f"detected {len(mask) - sum(mask)} outliers.", 2, ) # Add the estimator as attribute to the instance diff --git a/atom/feature_engineering.py b/atom/feature_engineering.py index 1921f97d8..fe67480f5 100644 --- a/atom/feature_engineering.py +++ b/atom/feature_engineering.py @@ -1540,8 +1540,7 @@ def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame: self._log( f" --> Dropping feature {column} " f"(score: {self.univariate_.scores_[n]:.2f} " - f"p-value: {self.univariate_.pvalues_[n]:.2f}).", - 2, + f"p-value: {self.univariate_.pvalues_[n]:.2f}).", 2, ) X = X.drop(columns=column) @@ -1570,8 +1569,7 @@ def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame: if hasattr(self._estimator, "ranking_"): self._log( f" --> Dropping feature {column} " - f"(rank {self._estimator.ranking_[n]}).", - 2, + f"(rank {self._estimator.ranking_[n]}).", 2, ) else: self._log(f" --> Dropping feature {column}.", 2) diff --git a/atom/plots/dataplot.py b/atom/plots/dataplot.py index 98453e2dc..df5e90c49 100644 --- a/atom/plots/dataplot.py +++ b/atom/plots/dataplot.py @@ -54,8 +54,9 @@ class DataPlot(BasePlot, metaclass=ABCMeta): def plot_acf( self, columns: ColumnSelector | None = None, - nlags: IntLargerZero | None = None, *, + nlags: IntLargerZero | None = None, + plot_interval: Bool = True, title: str | dict[str, Any] | None = None, legend: Legend | dict[str, Any] | None = "upper right", figsize: tuple[IntLargerZero, IntLargerZero] | None = None, @@ -83,6 +84,9 @@ def plot_acf( returned value includes lag 0 (i.e., 1), so the size of the vector is `(nlags + 1,)`. + plot_interval: bool, default=True + Whether to plot the 95% confidence interval. + title: str, dict or None, default=None Title for the plot. @@ -174,35 +178,35 @@ def plot_acf( yaxis=yaxis, ) - # Add error bands - fig.add_traces( - [ - go.Scatter( - x=x, - y=np.subtract(conf[:, 1], corr), - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(col)}, - hovertemplate="%{y}upper bound", - legendgroup=col, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - go.Scatter( - x=x, - y=np.subtract(conf[:, 0], corr), - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(col)}, - fill="tonexty", - fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)", - hovertemplate="%{y}lower bound", - legendgroup=col, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - ] - ) + if plot_interval: + fig.add_traces( + [ + go.Scatter( + x=x, + y=conf[:, 1] - corr, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(col)}, + hovertemplate="%{y}upper bound", + legendgroup=col, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + go.Scatter( + x=x, + y=conf[:, 0] - corr, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(col)}, + fill="tonexty", + fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)", + hovertemplate="%{y}lower bound", + legendgroup=col, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + ] + ) fig.update_yaxes(zerolinecolor="black") fig.update_layout({"hovermode": "x unified"}) @@ -227,8 +231,9 @@ def plot_ccf( self, columns: ColumnSelector = 0, target: TargetSelector = 0, - nlags: IntLargerZero | None = None, *, + nlags: IntLargerZero | None = None, + plot_interval: Bool = False, title: str | dict[str, Any] | None = None, legend: Legend | dict[str, Any] | None = "upper right", figsize: tuple[IntLargerZero, IntLargerZero] = (900, 600), @@ -261,6 +266,9 @@ def plot_ccf( returned value includes lag 0 (i.e., 1), so the size of the vector is `(nlags + 1,)`. + plot_interval: bool, default=False + Whether to plot the 95% confidence interval. + title: str, dict or None, default=None Title for the plot. @@ -360,35 +368,35 @@ def plot_ccf( yaxis=yaxis, ) - # Add error bands - fig.add_traces( - [ - go.Scatter( - x=x, - y=conf[:, 1], - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(col)}, - hovertemplate="%{y}upper bound", - legendgroup=col, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - go.Scatter( - x=x, - y=conf[:, 0], - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(col)}, - fill="tonexty", - fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)", - hovertemplate="%{y}lower bound", - legendgroup=col, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - ] - ) + if plot_interval: + fig.add_traces( + [ + go.Scatter( + x=x, + y=conf[:, 1] - corr, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(col)}, + hovertemplate="%{y}upper bound", + legendgroup=col, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + go.Scatter( + x=x, + y=conf[:, 0] - corr, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(col)}, + fill="tonexty", + fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)", + hovertemplate="%{y}lower bound", + legendgroup=col, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + ] + ) fig.update_yaxes(zerolinecolor="black") fig.update_layout({"hovermode": "x unified"}) @@ -1292,9 +1300,10 @@ def get_text(column: Series) -> Series: def plot_pacf( self, columns: ColumnSelector | None = None, + *, nlags: IntLargerZero | None = None, method: PACFMethods = "ywadjusted", - *, + plot_interval: Bool = True, title: str | dict[str, Any] | None = None, legend: Legend | dict[str, Any] | None = "upper right", figsize: tuple[IntLargerZero, IntLargerZero] | None = None, @@ -1306,9 +1315,8 @@ def plot_pacf( The partial autocorrelation function (PACF) measures the correlation between a time series and lagged versions of itself. It's useful, for example, to identify the order of - an autoregressive model. The transparent band represents - the 95% confidence interval.This plot is only available - for [forecast][time-series] tasks. + an autoregressive model. This plot is only available for + [forecast][time-series] tasks. Parameters ---------- @@ -1327,8 +1335,8 @@ def plot_pacf( - "yw" or "ywadjusted": Yule-Walker with sample-size adjustment in denominator for acovf. - - "ywm" or "ywmle": Yule-Walker without adjustment. - - "ols" : Regression of time series on lags of it and on + - "ywm" or "ywmle": Yule-Walker without an adjustment. + - "ols": Regression of time series on lags of it and on constant. - "ols-inefficient": Regression of time series on lags using a single common sample to estimate all pacf coefficients. @@ -1338,7 +1346,10 @@ def plot_pacf( correction. - "ldb" or "ldbiased": Levinson-Durbin recursion without bias correction. - - "burg": Burg"s partial autocorrelation estimator. + - "burg": Burg"s partial autocorrelation estimator. + + plot_interval: bool, default=True + Whether to plot the 95% confidence interval. title: str, dict or None, default=None Title for the plot. @@ -1431,35 +1442,35 @@ def plot_pacf( yaxis=yaxis, ) - # Add error bands - fig.add_traces( - [ - go.Scatter( - x=x, - y=np.subtract(conf[:, 1], corr), - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(col)}, - hovertemplate="%{y}upper bound", - legendgroup=col, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - go.Scatter( - x=x, - y=np.subtract(conf[:, 0], corr), - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(col)}, - fill="tonexty", - fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)", - hovertemplate="%{y}lower bound", - legendgroup=col, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - ] - ) + if plot_interval: + fig.add_traces( + [ + go.Scatter( + x=x, + y=conf[:, 1] - corr, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(col)}, + hovertemplate="%{y}upper bound", + legendgroup=col, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + go.Scatter( + x=x, + y=conf[:, 0] - corr, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(col)}, + fill="tonexty", + fillcolor=f"rgba({BasePlot._fig.get_elem(col)[4:-1]}, 0.2)", + hovertemplate="%{y}lower bound", + legendgroup=col, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + ] + ) fig.update_yaxes(zerolinecolor="black") fig.update_layout({"hovermode": "x unified"}) @@ -2005,6 +2016,7 @@ def plot_relationships( def plot_rfecv( self, *, + plot_interval: Bool = True, title: str | dict[str, Any] | None = None, legend: Legend | dict[str, Any] | None = "upper right", figsize: tuple[IntLargerZero, IntLargerZero] = (900, 600), @@ -2019,6 +2031,9 @@ def plot_rfecv( Parameters ---------- + plot_interval: bool, default=True + Whether to plot the 1-sigma confidence interval. + title: str, dict or None, default=None Title for the plot. @@ -2112,35 +2127,35 @@ def plot_rfecv( yaxis=yaxis, ) - # Add error bands - fig.add_traces( - [ - go.Scatter( - x=x, - y=np.add(mean, std), - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)}, - hovertemplate="%{y}upper bound", - legendgroup=ylabel, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - go.Scatter( - x=x, - y=np.subtract(mean, std), - mode="lines", - line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)}, - fill="tonexty", - fillcolor=f"rgba{BasePlot._fig.get_elem(ylabel)[3:-1]}, 0.2)", - hovertemplate="%{y}lower bound", - legendgroup=ylabel, - showlegend=False, - xaxis=xaxis, - yaxis=yaxis, - ), - ] - ) + if plot_interval: + fig.add_traces( + [ + go.Scatter( + x=x, + y=mean + std, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)}, + hovertemplate="%{y}upper bound", + legendgroup=ylabel, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + go.Scatter( + x=x, + y=mean - std, + mode="lines", + line={"width": 1, "color": BasePlot._fig.get_elem(ylabel)}, + fill="tonexty", + fillcolor=f"rgba{BasePlot._fig.get_elem(ylabel)[3:-1]}, 0.2)", + hovertemplate="%{y}lower bound", + legendgroup=ylabel, + showlegend=False, + xaxis=xaxis, + yaxis=yaxis, + ), + ] + ) fig.update_layout({"hovermode": "x unified"}) diff --git a/atom/utils/types.py b/atom/utils/types.py index d3aaa7f31..43bc2ff6d 100644 --- a/atom/utils/types.py +++ b/atom/utils/types.py @@ -253,7 +253,20 @@ def predict(self, *args, **kwargs) -> Pandas: ... Verbose: TypeAlias = Literal[0, 1, 2] # Data cleaning parameters -NumericalStrats: TypeAlias = Literal["drop", "mean", "median", "knn", "iterative", "most_frequent"] +NumericalStrats: TypeAlias = Literal[ + "drop", + "mean", + "median", + "knn", + "iterative", + "most_frequent", + "drift", + "linear", + "nearest", + "bfill", + "ffill", + "random", +] CategoricalStrats: TypeAlias = Literal["drop", "most_frequent"] DiscretizerStrats: TypeAlias = Literal["uniform", "quantile", "kmeans", "custom"] Bins: TypeAlias = IntLargerOne | Sequence[Scalar] | dict[str, IntLargerOne | Sequence[Scalar]] diff --git a/atom/utils/utils.py b/atom/utils/utils.py index ac2ed7689..e794b55e1 100644 --- a/atom/utils/utils.py +++ b/atom/utils/utils.py @@ -43,6 +43,7 @@ from pandas._typing import Axes, Dtype, DtypeArg from pandas.api.types import is_numeric_dtype from shap import Explainer, Explanation +from sklearn.base import BaseEstimator from sklearn.metrics import ( confusion_matrix, get_scorer, get_scorer_names, make_scorer, matthews_corrcoef, @@ -2524,6 +2525,8 @@ def prepare_df(out: TReturn, og: DataFrame) -> DataFrame: name=flt(getattr(transformer, "target_names_in_", None)), ) + use_y = True + kwargs: dict[str, Any] = {} inc = list(getattr(transformer, "_cols", getattr(Xt, "columns", []))) if "X" in (params := sign(getattr(transformer, method))): @@ -2534,13 +2537,15 @@ def prepare_df(out: TReturn, og: DataFrame) -> DataFrame: if len(kwargs) == 0: if yt is not None and hasattr(transformer, "_cols"): kwargs["X"] = to_df(yt)[inc] + use_y = False elif params["X"].default != Parameter.empty: kwargs["X"] = params["X"].default # Fill X with default else: return Xt, yt # If X is needed, skip the transformer if "y" in params: - if yt is not None: + # We skip `y` when already added to `X` + if yt is not None and use_y: kwargs["y"] = yt elif "X" not in params: return Xt, yt # If y is None and no X in transformer, skip the transformer @@ -2763,6 +2768,29 @@ def wrapper(*args, **kwargs) -> Any: return wrapper +def wrap_fit(f: Callable) -> Callable: + """Wrap the fit method of estimators to add custom attributes. + + Wrapper to add the `feature_names_in_` and `n_features_in_` + attributes to an arbitrary estimator during `fit`. Used for + classes that don't have these attributes (e.g., from cuml or + sktime). + + """ + + @wraps(f) + def wrapped(self, X, *args, **kwargs): + out = f(self, X, *args, **kwargs) + + # We add the attributes after running the function + # to avoid deleting them with .reset() calls + BaseEstimator._check_feature_names(self, X, reset=True) + BaseEstimator._check_n_features(self, X, reset=True) + return out + + return wrapped + + def wrap_transformer_methods(f: Callable) -> Callable: """Wrap transformer methods with shared code. diff --git a/docs_sources/changelog/v6.x.x.md b/docs_sources/changelog/v6.x.x.md index a0b32c8a7..2b475d0cf 100644 --- a/docs_sources/changelog/v6.x.x.md +++ b/docs_sources/changelog/v6.x.x.md @@ -32,6 +32,8 @@ **:rocket: Enhancements** * Transformations only on `y` are now accepted, e.g., `atom.scale(columns=-1)`. +* The [Imputer][] class has many more strategies for numerical columns designed + for time series. * Full support for [pandas nullable dtypes](https://pandas.pydata.org/docs/user_guide/integer_na.html). * The dataset can now be provided as callable. * The [FeatureExtractor][] class can extract features from the dataset's index. diff --git a/docs_sources/examples/accelerating_sklearnex.ipynb b/docs_sources/examples/accelerating_sklearnex.ipynb index 57c70bf29..b1a5abf4c 100644 --- a/docs_sources/examples/accelerating_sklearnex.ipynb +++ b/docs_sources/examples/accelerating_sklearnex.ipynb @@ -300,26 +300,26 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Dropping 637 samples due to missing values in feature MinTemp.\n", - " --> Dropping 322 samples due to missing values in feature MaxTemp.\n", - " --> Dropping 1406 samples due to missing values in feature Rainfall.\n", - " --> Dropping 60843 samples due to missing values in feature Evaporation.\n", - " --> Dropping 67816 samples due to missing values in feature Sunshine.\n", - " --> Dropping 9330 samples due to missing values in feature WindGustDir.\n", - " --> Dropping 9270 samples due to missing values in feature WindGustSpeed.\n", - " --> Dropping 10013 samples due to missing values in feature WindDir9am.\n", - " --> Dropping 3778 samples due to missing values in feature WindDir3pm.\n", - " --> Dropping 1348 samples due to missing values in feature WindSpeed9am.\n", - " --> Dropping 2630 samples due to missing values in feature WindSpeed3pm.\n", - " --> Dropping 1774 samples due to missing values in feature Humidity9am.\n", - " --> Dropping 3610 samples due to missing values in feature Humidity3pm.\n", - " --> Dropping 14014 samples due to missing values in feature Pressure9am.\n", - " --> Dropping 13981 samples due to missing values in feature Pressure3pm.\n", - " --> Dropping 53657 samples due to missing values in feature Cloud9am.\n", - " --> Dropping 57094 samples due to missing values in feature Cloud3pm.\n", - " --> Dropping 904 samples due to missing values in feature Temp9am.\n", - " --> Dropping 2726 samples due to missing values in feature Temp3pm.\n", - " --> Dropping 1406 samples due to missing values in feature RainToday.\n", + " --> Dropping 637 samples due to missing values in column MinTemp.\n", + " --> Dropping 322 samples due to missing values in column MaxTemp.\n", + " --> Dropping 1406 samples due to missing values in column Rainfall.\n", + " --> Dropping 60843 samples due to missing values in column Evaporation.\n", + " --> Dropping 67816 samples due to missing values in column Sunshine.\n", + " --> Dropping 9330 samples due to missing values in column WindGustDir.\n", + " --> Dropping 9270 samples due to missing values in column WindGustSpeed.\n", + " --> Dropping 10013 samples due to missing values in column WindDir9am.\n", + " --> Dropping 3778 samples due to missing values in column WindDir3pm.\n", + " --> Dropping 1348 samples due to missing values in column WindSpeed9am.\n", + " --> Dropping 2630 samples due to missing values in column WindSpeed3pm.\n", + " --> Dropping 1774 samples due to missing values in column Humidity9am.\n", + " --> Dropping 3610 samples due to missing values in column Humidity3pm.\n", + " --> Dropping 14014 samples due to missing values in column Pressure9am.\n", + " --> Dropping 13981 samples due to missing values in column Pressure3pm.\n", + " --> Dropping 53657 samples due to missing values in column Cloud9am.\n", + " --> Dropping 57094 samples due to missing values in column Cloud3pm.\n", + " --> Dropping 904 samples due to missing values in column Temp9am.\n", + " --> Dropping 2726 samples due to missing values in column Temp3pm.\n", + " --> Dropping 1406 samples due to missing values in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 26 classes.\n", diff --git a/docs_sources/examples/binary_classification.ipynb b/docs_sources/examples/binary_classification.ipynb index 61e92c70f..6b260e7cc 100644 --- a/docs_sources/examples/binary_classification.ipynb +++ b/docs_sources/examples/binary_classification.ipynb @@ -306,26 +306,26 @@ "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 7 samples for containing more than 16 missing values.\n", - " --> Imputing 23 missing values with median (11.9) in feature MinTemp.\n", - " --> Imputing 10 missing values with median (22.6) in feature MaxTemp.\n", - " --> Imputing 72 missing values with median (0.0) in feature Rainfall.\n", - " --> Imputing 3059 missing values with median (4.6) in feature Evaporation.\n", - " --> Imputing 3382 missing values with median (8.5) in feature Sunshine.\n", - " --> Dropping 467 samples due to missing values in feature WindGustDir.\n", - " --> Imputing 466 missing values with median (39.0) in feature WindGustSpeed.\n", - " --> Dropping 479 samples due to missing values in feature WindDir9am.\n", - " --> Dropping 165 samples due to missing values in feature WindDir3pm.\n", - " --> Imputing 53 missing values with median (13.0) in feature WindSpeed9am.\n", - " --> Imputing 115 missing values with median (17.0) in feature WindSpeed3pm.\n", - " --> Imputing 72 missing values with median (70.0) in feature Humidity9am.\n", - " --> Imputing 164 missing values with median (52.0) in feature Humidity3pm.\n", - " --> Imputing 699 missing values with median (1017.7) in feature Pressure9am.\n", - " --> Imputing 699 missing values with median (1015.4) in feature Pressure3pm.\n", - " --> Imputing 2698 missing values with median (5.0) in feature Cloud9am.\n", - " --> Imputing 2903 missing values with median (5.0) in feature Cloud3pm.\n", - " --> Imputing 32 missing values with median (16.7) in feature Temp9am.\n", - " --> Imputing 116 missing values with median (21.1) in feature Temp3pm.\n", - " --> Dropping 72 samples due to missing values in feature RainToday.\n" + " --> Imputing 23 missing values with median (11.9) in column MinTemp.\n", + " --> Imputing 10 missing values with median (22.6) in column MaxTemp.\n", + " --> Imputing 72 missing values with median (0.0) in column Rainfall.\n", + " --> Imputing 3059 missing values with median (4.6) in column Evaporation.\n", + " --> Imputing 3382 missing values with median (8.5) in column Sunshine.\n", + " --> Dropping 467 samples due to missing values in column WindGustDir.\n", + " --> Imputing 466 missing values with median (39.0) in column WindGustSpeed.\n", + " --> Dropping 479 samples due to missing values in column WindDir9am.\n", + " --> Dropping 165 samples due to missing values in column WindDir3pm.\n", + " --> Imputing 53 missing values with median (13.0) in column WindSpeed9am.\n", + " --> Imputing 115 missing values with median (17.0) in column WindSpeed3pm.\n", + " --> Imputing 72 missing values with median (70.0) in column Humidity9am.\n", + " --> Imputing 164 missing values with median (52.0) in column Humidity3pm.\n", + " --> Imputing 699 missing values with median (1017.7) in column Pressure9am.\n", + " --> Imputing 699 missing values with median (1015.4) in column Pressure3pm.\n", + " --> Imputing 2698 missing values with median (5.0) in column Cloud9am.\n", + " --> Imputing 2903 missing values with median (5.0) in column Cloud3pm.\n", + " --> Imputing 32 missing values with median (16.7) in column Temp9am.\n", + " --> Imputing 116 missing values with median (21.1) in column Temp3pm.\n", + " --> Dropping 72 samples due to missing values in column RainToday.\n" ] } ], diff --git a/docs_sources/examples/feature_engineering.ipynb b/docs_sources/examples/feature_engineering.ipynb index ff309fd2d..5bd9e658d 100644 --- a/docs_sources/examples/feature_engineering.ipynb +++ b/docs_sources/examples/feature_engineering.ipynb @@ -397,7 +397,7 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Imputing 12 missing values using the KNN imputer in feature NATURAL_LOGARITHM(Temp3pm).\n" + " --> Imputing 12 missing values using the KNN imputer in column NATURAL_LOGARITHM(Temp3pm).\n" ] } ], diff --git a/docs_sources/examples/getting_started.ipynb b/docs_sources/examples/getting_started.ipynb index 0ffe3fe8b..ee2ac20a5 100644 --- a/docs_sources/examples/getting_started.ipynb +++ b/docs_sources/examples/getting_started.ipynb @@ -71,26 +71,26 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Imputing 8 missing values with median (11.6) in feature MinTemp.\n", - " --> Imputing 2 missing values with median (22.3) in feature MaxTemp.\n", - " --> Imputing 12 missing values with median (0.0) in feature Rainfall.\n", - " --> Imputing 425 missing values with median (4.8) in feature Evaporation.\n", - " --> Imputing 480 missing values with median (8.55) in feature Sunshine.\n", - " --> Imputing 59 missing values with most_frequent (N) in feature WindGustDir.\n", - " --> Imputing 59 missing values with median (37.0) in feature WindGustSpeed.\n", - " --> Imputing 90 missing values with most_frequent (N) in feature WindDir9am.\n", - " --> Imputing 28 missing values with most_frequent (SW) in feature WindDir3pm.\n", - " --> Imputing 10 missing values with median (13.0) in feature WindSpeed9am.\n", - " --> Imputing 19 missing values with median (17.0) in feature WindSpeed3pm.\n", - " --> Imputing 17 missing values with median (70.0) in feature Humidity9am.\n", - " --> Imputing 31 missing values with median (51.0) in feature Humidity3pm.\n", - " --> Imputing 89 missing values with median (1017.8) in feature Pressure9am.\n", - " --> Imputing 87 missing values with median (1015.2) in feature Pressure3pm.\n", - " --> Imputing 383 missing values with median (5.0) in feature Cloud9am.\n", - " --> Imputing 412 missing values with median (5.0) in feature Cloud3pm.\n", - " --> Imputing 11 missing values with median (16.5) in feature Temp9am.\n", - " --> Imputing 26 missing values with median (20.7) in feature Temp3pm.\n", - " --> Imputing 12 missing values with most_frequent (No) in feature RainToday.\n", + " --> Imputing 8 missing values with median (11.6) in column MinTemp.\n", + " --> Imputing 2 missing values with median (22.3) in column MaxTemp.\n", + " --> Imputing 12 missing values with median (0.0) in column Rainfall.\n", + " --> Imputing 425 missing values with median (4.8) in column Evaporation.\n", + " --> Imputing 480 missing values with median (8.55) in column Sunshine.\n", + " --> Imputing 59 missing values with most_frequent (N) in column WindGustDir.\n", + " --> Imputing 59 missing values with median (37.0) in column WindGustSpeed.\n", + " --> Imputing 90 missing values with most_frequent (N) in column WindDir9am.\n", + " --> Imputing 28 missing values with most_frequent (SW) in column WindDir3pm.\n", + " --> Imputing 10 missing values with median (13.0) in column WindSpeed9am.\n", + " --> Imputing 19 missing values with median (17.0) in column WindSpeed3pm.\n", + " --> Imputing 17 missing values with median (70.0) in column Humidity9am.\n", + " --> Imputing 31 missing values with median (51.0) in column Humidity3pm.\n", + " --> Imputing 89 missing values with median (1017.8) in column Pressure9am.\n", + " --> Imputing 87 missing values with median (1015.2) in column Pressure3pm.\n", + " --> Imputing 383 missing values with median (5.0) in column Cloud9am.\n", + " --> Imputing 412 missing values with median (5.0) in column Cloud3pm.\n", + " --> Imputing 11 missing values with median (16.5) in column Temp9am.\n", + " --> Imputing 26 missing values with median (20.7) in column Temp3pm.\n", + " --> Imputing 12 missing values with most_frequent (No) in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 49 classes.\n", diff --git a/docs_sources/examples/holdout_set.ipynb b/docs_sources/examples/holdout_set.ipynb index ffe21b751..68065a975 100644 --- a/docs_sources/examples/holdout_set.ipynb +++ b/docs_sources/examples/holdout_set.ipynb @@ -323,26 +323,26 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Dropping 258 samples due to missing values in feature MinTemp.\n", - " --> Dropping 127 samples due to missing values in feature MaxTemp.\n", - " --> Dropping 553 samples due to missing values in feature Rainfall.\n", - " --> Dropping 24308 samples due to missing values in feature Evaporation.\n", - " --> Dropping 27187 samples due to missing values in feature Sunshine.\n", - " --> Dropping 3739 samples due to missing values in feature WindGustDir.\n", - " --> Dropping 3712 samples due to missing values in feature WindGustSpeed.\n", - " --> Dropping 3995 samples due to missing values in feature WindDir9am.\n", - " --> Dropping 1508 samples due to missing values in feature WindDir3pm.\n", - " --> Dropping 539 samples due to missing values in feature WindSpeed9am.\n", - " --> Dropping 1077 samples due to missing values in feature WindSpeed3pm.\n", - " --> Dropping 706 samples due to missing values in feature Humidity9am.\n", - " --> Dropping 1447 samples due to missing values in feature Humidity3pm.\n", - " --> Dropping 5610 samples due to missing values in feature Pressure9am.\n", - " --> Dropping 5591 samples due to missing values in feature Pressure3pm.\n", - " --> Dropping 21520 samples due to missing values in feature Cloud9am.\n", - " --> Dropping 22921 samples due to missing values in feature Cloud3pm.\n", - " --> Dropping 365 samples due to missing values in feature Temp9am.\n", - " --> Dropping 1106 samples due to missing values in feature Temp3pm.\n", - " --> Dropping 553 samples due to missing values in feature RainToday.\n", + " --> Dropping 258 samples due to missing values in column MinTemp.\n", + " --> Dropping 127 samples due to missing values in column MaxTemp.\n", + " --> Dropping 553 samples due to missing values in column Rainfall.\n", + " --> Dropping 24308 samples due to missing values in column Evaporation.\n", + " --> Dropping 27187 samples due to missing values in column Sunshine.\n", + " --> Dropping 3739 samples due to missing values in column WindGustDir.\n", + " --> Dropping 3712 samples due to missing values in column WindGustSpeed.\n", + " --> Dropping 3995 samples due to missing values in column WindDir9am.\n", + " --> Dropping 1508 samples due to missing values in column WindDir3pm.\n", + " --> Dropping 539 samples due to missing values in column WindSpeed9am.\n", + " --> Dropping 1077 samples due to missing values in column WindSpeed3pm.\n", + " --> Dropping 706 samples due to missing values in column Humidity9am.\n", + " --> Dropping 1447 samples due to missing values in column Humidity3pm.\n", + " --> Dropping 5610 samples due to missing values in column Pressure9am.\n", + " --> Dropping 5591 samples due to missing values in column Pressure3pm.\n", + " --> Dropping 21520 samples due to missing values in column Cloud9am.\n", + " --> Dropping 22921 samples due to missing values in column Cloud3pm.\n", + " --> Dropping 365 samples due to missing values in column Temp9am.\n", + " --> Dropping 1106 samples due to missing values in column Temp3pm.\n", + " --> Dropping 553 samples due to missing values in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 26 classes.\n", diff --git a/docs_sources/examples/train_sizing.ipynb b/docs_sources/examples/train_sizing.ipynb index ec75bff9d..1d106f059 100644 --- a/docs_sources/examples/train_sizing.ipynb +++ b/docs_sources/examples/train_sizing.ipynb @@ -287,26 +287,26 @@ "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 161 samples for containing more than 16 missing values.\n", - " --> Imputing 481 missing values with median (12.0) in feature MinTemp.\n", - " --> Imputing 265 missing values with median (22.6) in feature MaxTemp.\n", - " --> Imputing 1354 missing values with median (0.0) in feature Rainfall.\n", - " --> Imputing 60682 missing values with median (4.8) in feature Evaporation.\n", - " --> Imputing 67659 missing values with median (8.4) in feature Sunshine.\n", - " --> Imputing 9187 missing values with most_frequent (W) in feature WindGustDir.\n", - " --> Imputing 9127 missing values with median (39.0) in feature WindGustSpeed.\n", - " --> Imputing 9852 missing values with most_frequent (N) in feature WindDir9am.\n", - " --> Imputing 3617 missing values with most_frequent (SE) in feature WindDir3pm.\n", - " --> Imputing 1187 missing values with median (13.0) in feature WindSpeed9am.\n", - " --> Imputing 2469 missing values with median (19.0) in feature WindSpeed3pm.\n", - " --> Imputing 1613 missing values with median (70.0) in feature Humidity9am.\n", - " --> Imputing 3449 missing values with median (52.0) in feature Humidity3pm.\n", - " --> Imputing 13863 missing values with median (1017.6) in feature Pressure9am.\n", - " --> Imputing 13830 missing values with median (1015.2) in feature Pressure3pm.\n", - " --> Imputing 53496 missing values with median (5.0) in feature Cloud9am.\n", - " --> Imputing 56933 missing values with median (5.0) in feature Cloud3pm.\n", - " --> Imputing 743 missing values with median (16.7) in feature Temp9am.\n", - " --> Imputing 2565 missing values with median (21.1) in feature Temp3pm.\n", - " --> Imputing 1354 missing values with most_frequent (No) in feature RainToday.\n", + " --> Imputing 481 missing values with median (12.0) in column MinTemp.\n", + " --> Imputing 265 missing values with median (22.6) in column MaxTemp.\n", + " --> Imputing 1354 missing values with median (0.0) in column Rainfall.\n", + " --> Imputing 60682 missing values with median (4.8) in column Evaporation.\n", + " --> Imputing 67659 missing values with median (8.4) in column Sunshine.\n", + " --> Imputing 9187 missing values with most_frequent (W) in column WindGustDir.\n", + " --> Imputing 9127 missing values with median (39.0) in column WindGustSpeed.\n", + " --> Imputing 9852 missing values with most_frequent (N) in column WindDir9am.\n", + " --> Imputing 3617 missing values with most_frequent (SE) in column WindDir3pm.\n", + " --> Imputing 1187 missing values with median (13.0) in column WindSpeed9am.\n", + " --> Imputing 2469 missing values with median (19.0) in column WindSpeed3pm.\n", + " --> Imputing 1613 missing values with median (70.0) in column Humidity9am.\n", + " --> Imputing 3449 missing values with median (52.0) in column Humidity3pm.\n", + " --> Imputing 13863 missing values with median (1017.6) in column Pressure9am.\n", + " --> Imputing 13830 missing values with median (1015.2) in column Pressure3pm.\n", + " --> Imputing 53496 missing values with median (5.0) in column Cloud9am.\n", + " --> Imputing 56933 missing values with median (5.0) in column Cloud3pm.\n", + " --> Imputing 743 missing values with median (16.7) in column Temp9am.\n", + " --> Imputing 2565 missing values with median (21.1) in column Temp3pm.\n", + " --> Imputing 1354 missing values with most_frequent (No) in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 49 classes.\n", diff --git a/examples/accelerating_sklearnex.ipynb b/examples/accelerating_sklearnex.ipynb index 57c70bf29..b1a5abf4c 100644 --- a/examples/accelerating_sklearnex.ipynb +++ b/examples/accelerating_sklearnex.ipynb @@ -300,26 +300,26 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Dropping 637 samples due to missing values in feature MinTemp.\n", - " --> Dropping 322 samples due to missing values in feature MaxTemp.\n", - " --> Dropping 1406 samples due to missing values in feature Rainfall.\n", - " --> Dropping 60843 samples due to missing values in feature Evaporation.\n", - " --> Dropping 67816 samples due to missing values in feature Sunshine.\n", - " --> Dropping 9330 samples due to missing values in feature WindGustDir.\n", - " --> Dropping 9270 samples due to missing values in feature WindGustSpeed.\n", - " --> Dropping 10013 samples due to missing values in feature WindDir9am.\n", - " --> Dropping 3778 samples due to missing values in feature WindDir3pm.\n", - " --> Dropping 1348 samples due to missing values in feature WindSpeed9am.\n", - " --> Dropping 2630 samples due to missing values in feature WindSpeed3pm.\n", - " --> Dropping 1774 samples due to missing values in feature Humidity9am.\n", - " --> Dropping 3610 samples due to missing values in feature Humidity3pm.\n", - " --> Dropping 14014 samples due to missing values in feature Pressure9am.\n", - " --> Dropping 13981 samples due to missing values in feature Pressure3pm.\n", - " --> Dropping 53657 samples due to missing values in feature Cloud9am.\n", - " --> Dropping 57094 samples due to missing values in feature Cloud3pm.\n", - " --> Dropping 904 samples due to missing values in feature Temp9am.\n", - " --> Dropping 2726 samples due to missing values in feature Temp3pm.\n", - " --> Dropping 1406 samples due to missing values in feature RainToday.\n", + " --> Dropping 637 samples due to missing values in column MinTemp.\n", + " --> Dropping 322 samples due to missing values in column MaxTemp.\n", + " --> Dropping 1406 samples due to missing values in column Rainfall.\n", + " --> Dropping 60843 samples due to missing values in column Evaporation.\n", + " --> Dropping 67816 samples due to missing values in column Sunshine.\n", + " --> Dropping 9330 samples due to missing values in column WindGustDir.\n", + " --> Dropping 9270 samples due to missing values in column WindGustSpeed.\n", + " --> Dropping 10013 samples due to missing values in column WindDir9am.\n", + " --> Dropping 3778 samples due to missing values in column WindDir3pm.\n", + " --> Dropping 1348 samples due to missing values in column WindSpeed9am.\n", + " --> Dropping 2630 samples due to missing values in column WindSpeed3pm.\n", + " --> Dropping 1774 samples due to missing values in column Humidity9am.\n", + " --> Dropping 3610 samples due to missing values in column Humidity3pm.\n", + " --> Dropping 14014 samples due to missing values in column Pressure9am.\n", + " --> Dropping 13981 samples due to missing values in column Pressure3pm.\n", + " --> Dropping 53657 samples due to missing values in column Cloud9am.\n", + " --> Dropping 57094 samples due to missing values in column Cloud3pm.\n", + " --> Dropping 904 samples due to missing values in column Temp9am.\n", + " --> Dropping 2726 samples due to missing values in column Temp3pm.\n", + " --> Dropping 1406 samples due to missing values in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 26 classes.\n", diff --git a/examples/binary_classification.ipynb b/examples/binary_classification.ipynb index 61e92c70f..6b260e7cc 100644 --- a/examples/binary_classification.ipynb +++ b/examples/binary_classification.ipynb @@ -306,26 +306,26 @@ "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 7 samples for containing more than 16 missing values.\n", - " --> Imputing 23 missing values with median (11.9) in feature MinTemp.\n", - " --> Imputing 10 missing values with median (22.6) in feature MaxTemp.\n", - " --> Imputing 72 missing values with median (0.0) in feature Rainfall.\n", - " --> Imputing 3059 missing values with median (4.6) in feature Evaporation.\n", - " --> Imputing 3382 missing values with median (8.5) in feature Sunshine.\n", - " --> Dropping 467 samples due to missing values in feature WindGustDir.\n", - " --> Imputing 466 missing values with median (39.0) in feature WindGustSpeed.\n", - " --> Dropping 479 samples due to missing values in feature WindDir9am.\n", - " --> Dropping 165 samples due to missing values in feature WindDir3pm.\n", - " --> Imputing 53 missing values with median (13.0) in feature WindSpeed9am.\n", - " --> Imputing 115 missing values with median (17.0) in feature WindSpeed3pm.\n", - " --> Imputing 72 missing values with median (70.0) in feature Humidity9am.\n", - " --> Imputing 164 missing values with median (52.0) in feature Humidity3pm.\n", - " --> Imputing 699 missing values with median (1017.7) in feature Pressure9am.\n", - " --> Imputing 699 missing values with median (1015.4) in feature Pressure3pm.\n", - " --> Imputing 2698 missing values with median (5.0) in feature Cloud9am.\n", - " --> Imputing 2903 missing values with median (5.0) in feature Cloud3pm.\n", - " --> Imputing 32 missing values with median (16.7) in feature Temp9am.\n", - " --> Imputing 116 missing values with median (21.1) in feature Temp3pm.\n", - " --> Dropping 72 samples due to missing values in feature RainToday.\n" + " --> Imputing 23 missing values with median (11.9) in column MinTemp.\n", + " --> Imputing 10 missing values with median (22.6) in column MaxTemp.\n", + " --> Imputing 72 missing values with median (0.0) in column Rainfall.\n", + " --> Imputing 3059 missing values with median (4.6) in column Evaporation.\n", + " --> Imputing 3382 missing values with median (8.5) in column Sunshine.\n", + " --> Dropping 467 samples due to missing values in column WindGustDir.\n", + " --> Imputing 466 missing values with median (39.0) in column WindGustSpeed.\n", + " --> Dropping 479 samples due to missing values in column WindDir9am.\n", + " --> Dropping 165 samples due to missing values in column WindDir3pm.\n", + " --> Imputing 53 missing values with median (13.0) in column WindSpeed9am.\n", + " --> Imputing 115 missing values with median (17.0) in column WindSpeed3pm.\n", + " --> Imputing 72 missing values with median (70.0) in column Humidity9am.\n", + " --> Imputing 164 missing values with median (52.0) in column Humidity3pm.\n", + " --> Imputing 699 missing values with median (1017.7) in column Pressure9am.\n", + " --> Imputing 699 missing values with median (1015.4) in column Pressure3pm.\n", + " --> Imputing 2698 missing values with median (5.0) in column Cloud9am.\n", + " --> Imputing 2903 missing values with median (5.0) in column Cloud3pm.\n", + " --> Imputing 32 missing values with median (16.7) in column Temp9am.\n", + " --> Imputing 116 missing values with median (21.1) in column Temp3pm.\n", + " --> Dropping 72 samples due to missing values in column RainToday.\n" ] } ], diff --git a/examples/feature_engineering.ipynb b/examples/feature_engineering.ipynb index ff309fd2d..5bd9e658d 100644 --- a/examples/feature_engineering.ipynb +++ b/examples/feature_engineering.ipynb @@ -397,7 +397,7 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Imputing 12 missing values using the KNN imputer in feature NATURAL_LOGARITHM(Temp3pm).\n" + " --> Imputing 12 missing values using the KNN imputer in column NATURAL_LOGARITHM(Temp3pm).\n" ] } ], diff --git a/examples/getting_started.ipynb b/examples/getting_started.ipynb index 0ffe3fe8b..ee2ac20a5 100644 --- a/examples/getting_started.ipynb +++ b/examples/getting_started.ipynb @@ -71,26 +71,26 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Imputing 8 missing values with median (11.6) in feature MinTemp.\n", - " --> Imputing 2 missing values with median (22.3) in feature MaxTemp.\n", - " --> Imputing 12 missing values with median (0.0) in feature Rainfall.\n", - " --> Imputing 425 missing values with median (4.8) in feature Evaporation.\n", - " --> Imputing 480 missing values with median (8.55) in feature Sunshine.\n", - " --> Imputing 59 missing values with most_frequent (N) in feature WindGustDir.\n", - " --> Imputing 59 missing values with median (37.0) in feature WindGustSpeed.\n", - " --> Imputing 90 missing values with most_frequent (N) in feature WindDir9am.\n", - " --> Imputing 28 missing values with most_frequent (SW) in feature WindDir3pm.\n", - " --> Imputing 10 missing values with median (13.0) in feature WindSpeed9am.\n", - " --> Imputing 19 missing values with median (17.0) in feature WindSpeed3pm.\n", - " --> Imputing 17 missing values with median (70.0) in feature Humidity9am.\n", - " --> Imputing 31 missing values with median (51.0) in feature Humidity3pm.\n", - " --> Imputing 89 missing values with median (1017.8) in feature Pressure9am.\n", - " --> Imputing 87 missing values with median (1015.2) in feature Pressure3pm.\n", - " --> Imputing 383 missing values with median (5.0) in feature Cloud9am.\n", - " --> Imputing 412 missing values with median (5.0) in feature Cloud3pm.\n", - " --> Imputing 11 missing values with median (16.5) in feature Temp9am.\n", - " --> Imputing 26 missing values with median (20.7) in feature Temp3pm.\n", - " --> Imputing 12 missing values with most_frequent (No) in feature RainToday.\n", + " --> Imputing 8 missing values with median (11.6) in column MinTemp.\n", + " --> Imputing 2 missing values with median (22.3) in column MaxTemp.\n", + " --> Imputing 12 missing values with median (0.0) in column Rainfall.\n", + " --> Imputing 425 missing values with median (4.8) in column Evaporation.\n", + " --> Imputing 480 missing values with median (8.55) in column Sunshine.\n", + " --> Imputing 59 missing values with most_frequent (N) in column WindGustDir.\n", + " --> Imputing 59 missing values with median (37.0) in column WindGustSpeed.\n", + " --> Imputing 90 missing values with most_frequent (N) in column WindDir9am.\n", + " --> Imputing 28 missing values with most_frequent (SW) in column WindDir3pm.\n", + " --> Imputing 10 missing values with median (13.0) in column WindSpeed9am.\n", + " --> Imputing 19 missing values with median (17.0) in column WindSpeed3pm.\n", + " --> Imputing 17 missing values with median (70.0) in column Humidity9am.\n", + " --> Imputing 31 missing values with median (51.0) in column Humidity3pm.\n", + " --> Imputing 89 missing values with median (1017.8) in column Pressure9am.\n", + " --> Imputing 87 missing values with median (1015.2) in column Pressure3pm.\n", + " --> Imputing 383 missing values with median (5.0) in column Cloud9am.\n", + " --> Imputing 412 missing values with median (5.0) in column Cloud3pm.\n", + " --> Imputing 11 missing values with median (16.5) in column Temp9am.\n", + " --> Imputing 26 missing values with median (20.7) in column Temp3pm.\n", + " --> Imputing 12 missing values with most_frequent (No) in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 49 classes.\n", diff --git a/examples/holdout_set.ipynb b/examples/holdout_set.ipynb index ffe21b751..68065a975 100644 --- a/examples/holdout_set.ipynb +++ b/examples/holdout_set.ipynb @@ -323,26 +323,26 @@ "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", - " --> Dropping 258 samples due to missing values in feature MinTemp.\n", - " --> Dropping 127 samples due to missing values in feature MaxTemp.\n", - " --> Dropping 553 samples due to missing values in feature Rainfall.\n", - " --> Dropping 24308 samples due to missing values in feature Evaporation.\n", - " --> Dropping 27187 samples due to missing values in feature Sunshine.\n", - " --> Dropping 3739 samples due to missing values in feature WindGustDir.\n", - " --> Dropping 3712 samples due to missing values in feature WindGustSpeed.\n", - " --> Dropping 3995 samples due to missing values in feature WindDir9am.\n", - " --> Dropping 1508 samples due to missing values in feature WindDir3pm.\n", - " --> Dropping 539 samples due to missing values in feature WindSpeed9am.\n", - " --> Dropping 1077 samples due to missing values in feature WindSpeed3pm.\n", - " --> Dropping 706 samples due to missing values in feature Humidity9am.\n", - " --> Dropping 1447 samples due to missing values in feature Humidity3pm.\n", - " --> Dropping 5610 samples due to missing values in feature Pressure9am.\n", - " --> Dropping 5591 samples due to missing values in feature Pressure3pm.\n", - " --> Dropping 21520 samples due to missing values in feature Cloud9am.\n", - " --> Dropping 22921 samples due to missing values in feature Cloud3pm.\n", - " --> Dropping 365 samples due to missing values in feature Temp9am.\n", - " --> Dropping 1106 samples due to missing values in feature Temp3pm.\n", - " --> Dropping 553 samples due to missing values in feature RainToday.\n", + " --> Dropping 258 samples due to missing values in column MinTemp.\n", + " --> Dropping 127 samples due to missing values in column MaxTemp.\n", + " --> Dropping 553 samples due to missing values in column Rainfall.\n", + " --> Dropping 24308 samples due to missing values in column Evaporation.\n", + " --> Dropping 27187 samples due to missing values in column Sunshine.\n", + " --> Dropping 3739 samples due to missing values in column WindGustDir.\n", + " --> Dropping 3712 samples due to missing values in column WindGustSpeed.\n", + " --> Dropping 3995 samples due to missing values in column WindDir9am.\n", + " --> Dropping 1508 samples due to missing values in column WindDir3pm.\n", + " --> Dropping 539 samples due to missing values in column WindSpeed9am.\n", + " --> Dropping 1077 samples due to missing values in column WindSpeed3pm.\n", + " --> Dropping 706 samples due to missing values in column Humidity9am.\n", + " --> Dropping 1447 samples due to missing values in column Humidity3pm.\n", + " --> Dropping 5610 samples due to missing values in column Pressure9am.\n", + " --> Dropping 5591 samples due to missing values in column Pressure3pm.\n", + " --> Dropping 21520 samples due to missing values in column Cloud9am.\n", + " --> Dropping 22921 samples due to missing values in column Cloud3pm.\n", + " --> Dropping 365 samples due to missing values in column Temp9am.\n", + " --> Dropping 1106 samples due to missing values in column Temp3pm.\n", + " --> Dropping 553 samples due to missing values in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 26 classes.\n", diff --git a/examples/train_sizing.ipynb b/examples/train_sizing.ipynb index ec75bff9d..1d106f059 100644 --- a/examples/train_sizing.ipynb +++ b/examples/train_sizing.ipynb @@ -287,26 +287,26 @@ "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 161 samples for containing more than 16 missing values.\n", - " --> Imputing 481 missing values with median (12.0) in feature MinTemp.\n", - " --> Imputing 265 missing values with median (22.6) in feature MaxTemp.\n", - " --> Imputing 1354 missing values with median (0.0) in feature Rainfall.\n", - " --> Imputing 60682 missing values with median (4.8) in feature Evaporation.\n", - " --> Imputing 67659 missing values with median (8.4) in feature Sunshine.\n", - " --> Imputing 9187 missing values with most_frequent (W) in feature WindGustDir.\n", - " --> Imputing 9127 missing values with median (39.0) in feature WindGustSpeed.\n", - " --> Imputing 9852 missing values with most_frequent (N) in feature WindDir9am.\n", - " --> Imputing 3617 missing values with most_frequent (SE) in feature WindDir3pm.\n", - " --> Imputing 1187 missing values with median (13.0) in feature WindSpeed9am.\n", - " --> Imputing 2469 missing values with median (19.0) in feature WindSpeed3pm.\n", - " --> Imputing 1613 missing values with median (70.0) in feature Humidity9am.\n", - " --> Imputing 3449 missing values with median (52.0) in feature Humidity3pm.\n", - " --> Imputing 13863 missing values with median (1017.6) in feature Pressure9am.\n", - " --> Imputing 13830 missing values with median (1015.2) in feature Pressure3pm.\n", - " --> Imputing 53496 missing values with median (5.0) in feature Cloud9am.\n", - " --> Imputing 56933 missing values with median (5.0) in feature Cloud3pm.\n", - " --> Imputing 743 missing values with median (16.7) in feature Temp9am.\n", - " --> Imputing 2565 missing values with median (21.1) in feature Temp3pm.\n", - " --> Imputing 1354 missing values with most_frequent (No) in feature RainToday.\n", + " --> Imputing 481 missing values with median (12.0) in column MinTemp.\n", + " --> Imputing 265 missing values with median (22.6) in column MaxTemp.\n", + " --> Imputing 1354 missing values with median (0.0) in column Rainfall.\n", + " --> Imputing 60682 missing values with median (4.8) in column Evaporation.\n", + " --> Imputing 67659 missing values with median (8.4) in column Sunshine.\n", + " --> Imputing 9187 missing values with most_frequent (W) in column WindGustDir.\n", + " --> Imputing 9127 missing values with median (39.0) in column WindGustSpeed.\n", + " --> Imputing 9852 missing values with most_frequent (N) in column WindDir9am.\n", + " --> Imputing 3617 missing values with most_frequent (SE) in column WindDir3pm.\n", + " --> Imputing 1187 missing values with median (13.0) in column WindSpeed9am.\n", + " --> Imputing 2469 missing values with median (19.0) in column WindSpeed3pm.\n", + " --> Imputing 1613 missing values with median (70.0) in column Humidity9am.\n", + " --> Imputing 3449 missing values with median (52.0) in column Humidity3pm.\n", + " --> Imputing 13863 missing values with median (1017.6) in column Pressure9am.\n", + " --> Imputing 13830 missing values with median (1015.2) in column Pressure3pm.\n", + " --> Imputing 53496 missing values with median (5.0) in column Cloud9am.\n", + " --> Imputing 56933 missing values with median (5.0) in column Cloud3pm.\n", + " --> Imputing 743 missing values with median (16.7) in column Temp9am.\n", + " --> Imputing 2565 missing values with median (21.1) in column Temp3pm.\n", + " --> Imputing 1354 missing values with most_frequent (No) in column RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 49 classes.\n", diff --git a/tests/test_basetransformer.py b/tests/test_basetransformer.py index 7bc3a179c..9ddea11af 100644 --- a/tests/test_basetransformer.py +++ b/tests/test_basetransformer.py @@ -21,6 +21,7 @@ from sklearn.naive_bayes import GaussianNB from sklearnex import get_config from sklearnex.svm import SVC +from sktime.transformations.series.impute import Imputer from atom import ATOMClassifier, ATOMForecaster from atom.basetransformer import BaseTransformer @@ -214,6 +215,20 @@ def test_inherit_sp(): assert atom.tbats.estimator.get_params()["sp"] == [12, 24] # Multiple seasonality +def test_inherit_attributes_and_methods(): + """Assert that sklearn attributes and methods are added to the estimator.""" + imputer = Imputer().fit(y_fc) + assert not hasattr(imputer, "feature_names_in_") + assert not hasattr(imputer, "n_features_in_") + assert not hasattr(imputer, "get_feature_names_out") + + imputer = BaseTransformer(random_state=1)._inherit(imputer) + imputer.fit(pd.DataFrame(y_fc)) + assert hasattr(imputer, "feature_names_in_") + assert hasattr(imputer, "n_features_in_") + assert hasattr(imputer, "get_feature_names_out") + + # Test _get_est_class ============================================== >> @pytest.mark.skipif(machine() not in ("x86_64", "AMD64"), reason="Only x86 support") diff --git a/tests/test_data_cleaning.py b/tests/test_data_cleaning.py index c0252b45d..61e4240dd 100644 --- a/tests/test_data_cleaning.py +++ b/tests/test_data_cleaning.py @@ -19,6 +19,7 @@ Balancer, Cleaner, Decomposer, Discretizer, Encoder, Imputer, Normalizer, Pruner, Scaler, ) +from atom.utils.types import NumericalStrats from atom.utils.utils import NotFittedError, check_scaling, to_df from .conftest import ( @@ -29,6 +30,15 @@ # Test TransformerMixin ============================================ >> +def test_repr(): + """Assert that __repr__ hides the default engine.""" + assert str(Cleaner(engine="pyarrow")).startswith("Cleaner(engine=EngineTuple") + assert str(Cleaner()) == "Cleaner()" + assert str(Cleaner(device="gpu")) == "Cleaner(device='gpu')" + assert str(Cleaner(verbose=2)) == "Cleaner(verbose=2)" + assert str(Cleaner(device="gpu", verbose=2)) == "Cleaner(device='gpu', verbose=2)" + + def test_clone(): """Assert that cloning the transformer keeps internal attributes.""" pruner = Pruner().fit(X_bin) @@ -474,7 +484,7 @@ def test_imputing_all_missing_values_categorical(missing): X = [[missing, "a", "a"], ["b", "c", missing], ["b", "a", "c"], ["c", "a", "a"]] y = [1, 1, 0, 0] imputer = Imputer(strat_cat="most_frequent") - X, y = imputer.fit_transform(X, y) + X, _ = imputer.fit_transform(X, y) assert X.isna().sum().sum() == 0 @@ -491,7 +501,7 @@ def test_rows_too_many_nans(max_nan_rows, random): max_nan_rows=max_nan_rows, ) X, y = imputer.fit_transform(X, y) - assert len(X) == 569 # Original size + assert len(X) == len(y) == 569 # Original size assert X.isna().sum().sum() == 0 @@ -515,7 +525,7 @@ def test_cols_too_many_nans(max_nan_cols): def test_imputing_numeric_drop(): """Assert that imputing drop for numerical values works.""" imputer = Imputer(strat_num="drop") - X, y = imputer.fit_transform(X10_nan, y10) + X, _ = imputer.fit_transform(X10_nan, y10) assert len(X) == 8 assert X.isna().sum().sum() == 0 @@ -523,55 +533,23 @@ def test_imputing_numeric_drop(): def test_imputing_numeric_number(): """Assert that imputing a number for numerical values works.""" imputer = Imputer(strat_num=3.2) - X, y = imputer.fit_transform(X10_nan, y10) + X, _ = imputer.fit_transform(X10_nan, y10) assert X.iloc[0, 0] == 3.2 assert X.isna().sum().sum() == 0 -def test_imputing_numeric_mean(): - """Assert that imputing the mean for numerical values works.""" - imputer = Imputer(strat_num="mean") - X, y = imputer.fit_transform(X10_nan, y10) - assert X.iloc[0, 0] == pytest.approx(2.577778, rel=1e-6, abs=1e-12) - assert X.isna().sum().sum() == 0 - - -def test_imputing_numeric_median(): - """Assert that imputing the median for numerical values works.""" - imputer = Imputer(strat_num="median") - X, y = imputer.fit_transform(X10_nan, y10) - assert X.iloc[0, 0] == 3 - assert X.isna().sum().sum() == 0 - - -def test_imputing_numeric_knn(): - """Assert that imputing numerical values with KNNImputer works.""" - imputer = Imputer(strat_num="knn", random_state=1) - X, y = imputer.fit_transform(X10_nan, y10) - assert X.iloc[0, 0] == 3.04 - assert X.isna().sum().sum() == 0 - - -def test_imputing_numeric_iterative(): - """Assert that imputing numerical values with IterativeImputer works.""" - imputer = Imputer(strat_num="iterative") - X, y = imputer.fit_transform(X10_nan, y10) - assert X.iloc[0, 0] == pytest.approx(2.577836, rel=1e-6, abs=1e-12) - assert X.isna().sum().sum() == 0 - - -def test_imputing_numeric_most_frequent(): - """Assert that imputing the most_frequent for numerical values works.""" - imputer = Imputer(strat_num="most_frequent") - X, y = imputer.fit_transform(X10_nan, y10) - assert X.iloc[0, 0] == 3 +@pytest.mark.parametrize("strat_num", NumericalStrats.__args__) +def test_imputing_numeric(strat_num): + """Assert that imputing numerical columns works.""" + imputer = Imputer(strat_num=strat_num) + X, _ = imputer.fit_transform(X10_nan, y10) assert X.isna().sum().sum() == 0 def test_imputing_non_numeric_string(): """Assert that imputing a string for non-numerical values works.""" imputer = Imputer(strat_cat="missing") - X, y = imputer.fit_transform(X10_sn, y10) + X, _ = imputer.fit_transform(X10_sn, y10) assert X.iloc[0, 2] == "missing" assert X.isna().sum().sum() == 0 @@ -579,7 +557,7 @@ def test_imputing_non_numeric_string(): def test_imputing_non_numeric_drop(): """Assert that the drop strategy for non-numerical works.""" imputer = Imputer(strat_cat="drop") - X, y = imputer.fit_transform(X10_sn, y10) + X, _ = imputer.fit_transform(X10_sn, y10) assert len(X) == 9 assert X.isna().sum().sum() == 0 @@ -587,7 +565,7 @@ def test_imputing_non_numeric_drop(): def test_imputing_non_numeric_most_frequent(): """Assert that the most_frequent strategy for non-numerical works.""" imputer = Imputer(strat_cat="most_frequent") - X, y = imputer.fit_transform(X10_sn, y10) + X, _ = imputer.fit_transform(X10_sn, y10) assert X.iloc[0, 2] == "d" assert X.isna().sum().sum() == 0 diff --git a/tests/test_plots.py b/tests/test_plots.py index c4191b647..57e3cc066 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -23,7 +23,7 @@ from .conftest import ( X10, X10_str, X_bin, X_class, X_ex, X_label, X_reg, X_sparse, X_text, y10, - y_bin, y_class, y_fc, y_label, y_multiclass, y_reg, + y_bin, y_class, y_ex, y_fc, y_label, y_multiclass, y_reg, ) @@ -282,6 +282,16 @@ def test_plot_acf(columns): atom.plot_acf(columns=columns, display=False) +def test_plot_ccf(): + """Assert that the plot_ccf method works.""" + atom = ATOMForecaster(y_fc, random_state=1) + with pytest.raises(ValueError, match=".*requires at least two columns.*"): + atom.plot_ccf(display=False) + + atom = ATOMForecaster(X_ex, y=y_ex, random_state=1) + atom.plot_ccf(plot_interval=True, display=False) + + @pytest.mark.parametrize("show", [10, None]) def test_plot_components(show): """Assert that the plot_components method works.""" @@ -316,6 +326,13 @@ def test_plot_distribution(): atom.plot_distribution(columns=[0, 1], distributions="pearson3", display=False) +@pytest.mark.parametrize("columns", [None, -1]) +def test_plot_fft(columns): + """Assert that the plot_fft method works.""" + atom = ATOMForecaster(y_fc, random_state=1) + atom.plot_fft(columns=columns, display=False) + + @pytest.mark.parametrize("ngram", [1, 2, 3, 4]) def test_plot_ngrams(ngram): """Assert that the plot_ngrams method works.""" @@ -345,6 +362,13 @@ def test_plot_pca(X): atom.plot_pca(display=False) +@pytest.mark.parametrize("columns", [None, -1]) +def test_plot_periodogram(columns): + """Assert that the plot_periodogram method works.""" + atom = ATOMForecaster(y_fc, random_state=1) + atom.plot_periodogram(columns=columns, display=False) + + def test_plot_qq(): """Assert that the plot_qq method works.""" atom = ATOMClassifier(X_bin, y_bin, random_state=1)