Skip to content

Commit

Permalink
fixed issue with history slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Jun 21, 2024
1 parent cc7be56 commit 73bd136
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.0"
__version__ = "0.7.1"
14 changes: 8 additions & 6 deletions src/enfobench/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame:
"""Returns the history of the target variable up to the cutoff date.
The cutoff date is the timestamp when the forecast is made,
therefore it is included in the history.
therefore the cutoff_date is not included in the history.
Args:
cutoff_date: The cutoff date.
Expand All @@ -158,13 +158,14 @@ def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame:
The history of the target variable up to the cutoff date.
"""
self._check_cutoff_in_rage(cutoff_date)
return self._target[self._target.index <= cutoff_date]
return self._target[self._target.index < cutoff_date]

def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
"""Returns the past covariates for the cutoff date.
The cutoff date is the timestamp when the forecast is made,
therefore it is included in the past covariates.
The cutoff date is the timestamp when the forecast is made.
As the covariates are weather parameters measured at the indicated timestamp,
the cutoff_date is included in the past covariates.
Args:
cutoff_date: The cutoff date.
Expand All @@ -181,8 +182,9 @@ def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
"""Returns the future covariates for the cutoff date.
The cutoff date is the timestamp when the forecast is made,
therefore it is not included in the future covariates.
The cutoff date is the timestamp when the forecast is made.
As the covariates are weather parameters measured at the indicated timestamp,
the cutoff_date is not included in the future covariates.
Args:
cutoff_date: The cutoff date.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_get_history(helpers, resolution):
assert isinstance(history.index, pd.DatetimeIndex)
assert "y" in history.columns
assert len(history.columns) == 1
assert cutoff_date in history.index
assert (history.index <= cutoff_date).all()
assert cutoff_date not in history.index
assert (history.index < cutoff_date).all()


@pytest.mark.parametrize("resolution", ["15T", "30T", "1H"])
Expand Down
7 changes: 4 additions & 3 deletions tests/test_evaluations/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ def test_cross_validate_univariate_locally(helpers, model):
assert "yhat" in forecasts.columns

assert "timestamp" in forecasts.columns
assert (forecasts.timestamp > start_date).all()
assert forecasts.timestamp.iloc[-1] == end_date
assert (forecasts.timestamp >= start_date).all()
assert forecasts.timestamp.iloc[0] == start_date

assert "cutoff_date" in forecasts.columns
assert (start_date.time() == forecasts.cutoff_date.dt.time).all()

for cutoff_date, forecast in forecasts.groupby("cutoff_date"):
assert len(forecast) == 38 * 2 # 38 hours with half-hour series
assert (forecast.timestamp > cutoff_date).all()
assert forecast.timestamp.iloc[0] == cutoff_date
assert (forecast.timestamp >= cutoff_date).all()

assert list(forecasts.cutoff_date.unique()) == list(
pd.date_range(start="2020-02-01 10:00:00", end="2020-03-30 10:00:00", freq="1D", inclusive="both")
Expand Down
3 changes: 2 additions & 1 deletion tests/test_evaluations/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,5 @@ def test_create_forecast_index(helpers, freq, horizon):
assert isinstance(index, pd.DatetimeIndex)
assert index.freq == history.index.freq
assert len(index) == horizon
assert (index > cutoff_date).all()
assert index[0] == cutoff_date
assert (index >= cutoff_date).all()

0 comments on commit 73bd136

Please sign in to comment.