Skip to content

Commit

Permalink
fixed issue with metrics evaluation where a group would not have any …
Browse files Browse the repository at this point in the history
…target values
  • Loading branch information
attila-balint-kul committed Nov 18, 2023
1 parent 01b7702 commit cad79ec
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 40 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ cv_results = cross_validate(
The package also collects common metrics used in forecasting.

```python
from enfobench.evaluation import evaluate_metrics_on_forecasts
from enfobench.evaluation import evaluate_metrics

from enfobench.evaluation.metrics import (
mean_bias_error,
mean_absolute_error,
mean_squared_error,
mean_bias_error,
mean_absolute_error,
mean_squared_error,
root_mean_squared_error,
)

# Simply pass in the cross validation results and the metrics you want to evaluate.
metrics = evaluate_metrics_on_forecasts(
metrics = evaluate_metrics(
cv_results,
metrics={
"mean_bias_error": mean_bias_error,
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.1"
__version__ = "0.3.2"
4 changes: 2 additions & 2 deletions src/enfobench/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from enfobench.evaluation.client import ForecastClient
from enfobench.evaluation.evaluate import (
cross_validate,
evaluate_metrics_on_forecasts,
evaluate_metrics,
)
from enfobench.evaluation.model import AuthorInfo, ForecasterType, Model, ModelInfo
from enfobench.evaluation.protocols import Dataset

__all__ = [
"ForecastClient",
"cross_validate",
"evaluate_metrics_on_forecasts",
"evaluate_metrics",
"Dataset",
"Model",
"ModelInfo",
Expand Down
72 changes: 40 additions & 32 deletions src/enfobench/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from collections.abc import Callable
from typing import Any

import numpy as np
import pandas as pd
from tqdm import tqdm

Expand All @@ -10,8 +12,8 @@
from enfobench.evaluation.utils import generate_cutoff_dates, steps_in_horizon


def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> float:
"""Evaluate a single metric on a single forecast.
def _compute_metric(forecast: pd.DataFrame, metric: Callable) -> float:
"""Compute a single metric value.
Args:
forecast: Forecast to evaluate.
Expand All @@ -20,59 +22,65 @@ def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> flo
Returns:
Metric value.
"""
_nonempty_df = forecast.dropna(subset=["y"])
metric_value = metric(_nonempty_df.y, _nonempty_df.yhat)
metric_value = metric(forecast.y, forecast.yhat)
return metric_value


def evaluate_metrics_on_forecast(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[str, float]:
"""Evaluate multiple metrics on a single forecast.
def _compute_metrics(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[str, float]:
"""Compute multiple metric values.
Args:
forecast: Forecast to evaluate.
metrics: Metric to evaluate.
Returns:
Metric value.
Metrics dictionary with metric names as keys and metric values as values.
"""
metric_values = {
metric_name: evaluate_metric_on_forecast(forecast, metric) for metric_name, metric in metrics.items()
}
metric_values = {metric_name: _compute_metric(forecast, metric) for metric_name, metric in metrics.items()}
return metric_values


def evaluate_metric_on_forecasts(forecasts: pd.DataFrame, metric: Callable) -> pd.DataFrame:
"""Evaluate a single metric on a set of forecasts made at different cutoff points.
Args:
forecasts: Forecasts to evaluate.
metric: Metric to evaluate.
Returns:
Metric values for each cutoff with their weight.
"""
metrics = {
cutoff: evaluate_metric_on_forecast(group_df, metric) for cutoff, group_df in forecasts.groupby("cutoff_date")
}
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["value"])
return metrics_df
def _evaluate_group(forecasts: pd.DataFrame, metrics: dict[str, Callable], index: Any) -> pd.DataFrame:
clean_df = forecasts.dropna(subset=["y"])
if clean_df.empty:
ratio = 0.0
metrics = {metric_name: np.nan for metric_name in metrics}
else:
ratio = len(clean_df) / len(forecasts)
metrics = _compute_metrics(clean_df, metrics)
df = pd.DataFrame({**metrics, "weight": ratio}, index=[index])
return df


def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame, metrics: dict[str, Callable]) -> pd.DataFrame:
"""Evaluate multiple metrics on a set of forecasts made at different cutoff points.
def evaluate_metrics(
forecasts: pd.DataFrame,
metrics: dict[str, Callable],
*,
groupby: str | None = None,
) -> pd.DataFrame:
"""Evaluate multiple metrics on forecasts.
Args:
forecasts: Forecasts to evaluate.
metrics: Metric to evaluate.
groupby: Column to group forecasts by. (Optional, if not provided, forecasts will not be grouped.)
Returns:
Metric values for each cutoff with their weight.
"""
metric_dfs = [
evaluate_metric_on_forecasts(forecasts, metric_func).rename(columns={"value": metric_name})
for metric_name, metric_func in metrics.items()
]
metrics_df = pd.concat(metric_dfs, axis=1)
if groupby is None:
return _evaluate_group(forecasts, metrics, 0)

if groupby not in forecasts.columns:
msg = f"Groupby column {groupby} not found in forecasts."
raise ValueError(msg)

metrics_df = pd.concat(
[_evaluate_group(group_df, metrics, value) for value, group_df in tqdm(forecasts.groupby(groupby))]
)

metrics_df.rename_axis(groupby, inplace=True)
metrics_df.reset_index(inplace=True)
return metrics_df


Expand Down
45 changes: 45 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,48 @@ def external_forecasts() -> pd.DataFrame:
forecasts.append(forecast)

return pd.concat(forecasts, ignore_index=True)


@pytest.fixture(scope="session")
def clean_forecasts() -> pd.DataFrame:
cutoff_dates = pd.date_range("2020-01-01", "2021-01-01", freq="D")

forecasts = []
for cutoff_date in cutoff_dates:
index = pd.date_range(cutoff_date + pd.Timedelta(hours=1), cutoff_date + pd.Timedelta(hours=25), freq="H")
forecast = pd.DataFrame(
data={
"timestamp": index,
"yhat": np.random.random(len(index)),
"y": np.random.random(len(index)),
},
)
forecast["cutoff_date"] = cutoff_date
forecasts.append(forecast)

forecast_df = pd.concat(forecasts, ignore_index=True)
assert not forecast_df.isna().any(axis=None)
return forecast_df


@pytest.fixture(scope="session")
def forecasts_with_missing_values() -> pd.DataFrame:
cutoff_dates = pd.date_range("2020-01-01", "2021-01-01", freq="D")

forecasts = []
for cutoff_date in cutoff_dates:
index = pd.date_range(cutoff_date + pd.Timedelta(hours=1), cutoff_date + pd.Timedelta(hours=25), freq="H")
forecast = pd.DataFrame(
data={
"timestamp": index,
"yhat": np.random.random(len(index)),
"y": np.random.random(len(index)),
},
)
forecast["cutoff_date"] = cutoff_date
forecasts.append(forecast)

forecast_df = pd.concat(forecasts, ignore_index=True)
forecast_df.loc[forecast_df["y"] <= 0.3, "y"] = np.nan
assert forecast_df.isna().any(axis=None)
return forecast_df
92 changes: 92 additions & 0 deletions tests/test_evaluations/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ForecastClient,
cross_validate,
)
from enfobench.evaluation.evaluate import _compute_metric, _compute_metrics, evaluate_metrics
from enfobench.evaluation.metrics import mean_absolute_error, mean_squared_error
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import generate_cutoff_dates

Expand Down Expand Up @@ -148,3 +150,93 @@ def test_cross_validate_multivariate_via_server(model, target, covariates, exter
assert "timestamp" in forecasts.columns
assert "yhat" in forecasts.columns
assert "cutoff_date" in forecasts.columns


def test_compute_metric(clean_forecasts):
metric_value = _compute_metric(clean_forecasts, mean_absolute_error)

assert 0 < metric_value < 1


def test_compute_metrics(clean_forecasts):
metric_values = _compute_metrics(clean_forecasts, {"MAE": mean_absolute_error, "MSE": mean_squared_error})

assert isinstance(metric_values, dict)
assert 0 < metric_values["MAE"] < 1
assert 0 < metric_values["MSE"] < 1


def test_compute_metric_on_forecast_with_missing_values_raises_error(forecasts_with_missing_values):
with pytest.raises(ValueError):
_compute_metric(forecasts_with_missing_values, mean_absolute_error)


def test_compute_metrics_on_forecast_with_missing_values_raises_error(forecasts_with_missing_values):
with pytest.raises(ValueError):
_compute_metrics(forecasts_with_missing_values, {"MAE": mean_absolute_error, "MSE": mean_squared_error})


def test_evaluate_metrics_on_clean_forecasts(clean_forecasts):
metrics = evaluate_metrics(clean_forecasts, {"MAE": mean_absolute_error, "MSE": mean_squared_error})

assert isinstance(metrics, pd.DataFrame)
assert "MAE" in metrics.columns
assert "MSE" in metrics.columns
assert "weight" in metrics.columns
assert len(metrics) == 1
assert 0 < metrics["MAE"].iloc[0] < 1
assert 0 < metrics["MSE"].iloc[0] < 1
assert metrics["weight"].iloc[0] == 1


def test_evaluate_metrics_on_forecasts_with_missing_values(forecasts_with_missing_values):
metrics = evaluate_metrics(forecasts_with_missing_values, {"MAE": mean_absolute_error, "MSE": mean_squared_error})

assert isinstance(metrics, pd.DataFrame)
assert "MAE" in metrics.columns
assert "MSE" in metrics.columns
assert "weight" in metrics.columns
assert len(metrics) == 1
assert 0 < metrics["MAE"].iloc[0] < 1
assert 0 < metrics["MSE"].iloc[0] < 1
assert pytest.approx(metrics["weight"].iloc[0], 0.1) == 1 - 0.3


def test_evaluate_metrics_on_clean_forecasts_grouped_by(clean_forecasts):
metrics = evaluate_metrics(
clean_forecasts,
{"MAE": mean_absolute_error, "MSE": mean_squared_error},
groupby="cutoff_date",
)

grouped_values = clean_forecasts["cutoff_date"].unique()

assert isinstance(metrics, pd.DataFrame)
assert "MAE" in metrics.columns
assert "MSE" in metrics.columns
assert "weight" in metrics.columns
assert "cutoff_date" in metrics.columns
assert len(metrics) == len(grouped_values)
assert all(0 < metric < 1 for metric in metrics["MAE"])
assert all(0 < metric < 1 for metric in metrics["MSE"])
assert all(metric == 1 for metric in metrics["weight"])


def test_evaluate_metrics_on_forecasts_with_missing_values_grouped_by(forecasts_with_missing_values):
metrics = evaluate_metrics(
forecasts_with_missing_values,
{"MAE": mean_absolute_error, "MSE": mean_squared_error},
groupby="cutoff_date",
)

grouped_values = forecasts_with_missing_values["cutoff_date"].unique()

assert isinstance(metrics, pd.DataFrame)
assert "MAE" in metrics.columns
assert "MSE" in metrics.columns
assert "weight" in metrics.columns
assert "cutoff_date" in metrics.columns
assert len(metrics) == len(grouped_values)
assert all(0 < metric < 1 for metric in metrics["MAE"])
assert all(0 < metric < 1 for metric in metrics["MSE"])
assert all(0 <= metric <= 1 for metric in metrics["weight"])

0 comments on commit cad79ec

Please sign in to comment.