From 73bd1367e83ac0925572a4064a29e2bb3143194c Mon Sep 17 00:00:00 2001 From: attilabalint Date: Fri, 21 Jun 2024 19:24:17 +0200 Subject: [PATCH] fixed issue with history slicing --- src/enfobench/__version__.py | 2 +- src/enfobench/core/dataset.py | 14 ++++++++------ tests/test_dataset/test_dataset.py | 4 ++-- tests/test_evaluations/test_evaluate.py | 7 ++++--- tests/test_evaluations/test_utils.py | 3 ++- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/enfobench/__version__.py b/src/enfobench/__version__.py index 49e0fc1..a5f830a 100644 --- a/src/enfobench/__version__.py +++ b/src/enfobench/__version__.py @@ -1 +1 @@ -__version__ = "0.7.0" +__version__ = "0.7.1" diff --git a/src/enfobench/core/dataset.py b/src/enfobench/core/dataset.py index 4205450..1adc3c5 100644 --- a/src/enfobench/core/dataset.py +++ b/src/enfobench/core/dataset.py @@ -149,7 +149,7 @@ def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame: """Returns the history of the target variable up to the cutoff date. The cutoff date is the timestamp when the forecast is made, - therefore it is included in the history. + therefore the cutoff_date is not included in the history. Args: cutoff_date: The cutoff date. @@ -158,13 +158,14 @@ def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame: The history of the target variable up to the cutoff date. """ self._check_cutoff_in_rage(cutoff_date) - return self._target[self._target.index <= cutoff_date] + return self._target[self._target.index < cutoff_date] def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None: """Returns the past covariates for the cutoff date. - The cutoff date is the timestamp when the forecast is made, - therefore it is included in the past covariates. + The cutoff date is the timestamp when the forecast is made. + As the covariates are weather parameters measured at the indicated timestamp, + the cutoff_date is included in the past covariates. Args: cutoff_date: The cutoff date. @@ -181,8 +182,9 @@ def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None: def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None: """Returns the future covariates for the cutoff date. - The cutoff date is the timestamp when the forecast is made, - therefore it is not included in the future covariates. + The cutoff date is the timestamp when the forecast is made. + As the covariates are weather parameters measured at the indicated timestamp, + the cutoff_date is not included in the future covariates. Args: cutoff_date: The cutoff date. diff --git a/tests/test_dataset/test_dataset.py b/tests/test_dataset/test_dataset.py index 5820a9b..b6f1177 100644 --- a/tests/test_dataset/test_dataset.py +++ b/tests/test_dataset/test_dataset.py @@ -17,8 +17,8 @@ def test_get_history(helpers, resolution): assert isinstance(history.index, pd.DatetimeIndex) assert "y" in history.columns assert len(history.columns) == 1 - assert cutoff_date in history.index - assert (history.index <= cutoff_date).all() + assert cutoff_date not in history.index + assert (history.index < cutoff_date).all() @pytest.mark.parametrize("resolution", ["15T", "30T", "1H"]) diff --git a/tests/test_evaluations/test_evaluate.py b/tests/test_evaluations/test_evaluate.py index c2892c4..3fdf0b2 100644 --- a/tests/test_evaluations/test_evaluate.py +++ b/tests/test_evaluations/test_evaluate.py @@ -26,15 +26,16 @@ def test_cross_validate_univariate_locally(helpers, model): assert "yhat" in forecasts.columns assert "timestamp" in forecasts.columns - assert (forecasts.timestamp > start_date).all() - assert forecasts.timestamp.iloc[-1] == end_date + assert (forecasts.timestamp >= start_date).all() + assert forecasts.timestamp.iloc[0] == start_date assert "cutoff_date" in forecasts.columns assert (start_date.time() == forecasts.cutoff_date.dt.time).all() for cutoff_date, forecast in forecasts.groupby("cutoff_date"): assert len(forecast) == 38 * 2 # 38 hours with half-hour series - assert (forecast.timestamp > cutoff_date).all() + assert forecast.timestamp.iloc[0] == cutoff_date + assert (forecast.timestamp >= cutoff_date).all() assert list(forecasts.cutoff_date.unique()) == list( pd.date_range(start="2020-02-01 10:00:00", end="2020-03-30 10:00:00", freq="1D", inclusive="both") diff --git a/tests/test_evaluations/test_utils.py b/tests/test_evaluations/test_utils.py index 3fc9f0e..c9d8d4e 100644 --- a/tests/test_evaluations/test_utils.py +++ b/tests/test_evaluations/test_utils.py @@ -170,4 +170,5 @@ def test_create_forecast_index(helpers, freq, horizon): assert isinstance(index, pd.DatetimeIndex) assert index.freq == history.index.freq assert len(index) == horizon - assert (index > cutoff_date).all() + assert index[0] == cutoff_date + assert (index >= cutoff_date).all()