Skip to content

Commit

Permalink
fix plot_forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 21, 2023
1 parent 3efd838 commit 26c05d0
Show file tree
Hide file tree
Showing 22 changed files with 339 additions and 249 deletions.
2 changes: 1 addition & 1 deletion atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def ATOMModel(
if acronym:
estimator_c.acronym = acronym
estimator_c.needs_scaling = needs_scaling
estimator_c.native_multioutput = native_multioutput
estimator_c.native_multilabel = native_multilabel
estimator_c.native_multioutput = native_multioutput
estimator_c.validation = validation

return estimator_c
Expand Down
73 changes: 54 additions & 19 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,12 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
estimator = MultiOutputClassifier(estimator)
elif self.task.is_regression:
estimator = MultiOutputRegressor(estimator)
elif self.task.is_forecast:
if hasattr(self, "_estimators") and self._goal.name not in self._estimators:
# Forecasting task with a regressor
estimator = make_reduction(estimator)
elif self.task.is_forecast and self._goal.name not in self._estimators:
# Forecasting task with a regressor
if self.native_multioutput:
estimator = make_reduction(estimator, strategy="multioutput")
else:
estimator = make_reduction(estimator, strategy="recursive")

return self._inherit(estimator)

Expand Down Expand Up @@ -558,7 +560,6 @@ def _fit_estimator(
est_params_fit["fh"] = est_params_fit.get("fh", self.test.index)

estimator.fit(data[1], X=check_empty(data[0]), **est_params_fit)

else:
estimator.fit(*data, **est_params_fit)

Expand Down Expand Up @@ -674,7 +675,11 @@ def _get_pred(
y_true = y.loc[y.index.isin(self._all.index)]

if self.task.is_forecast:
y_pred = self._prediction(fh=X.index, X=check_empty(X), verbose=0, method=attr)
try:
y_pred = self._prediction(fh=X.index, X=check_empty(X), verbose=0, method=attr)
except (ValueError, NotImplementedError):
# In-sample predictions aren't implemented for some models
y_pred = bk.Series([np.NaN] * len(X), index=X.index)
else:
y_pred = self._prediction(X.index, verbose=0, method=attr)

Expand Down Expand Up @@ -2451,11 +2456,32 @@ def transform(
class ClassRegModel(BaseModel):
"""Classification and regression models."""

_prediction_methods = (
"predict",
"predict_proba",
"predict_log_proba",
"decision_function",
)

def __init__(self, *args, **kwargs):
"""Assign prediction methods depending on the task.
Regressors can be used for forecast tasks, hence we need to
overwrite the default prediction methods.
"""
super().__init__(*args, **kwargs)

if self._goal is Goal.forecast:
for method in ("_prediction", *ForecastModel._prediction_methods):

Check notice on line 2476 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _prediction_methods of a class
setattr(self.__class__, method, getattr(ForecastModel, method))

@crash
def get_tags(self) -> dict[str, Any]:
"""Get the model's tags.
Return class parameters that provide general information about
the estimator's characteristics.
the model's characteristics.
Returns
-------
Expand All @@ -2467,14 +2493,14 @@ def get_tags(self) -> dict[str, Any]:
"acronym": self.acronym,
"fullname": self.fullname,
"estimator": self._est_class,
"module": self._est_class.__module__.split(".")[0] + self._module,
"handles_missing": self.handles_missing,
"module": self._est_class.__module__,
"handles_missing": getattr(self, "handles_missing", None),
"needs_scaling": self.needs_scaling,
"accepts_sparse": self.accepts_sparse,
"accepts_sparse": getattr(self, "accepts_sparse", None),
"native_multilabel": self.native_multilabel,
"native_multioutput": self.native_multioutput,
"validation": self.validation,
"supports_engines": ", ".join(self.supports_engines),
"supports_engines": ", ".join(getattr(self, "supports_engines", [])),
}

@overload
Expand Down Expand Up @@ -2885,11 +2911,21 @@ def score(
class ForecastModel(BaseModel):
"""Forecasting models."""

_prediction_methods = (
"predict",
"predict_interval",
"predict_proba",
"predict_quantiles",
"predict_residuals",
"predict_var",
)

@crash
def get_tags(self) -> dict[str, Any]:
"""Get the model's tags.
Return class parameters that provide general information about
the estimator's characteristics.
the model's characteristics.
Returns
-------
Expand All @@ -2900,13 +2936,12 @@ def get_tags(self) -> dict[str, Any]:
return {
"acronym": self.acronym,
"fullname": self.fullname,
"estimator": self._est_class.__name__,
"module": self._est_class.__module__.split(".")[0] + self._module,
"handles_missing": self.handles_missing,
"in_sample_prediction": self.in_sample_prediction,
"multiple_seasonality": self.multiple_seasonality,
"native_multivariate": self.native_multivariate,
"supports_engines": ", ".join(self.supports_engines),
"estimator": self._est_class,
"module": self._est_class.__module__,
"handles_missing": getattr(self, "handles_missing", None),
"multiple_seasonality": getattr(self, "multiple_seasonality", None),
"native_multioutput": self.native_multioutput,
"supports_engines": ", ".join(getattr(self, "supports_engines", [])),
}

@overload
Expand Down
10 changes: 4 additions & 6 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
TargetSelector, YSelector, dataframe_t, int_t, segment_t, sequence_t,
)
from atom.utils.utils import (
ClassMap, DataContainer, SeasonalPeriod, Task, bk, check_is_fitted,
ClassMap, DataContainer, Goal, SeasonalPeriod, Task, bk, check_is_fitted,
composed, crash, divide, flt, get_cols, get_segment, get_versions,
has_task, lst, merge, method_to_log, n_cols,
)
Expand Down Expand Up @@ -705,7 +705,7 @@ def _has_data_sets(
X_test, y_test = self._check_input(arrays[1][0], arrays[1][1])

Check notice on line 705 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
sets = _has_data_sets(X_train, y_train, X_test, y_test)
elif isinstance(arrays[1], (*int_t, str)) or n_cols(arrays[1]) == 1:
if not self._goal.name == "forecast":
if self._goal is not Goal.forecast:
# X, y
sets = _no_data_sets(*self._check_input(arrays[0], arrays[1]))
else:
Expand Down Expand Up @@ -899,7 +899,7 @@ def available_models(self) -> pd.DataFrame:
- **acronym:** Model's acronym (used to call the model).
- **fullname:** Name of the model's class.
- **estimator:** Class of the model's underlying estimator.
- **estimator:** Name of the model's underlying estimator.
- **module:** The estimator's module.
- **handles_missing:** Whether the model can handle missing
(`NaN`) values without preprocessing. If False, consider using
Expand All @@ -914,15 +914,13 @@ def available_models(self) -> pd.DataFrame:
for [multilabel][] tasks.
- **native_multioutput:** Whether the model has native support
for [multioutput tasks][].
- **native_multivariate:** Whether the model has native support
for [multivariate][] tasks.
- **validation:** Whether the model has [in-training validation][].
- **supports_engines:** Engines supported by the model.
"""
rows = []
for model in MODELS:
m = model(goal=self._goal)
m = model(goal=self._goal, branches=self._branches)
if self._goal.name in m._estimators:

Check notice on line 924 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _estimators of a class
rows.append(m.get_tags())

Expand Down
102 changes: 50 additions & 52 deletions atom/basetrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from atom.baserunner import BaseRunner
from atom.branch import BranchManager
from atom.data_cleaning import BaseTransformer
from atom.models import MODELS, CustomModel
from atom.models import MODELS, create_custom_model
from atom.plots import RunnerPlot
from atom.utils.types import Model, sequence_t
from atom.utils.utils import (
Expand Down Expand Up @@ -168,60 +168,59 @@ def _prepare_parameters(self):
if isinstance(model, str):
for m in model.split("+"):
if m.startswith("!"):
exc.append(m[1:])
else:
try:
if len(name := m.split("_", 1)) > 1:
name, tag = name[0].lower(), f"_{name[1]}"
else:
name, tag = name[0].lower(), ""

cls = next(n for n in MODELS if n.acronym.lower() == name)

except StopIteration:
exc.append(m[1:].lower())
continue

try:
if len(name := m.split("_", 1)) > 1:
name, tag = name[0].lower(), f"_{name[1]}"
else:
name, tag = name[0].lower(), ""

cls = next(n for n in MODELS if n.acronym.lower() == name)

except StopIteration:
raise ValueError(
f"Invalid value for the models parameter, got {m}. "
"Note that tags must be separated by an underscore. "
"Available model are:\n" +
"\n".join(
[
f" --> {m.__name__} ({m.acronym})"
for m in MODELS
if self._goal.name in m._estimators

Check notice on line 191 in atom/basetrainer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _estimators of a class
]
)
) from None

# Check if libraries for non-sklearn models are available
dependencies = {
"BATS": "tbats",
"CatB": "catboost",
"LGB": "lightgbm",
"MSTL": "statsforecast",
"TBATS": "tbats",
"XGB": "xgboost",
}
if cls.acronym in dependencies:
check_dependency(dependencies[cls.acronym])

# Check if the model supports the task
if self._goal.name not in cls._estimators:
# Forecast task can use regression models
if self._goal is not Goal.forecast or "regression" not in cls._estimators:
raise ValueError(
f"Invalid value for the models parameter, got {m}. "
"Note that tags must be separated by an underscore. "
"Available model are:\n" +
"\n".join(
[
f" --> {m.__name__} ({m.acronym})"
for m in MODELS
if self._goal.name in m._estimators
]
)
) from None

# Check if libraries for non-sklearn models are available
dependencies = {
"BATS": "tbats",
"CatB": "catboost",
"LGB": "lightgbm",
"MSTL": "statsforecast",
"TBATS": "tbats",
"XGB": "xgboost",
}
if cls.acronym in dependencies:
check_dependency(dependencies[cls.acronym])

# Check if the model supports the task
if self._goal.name not in cls._estimators:
# Forecast task can use regression models
if self._goal.name == "forecast" and "regression" in cls._estimators:
kwargs["goal"] = Goal.Regression
else:
raise ValueError(
f"The {cls.__name__} model is not "
f"available for {self.task.name} tasks!"
)

inc.append(cls(name=f"{cls.acronym}{tag}", **kwargs))
f"The {cls.__name__} model is not "
f"available for {self.task.name} tasks!"
)

inc.append(cls(name=f"{cls.acronym}{tag}", **kwargs))

elif isinstance(model, Model): # For new instances or reruns
inc.append(model)

else: # Model is a custom estimator
inc.append(CustomModel(estimator=model, **kwargs))
inc.append(create_custom_model(estimator=model, **kwargs))

if inc and exc:
raise ValueError(
Expand All @@ -239,9 +238,8 @@ def _prepare_parameters(self):
self._models = ClassMap(*inc)
else:
self._models = ClassMap(
model(**kwargs)
for model in MODELS
if self._goal.name in model._estimators and model.acronym not in exc
model(**kwargs) for model in MODELS
if self._goal.name in model._estimators and model.acronym.lower() not in exc

Check notice on line 242 in atom/basetrainer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _estimators of a class
)

# Prepare est_params ======================================= >>
Expand Down
10 changes: 5 additions & 5 deletions atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,13 @@ def _inherit(self, obj: T_Estimator) -> T_Estimator:
"""
signature = sign(obj.__init__) # type: ignore[misc]
for p in ("n_jobs", "random_state"):
if p in signature and obj.get_params()[p] == signature[p]._default:
obj.set_params(**{p: getattr(self, p)})
if p in signature and getattr(obj, p, "<!>") == signature[p]._default:

Check notice on line 375 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _default of a class
setattr(obj, p, getattr(self, p))

# Add seasonal period to the estimator
if self._config.sp:
if "sp" in signature and obj.get_params()["sp"] == signature["sp"]._default:
obj.set_params(sp=self._config.sp)
if hasattr(self, "_config") and self._config.sp:
if "sp" in signature and getattr(obj, "sp", "<!>") == signature["sp"]._default:

Check notice on line 380 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _default of a class
obj.sp = self._config.sp

return obj

Expand Down
2 changes: 1 addition & 1 deletion atom/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
QuadraticDiscriminantAnalysis, RadiusNearestNeighbors, RandomForest, Ridge,
StochasticGradientDescent, SupportVectorMachine, XGBoost,
)
from atom.models.custom import CustomModel
from atom.models.custom import create_custom_model
from atom.models.ensembles import Stacking, Voting
from atom.models.ts import (
ARIMA, BATS, ETS, MSTL, SARIMAX, STL, TBATS, VAR, VARMAX, AutoARIMA,
Expand Down
Loading

0 comments on commit 26c05d0

Please sign in to comment.