diff --git a/atom/api.py b/atom/api.py
index ec9bf477f..1aba8af50 100644
--- a/atom/api.py
+++ b/atom/api.py
@@ -119,8 +119,8 @@ def ATOMModel(
if acronym:
estimator_c.acronym = acronym
estimator_c.needs_scaling = needs_scaling
- estimator_c.native_multioutput = native_multioutput
estimator_c.native_multilabel = native_multilabel
+ estimator_c.native_multioutput = native_multioutput
estimator_c.validation = validation
return estimator_c
diff --git a/atom/basemodel.py b/atom/basemodel.py
index eca6ecb78..0194de833 100644
--- a/atom/basemodel.py
+++ b/atom/basemodel.py
@@ -459,10 +459,12 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
estimator = MultiOutputClassifier(estimator)
elif self.task.is_regression:
estimator = MultiOutputRegressor(estimator)
- elif self.task.is_forecast:
- if hasattr(self, "_estimators") and self._goal.name not in self._estimators:
- # Forecasting task with a regressor
- estimator = make_reduction(estimator)
+ elif self.task.is_forecast and self._goal.name not in self._estimators:
+ # Forecasting task with a regressor
+ if self.native_multioutput:
+ estimator = make_reduction(estimator, strategy="multioutput")
+ else:
+ estimator = make_reduction(estimator, strategy="recursive")
return self._inherit(estimator)
@@ -558,7 +560,6 @@ def _fit_estimator(
est_params_fit["fh"] = est_params_fit.get("fh", self.test.index)
estimator.fit(data[1], X=check_empty(data[0]), **est_params_fit)
-
else:
estimator.fit(*data, **est_params_fit)
@@ -674,7 +675,11 @@ def _get_pred(
y_true = y.loc[y.index.isin(self._all.index)]
if self.task.is_forecast:
- y_pred = self._prediction(fh=X.index, X=check_empty(X), verbose=0, method=attr)
+ try:
+ y_pred = self._prediction(fh=X.index, X=check_empty(X), verbose=0, method=attr)
+ except (ValueError, NotImplementedError):
+ # In-sample predictions aren't implemented for some models
+ y_pred = bk.Series([np.NaN] * len(X), index=X.index)
else:
y_pred = self._prediction(X.index, verbose=0, method=attr)
@@ -2451,11 +2456,32 @@ def transform(
class ClassRegModel(BaseModel):
"""Classification and regression models."""
+ _prediction_methods = (
+ "predict",
+ "predict_proba",
+ "predict_log_proba",
+ "decision_function",
+ )
+
+ def __init__(self, *args, **kwargs):
+ """Assign prediction methods depending on the task.
+
+ Regressors can be used for forecast tasks, hence we need to
+ overwrite the default prediction methods.
+
+ """
+ super().__init__(*args, **kwargs)
+
+ if self._goal is Goal.forecast:
+ for method in ("_prediction", *ForecastModel._prediction_methods):
+ setattr(self.__class__, method, getattr(ForecastModel, method))
+
+ @crash
def get_tags(self) -> dict[str, Any]:
"""Get the model's tags.
Return class parameters that provide general information about
- the estimator's characteristics.
+ the model's characteristics.
Returns
-------
@@ -2467,14 +2493,14 @@ def get_tags(self) -> dict[str, Any]:
"acronym": self.acronym,
"fullname": self.fullname,
"estimator": self._est_class,
- "module": self._est_class.__module__.split(".")[0] + self._module,
- "handles_missing": self.handles_missing,
+ "module": self._est_class.__module__,
+ "handles_missing": getattr(self, "handles_missing", None),
"needs_scaling": self.needs_scaling,
- "accepts_sparse": self.accepts_sparse,
+ "accepts_sparse": getattr(self, "accepts_sparse", None),
"native_multilabel": self.native_multilabel,
"native_multioutput": self.native_multioutput,
"validation": self.validation,
- "supports_engines": ", ".join(self.supports_engines),
+ "supports_engines": ", ".join(getattr(self, "supports_engines", [])),
}
@overload
@@ -2885,11 +2911,21 @@ def score(
class ForecastModel(BaseModel):
"""Forecasting models."""
+ _prediction_methods = (
+ "predict",
+ "predict_interval",
+ "predict_proba",
+ "predict_quantiles",
+ "predict_residuals",
+ "predict_var",
+ )
+
+ @crash
def get_tags(self) -> dict[str, Any]:
"""Get the model's tags.
Return class parameters that provide general information about
- the estimator's characteristics.
+ the model's characteristics.
Returns
-------
@@ -2900,13 +2936,12 @@ def get_tags(self) -> dict[str, Any]:
return {
"acronym": self.acronym,
"fullname": self.fullname,
- "estimator": self._est_class.__name__,
- "module": self._est_class.__module__.split(".")[0] + self._module,
- "handles_missing": self.handles_missing,
- "in_sample_prediction": self.in_sample_prediction,
- "multiple_seasonality": self.multiple_seasonality,
- "native_multivariate": self.native_multivariate,
- "supports_engines": ", ".join(self.supports_engines),
+ "estimator": self._est_class,
+ "module": self._est_class.__module__,
+ "handles_missing": getattr(self, "handles_missing", None),
+ "multiple_seasonality": getattr(self, "multiple_seasonality", None),
+ "native_multioutput": self.native_multioutput,
+ "supports_engines": ", ".join(getattr(self, "supports_engines", [])),
}
@overload
diff --git a/atom/baserunner.py b/atom/baserunner.py
index 53aee3a83..4542f9121 100644
--- a/atom/baserunner.py
+++ b/atom/baserunner.py
@@ -43,7 +43,7 @@
TargetSelector, YSelector, dataframe_t, int_t, segment_t, sequence_t,
)
from atom.utils.utils import (
- ClassMap, DataContainer, SeasonalPeriod, Task, bk, check_is_fitted,
+ ClassMap, DataContainer, Goal, SeasonalPeriod, Task, bk, check_is_fitted,
composed, crash, divide, flt, get_cols, get_segment, get_versions,
has_task, lst, merge, method_to_log, n_cols,
)
@@ -705,7 +705,7 @@ def _has_data_sets(
X_test, y_test = self._check_input(arrays[1][0], arrays[1][1])
sets = _has_data_sets(X_train, y_train, X_test, y_test)
elif isinstance(arrays[1], (*int_t, str)) or n_cols(arrays[1]) == 1:
- if not self._goal.name == "forecast":
+ if self._goal is not Goal.forecast:
# X, y
sets = _no_data_sets(*self._check_input(arrays[0], arrays[1]))
else:
@@ -899,7 +899,7 @@ def available_models(self) -> pd.DataFrame:
- **acronym:** Model's acronym (used to call the model).
- **fullname:** Name of the model's class.
- - **estimator:** Class of the model's underlying estimator.
+ - **estimator:** Name of the model's underlying estimator.
- **module:** The estimator's module.
- **handles_missing:** Whether the model can handle missing
(`NaN`) values without preprocessing. If False, consider using
@@ -914,15 +914,13 @@ def available_models(self) -> pd.DataFrame:
for [multilabel][] tasks.
- **native_multioutput:** Whether the model has native support
for [multioutput tasks][].
- - **native_multivariate:** Whether the model has native support
- for [multivariate][] tasks.
- **validation:** Whether the model has [in-training validation][].
- **supports_engines:** Engines supported by the model.
"""
rows = []
for model in MODELS:
- m = model(goal=self._goal)
+ m = model(goal=self._goal, branches=self._branches)
if self._goal.name in m._estimators:
rows.append(m.get_tags())
diff --git a/atom/basetrainer.py b/atom/basetrainer.py
index affa94fea..d41c41783 100644
--- a/atom/basetrainer.py
+++ b/atom/basetrainer.py
@@ -22,7 +22,7 @@
from atom.baserunner import BaseRunner
from atom.branch import BranchManager
from atom.data_cleaning import BaseTransformer
-from atom.models import MODELS, CustomModel
+from atom.models import MODELS, create_custom_model
from atom.plots import RunnerPlot
from atom.utils.types import Model, sequence_t
from atom.utils.utils import (
@@ -168,60 +168,59 @@ def _prepare_parameters(self):
if isinstance(model, str):
for m in model.split("+"):
if m.startswith("!"):
- exc.append(m[1:])
- else:
- try:
- if len(name := m.split("_", 1)) > 1:
- name, tag = name[0].lower(), f"_{name[1]}"
- else:
- name, tag = name[0].lower(), ""
-
- cls = next(n for n in MODELS if n.acronym.lower() == name)
-
- except StopIteration:
+ exc.append(m[1:].lower())
+ continue
+
+ try:
+ if len(name := m.split("_", 1)) > 1:
+ name, tag = name[0].lower(), f"_{name[1]}"
+ else:
+ name, tag = name[0].lower(), ""
+
+ cls = next(n for n in MODELS if n.acronym.lower() == name)
+
+ except StopIteration:
+ raise ValueError(
+ f"Invalid value for the models parameter, got {m}. "
+ "Note that tags must be separated by an underscore. "
+ "Available model are:\n" +
+ "\n".join(
+ [
+ f" --> {m.__name__} ({m.acronym})"
+ for m in MODELS
+ if self._goal.name in m._estimators
+ ]
+ )
+ ) from None
+
+ # Check if libraries for non-sklearn models are available
+ dependencies = {
+ "BATS": "tbats",
+ "CatB": "catboost",
+ "LGB": "lightgbm",
+ "MSTL": "statsforecast",
+ "TBATS": "tbats",
+ "XGB": "xgboost",
+ }
+ if cls.acronym in dependencies:
+ check_dependency(dependencies[cls.acronym])
+
+ # Check if the model supports the task
+ if self._goal.name not in cls._estimators:
+ # Forecast task can use regression models
+ if self._goal is not Goal.forecast or "regression" not in cls._estimators:
raise ValueError(
- f"Invalid value for the models parameter, got {m}. "
- "Note that tags must be separated by an underscore. "
- "Available model are:\n" +
- "\n".join(
- [
- f" --> {m.__name__} ({m.acronym})"
- for m in MODELS
- if self._goal.name in m._estimators
- ]
- )
- ) from None
-
- # Check if libraries for non-sklearn models are available
- dependencies = {
- "BATS": "tbats",
- "CatB": "catboost",
- "LGB": "lightgbm",
- "MSTL": "statsforecast",
- "TBATS": "tbats",
- "XGB": "xgboost",
- }
- if cls.acronym in dependencies:
- check_dependency(dependencies[cls.acronym])
-
- # Check if the model supports the task
- if self._goal.name not in cls._estimators:
- # Forecast task can use regression models
- if self._goal.name == "forecast" and "regression" in cls._estimators:
- kwargs["goal"] = Goal.Regression
- else:
- raise ValueError(
- f"The {cls.__name__} model is not "
- f"available for {self.task.name} tasks!"
- )
-
- inc.append(cls(name=f"{cls.acronym}{tag}", **kwargs))
+ f"The {cls.__name__} model is not "
+ f"available for {self.task.name} tasks!"
+ )
+
+ inc.append(cls(name=f"{cls.acronym}{tag}", **kwargs))
elif isinstance(model, Model): # For new instances or reruns
inc.append(model)
else: # Model is a custom estimator
- inc.append(CustomModel(estimator=model, **kwargs))
+ inc.append(create_custom_model(estimator=model, **kwargs))
if inc and exc:
raise ValueError(
@@ -239,9 +238,8 @@ def _prepare_parameters(self):
self._models = ClassMap(*inc)
else:
self._models = ClassMap(
- model(**kwargs)
- for model in MODELS
- if self._goal.name in model._estimators and model.acronym not in exc
+ model(**kwargs) for model in MODELS
+ if self._goal.name in model._estimators and model.acronym.lower() not in exc
)
# Prepare est_params ======================================= >>
diff --git a/atom/basetransformer.py b/atom/basetransformer.py
index 95fbdf759..ae33a830c 100644
--- a/atom/basetransformer.py
+++ b/atom/basetransformer.py
@@ -372,13 +372,13 @@ def _inherit(self, obj: T_Estimator) -> T_Estimator:
"""
signature = sign(obj.__init__) # type: ignore[misc]
for p in ("n_jobs", "random_state"):
- if p in signature and obj.get_params()[p] == signature[p]._default:
- obj.set_params(**{p: getattr(self, p)})
+ if p in signature and getattr(obj, p, "") == signature[p]._default:
+ setattr(obj, p, getattr(self, p))
# Add seasonal period to the estimator
- if self._config.sp:
- if "sp" in signature and obj.get_params()["sp"] == signature["sp"]._default:
- obj.set_params(sp=self._config.sp)
+ if hasattr(self, "_config") and self._config.sp:
+ if "sp" in signature and getattr(obj, "sp", "") == signature["sp"]._default:
+ obj.sp = self._config.sp
return obj
diff --git a/atom/models/__init__.py b/atom/models/__init__.py
index 80ab0436f..186854757 100644
--- a/atom/models/__init__.py
+++ b/atom/models/__init__.py
@@ -17,7 +17,7 @@
QuadraticDiscriminantAnalysis, RadiusNearestNeighbors, RandomForest, Ridge,
StochasticGradientDescent, SupportVectorMachine, XGBoost,
)
-from atom.models.custom import CustomModel
+from atom.models.custom import create_custom_model
from atom.models.ensembles import Stacking, Voting
from atom.models.ts import (
ARIMA, BATS, ETS, MSTL, SARIMAX, STL, TBATS, VAR, VARMAX, AutoARIMA,
diff --git a/atom/models/custom.py b/atom/models/custom.py
index 4697f1d7f..c1ac08b7f 100644
--- a/atom/models/custom.py
+++ b/atom/models/custom.py
@@ -1,76 +1,105 @@
"""Automated Tool for Optimized Modeling (ATOM).
Author: Mavs
-Description: Module containing the CustomModel class.
+Description: Module containing the create_custom_model function.
"""
from typing import Any
-from atom.basemodel import ClassRegModel
+from atom.basemodel import BaseModel, ClassRegModel, ForecastModel
from atom.utils.types import Predictor
+from atom.utils.utils import Goal
-class CustomModel(ClassRegModel):
- """Model with estimator provided by user."""
-
- def __init__(self, **kwargs):
- # Assign the estimator and store the provided parameters
- if callable(est := kwargs.pop("estimator")):
- self._est = est
- self._params = {}
- else:
- self._est = est.__class__
- self._params = est.get_params()
-
- if hasattr(est, "name"):
- name = est.name
- else:
- from atom.models import MODELS
-
- # If no name is provided, use the name of the class
- name = self.fullname
- if len(n := list(filter(str.isupper, name))) >= 2 and n not in MODELS:
- name = "".join(n)
-
- self.acronym = getattr(est, "acronym", name)
- if not name.startswith(self.acronym):
- raise ValueError(
- f"The name ({name}) and acronym ({self.acronym}) of model "
- f"{self.fullname} do not match. The name should start with "
- f"the model's acronym."
- )
-
- self.handles_missing = getattr(est, "handles_missing", False)
- self.needs_scaling = getattr(est, "needs_scaling", False)
- self.native_multilabel = getattr(est, "native_multilabel", False)
- self.native_multioutput = getattr(est, "native_multioutput", False)
- self.validation = getattr(est, "validation", None)
-
- super().__init__(name=name, **kwargs)
-
- @property
- def fullname(self) -> str:
- """Return the estimator's class name."""
- return self._est_class.__name__
-
- @property
- def _est_class(self) -> type[Predictor]:
- """Return the estimator's class."""
- return self._est
-
- def _get_est(self, params: dict[str, Any]) -> Predictor:
- """Get the model's estimator with unpacked parameters.
-
- Parameters
- ----------
- params: dict
- Hyperparameters for the estimator.
-
- Returns
- -------
- Predictor
- Estimator instance.
+def create_custom_model(estimator: Predictor, **kwargs) -> BaseModel:
+ """Create and return a custom model wrapper.
+
+ Parameters
+ ----------
+ estimator: Predictor
+ Estimator to be used.
+
+ **kwargs
+ Additional keyword arguments for the model.
+
+ Returns
+ -------
+ CustomModel
+ Custom model instance.
+
+ """
+ # Dynamically inherit from the appropriate base class
+ base_class = ForecastModel if kwargs["goal"] is Goal.forecast else ClassRegModel
+
+ class CustomModel(base_class): # type: ignore[misc, valid-type]
+ """Model with estimator provided by user.
+
+ This class inherits dynamically from either ClassRegModel or
+ ForecastModel, depending on the current task.
"""
- return super()._get_est(self._params | params)
+
+ def __init__(self, **kwargs):
+ # Assign the estimator and store the provided parameters
+ if callable(est := kwargs.pop("estimator")):
+ self._est = est
+ self._params = {}
+ else:
+ self._est = est.__class__
+ self._params = est.get_params()
+
+ if hasattr(est, "name"):
+ name = est.name
+ else:
+ from atom.models import MODELS
+
+ # If no name is provided, use the name of the class
+ name = self.fullname
+ if len(n := list(filter(str.isupper, name))) >= 2 and n not in MODELS:
+ name = "".join(n)
+
+ self.acronym = getattr(est, "acronym", name)
+ if not name.startswith(self.acronym):
+ raise ValueError(
+ f"The name ({name}) and acronym ({self.acronym}) of model "
+ f"{self.fullname} do not match. The name should start with "
+ f"the model's acronym."
+ )
+
+ self.needs_scaling = getattr(est, "needs_scaling", False)
+ self.native_multilabel = getattr(est, "native_multilabel", False)
+ self.native_multioutput = getattr(est, "native_multioutput", False)
+ self.validation = getattr(est, "validation", None)
+
+ super().__init__(name=name, **kwargs)
+
+ self._estimators = {self._goal.name: self._est_class.__name__}
+
+ @property
+ def fullname(self) -> str:
+ """Return the estimator's class name."""
+ return self._est_class.__name__
+
+ @property
+ def _est_class(self) -> type[Predictor]:
+ """Return the estimator's class."""
+ return self._est
+
+ def _get_est(self, params: dict[str, Any]) -> Predictor:
+ """Get the model's estimator with unpacked parameters.
+
+ Parameters
+ ----------
+ params: dict
+ Hyperparameters for the estimator.
+
+ Returns
+ -------
+ Predictor
+ Estimator instance.
+
+ """
+ return super()._get_est(self._params | params)
+
+ return CustomModel(estimator=estimator, **kwargs)
diff --git a/atom/models/ensembles.py b/atom/models/ensembles.py
index 184761956..f051dbb17 100644
--- a/atom/models/ensembles.py
+++ b/atom/models/ensembles.py
@@ -9,7 +9,7 @@
from typing import Any, ClassVar
-from atom.basemodel import ClassRegModel
+from atom.basemodel import BaseModel, ClassRegModel
from atom.utils.types import Model, Predictor
from atom.utils.utils import sign
@@ -43,7 +43,7 @@ class Stacking(ClassRegModel):
def __init__(self, models: list[Model], **kwargs):
self._models = models
- kw_model = {k: v for k, v in kwargs.items() if k in sign(ClassRegModel.__init__)}
+ kw_model = {k: v for k, v in kwargs.items() if k in sign(BaseModel.__init__)}
super().__init__(**kw_model)
self._est_params = {k: v for k, v in kwargs.items() if k not in kw_model}
@@ -100,7 +100,7 @@ class Voting(ClassRegModel):
def __init__(self, models: list[Model], **kwargs):
self._models = models
- kw_model = {k: v for k, v in kwargs.items() if k in sign(ClassRegModel.__init__)}
+ kw_model = {k: v for k, v in kwargs.items() if k in sign(BaseModel.__init__)}
super().__init__(**kw_model)
self._est_params = {k: v for k, v in kwargs.items() if k not in kw_model}
diff --git a/atom/models/ts.py b/atom/models/ts.py
index 569141899..a8b3382be 100644
--- a/atom/models/ts.py
+++ b/atom/models/ts.py
@@ -17,6 +17,7 @@
from atom.basemodel import ForecastModel
from atom.utils.types import Predictor
+from atom.utils.utils import SeasonalPeriod
class ARIMA(ForecastModel):
@@ -79,7 +80,7 @@ class ARIMA(ForecastModel):
handles_missing = True
uses_exogenous = True
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.arima"
@@ -226,7 +227,7 @@ class AutoARIMA(ForecastModel):
handles_missing = True
uses_exogenous = True
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.arima"
@@ -301,7 +302,7 @@ class BATS(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.bats"
@@ -381,7 +382,7 @@ class Croston(ForecastModel):
handles_missing = False
uses_exogenous = True
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.croston"
@@ -418,7 +419,7 @@ class DynamicFactor(ForecastModel):
See Also
--------
atom.models:ExponentialSmoothing
- atom.models:LTS
+ atom.models:STL
atom.models:PolynomialTrend
Examples
@@ -440,7 +441,7 @@ class DynamicFactor(ForecastModel):
handles_missing = True
uses_exogenous = True
multiple_seasonality = False
- native_multivariate = True
+ native_multioutput = True
supports_engines = ("sktime",)
_module = "sktime.forecasting.dynamic_factor"
@@ -505,7 +506,7 @@ class ExponentialSmoothing(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.exp_smoothing"
@@ -590,7 +591,7 @@ class ETS(ForecastModel):
handles_missing = True
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.ets"
@@ -683,7 +684,7 @@ class MSTL(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = True
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.statsforecast"
@@ -760,7 +761,7 @@ class NaiveForecaster(ForecastModel):
handles_missing = True
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.naive"
@@ -814,7 +815,7 @@ class PolynomialTrend(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.trend"
@@ -873,11 +874,44 @@ class Prophet(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = True
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.fbprophet"
- _estimators: ClassVar[dict[str, str]] = {"forecast": "StatsForecastMSTL"}
+ _estimators: ClassVar[dict[str, str]] = {"forecast": "Prophet"}
+
+ def _get_est(self, params: dict[str, Any]) -> Predictor:
+ """Get the model's estimator with unpacked parameters.
+
+ Parameters
+ ----------
+ params: dict
+ Hyperparameters for the estimator.
+
+ Returns
+ -------
+ Predictor
+ Estimator instance.
+
+
+ """
+ # Prophet expects a DateTime index frequency
+ if self._config.sp:
+ try:
+ freq = next(
+ n for n, m in SeasonalPeriod.__members__.items()
+ if m.value == self._config.sp
+ )
+ except StopIteration:
+ # If not in mapping table, get from index
+ if hasattr(self.X_train.index, "freq"):
+ freq = self.X_train.index.freq.name
+ else:
+ freq = None
+ else:
+ freq = None
+
+ return super()._get_est({"freq": freq} | params)
@staticmethod
def _get_distributions() -> dict[str, BaseDistribution]:
@@ -945,7 +979,7 @@ class SARIMAX(ForecastModel):
handles_missing = False
uses_exogenous = True
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.sarimax"
@@ -1081,7 +1115,7 @@ class STL(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.trend"
@@ -1158,7 +1192,7 @@ class TBATS(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = True
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.tbats"
@@ -1242,7 +1276,7 @@ class Theta(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = False
+ native_multioutput = False
supports_engines = ("sktime",)
_module = "sktime.forecasting.theta"
@@ -1300,7 +1334,7 @@ class VAR(ForecastModel):
handles_missing = False
uses_exogenous = False
multiple_seasonality = False
- native_multivariate = True
+ native_multioutput = True
supports_engines = ("sktime",)
_module = "sktime.forecasting.var"
@@ -1359,7 +1393,7 @@ class VARMAX(ForecastModel):
handles_missing = False
uses_exogenous = True
multiple_seasonality = False
- native_multivariate = True
+ native_multioutput = True
supports_engines = ("sktime",)
_module = "sktime.forecasting.var"
diff --git a/atom/plots/predictionplot.py b/atom/plots/predictionplot.py
index fe4f7d038..628155365 100644
--- a/atom/plots/predictionplot.py
+++ b/atom/plots/predictionplot.py
@@ -565,7 +565,7 @@ def plot_det(
display=display,
)
- @available_if(has_task("regression"))
+ @available_if(has_task("!classification"))
@crash
def plot_errors(
self,
@@ -1071,32 +1071,36 @@ def plot_forecast(
xaxis, yaxis = BasePlot._fig.get_axes()
# Draw original time series
- for ds in ("train", "test"):
- fig.add_trace(
- go.Scatter(
- x=self._get_plot_index(getattr(self, ds)),
- y=getattr(self, ds)[target_c],
- mode="lines+markers",
- line={
- "width": 2,
- "color": "black",
- "dash": BasePlot._fig.get_elem(ds, "dash"),
- },
- opacity=0.6,
- name=ds,
- showlegend=False if models else BasePlot._fig.showlegend(ds, legend),
- xaxis=xaxis,
- yaxis=yaxis,
+ for ds in ("train", "test", "holdout"):
+ if getattr(self, ds) is not None:
+ fig.add_trace(
+ go.Scatter(
+ x=self._get_plot_index(getattr(self, ds)),
+ y=getattr(self, ds)[target_c],
+ mode="lines+markers",
+ line={
+ "width": 2,
+ "color": "black",
+ "dash": BasePlot._fig.get_elem(ds, "dash"),
+ },
+ opacity=0.6,
+ name=ds,
+ showlegend=False if models else BasePlot._fig.showlegend(ds, legend),
+ xaxis=xaxis,
+ yaxis=yaxis,
+ )
)
- )
# Draw predictions
for m in models_c:
- # TODO: Fix the way we get fh
if isinstance(fh, str):
- pass
+ # Get fh and corresponding X from data set
+ fh = self.branch._get_rows(fh).index
+ X = m.X.loc[fh]
+ elif X is not None:
+ X = m.transform(X)
- y_pred = m.predict(fh, X)
+ y_pred = m.predict(fh=fh, X=X)
if self.task.is_multioutput:
y_pred = y_pred[target_c]
@@ -1114,7 +1118,7 @@ def plot_forecast(
if plot_interval:
try:
- y_pred = m.predict_interval(fh, X)
+ y_pred = m.predict_interval(fh=fh, X=X)
except NotImplementedError:
continue # Fails for some models like ES
@@ -1570,6 +1574,7 @@ def plot_lift(
display=display,
)
+ @available_if(has_task("!forecast"))
@crash
def plot_parshap(
self,
@@ -1593,7 +1598,7 @@ def plot_parshap(
line) performed worse on the test set than on the training set.
If the estimator has a `scores_`, `feature_importances_` or
`coef_` attribute, its normalized values are shown in a color
- map.
+ map. This plot is not available for [forecast][time-series] tasks.
Parameters
----------
@@ -2785,7 +2790,7 @@ def plot_probabilities(
display=display,
)
- @available_if(has_task("regression"))
+ @available_if(has_task("!classification"))
@crash
def plot_residuals(
self,
diff --git a/atom/plots/shapplot.py b/atom/plots/shapplot.py
index 81ab3552a..0242f8788 100644
--- a/atom/plots/shapplot.py
+++ b/atom/plots/shapplot.py
@@ -16,13 +16,14 @@
import matplotlib.pyplot as plt
import shap
from beartype import beartype
+from sklearn.utils.metaestimators import available_if
from atom.plots.baseplot import BasePlot
from atom.utils.types import (
Bool, Int, IntLargerZero, Legend, ModelSelector, RowSelector,
TargetsSelector,
)
-from atom.utils.utils import check_canvas, crash
+from atom.utils.utils import check_canvas, crash, has_task
@beartype
@@ -36,6 +37,7 @@ class ShapPlot(BasePlot, metaclass=ABCMeta):
"""
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_bar(
self,
@@ -150,6 +152,7 @@ class is always the positive one.
display=display,
)
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_beeswarm(
self,
@@ -262,6 +265,7 @@ class is always the positive one.
display=display,
)
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_decision(
self,
@@ -386,6 +390,7 @@ class is always the positive one.
display=display,
)
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_force(
self,
@@ -517,6 +522,7 @@ class is always the positive one.
return None
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_heatmap(
self,
@@ -633,6 +639,7 @@ class is always the positive one.
display=display,
)
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_scatter(
self,
@@ -752,6 +759,7 @@ class is always the positive one.
display=display,
)
+ @available_if(has_task("!forecast"))
@crash
def plot_shap_waterfall(
self,
diff --git a/atom/utils/types.py b/atom/utils/types.py
index 5400f7cdc..a415473ce 100644
--- a/atom/utils/types.py
+++ b/atom/utils/types.py
@@ -186,7 +186,7 @@ def predict(self, *args, **kwargs) -> Pandas: ...
XSelector: TypeAlias = XTypes | Callable[..., XTypes]
YTypes: TypeAlias = dict[str, Any] | Sequence[Any] | XSelector
YSelector: TypeAlias = Int | str | YTypes
-FHSelector: TypeAlias = int | Sequence[Any] | ForecastingHorizon
+FHSelector: TypeAlias = Int | Sequence[Any] | ForecastingHorizon
# Return types for transform methods
TReturn: TypeAlias = np.ndarray | sps.spmatrix | Series | DataFrame
diff --git a/atom/utils/utils.py b/atom/utils/utils.py
index c2713de39..801c5c4b4 100644
--- a/atom/utils/utils.py
+++ b/atom/utils/utils.py
@@ -2715,24 +2715,6 @@ def check(runner: BaseRunner) -> bool:
return check
-def has_attr(attr: str) -> Callable:
- """Check that the instance has attribute `attr`.
-
- Parameters
- ----------
- attr: str
- Name of the attribute to check.
-
- """
-
- def check(runner: BaseRunner) -> bool:
- # Raise original `AttributeError` if `attr` does not exist
- getattr(runner, attr)
- return True
-
- return check
-
-
def estimator_has_attr(attr: str) -> Callable:
"""Check that the estimator has attribute `attr`.
diff --git a/docs_sources/scripts/autodocs.py b/docs_sources/scripts/autodocs.py
index 49e61c6a5..e66abb179 100644
--- a/docs_sources/scripts/autodocs.py
+++ b/docs_sources/scripts/autodocs.py
@@ -429,8 +429,6 @@ def get_tags(self) -> str:
text += " [native multilabel][multilabel]{ .md-tag }"
if getattr(self.obj, "native_multioutput", False):
text += " [native multioutput][multioutput-tasks]{ .md-tag }"
- if getattr(self.obj, "native_multivariate", False):
- text += " [native multivariate][multivariate]{ .md-tag }"
if getattr(self.obj, "validation", None):
text += " [in-training validation][]{ .md-tag }"
if any(engine not in ("sklearn", "sktime") for engine in self.obj.supports_engines):
diff --git a/docs_sources/user_guide/data_management.md b/docs_sources/user_guide/data_management.md
index e1715c78e..3ae14557c 100644
--- a/docs_sources/user_guide/data_management.md
+++ b/docs_sources/user_guide/data_management.md
@@ -206,14 +206,14 @@ for each sample.
Multivariate is the multioutput task for forecasting. In this case, we
try to forecast more than one time series at the same time.
-Although all forecasting models in ATOM support multivariate tasks, we
+Although all forecasting models in ATOM support multioutput tasks, we
differentiate two types of models:
-* The "native multivariate" models apply forecasts where every prediction
- of endogeneous (`y`) variables will depend on values of the other target
+* The "native multioutput" models apply forecasts where every prediction
+ of endogenous (`y`) variables will depend on values of the other target
columns.
* The rest of the models apply an estimator per column, meaning that forecasts
- will be made per endogeneous variable, and not be affected by other variables.
+ will be made per endogenous variable, and not be affected by other variables.
To access the column-wise estimators, use the estimator's `forecasters_`
parameter, which stores the fitted forecasters in a dataframe.
@@ -223,8 +223,8 @@ Read more about time series tasks [here][time-series].
Some models have native support for multioutput tasks. This means that
the original estimator is used to make predictions directly on all the
-target columns. Examples of such models are [KNearestNeighbors][],
-[RandomForest][] and [ExtraTrees][].
+target columns. Read in the [model selection][] section how to get an
+overview of all models and their tags, including the `native_multioutput`.
### Non-native multioutput models
@@ -246,7 +246,7 @@ meta-estimators are respectively:
!!! warning
Currently, scikit-learn metrics do not support multiclass-multioutput
classification tasks. In this case, ATOM calculates the mean of the
- selected metric over every individual target.
+ selected metric over every target.
!!! tip
* Set the `native_multilabel` or `native_multioutput` parameter in
diff --git a/docs_sources/user_guide/models.md b/docs_sources/user_guide/models.md
index 44709837e..75670bddf 100644
--- a/docs_sources/user_guide/models.md
+++ b/docs_sources/user_guide/models.md
@@ -38,7 +38,7 @@ per task, but can include:
- **acronym:** Model's acronym (used to call the model).
- **fullname:** Name of the model's class.
-- **estimator:** Class of the model's underlying estimator.
+- **estimator:** Name of the model's underlying estimator.
- **module:** The estimator's module.
- **handles_missing:** Whether the model can handle missing (`NaN`) values
without preprocessing. If False, consider using the [Imputer][] class
@@ -51,7 +51,6 @@ per task, but can include:
[seasonality period][seasonality].
- **native_multilabel:** Whether the model has native support for [multilabel][] tasks.
- **native_multioutput:** Whether the model has native support for [multioutput tasks][].
-- **native_multivariate:** Whether the model has native support for [multivariate][] tasks.
- **validation:** Whether the model has [in-training validation][].
- **supports_engines:** [Engines][estimator-acceleration] supported by the model.
diff --git a/docs_sources/user_guide/plots.md b/docs_sources/user_guide/plots.md
index 23490ebbe..b532abbd0 100644
--- a/docs_sources/user_guide/plots.md
+++ b/docs_sources/user_guide/plots.md
@@ -141,13 +141,12 @@ To avoid having to recalculate the values for every plot, ATOM stores
the shapley values internally after the first calculation, and access
them later when needed again.
-!!! note
- Since the plot figures are not made by ATOM, note the following:
-
+!!! warning
* It's not possible to draw multiple models in the same figure.
Selecting more than one model will raise an exception. To avoid
this, call the plot directly from a model, e.g., `#!python atom.lr.plot_shap_force()`.
* The returned plot is a matplotlib figure, not plotly's.
+ * SHAP plots aren't available for [forecast][time-series] tasks.
diff --git a/docs_sources/user_guide/time_series.md b/docs_sources/user_guide/time_series.md
index 577dd7c67..90b93d674 100644
--- a/docs_sources/user_guide/time_series.md
+++ b/docs_sources/user_guide/time_series.md
@@ -1,21 +1,7 @@
# Time series
-------------
-
-## Forecast
-
-
-
-
-## Time series classification
-
-
-
-
-## Time series regression
-
-
-
+Introduction
## Exogenous variables
@@ -29,3 +15,9 @@
effects, which are patterns that tend to recur at consistent intervals.
The same period is used for all columns in a [multivariate][] setting.
+
+
+
+## Forecasting with regressors
+
+No in-sample predictions, so scores on training set are `NaN`.
diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py
index ed82a8385..3499bc219 100644
--- a/tests/test_basemodel.py
+++ b/tests/test_basemodel.py
@@ -53,9 +53,9 @@ def test_repr():
def test_dir():
"""Assert that __dir__ contains all the extra attributes."""
- atom = ATOMClassifier(X_bin, y_bin, random_state=1)
+ atom = ATOMRegressor(X_reg, y_reg, random_state=1)
atom.run("dummy")
- assert all(attr in dir(atom.dummy) for attr in ("y", "mean radius", "head"))
+ assert all(attr in dir(atom.dummy) for attr in ("y", "age", "head"))
def test_getattr():
diff --git a/tests/test_baserunner.py b/tests/test_baserunner.py
index 02eb0a419..ff1e12bf8 100644
--- a/tests/test_baserunner.py
+++ b/tests/test_baserunner.py
@@ -43,9 +43,9 @@ def test_getstate_and_setstate():
def test_dir():
"""Assert that __dir__ contains all the extra attributes."""
- atom = ATOMClassifier(X_bin, y_bin, random_state=1)
+ atom = ATOMRegressor(X_reg, y_reg, random_state=1)
atom.run("dummy")
- assert all(attr in dir(atom) for attr in ("X", "main", "mean radius", "dummy"))
+ assert all(attr in dir(atom) for attr in ("X", "main", "age", "dummy"))
def test_getattr_branch():
diff --git a/tests/test_models.py b/tests/test_models.py
index 72044a119..b22557f7c 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -106,7 +106,12 @@ def test_all_models_regression():
def test_all_models_forecast():
"""Assert that all models work with forecast."""
atom = ATOMForecaster(y_fc, random_state=2)
- atom.run(models=None, n_trials=5, errors="raise")
+ atom.run(
+ models="!DF",
+ n_trials=5,
+ ht_params={"catch": (Exception,)},
+ errors="raise",
+ )
@pytest.mark.skipif(machine() not in ("x86_64", "AMD64"), reason="Only x86 support")
diff --git a/tests/test_plots.py b/tests/test_plots.py
index 54b0b6115..f26191c04 100644
--- a/tests/test_plots.py
+++ b/tests/test_plots.py
@@ -22,8 +22,8 @@
from atom.utils.utils import NotFittedError
from .conftest import (
- X10, X10_str, X_bin, X_class, X_label, X_reg, X_sparse, X_text, y10, y_bin,
- y_class, y_fc, y_label, y_multiclass, y_reg,
+ X10, X10_str, X_bin, X_class, X_ex, X_label, X_reg, X_sparse, X_text, y10,
+ y_bin, y_class, y_ex, y_fc, y_label, y_multiclass, y_reg,
)
@@ -550,6 +550,14 @@ def test_plot_feature_importance():
atom.tree.plot_feature_importance(show=5, display=False)
+def test_plot_forecast():
+ """Assert that the plot_forecast method works."""
+ atom = ATOMForecaster(X_ex, y=y_ex, holdout_size=0.1, random_state=1)
+ atom.run(models=["NF", "ES"])
+ atom.plot_forecast(display=False)
+ atom.plot_forecast(fh=atom.holdout.index, X=atom.holdout, display=False)
+
+
def test_plot_gains():
"""Assert that the plot_gains method works."""
atom = ATOMClassifier(X_bin, y_bin, random_state=1)