-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e3817e8
commit dc6cc54
Showing
12 changed files
with
288 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.0.6" | ||
__version__ = "0.0.7" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from ._evaluate import ( | ||
evaluate_metric_on_forecast, | ||
evaluate_metric_on_forecasts, | ||
evaluate_metrics_on_forecast, | ||
evaluate_metrics_on_forecasts, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
from enfobench import Model, ForecastClient | ||
from enfobench.utils import steps_in_horizon | ||
|
||
|
||
def generate_cutoff_dates( | ||
start: pd.Timestamp, | ||
end: pd.Timestamp, | ||
horizon: pd.Timedelta, | ||
step: pd.Timedelta, | ||
) -> list[pd.Timestamp]: | ||
"""Generate cutoff dates for cross-validation. | ||
Parameters | ||
---------- | ||
start: | ||
Start date of the time series. | ||
end: | ||
End date of the time series. | ||
horizon: | ||
Forecast horizon. | ||
step: | ||
Step size between cutoff dates. | ||
""" | ||
cutoff_dates = [] | ||
|
||
cutoff = start | ||
while cutoff <= end - horizon: | ||
cutoff_dates.append(cutoff) | ||
cutoff += step | ||
|
||
if not cutoff_dates: | ||
raise ValueError("No dates for cross-validation") | ||
return cutoff_dates | ||
|
||
|
||
def cross_validate( | ||
model: Model | ForecastClient, | ||
start: pd.Timestamp, | ||
horizon: pd.Timedelta, | ||
step: pd.Timedelta, | ||
y: pd.Series, | ||
level: list[int] | None = None, | ||
freq: str | None = None, | ||
) -> pd.DataFrame: | ||
"""Cross-validate a model. | ||
Parameters | ||
---------- | ||
model: | ||
Model to cross-validate. | ||
start: | ||
Start date of the time series. | ||
horizon: | ||
Forecast horizon. | ||
step: | ||
Step size between cutoff dates. | ||
y: | ||
Time series target values. | ||
level: | ||
Prediction intervals to compute. | ||
(Optional, if not provided, simple point forecasts will be computed.) | ||
freq: | ||
Frequency of the time series. | ||
(Optional, if not provided, it will be inferred from the time series index.) | ||
""" | ||
cutoff_dates = generate_cutoff_dates(start, y.index[-1], horizon, step) | ||
horizon_length = steps_in_horizon(horizon, freq or y.index.inferred_freq) | ||
|
||
# Cross-validation | ||
forecasts = [] | ||
for cutoff in tqdm(cutoff_dates): | ||
# make sure that there is no data leakage | ||
history = y.loc[y.ds <= cutoff, ["ds", 'y']] | ||
|
||
forecast = model.predict( | ||
horizon_length, | ||
y=history, | ||
level=level, | ||
) | ||
forecast = forecast.fillna(0) | ||
forecast['cutoff'] = cutoff | ||
forecasts.append(forecast) | ||
|
||
crossval_df = pd.concat(forecasts) | ||
return crossval_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import Callable | ||
|
||
import pandas as pd | ||
|
||
|
||
def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> float: | ||
"""Evaluate a single metric on a single forecast. | ||
Parameters: | ||
----------- | ||
forecast: | ||
Forecast to evaluate. | ||
metric: | ||
Metric to evaluate. | ||
Returns: | ||
-------- | ||
metric_value: | ||
Metric value. | ||
""" | ||
_nonempty_df = forecast.dropna(subset=['y']) | ||
metric_value = metric(_nonempty_df.y, _nonempty_df.yhat) | ||
return metric_value | ||
|
||
|
||
def evaluate_metrics_on_forecast(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[ | ||
str, float]: | ||
"""Evaluate multiple metrics on a single forecast. | ||
Parameters: | ||
----------- | ||
forecast: | ||
Forecast to evaluate. | ||
metrics: | ||
Metric to evaluate. | ||
Returns: | ||
-------- | ||
metric_value: | ||
Metric value. | ||
""" | ||
metric_values = { | ||
metric_name: evaluate_metric_on_forecast(forecast, metric) | ||
for metric_name, metric in metrics.items() | ||
} | ||
return metric_values | ||
|
||
|
||
def evaluate_metric_on_forecasts(forecasts: pd.DataFrame, metric: Callable) -> pd.DataFrame: | ||
"""Evaluate a single metric on a set of forecasts made at different cutoff points. | ||
Parameters: | ||
----------- | ||
forecasts: | ||
Forecasts to evaluate. | ||
metric: | ||
Metric to evaluate. | ||
Returns: | ||
-------- | ||
metrics_df: | ||
Metric values for each cutoff with their weight. | ||
""" | ||
metrics = { | ||
cutoff: evaluate_metric_on_forecast(group_df, metric) | ||
for cutoff, group_df in forecasts.groupby('cutoff') | ||
} | ||
metrics_df = pd.DataFrame.from_dict(metrics, orient='index', columns=['value']) | ||
return metrics_df | ||
|
||
|
||
def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame, | ||
metrics: dict[str, Callable]) -> pd.DataFrame: | ||
"""Evaluate multiple metrics on a set of forecasts made at different cutoff points. | ||
Parameters: | ||
----------- | ||
forecasts: | ||
Forecasts to evaluate. | ||
metrics: | ||
Metric to evaluate. | ||
Returns: | ||
-------- | ||
metrics_df: | ||
Metric values for each cutoff with their weight. | ||
""" | ||
metric_dfs = [ | ||
evaluate_metric_on_forecasts(forecasts, metric_func).rename(columns={'value': metric_name}) | ||
for metric_name, metric_func in metrics.items() | ||
] | ||
metrics_df = pd.concat(metric_dfs, axis=1) | ||
return metrics_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.