From 4d249eff008a4b2b0224107121016d25730d2131 Mon Sep 17 00:00:00 2001 From: Attila Balint Date: Fri, 24 Nov 2023 15:02:10 +0100 Subject: [PATCH] sorted crossval dataframe --- src/enfobench/__version__.py | 2 +- src/enfobench/dataset/utils.py | 6 ++++-- src/enfobench/evaluation/evaluate.py | 1 + tests/test_dataset/test_utils.py | 1 + 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/enfobench/__version__.py b/src/enfobench/__version__.py index 334b899..a8d4557 100644 --- a/src/enfobench/__version__.py +++ b/src/enfobench/__version__.py @@ -1 +1 @@ -__version__ = "0.3.4" +__version__ = "0.3.5" diff --git a/src/enfobench/dataset/utils.py b/src/enfobench/dataset/utils.py index 1d7903e..6e83215 100644 --- a/src/enfobench/dataset/utils.py +++ b/src/enfobench/dataset/utils.py @@ -32,7 +32,8 @@ def create_perfect_forecasts_from_covariates( forecast = past_covariates.loc[(past_covariates.index > start) & (past_covariates.index <= start + horizon)] forecast.rename_axis("timestamp", inplace=True) forecast.reset_index(inplace=True) - forecast["cutoff_date"] = start.isoformat() # pd.concat fails if cutoff_date is a Timestamp + forecast["cutoff_date"] = pd.to_datetime(start, unit="ns") + forecast.set_index(["cutoff_date", "timestamp"], inplace=True) if len(forecast) == 0: warnings.warn( @@ -45,5 +46,6 @@ def create_perfect_forecasts_from_covariates( start += step forecast_df = pd.concat(forecasts, ignore_index=False) - forecast_df["cutoff_date"] = pd.to_datetime(forecast_df["cutoff_date"]) # convert back to Timestamp + forecast_df.reset_index(inplace=True) + forecast_df.sort_values(by=["cutoff_date", "timestamp"], inplace=True) return forecast_df diff --git a/src/enfobench/evaluation/evaluate.py b/src/enfobench/evaluation/evaluate.py index 35da01b..0fe219b 100644 --- a/src/enfobench/evaluation/evaluate.py +++ b/src/enfobench/evaluation/evaluate.py @@ -159,4 +159,5 @@ def cross_validate( # Merge the forecast with the target crossval_df = crossval_df.merge(dataset._target, left_on="timestamp", right_index=True) + crossval_df.sort_values(by=["cutoff_date", "timestamp"], inplace=True) return crossval_df diff --git a/tests/test_dataset/test_utils.py b/tests/test_dataset/test_utils.py index 4fafb3a..96bdf9e 100644 --- a/tests/test_dataset/test_utils.py +++ b/tests/test_dataset/test_utils.py @@ -11,6 +11,7 @@ def test_create_perfect_forecasts_from_covariates(): data=np.random.rand(len(index), 2), columns=["covariate_1", "covariate_2"], ) + past_covariates.drop(index[2], inplace=True) future_covariates = create_perfect_forecasts_from_covariates( past_covariates,