Skip to content

Commit

Permalink
fixed dataset.utils bug
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Nov 24, 2023
1 parent 2013762 commit ca6f3fd
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.3"
__version__ = "0.3.4"
14 changes: 8 additions & 6 deletions src/enfobench/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def create_perfect_forecasts_from_covariates(
*,
horizon: pd.Timedelta,
step: pd.Timedelta,
**kwargs,
start: pd.Timestamp | None = None,
) -> pd.DataFrame:
"""Create forecasts from covariates.
Expand All @@ -19,29 +19,31 @@ def create_perfect_forecasts_from_covariates(
past_covariates: The external covariates.
horizon: The forecast horizon.
step: The step size between forecasts.
start: The start date of the forecast. If None, the first date of the covariates is used.
Returns:
The external forecast dataframe.
"""
start = kwargs.get("start", past_covariates.index[0])
start = start or past_covariates.index[0]
last_date = past_covariates.index[-1]

forecasts = []
while start + horizon <= last_date:
forecast = past_covariates.loc[(past_covariates.index > start) & (past_covariates.index <= start + horizon)]
forecast.insert(0, "cutoff_date", start)
forecast.rename_axis("timestamp", inplace=True)
forecast.reset_index(inplace=True)
forecast["cutoff_date"] = start.isoformat() # pd.concat fails if cutoff_date is a Timestamp

if len(forecast) == 0:
warnings.warn(
f"Covariates not found for {start} - {start + horizon}, cannot make forecast at step {start}",
UserWarning,
stacklevel=2,
)

forecasts.append(forecast)
else:
forecasts.append(forecast)
start += step

forecast_df = pd.concat(forecasts, ignore_index=True)
forecast_df = pd.concat(forecasts, ignore_index=False)
forecast_df["cutoff_date"] = pd.to_datetime(forecast_df["cutoff_date"]) # convert back to Timestamp
return forecast_df
8 changes: 6 additions & 2 deletions src/enfobench/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ def periods_in_duration(target: pd.DatetimeIndex, duration: timedelta | pd.Timed

first_delta = target[1] - target[0]
last_delta = target[-1] - target[-2]
assert first_delta == last_delta, "Season length is not constant"
if first_delta != last_delta:
msg = f"Season length is not constant: '{first_delta}' != '{last_delta}'"
raise ValueError(msg)

periods = duration / first_delta
assert periods.is_integer(), "Season length is not a multiple of the frequency"
if not periods.is_integer():
msg = f"Season length '{duration}' is not a multiple of the frequency '{first_delta}'"
raise ValueError(msg)
return int(periods)


Expand Down
Empty file added tests/test_dataset/__init__.py
Empty file.
File renamed without changes.
26 changes: 26 additions & 0 deletions tests/test_dataset/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import pandas as pd

from enfobench.dataset.utils import create_perfect_forecasts_from_covariates


def test_create_perfect_forecasts_from_covariates():
index = pd.date_range(start="2020-01-01", end="2020-10-02 13:54:00", freq="1H")
past_covariates = pd.DataFrame(
index=index,
data=np.random.rand(len(index), 2),
columns=["covariate_1", "covariate_2"],
)

future_covariates = create_perfect_forecasts_from_covariates(
past_covariates,
start=pd.Timestamp("2020-01-01"),
step=pd.Timedelta("1D"),
horizon=pd.Timedelta("7D"),
)

assert isinstance(future_covariates, pd.DataFrame)
assert "covariate_1" in future_covariates.columns
assert "covariate_2" in future_covariates.columns
assert "timestamp" in future_covariates.columns
assert "cutoff_date" in future_covariates.columns

0 comments on commit ca6f3fd

Please sign in to comment.