Skip to content

Commit

Permalink
sorted crossval dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Nov 24, 2023
1 parent ca6f3fd commit 4d249ef
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.4"
__version__ = "0.3.5"
6 changes: 4 additions & 2 deletions src/enfobench/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def create_perfect_forecasts_from_covariates(
forecast = past_covariates.loc[(past_covariates.index > start) & (past_covariates.index <= start + horizon)]
forecast.rename_axis("timestamp", inplace=True)
forecast.reset_index(inplace=True)
forecast["cutoff_date"] = start.isoformat() # pd.concat fails if cutoff_date is a Timestamp
forecast["cutoff_date"] = pd.to_datetime(start, unit="ns")
forecast.set_index(["cutoff_date", "timestamp"], inplace=True)

if len(forecast) == 0:
warnings.warn(
Expand All @@ -45,5 +46,6 @@ def create_perfect_forecasts_from_covariates(
start += step

forecast_df = pd.concat(forecasts, ignore_index=False)
forecast_df["cutoff_date"] = pd.to_datetime(forecast_df["cutoff_date"]) # convert back to Timestamp
forecast_df.reset_index(inplace=True)
forecast_df.sort_values(by=["cutoff_date", "timestamp"], inplace=True)
return forecast_df
1 change: 1 addition & 0 deletions src/enfobench/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ def cross_validate(

# Merge the forecast with the target
crossval_df = crossval_df.merge(dataset._target, left_on="timestamp", right_index=True)
crossval_df.sort_values(by=["cutoff_date", "timestamp"], inplace=True)
return crossval_df
1 change: 1 addition & 0 deletions tests/test_dataset/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def test_create_perfect_forecasts_from_covariates():
data=np.random.rand(len(index), 2),
columns=["covariate_1", "covariate_2"],
)
past_covariates.drop(index[2], inplace=True)

future_covariates = create_perfect_forecasts_from_covariates(
past_covariates,
Expand Down

0 comments on commit 4d249ef

Please sign in to comment.