Skip to content

Commit

Permalink
added plot_series
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 22, 2023
1 parent 26c05d0 commit fbccfb3
Show file tree
Hide file tree
Showing 15 changed files with 280 additions and 121 deletions.
2 changes: 1 addition & 1 deletion atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ def _run(self, trainer: BaseRunner):
self._delete_models(model.name)
self._log(
f"Consecutive runs of model {model.name}. "
"The former model has been overwritten.", 1,
"The former model has been overwritten.", 3,
)

self._models.extend(trainer._models)

Check notice on line 2084 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _models of a class
Expand Down
90 changes: 48 additions & 42 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from sklearn.utils import resample
from sklearn.utils.metaestimators import available_if
from sktime.forecasting.base import ForecastingHorizon
from sktime.forecasting.compose import make_reduction
from sktime.proba.normal import Normal
from sktime.split import SingleWindowSplitter
Expand All @@ -61,18 +62,18 @@
from atom.plots import RunnerPlot
from atom.utils.constants import DF_ATTRS
from atom.utils.types import (
HT, Backend, Bool, DataFrame, Engine, FHSelector, Float, FloatZeroToOneExc,
Index, Int, IntLargerEqualZero, MetricConstructor, MetricFunction, NJobs,
Pandas, PredictionMethods, PredictionMethodsTS, Predictor, RowSelector,
Scalar, Scorer, Sequence, Stages, TargetSelector, Verbose, Warnings,
XSelector, YSelector, dataframe_t, float_t, int_t,
HT, Backend, Bool, DataFrame, Engine, Float, FloatZeroToOneExc, Index, Int,
IntLargerEqualZero, MetricConstructor, MetricFunction, NJobs, Pandas,
PredictionMethods, PredictionMethodsTS, Predictor, RowSelector, Scalar,
Scorer, Sequence, Stages, TargetSelector, Verbose, Warnings, XSelector,
YSelector, dataframe_t, float_t, int_t,
)
from atom.utils.utils import (
ClassMap, DataConfig, Goal, PlotCallback, ShapExplanation, Task,
TrialsCallback, adjust_verbosity, bk, cache, check_dependency, check_empty,
check_is_fitted, check_scaling, composed, crash, estimator_has_attr,
fit_and_score, flt, get_cols, get_custom_scorer, has_task, it, lst, merge,
method_to_log, rnd, sign, time_to_str, to_df, to_pandas, to_series,
method_to_log, rnd, sign, time_to_str, to_pandas,
)


