Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 25, 2023
1 parent fbccfb3 commit 36fff85
Show file tree
Hide file tree
Showing 19 changed files with 226 additions and 46 deletions.
14 changes: 11 additions & 3 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,17 +3014,25 @@ def _prediction(
called.
"""
Xt, yt = self.transform(X, y, verbose=verbose)
if y is not None or X is not None:
if isinstance(out := self.transform(X, y, verbose=verbose), tuple):
Xt, yt = out

Check notice on line 3019 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
elif X is not None:
Xt, yt = out, y

Check notice on line 3021 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
else:
Xt, yt = X, out

Check notice on line 3023 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
else:
Xt, yt = X, y

Check notice on line 3025 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if method != "score":
fh = kwargs.get("fh")
if fh is not None and not isinstance(fh, ForecastingHorizon):
kwargs["fh"] = self.branch._get_rows(fh).index

Check notice on line 3030 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

if "y" in sign(func := getattr(self.estimator, method)):
return self.memory.cache(func)(fh=fh, y=yt, X=Xt, **kwargs)
return self.memory.cache(func)(y=yt, X=Xt, **kwargs)

Check warning on line 3033 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Fixture is not requested by test functions

Fixture 'self.memory.cache' is not requested by test functions or @pytest.mark.usefixtures marker
else:
return self.memory.cache(func)(fh=fh, X=Xt, **kwargs)
return self.memory.cache(func)(X=Xt, **kwargs)

Check warning on line 3035 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Fixture is not requested by test functions

Fixture 'self.memory.cache' is not requested by test functions or @pytest.mark.usefixtures marker
else:
if metric is None:
scorer = self._metric[0]
Expand Down
29 changes: 24 additions & 5 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
Bool, DataFrame, FloatZeroToOneExc, HarmonicsSelector, Int, IntLargerOne,
MetricConstructor, Model, ModelSelector, ModelsSelector, Pandas,
RowSelector, Scalar, Seasonality, Segment, Sequence, Series,
TargetSelector, YSelector, dataframe_t, int_t, segment_t, sequence_t,
TargetSelector, YSelector, bool_t, dataframe_t, int_t, segment_t,
sequence_t,
)
from atom.utils.utils import (
ClassMap, DataContainer, Goal, SeasonalPeriod, Task, bk, check_is_fitted,
Expand Down Expand Up @@ -888,9 +889,18 @@ def _delete_models(self, models: str | Model | Sequence[str | Model]):
self._metric = ClassMap()

@crash
def available_models(self) -> pd.DataFrame:
def available_models(self, **kwargs) -> pd.DataFrame:
"""Give an overview of the available predefined models.
Parameters
----------
**kwargs
Filter the returned models providing any of the column as
keyword arguments, where the value is the desired filter,
e.g., `accepts_sparse=True`, to get all models that accept
sparse input or `supports_engines="cuml"` to get all models
that support the [cuML][] engine.
Returns
-------
pd.DataFrame
Expand All @@ -902,8 +912,8 @@ def available_models(self) -> pd.DataFrame:
- **estimator:** Name of the model's underlying estimator.
- **module:** The estimator's module.
- **handles_missing:** Whether the model can handle missing
(`NaN`) values without preprocessing. If False, consider using
the [Imputer][] class before training the models.
values without preprocessing. If False, consider using the
[Imputer][] class before training the models.
- **needs_scaling:** Whether the model requires feature scaling.
If True, [automated feature scaling][] is applied.
- **accepts_sparse:** Whether the model accepts [sparse input][sparse-datasets].
Expand All @@ -922,7 +932,16 @@ def available_models(self) -> pd.DataFrame:
for model in MODELS:
m = model(goal=self._goal, branches=self._branches)
if self._goal.name in m._estimators:

Check notice on line 934 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _estimators of a class
rows.append(m.get_tags())
tags = m.get_tags()

for key, value in kwargs.items():
k = tags.get(key)
if isinstance(value, bool_t) and value is not bool(k):
break
elif isinstance(value, str) and not re.search(value, k, re.I):
break
else:
rows.append(tags)

return pd.DataFrame(rows)

Expand Down
12 changes: 6 additions & 6 deletions atom/models/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class ARIMA(ForecastModel):
"""Autoregressive Integrated Moving Average Model.
"""Autoregressive Integrated Moving Average.
Seasonal ARIMA models and exogenous input is supported, hence this
estimator is capable of fitting SARIMA, ARIMAX, and SARIMAX.
Expand Down Expand Up @@ -178,7 +178,7 @@ def _get_distributions(self) -> dict[str, BaseDistribution]:


class AutoARIMA(ForecastModel):
"""Automatic Autoregressive Integrated Moving Average Model.
"""Automatic Autoregressive Integrated Moving Average.
[ARIMA][] implementation that includes automated fitting of
(S)ARIMA(X) hyperparameters (p, d, q, P, D, Q). The AutoARIMA
Expand Down Expand Up @@ -649,7 +649,7 @@ def _get_distributions(self) -> dict[str, BaseDistribution]:


class MSTL(ForecastModel):
"""Multiple Seasonal-Trend decomposition using LOESS model.
"""Multiple Seasonal-Trend decomposition using LOESS.
The MSTL decomposes the time series in multiple seasonalities using
LOESS. Then forecasts the trend using a custom non-seasonal model
Expand Down Expand Up @@ -956,7 +956,7 @@ def _get_distributions() -> dict[str, BaseDistribution]:


class SARIMAX(ForecastModel):
"""Seasonal Autoregressive Integrated Moving Average with eXogenous factors.
"""Seasonal Autoregressive Integrated Moving Average.
SARIMAX stands for Seasonal Autoregressive Integrated Moving Average
with eXogenous factors. It extends [ARIMA][] by incorporating seasonal
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def _get_distributions(self) -> dict[str, BaseDistribution]:


class STL(ForecastModel):
"""Seasonal-Trend decomposition using Loess.
"""Seasonal-Trend decomposition using LOESS.
STL is a technique commonly used for decomposing time series data
into components like trend, seasonality, and residuals.
Expand Down Expand Up @@ -1381,7 +1381,7 @@ def _get_distributions() -> dict[str, BaseDistribution]:


class VARMAX(ForecastModel):
"""Vector Autoregressive Moving-Average with exogenous variables.
"""Vector Autoregressive Moving-Average.
VARMAX is an extension of the [VAR][] model that incorporates not
only lagged values of the endogenous variables, but also includes
Expand Down
6 changes: 6 additions & 0 deletions atom/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def transform(
Transformed target column. Only returned if provided.
"""
if X is None and y is None:
raise ValueError("X and y cannot be both None.")

for _, _, transformer in self._iter(**kwargs):
with adjust_verbosity(transformer, self.verbose):
X, y = self._mem_transform(transformer, X, y)

Check notice on line 485 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
Expand Down Expand Up @@ -520,6 +523,9 @@ def inverse_transform(
Transformed target column. Only returned if provided.
"""
if X is None and y is None:
raise ValueError("X and y cannot be both None.")

for _, _, transformer in reversed(list(self._iter())):
with adjust_verbosity(transformer, self.verbose):
X, y = self._mem_transform(transformer, X, y, method="inverse_transform")

Check notice on line 531 in atom/pipeline.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
Expand Down
32 changes: 21 additions & 11 deletions atom/plots/predictionplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,11 @@ def plot_feature_importance(
def plot_forecast(
self,
models: ModelsSelector = None,
fh: RowSelector | ForecastingHorizon = "test",
fh: RowSelector | ForecastingHorizon = "dataset",
X: XSelector | None = None,

Check notice on line 971 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
target: TargetSelector = 0,
*,
plot_insample: Bool = False,
plot_interval: Bool = True,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "upper left",
Expand All @@ -988,7 +989,7 @@ def plot_forecast(
models: int, str, Model, segment, sequence or None, default=None
Models to plot. If None, all models are selected.
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][], default="test"
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][], default="dataset"
The [forecasting horizon][row-and-column-selection] for
which to plot the predictions.
Expand All @@ -999,6 +1000,10 @@ def plot_forecast(
target: int or str, default=0
Target column to look at. Only for [multivariate][] tasks.
plot_insample: bool, default=False
Whether to draw in-sample predictions (predictions on the training
set). Models that do not support this feature are silently skipped.
plot_interval: bool, default=True
Whether to plot prediction intervals together with the exact
predicted values. Models wihtout a `predict_interval` method
Expand Down Expand Up @@ -1040,7 +1045,7 @@ def plot_forecast(
--------
atom.plots:DataPlot.plot_distribution
atom.plots:DataPlot.plot_series
atom.plots:PredictionPlot.plot_roc
atom.plots:PredictionPlot.plot_errors
Examples
--------
Expand Down Expand Up @@ -1070,7 +1075,7 @@ def plot_forecast(
fh = self.branch._get_rows(fh).index

Check notice on line 1075 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

if X is None:
X = self.branch.X.loc[fh]
X = self.branch._all.loc[fh]

Check notice on line 1078 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Check notice on line 1078 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _all of a class
else:
X = self.transform(X)

Check notice on line 1080 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Expand All @@ -1083,9 +1088,12 @@ def plot_forecast(
if self.task.is_multioutput:
y_pred = y_pred[target_c]

if not plot_insample:
y_pred.loc[m.branch.train.index] = np.NaN

fig.add_trace(
self._draw_line(
x=self._get_plot_index(y_pred),
x=(x := self._get_plot_index(y_pred)),
y=y_pred,
mode="lines+markers",
parent=m.name,
Expand All @@ -1098,7 +1106,7 @@ def plot_forecast(
if plot_interval:
try:
y_pred = m.predict_interval(fh=fh, X=X)
except NotImplementedError:
except (AttributeError, NotImplementedError):
continue # Fails for some models like ES

if self.task.is_multioutput:
Expand All @@ -1107,10 +1115,13 @@ def plot_forecast(
else:
y = y_pred # Univariate

if not plot_insample:
y_pred.loc[m.branch.train.index] = np.NaN

fig.add_traces(
[
go.Scatter(
x=self._get_plot_index(y_pred),
x=x,
y=y.iloc[:, 1],
mode="lines",
line={"width": 1, "color": BasePlot._fig.get_elem(m.name)},
Expand All @@ -1121,7 +1132,7 @@ def plot_forecast(
yaxis=yaxis,
),
go.Scatter(
x=self._get_plot_index(y_pred),
x=x,
y=y.iloc[:, 0],
mode="lines",
line={"width": 1, "color": BasePlot._fig.get_elem(m.name)},
Expand All @@ -1139,12 +1150,11 @@ def plot_forecast(
# Draw original time series
fig.add_trace(
go.Scatter(
x=y_pred.index,
y=self.branch.dataset.loc[y_pred.index, target_c],
x=x,

Check warning on line 1153 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'x' might be referenced before assignment
y=self.branch._all.loc[y_pred.index, target_c],

Check notice on line 1154 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _all of a class

Check warning on line 1154 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'y_pred' might be referenced before assignment
mode="lines+markers",
line={"width": 1, "color": "black", "dash": "dash"},
opacity=0.6,
layer="below",
showlegend=False,
xaxis=xaxis,
yaxis=yaxis,
Expand Down
18 changes: 16 additions & 2 deletions docs_sources/about.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ core project contributors with a set of developer tools free of charge.
</a>
</div>
</div>
<div class="column">
<div class="logo">
<a href="../API/models/prophet" draggable="false">
<img src="../img/logos/prophet.png" alt="prophet" draggable="false">
</a>
</div>
</div>
<div class="column">
<div class="logo">
<a href="../user_guide/accelerating/#gpu-acceleration" draggable="false">
Expand All @@ -229,15 +236,22 @@ core project contributors with a set of developer tools free of charge.
</a>
</div>
</div>
</div>
<div class="row">
<div class="column">
<div class="logo">
<a href="../user_guide/models" draggable="false">
<img src="../img/logos/sklearn.png" alt="scikit-learn" draggable="false">
</a>
</div>
</div>
</div>
<div class="row">
<div class="column">
<div class="logo">
<a href="../user_guide/time-series" draggable="false">
<img src="../img/logos/sktime.png" alt="sktime" draggable="false">
</a>
</div>
</div>
<div class="column">
<div class="logo">
<a href="../API/models/xgb" draggable="false">
Expand Down
2 changes: 1 addition & 1 deletion docs_sources/api/ATOM/atomforecaster.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ of utility methods to handle the data and manage the pipeline.
- eda
- evaluate
- export_pipeline
- get_class_weight
- get_sample_weight
- get_seasonal_period
- inverse_transform
- load
- merge
Expand Down
1 change: 0 additions & 1 deletion docs_sources/api/ATOM/atomregressor.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ of utility methods to handle the data and manage the pipeline.
- eda
- evaluate
- export_pipeline
- get_class_weight
- get_sample_weight
- inverse_transform
- load
Expand Down
Binary file added docs_sources/img/logos/prophet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs_sources/img/logos/sktime.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs_sources/license.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# MIT License
-------------

Copyright &copy; 2023 Mavs
Copyright &copy; 2019-2024 Mavs

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
7 changes: 6 additions & 1 deletion docs_sources/user_guide/data_management.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,12 @@ The check is performed in the order described hereunder:

Additionally, the forecast horizon (parameter `fh`) in [forecasting tasks][time-series]
can be selected much in the same way as `rows`, where the horizon is inferred
as the index of the row selection.
as the index of the row selection. Note that, contrary to sktime's API but for
consistency with the rest of ATOM's API, atom's fh starts with the training set,
i.e., selecting `#!python atom.nf.predict(fh=range(5))` forecasts the first 5
rows of the training set, not the test set. To get the same result as sktime, use
`#!python atom.nf.predict(fh=range(len(atom.test), len(atom.test) + 5))` or
`#!python atom.nf.predict(fh=atom.test.index[:5])` instead.


!!! info
Expand Down
7 changes: 6 additions & 1 deletion docs_sources/user_guide/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ per task, but can include:
- **fullname:** Name of the model's class.
- **estimator:** Name of the model's underlying estimator.
- **module:** The estimator's module.
- **handles_missing:** Whether the model can handle missing (`NaN`) values
- **handles_missing:** Whether the model can handle missing values
without preprocessing. If False, consider using the [Imputer][] class
before training the models.
- **needs_scaling:** Whether the model requires feature scaling. If True,
Expand All @@ -54,6 +54,11 @@ per task, but can include:
- **validation:** Whether the model has [in-training validation][].
- **supports_engines:** [Engines][estimator-acceleration] supported by the model.

To filter for specific tags, specify the column name with the desired value
in the arguments of `available_models`, e.g., `#!python atom.available_models(accepts_sparse=True)`
to get all models that accept sparse input or `#!python atom.available_models(supports_engines="cuml")`
to get all models that support the [cuML][] engine.


<br>

Expand Down
Loading

0 comments on commit 36fff85

Please sign in to comment.