-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implemented multivariate forecast benchmarking
- Loading branch information
1 parent
76a66e4
commit f874419
Showing
12 changed files
with
167 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.1.1" | ||
__version__ = "0.2.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import warnings | ||
|
||
import pandas as pd | ||
|
||
|
||
def steps_in_horizon(horizon: pd.Timedelta, freq: str) -> int: | ||
"""Return the number of steps in a given horizon. | ||
Parameters | ||
---------- | ||
horizon: | ||
The horizon to be split into steps. | ||
freq: | ||
The frequency of the horizon. | ||
Returns | ||
------- | ||
The number of steps in the horizon. | ||
""" | ||
freq = "1" + freq if not freq[0].isdigit() else freq | ||
periods = horizon / pd.Timedelta(freq) | ||
if not periods.is_integer(): | ||
raise ValueError("Horizon is not a multiple of the frequency") | ||
return int(periods) | ||
|
||
|
||
def create_forecast_index(history: pd.DataFrame, horizon: int) -> pd.DatetimeIndex: | ||
"""Create time index for a forecast horizon. | ||
Parameters | ||
---------- | ||
history: | ||
The history of the time series. | ||
horizon: | ||
The forecast horizon. | ||
Returns | ||
------- | ||
The time index for the forecast horizon. | ||
""" | ||
last_date = history["ds"].iloc[-1] | ||
inferred_freq = history["ds"].dt.freq | ||
freq = "1" + inferred_freq if not inferred_freq[0].isdigit() else inferred_freq | ||
return pd.date_range( | ||
start=last_date + pd.Timedelta(freq), | ||
periods=horizon, | ||
freq=freq, | ||
) | ||
|
||
|
||
def create_perfect_forecasts_from_covariates( | ||
covariates: pd.DataFrame, | ||
horizon: pd.Timedelta, | ||
step: pd.Timedelta, | ||
**kwargs, | ||
) -> pd.DataFrame: | ||
"""Create forecasts from covariates. | ||
Sometimes external forecasts are not available for the entire horizon. This function creates | ||
external forecast dataframe from external covariates as a perfect forecast. | ||
Parameters | ||
---------- | ||
covariates: | ||
The external covariates. | ||
horizon: | ||
The forecast horizon. | ||
step: | ||
The step size between forecasts. | ||
Returns | ||
------- | ||
The external forecast dataframe. | ||
""" | ||
if kwargs.get("start") is not None: | ||
start = kwargs.get("start") | ||
else: | ||
start = covariates.index[0] | ||
|
||
last_date = covariates.index[-1] | ||
|
||
forecasts = [] | ||
while start + horizon <= last_date: | ||
forecast = covariates.loc[ | ||
(covariates.index > start) & (covariates.index <= start + horizon) | ||
] | ||
forecast.insert(0, "cutoff_date", start) | ||
forecast.rename_axis("ds", inplace=True) | ||
forecast.reset_index(inplace=True) | ||
|
||
if len(forecast) == 0: | ||
warnings.warn( | ||
f"Covariates not found for {start} - {start + horizon}, cannot make forecast at step {start}", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
|
||
forecasts.append(forecast) | ||
start += step | ||
|
||
forecast_df = pd.concat(forecasts, ignore_index=True) | ||
return forecast_df |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from enfobench.evaluation import utils | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"horizon, freq, expected", | ||
[ | ||
("1 day", "15T", 96), | ||
("1 day", "1H", 24), | ||
("7 days", "1H", 7 * 24), | ||
("1D", "1D", 1), | ||
("1H", "1H", 1), | ||
], | ||
) | ||
def test_steps_in_horizon(horizon, freq, expected): | ||
assert utils.steps_in_horizon(pd.Timedelta(horizon), freq) == expected | ||
|
||
|
||
def test_steps_in_horizon_raises_with_non_multiple_horizon(): | ||
with pytest.raises(ValueError): | ||
utils.steps_in_horizon(pd.Timedelta("36 minutes"), "15T") | ||
|
||
|
||
def test_create_forecast_index(target): | ||
history = target.to_frame("y").rename_axis("ds").reset_index() | ||
horizon = 96 | ||
last_date = history["ds"].iloc[-1] | ||
|
||
index = utils.create_forecast_index(history, horizon) | ||
|
||
assert isinstance(index, pd.DatetimeIndex) | ||
assert index.freq == target.index.freq | ||
assert len(index) == horizon | ||
assert all(idx > last_date for idx in index) | ||
|
||
|
||
def test_create_perfect_forecasts_from_covariates(covariates): | ||
forecasts = utils.create_perfect_forecasts_from_covariates( | ||
covariates, | ||
horizon=pd.Timedelta("7 days"), | ||
step=pd.Timedelta("1D"), | ||
) | ||
|
||
assert isinstance(forecasts, pd.DataFrame) | ||
assert "ds" in forecasts.columns | ||
assert "cutoff_date" in forecasts.columns | ||
assert all(col in forecasts.columns for col in covariates.columns) |