Skip to content

Commit

Permalink
added visualizations and changed cutoff date to be the first forecast…
Browse files Browse the repository at this point in the history
… date, rather than the last measured date
  • Loading branch information
attila-balint-kul committed Dec 7, 2023
1 parent d0bde3f commit 32641a9
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 396 deletions.
2 changes: 1 addition & 1 deletion models/sf-auto-ces/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _forecast(self, y: pd.Series, level: list[int] | None = None) -> pd.DataFram
pred = model.forecast(y=y.values, h=periods_in_7_days, level=level)

# Create index for forecast
index = create_forecast_index(history=y.to_frame('y'), horizon=periods_in_7_days)
index = create_forecast_index(history=y.to_frame("y"), horizon=periods_in_7_days)

# Postprocess forecast
self._last_prediction = pd.DataFrame(index=index, data=pred).rename(columns={"mean": "yhat"}).fillna(y.mean())
Expand Down
10 changes: 4 additions & 6 deletions models/sf-mstl/src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import pandas as pd
from statsforecast.models import MSTL

Expand All @@ -11,7 +9,7 @@
class MSTLModel:
def info(self) -> ModelInfo:
return ModelInfo(
name=f"Statsforecast.MSTL.1D.7D",
name="Statsforecast.MSTL.1D.7D",
authors=[AuthorInfo(name="Attila Balint", email="[email protected]")],
type=ForecasterType.quantile,
params={
Expand All @@ -35,9 +33,9 @@ def forecast(
y = history.y.fillna(history.y.mean())

# Create model
periods_in_1D = periods_in_duration(y.index, duration='1D')
periods_in_7D = periods_in_duration(y.index, duration='7D')
model = MSTL(season_length=[periods_in_1D, periods_in_7D])
periods_in_one_day = periods_in_duration(y.index, duration="1D")
periods_in_one_week = periods_in_duration(y.index, duration="7D")
model = MSTL(season_length=[periods_in_one_day, periods_in_one_week])

# Make forecast
pred = model.forecast(y=y.values, h=horizon, level=level, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/enfobench/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame:
The history of the target variable up to the cutoff date.
"""
self._check_cutoff_in_rage(cutoff_date)
return self._target[self._target.index <= cutoff_date]
return self._target[self._target.index < cutoff_date]

def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
"""Returns the past covariates for the cutoff date.
Expand All @@ -366,7 +366,7 @@ def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
return None

self._check_cutoff_in_rage(cutoff_date)
return self._past_covariates[self._past_covariates.index <= cutoff_date]
return self._past_covariates[self._past_covariates.index < cutoff_date]

def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
"""Returns the future covariates for the cutoff date.
Expand All @@ -387,7 +387,7 @@ def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | Non

future_covariates = self._future_covariates[
(self._future_covariates.cutoff_date == last_past_cutoff_date)
& (self._future_covariates.timestamp > cutoff_date)
& (self._future_covariates.timestamp >= cutoff_date)
]
future_covariates.set_index("timestamp", inplace=True)
return future_covariates
4 changes: 2 additions & 2 deletions src/enfobench/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def create_perfect_forecasts_from_covariates(

forecasts = []
while start + horizon <= last_date:
forecast = past_covariates.loc[(past_covariates.index > start) & (past_covariates.index <= start + horizon)]
forecast = past_covariates.loc[(past_covariates.index >= start) & (past_covariates.index < start + horizon)]
forecast.rename_axis("timestamp", inplace=True)
forecast.reset_index(inplace=True)
forecast["cutoff_date"] = pd.to_datetime(start, unit="ns")
forecast.insert(0, "cutoff_date", pd.to_datetime(start, unit="ns"))
forecast.set_index(["cutoff_date", "timestamp"], inplace=True)

if len(forecast) == 0:
Expand Down
7 changes: 6 additions & 1 deletion src/enfobench/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enfobench.dataset import Dataset
from enfobench.evaluation.client import ForecastClient
from enfobench.evaluation.model import Model
from enfobench.evaluation.utils import generate_cutoff_dates, steps_in_horizon
from enfobench.evaluation.utils import generate_cutoff_dates, steps_in_horizon, create_forecast_index


def _compute_metric(forecast: pd.DataFrame, metric: Callable) -> float:
Expand Down Expand Up @@ -140,6 +140,11 @@ def cross_validate(
)
raise ValueError(msg)

expected_index = create_forecast_index(history, horizon_length)
if not (expected_index == forecast.index).all():
msg = "Forecast index does not match the expected index."
raise ValueError(msg)

forecast_contains_nans = forecast.isna().any(axis=None)
if forecast_contains_nans:
msg = "Forecast contains NaNs, make sure to fill in missing values."
Expand Down
16 changes: 10 additions & 6 deletions src/enfobench/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,20 @@ def periods_in_duration(target: pd.DatetimeIndex, duration: timedelta | pd.Timed
duration = pd.Timedelta(duration)
elif isinstance(duration, str):
duration = pd.Timedelta(duration)
elif isinstance(duration, pd.Timedelta):
pass
else:
msg = f"Duration must be one of [pd.Timedelta, timedelta, str], got {type(duration)}"
raise ValueError(msg)

first_delta = target[1] - target[0]
last_delta = target[-1] - target[-2]
if first_delta != last_delta:
msg = f"Season length is not constant: '{first_delta}' != '{last_delta}'"
if len(target.diff()[1:].unique()) != 1:
msg = f"Multiple frequencies found: '{[td for td in target.diff()[1:].unique()]}'"
raise ValueError(msg)

periods = duration / first_delta
delta_t = target.diff()[1:].unique()[0]
periods = duration / target.diff()[1:].unique()[0]
if not periods.is_integer():
msg = f"Season length '{duration}' is not a multiple of the frequency '{first_delta}'"
msg = f"Season length '{duration}' is not a multiple of the frequency '{delta_t}'"
raise ValueError(msg)
return int(periods)

Expand Down
Loading

0 comments on commit 32641a9

Please sign in to comment.