Expand Down Expand Up @@ -540,7 +541,7 @@ def _fit_estimator(
# Multi-objective optimization doesn't support pruning
if trial and len(self._metric) == 1:
trial.report(
self._score_from_est(self._metric[0], estimator, *validation),
value=float(self._score_from_est(self._metric[0], estimator, *validation)),
step=step,
)

Expand Down Expand Up @@ -726,15 +727,16 @@ def _score_from_est(
"""
if self.task.is_forecast:
# Sktime uses signature estimator.predict(fh, X)
y_pred = to_series(estimator.predict(y.index, check_empty(X)), index=y.index)
return self._score_from_pred(scorer, y, y_pred, **kwargs)
elif self.task is Task.multiclass_multioutput_classification:
# Calculate predictions with shape=(n_samples, n_targets)
y_pred = to_df(estimator.predict(X), index=y.index, columns=y.columns)
return self._score_from_pred(scorer, y, y_pred, **kwargs)
y_pred = estimator.predict(fh=y.index, X=check_empty(X))
else:
return scorer(estimator, X, y, **kwargs)
y_pred = to_pandas(
data=estimator.predict(X),
index=y.index,
columns=getattr(y, "columns", None),
name=getattr(y, "name", None),
)

return self._score_from_pred(scorer, y, y_pred, **kwargs)

def _score_from_pred(
self,
Expand Down Expand Up @@ -3012,13 +3014,17 @@ def _prediction(
called.
"""
Xt, yt = X, y # self.transform(X, y, verbose=verbose) TODO: Fix pipeline ts
Xt, yt = self.transform(X, y, verbose=verbose)

Check notice on line 3017 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if method != "score":
fh = kwargs.get("fh")
if fh is not None and not isinstance(fh, ForecastingHorizon):
kwargs["fh"] = self.branch._get_rows(fh).index

Check notice on line 3022 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

if "y" in sign(func := getattr(self.estimator, method)):
return self.memory.cache(func)(y=yt, X=Xt, **kwargs)
return self.memory.cache(func)(fh=fh, y=yt, X=Xt, **kwargs)

Check warning on line 3025 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Fixture is not requested by test functions

Fixture 'self.memory.cache' is not requested by test functions or @pytest.mark.usefixtures marker
else:
return self.memory.cache(func)(X=Xt, **kwargs)
return self.memory.cache(func)(fh=fh, X=Xt, **kwargs)

Check warning on line 3027 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Fixture is not requested by test functions

Fixture 'self.memory.cache' is not requested by test functions or @pytest.mark.usefixtures marker
else:
if metric is None:
scorer = self._metric[0]
Expand All @@ -3031,7 +3037,7 @@ def _prediction(
@composed(crash, method_to_log, beartype)
def predict(
self,
fh: FHSelector,
fh: RowSelector | ForecastingHorizon,
X: XSelector | None = None,

Check notice on line 3041 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
verbose: Int | None = None,
Expand All @@ -3046,9 +3052,9 @@ def predict(
Parameters
----------
fh: int, range, sequence or [ForecastingHorizon][]
The forecasting horizon encoding the time stamps to
forecast at.
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][]
The [forecasting horizon][row-and-column-selection] encoding
the time stamps to forecast at.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
Expand All @@ -3070,7 +3076,7 @@ def predict(
@composed(crash, method_to_log, beartype)
def predict_interval(
self,
fh: FHSelector,
fh: RowSelector | ForecastingHorizon,
X: XSelector | None = None,

Check notice on line 3080 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
coverage: Float | Sequence[Float] = 0.9,
Expand All @@ -3086,9 +3092,9 @@ def predict_interval(
Parameters
----------
fh: int, range, sequence or [ForecastingHorizon][]
The forecasting horizon encoding the time stamps to
forecast at.
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][]
The [forecasting horizon][row-and-column-selection] encoding
the time stamps to forecast at.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
Expand Down Expand Up @@ -3119,7 +3125,7 @@ def predict_interval(
@composed(crash, method_to_log, beartype)
def predict_proba(
self,
fh: FHSelector,
fh: RowSelector | ForecastingHorizon,
X: XSelector | None = None,

Check notice on line 3129 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
marginal: Bool = True,
Expand All @@ -3135,9 +3141,9 @@ def predict_proba(
Parameters
----------
fh: int, range, sequence or [ForecastingHorizon][]
The forecasting horizon encoding the time stamps to
forecast at.
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][]
The [forecasting horizon][row-and-column-selection] encoding
the time stamps to forecast at.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
Expand Down Expand Up @@ -3167,7 +3173,7 @@ def predict_proba(
@composed(crash, method_to_log, beartype)
def predict_quantiles(
self,
fh: FHSelector,
fh: RowSelector | ForecastingHorizon,
X: XSelector | None = None,

Check notice on line 3177 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
alpha: Float | Sequence[Float] = (0.05, 0.95),
Expand All @@ -3183,9 +3189,9 @@ def predict_quantiles(
Parameters
----------
fh: int, range, sequence or [ForecastingHorizon][]
The forecasting horizon encoding the time stamps to
forecast at.
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][]
The [forecasting horizon][row-and-column-selection] encoding
the time stamps to forecast at.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
Expand Down Expand Up @@ -3256,7 +3262,7 @@ def predict_residuals(
@composed(crash, method_to_log, beartype)
def predict_var(
self,
fh: RowSelector | FHSelector,
fh: RowSelector | ForecastingHorizon,
X: XSelector | None = None,

Check notice on line 3266 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
cov: Bool = False,
Expand All @@ -3272,9 +3278,9 @@ def predict_var(
Parameters
----------
fh: int, range, sequence or [ForecastingHorizon][]
The forecasting horizon encoding the time stamps to
forecast at.
fh: hashable, segment, sequence, dataframe or [ForecastingHorizon][]
The [forecasting horizon][row-and-column-selection] encoding
the time stamps to forecast at.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
Expand Down Expand Up @@ -3308,7 +3314,7 @@ def score(
self,
y: RowSelector | YSelector,
X: XSelector | None = None,

Check notice on line 3316 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
fh: FHSelector | None = None,
fh: RowSelector | ForecastingHorizon | None = None,
*,
metric: str | MetricFunction | Scorer | None = None,
verbose: Int | None = None,
Expand All @@ -3334,9 +3340,9 @@ def score(
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
fh: int, sequence or [ForecastingHorizon][] or None, default=None
The forecasting horizon encoding the time stamps to
forecast at.
fh: hashable, segment, sequence, dataframe, [ForecastingHorizon][] or None, default=None
The [forecasting horizon][row-and-column-selection] encoding
the time stamps to forecast at.
metric: str, func, scorer or None, default=None
Metric to calculate. Choose from any of sklearn's scorers,
Expand Down
4 changes: 2 additions & 2 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,15 @@ def holdout(self) -> DataFrame | None:
@property
def models(self) -> str | list[str] | None:
"""Name of the model(s)."""
if isinstance(self._models, ClassMap):
if self._models:
return flt(self._models.keys())
else:
return None

@property
def metric(self) -> str | list[str] | None:
"""Name of the metric(s)."""
if isinstance(self._metric, ClassMap):
if self._metric:
return flt(self._metric.keys())
else:
return None
Expand Down
2 changes: 1 addition & 1 deletion atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _inherit(self, obj: T_Estimator) -> T_Estimator:
# Add seasonal period to the estimator
if hasattr(self, "_config") and self._config.sp:
if "sp" in signature and getattr(obj, "sp", "<!>") == signature["sp"]._default:

Check notice on line 380 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _default of a class
obj.sp = self._config.sp
obj.sp = self._config.sp if self.multiple_seasonality else flt(self._config.sp)

return obj

Expand Down
24 changes: 24 additions & 0 deletions atom/models/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,30 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
"""
return super()._get_est({"season_length": self._config.sp or 1} | params)

def _get_parameters(self, trial: Trial) -> dict:
"""Get the trial's hyperparameters.
Parameters
----------
trial: [Trial][]
Current trial.
Returns
-------
dict
Trial's hyperparameters.
"""
params = super()._get_parameters(trial)

# MSTL has stl_kwargs, that takes a dict of hyperparameters
if "stl_kwargs" in self._est_params:
new_params = {}
else:
new_params = {"stl_kwargs": params}

return new_params

@staticmethod
def _get_distributions() -> dict[str, BaseDistribution]:
"""Get the predefined hyperparameter distributions.
Expand Down
26 changes: 15 additions & 11 deletions atom/plots/baseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,17 +418,21 @@ def _draw_line(
"""
return go.Scatter(
line={
"width": self.line_width,
"color": BasePlot._fig.get_elem(parent),
"dash": BasePlot._fig.get_elem(child, "dash"),
},
marker={
"symbol": BasePlot._fig.get_elem(child, "marker"),
"size": self.marker_size,
"color": BasePlot._fig.get_elem(parent),
"line": {"width": 1, "color": "rgba(255, 255, 255, 0.9)"},
},
line=kwargs.pop(
"line", {
"width": self.line_width,
"color": BasePlot._fig.get_elem(parent),
"dash": BasePlot._fig.get_elem(child, "dash"),
}
),
marker=kwargs.pop(
"marker", {
"symbol": BasePlot._fig.get_elem(child, "marker"),
"size": self.marker_size,
"color": BasePlot._fig.get_elem(parent),
"line": {"width": 1, "color": "rgba(255, 255, 255, 0.9)"},
}
),
hovertemplate=kwargs.pop(
"hovertemplate",
f"(%{{x}}, %{{y}})<extra>{parent}{f' - {child}' if child else ''}</extra>",
Expand Down
Loading

0 comments on commit fbccfb3

Please sign in to comment.