Skip to content

Commit

Permalink
added cross validation method
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Jun 22, 2023
1 parent e3817e8 commit dc6cc54
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 110 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"python-multipart>=0.0.0,<1.0.0",
"pyarrow>=12.0.0,<13.0.0",
"requests>=2.26.0,<3.0.0",
"tqdm>=4.60.0,<5.0.0",
"uvicorn>=0.20.0,<1.0.0",
]
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.6"
__version__ = "0.0.7"
6 changes: 6 additions & 0 deletions src/enfobench/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._evaluate import (
evaluate_metric_on_forecast,
evaluate_metric_on_forecasts,
evaluate_metrics_on_forecast,
evaluate_metrics_on_forecasts,
)
90 changes: 90 additions & 0 deletions src/enfobench/evaluation/_cross_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

import pandas as pd
from tqdm import tqdm

from enfobench import Model, ForecastClient
from enfobench.utils import steps_in_horizon


def generate_cutoff_dates(
start: pd.Timestamp,
end: pd.Timestamp,
horizon: pd.Timedelta,
step: pd.Timedelta,
) -> list[pd.Timestamp]:
"""Generate cutoff dates for cross-validation.
Parameters
----------
start:
Start date of the time series.
end:
End date of the time series.
horizon:
Forecast horizon.
step:
Step size between cutoff dates.
"""
cutoff_dates = []

cutoff = start
while cutoff <= end - horizon:
cutoff_dates.append(cutoff)
cutoff += step

if not cutoff_dates:
raise ValueError("No dates for cross-validation")
return cutoff_dates


def cross_validate(
model: Model | ForecastClient,
start: pd.Timestamp,
horizon: pd.Timedelta,
step: pd.Timedelta,
y: pd.Series,
level: list[int] | None = None,
freq: str | None = None,
) -> pd.DataFrame:
"""Cross-validate a model.
Parameters
----------
model:
Model to cross-validate.
start:
Start date of the time series.
horizon:
Forecast horizon.
step:
Step size between cutoff dates.
y:
Time series target values.
level:
Prediction intervals to compute.
(Optional, if not provided, simple point forecasts will be computed.)
freq:
Frequency of the time series.
(Optional, if not provided, it will be inferred from the time series index.)
"""
cutoff_dates = generate_cutoff_dates(start, y.index[-1], horizon, step)
horizon_length = steps_in_horizon(horizon, freq or y.index.inferred_freq)

# Cross-validation
forecasts = []
for cutoff in tqdm(cutoff_dates):
# make sure that there is no data leakage
history = y.loc[y.ds <= cutoff, ["ds", 'y']]

forecast = model.predict(
horizon_length,
y=history,
level=level,
)
forecast = forecast.fillna(0)
forecast['cutoff'] = cutoff
forecasts.append(forecast)

crossval_df = pd.concat(forecasts)
return crossval_df
93 changes: 93 additions & 0 deletions src/enfobench/evaluation/_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Callable

import pandas as pd


def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> float:
"""Evaluate a single metric on a single forecast.
Parameters:
-----------
forecast:
Forecast to evaluate.
metric:
Metric to evaluate.
Returns:
--------
metric_value:
Metric value.
"""
_nonempty_df = forecast.dropna(subset=['y'])
metric_value = metric(_nonempty_df.y, _nonempty_df.yhat)
return metric_value


def evaluate_metrics_on_forecast(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[
str, float]:
"""Evaluate multiple metrics on a single forecast.
Parameters:
-----------
forecast:
Forecast to evaluate.
metrics:
Metric to evaluate.
Returns:
--------
metric_value:
Metric value.
"""
metric_values = {
metric_name: evaluate_metric_on_forecast(forecast, metric)
for metric_name, metric in metrics.items()
}
return metric_values


def evaluate_metric_on_forecasts(forecasts: pd.DataFrame, metric: Callable) -> pd.DataFrame:
"""Evaluate a single metric on a set of forecasts made at different cutoff points.
Parameters:
-----------
forecasts:
Forecasts to evaluate.
metric:
Metric to evaluate.
Returns:
--------
metrics_df:
Metric values for each cutoff with their weight.
"""
metrics = {
cutoff: evaluate_metric_on_forecast(group_df, metric)
for cutoff, group_df in forecasts.groupby('cutoff')
}
metrics_df = pd.DataFrame.from_dict(metrics, orient='index', columns=['value'])
return metrics_df


def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame,
metrics: dict[str, Callable]) -> pd.DataFrame:
"""Evaluate multiple metrics on a set of forecasts made at different cutoff points.
Parameters:
-----------
forecasts:
Forecasts to evaluate.
metrics:
Metric to evaluate.
Returns:
--------
metrics_df:
Metric values for each cutoff with their weight.
"""
metric_dfs = [
evaluate_metric_on_forecasts(forecasts, metric_func).rename(columns={'value': metric_name})
for metric_name, metric_func in metrics.items()
]
metrics_df = pd.concat(metric_dfs, axis=1)
return metrics_df
130 changes: 30 additions & 100 deletions src/enfobench/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
from numpy import ndarray


def check_not_empty(*arrays: ndarray) -> None:
"""Check that none of the arrays are not empty.
Parameters
----------
*arrays: list or tuple of input arrays.
Objects that will be checked for emptiness.
"""
if any([X.size == 0 for X in arrays]):
raise ValueError("Found empty array in inputs.")


