Skip to content

Commit

Permalink
added utils for periods in duration
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Nov 24, 2023
1 parent a5b7c20 commit 2013762
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.2"
__version__ = "0.3.3"
17 changes: 17 additions & 0 deletions src/enfobench/evaluation/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

import pandas as pd


Expand All @@ -19,6 +21,21 @@ def steps_in_horizon(horizon: pd.Timedelta, freq: str) -> int:
return int(periods)


def periods_in_duration(target: pd.DatetimeIndex, duration: timedelta | pd.Timedelta | str) -> int:
if isinstance(duration, timedelta):
duration = pd.Timedelta(duration)
elif isinstance(duration, str):
duration = pd.Timedelta(duration)

first_delta = target[1] - target[0]
last_delta = target[-1] - target[-2]
assert first_delta == last_delta, "Season length is not constant"

periods = duration / first_delta
assert periods.is_integer(), "Season length is not a multiple of the frequency"
return int(periods)


def create_forecast_index(history: pd.DataFrame, horizon: int) -> pd.DatetimeIndex:
"""Create time index for a forecast horizon.
Expand Down
55 changes: 55 additions & 0 deletions tests/test_evaluations/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

import pandas as pd
import pytest

Expand Down Expand Up @@ -48,3 +50,56 @@ def test_create_perfect_forecasts_from_covariates(covariates):
assert "timestamp" in forecasts.columns
assert "cutoff_date" in forecasts.columns
assert all(col in forecasts.columns for col in covariates.columns)


@pytest.mark.parametrize(
"freq, duration, expected",
[
("15T", "15T", 1),
("15T", "1H", 4),
("15T", "1D", 96),
("15T", "2D", 2 * 96),
("15T", "1W", 7 * 96),
("15T", pd.Timedelta("15T"), 1),
("15T", pd.Timedelta("1H"), 4),
("15T", pd.Timedelta("1D"), 96),
("15T", pd.Timedelta("2D"), 2 * 96),
("15T", pd.Timedelta("1W"), 7 * 96),
("15T", timedelta(minutes=15), 1),
("15T", timedelta(hours=1), 4),
("15T", timedelta(days=1), 96),
("15T", timedelta(days=2), 2 * 96),
("15T", timedelta(weeks=1), 7 * 96),
("30T", "30T", 1),
("30T", "1H", 2),
("30T", "1D", 48),
("30T", "2D", 2 * 48),
("30T", "1W", 7 * 48),
("30T", pd.Timedelta("30T"), 1),
("30T", pd.Timedelta("1H"), 2),
("30T", pd.Timedelta("1D"), 48),
("30T", pd.Timedelta("2D"), 2 * 48),
("30T", pd.Timedelta("1W"), 7 * 48),
("30T", timedelta(minutes=30), 1),
("30T", timedelta(hours=1), 2),
("30T", timedelta(days=1), 48),
("30T", timedelta(days=2), 2 * 48),
("30T", timedelta(weeks=1), 7 * 48),
("1H", "1H", 1),
("1H", "1D", 24),
("1H", "2D", 2 * 24),
("1H", "1W", 7 * 24),
("1H", pd.Timedelta("1H"), 1),
("1H", pd.Timedelta("1D"), 24),
("1H", pd.Timedelta("2D"), 2 * 24),
("1H", pd.Timedelta("1W"), 7 * 24),
("1H", timedelta(hours=1), 1),
("1H", timedelta(days=1), 24),
("1H", timedelta(days=2), 2 * 24),
("1H", timedelta(weeks=1), 7 * 24),
],
)
def test_periods_in_duration(freq, duration, expected):
target = pd.date_range(start="2020-01-01", end="2020-02-01", freq=freq)

assert utils.periods_in_duration(target, duration) == expected

0 comments on commit 2013762

Please sign in to comment.