def check_consistent_length(*arrays: ndarray) -> None:
"""Check that all arrays have consistent length.
Expand Down Expand Up @@ -41,6 +53,19 @@ def check_has_no_nan(*arrays: ndarray) -> None:
)


def check_arrays(*arrays: ndarray) -> None:
"""Check that all arrays are valid.
Parameters
----------
*arrays : list or tuple of input arrays.
Objects that will be checked for validity.
"""
check_not_empty(*arrays)
check_consistent_length(*arrays)
check_has_no_nan(*arrays)


def mean_absolute_error(y_true: ndarray, y_pred: ndarray) -> float:
"""Mean absolute error regression loss.
Expand All @@ -51,8 +76,7 @@ def mean_absolute_error(y_true: ndarray, y_pred: ndarray) -> float:
y_pred : array-like of shape (n_samples,)
Estimated target values.
"""
check_consistent_length(y_true, y_pred)
check_has_no_nan(y_true, y_pred)
check_arrays(y_true, y_pred)
return float(np.mean(np.abs(y_true - y_pred)))


Expand All @@ -66,8 +90,7 @@ def mean_bias_error(y_true: ndarray, y_pred: ndarray) -> float:
y_pred : array-like of shape (n_samples,)
Estimated target values.
"""
check_consistent_length(y_true, y_pred)
check_has_no_nan(y_true, y_pred)
check_arrays(y_true, y_pred)
return float(np.mean(y_pred - y_true))


Expand All @@ -81,8 +104,7 @@ def mean_squared_error(y_true: ndarray, y_pred: ndarray) -> float:
y_pred : array-like of shape (n_samples,)
Estimated target values.
"""
check_consistent_length(y_true, y_pred)
check_has_no_nan(y_true, y_pred)
check_arrays(y_true, y_pred)
return float(np.mean((y_true - y_pred) ** 2))


Expand All @@ -96,8 +118,7 @@ def root_mean_squared_error(y_true: ndarray, y_pred: ndarray) -> float:
y_pred : array-like of shape (n_samples,)
Estimated target values.
"""
check_consistent_length(y_true, y_pred)
check_has_no_nan(y_true, y_pred)
check_arrays(y_true, y_pred)
return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))


Expand All @@ -111,98 +132,7 @@ def mean_absolute_percentage_error(y_true: ndarray, y_pred: ndarray) -> float:
y_pred : array-like of shape (n_samples,)
Estimated target values.
"""
check_consistent_length(y_true, y_pred)
check_has_no_nan(y_true, y_pred)
check_arrays(y_true, y_pred)
if np.any(y_true == 0):
raise ValueError("Found zero in true values. MAPE is undefined.")
return float(100. * np.mean(np.abs((y_true - y_pred) / y_true)))


def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> float:
"""Evaluate a single metric on a single forecast.
Parameters:
-----------
forecast:
Forecast to evaluate.
metric:
Metric to evaluate.
Returns:
--------
metric_value:
Metric value.
"""
_nonempty_df = forecast.dropna(subset=['y'])
metric_value = metric(_nonempty_df.y, _nonempty_df.yhat)
return metric_value


def evaluate_metrics_on_forecast(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[
str, float]:
"""Evaluate multiple metrics on a single forecast.
Parameters:
-----------
forecast:
Forecast to evaluate.
metrics:
Metric to evaluate.
Returns:
--------
metric_value:
Metric value.
"""
metric_values = {
metric_name: evaluate_metric_on_forecast(forecast, metric)
for metric_name, metric in metrics.items()
}
return metric_values


def evaluate_metric_on_forecasts(forecasts: pd.DataFrame, metric: Callable) -> pd.DataFrame:
"""Evaluate a single metric on a set of forecasts made at different cutoff points.
Parameters:
-----------
forecasts:
Forecasts to evaluate.
metric:
Metric to evaluate.
Returns:
--------
metrics_df:
Metric values for each cutoff with their weight.
"""
metrics = {
cutoff: evaluate_metric_on_forecast(group_df, metric)
for cutoff, group_df in forecasts.groupby('cutoff')
}
metrics_df = pd.DataFrame.from_dict(metrics, orient='index', columns=['value'])
return metrics_df


def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame,
metrics: dict[str, Callable]) -> pd.DataFrame:
"""Evaluate multiple metrics on a set of forecasts made at different cutoff points.
Parameters:
-----------
forecasts:
Forecasts to evaluate.
metrics:
Metric to evaluate.
Returns:
--------
metrics_df:
Metric values for each cutoff with their weight.
"""
metric_dfs = [
evaluate_metric_on_forecasts(forecasts, metric_func).rename(columns={'value': metric_name})
for metric_name, metric_func in metrics.items()
]
metrics_df = pd.concat(metric_dfs, axis=1)
return metrics_df
2 changes: 1 addition & 1 deletion src/enfobench/evaluation/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Model(Protocol):
def info(self) -> ModelInfo:
...

def forecast(
def predict(
self,
h: int,
y: pd.Series,
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/evaluation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def predict(
y_df["ds"] = pd.to_datetime(y_df["ds"])
y = y_df.set_index("ds").y

forecast = model.forecast(
forecast = model.predict(
h=horizon,
y=y,
# X=X_df,
Expand Down
Loading

0 comments on commit dc6cc54

Please sign in to comment.