Skip to content

Commit

Permalink
cleanup and bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Jun 22, 2023
1 parent dc6cc54 commit 88d7082
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.7"
__version__ = "0.0.8"
1 change: 1 addition & 0 deletions src/enfobench/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._cross_validate import cross_validate
from ._evaluate import (
evaluate_metric_on_forecast,
evaluate_metric_on_forecasts,
Expand Down
7 changes: 4 additions & 3 deletions src/enfobench/evaluation/_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pandas as pd
from tqdm import tqdm

from enfobench import Model, ForecastClient
from enfobench.evaluation.client import ForecastClient
from enfobench.evaluation.protocols import Model
from enfobench.utils import steps_in_horizon


Expand Down Expand Up @@ -75,15 +76,15 @@ def cross_validate(
forecasts = []
for cutoff in tqdm(cutoff_dates):
# make sure that there is no data leakage
history = y.loc[y.ds <= cutoff, ["ds", 'y']]
history = y.loc[y.ds <= cutoff, ["ds", "y"]]

forecast = model.predict(
horizon_length,
y=history,
level=level,
)
forecast = forecast.fillna(0)
forecast['cutoff'] = cutoff
forecast["cutoff"] = cutoff
forecasts.append(forecast)

crossval_df = pd.concat(forecasts)
Expand Down
18 changes: 10 additions & 8 deletions src/enfobench/evaluation/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> flo
metric_value:
Metric value.
"""
_nonempty_df = forecast.dropna(subset=['y'])
_nonempty_df = forecast.dropna(subset=["y"])
metric_value = metric(_nonempty_df.y, _nonempty_df.yhat)
return metric_value


def evaluate_metrics_on_forecast(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[
str, float]:
def evaluate_metrics_on_forecast(
forecast: pd.DataFrame, metrics: dict[str, Callable]
) -> dict[str, float]:
"""Evaluate multiple metrics on a single forecast.
Parameters:
Expand Down Expand Up @@ -63,14 +64,15 @@ def evaluate_metric_on_forecasts(forecasts: pd.DataFrame, metric: Callable) -> p
"""
metrics = {
cutoff: evaluate_metric_on_forecast(group_df, metric)
for cutoff, group_df in forecasts.groupby('cutoff')
for cutoff, group_df in forecasts.groupby("cutoff")
}
metrics_df = pd.DataFrame.from_dict(metrics, orient='index', columns=['value'])
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["value"])
return metrics_df


def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame,
metrics: dict[str, Callable]) -> pd.DataFrame:
def evaluate_metrics_on_forecasts(
forecasts: pd.DataFrame, metrics: dict[str, Callable]
) -> pd.DataFrame:
"""Evaluate multiple metrics on a set of forecasts made at different cutoff points.
Parameters:
Expand All @@ -86,7 +88,7 @@ def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame,
Metric values for each cutoff with their weight.
"""
metric_dfs = [
evaluate_metric_on_forecasts(forecasts, metric_func).rename(columns={'value': metric_name})
evaluate_metric_on_forecasts(forecasts, metric_func).rename(columns={"value": metric_name})
for metric_name, metric_func in metrics.items()
]
metrics_df = pd.concat(metric_dfs, axis=1)
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/evaluation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ def predict(
response.raise_for_status()

df = pd.DataFrame.from_records(response.json()["forecast"])
df['ds'] = pd.to_datetime(df['ds'])
df["ds"] = pd.to_datetime(df["ds"])
return df
17 changes: 5 additions & 12 deletions src/enfobench/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import Callable

import numpy as np
import pandas as pd
from numpy import ndarray


Expand All @@ -13,7 +10,7 @@ def check_not_empty(*arrays: ndarray) -> None:
*arrays: list or tuple of input arrays.
Objects that will be checked for emptiness.
"""
if any([X.size == 0 for X in arrays]):
if any(X.size == 0 for X in arrays):
raise ValueError("Found empty array in inputs.")


Expand All @@ -27,15 +24,13 @@ def check_consistent_length(*arrays: ndarray) -> None:
*arrays : list or tuple of input arrays.
Objects that will be checked for consistent length.
"""
if any([X.ndim != 1 for X in arrays]):
if any(X.ndim != 1 for X in arrays):
raise ValueError("Found multi dimensional array in inputs.")

lengths = [len(X) for X in arrays]
uniques = np.unique(lengths)
if len(uniques) > 1:
raise ValueError(
f"Found input variables with inconsistent numbers of samples: {lengths}"
)
raise ValueError(f"Found input variables with inconsistent numbers of samples: {lengths}")


def check_has_no_nan(*arrays: ndarray) -> None:
Expand All @@ -48,9 +43,7 @@ def check_has_no_nan(*arrays: ndarray) -> None:
"""
for X in arrays:
if np.isnan(X).any():
raise ValueError(
f"Found NaNs in input variables: {X}"
)
raise ValueError(f"Found NaNs in input variables: {X}")


def check_arrays(*arrays: ndarray) -> None:
Expand Down Expand Up @@ -135,4 +128,4 @@ def mean_absolute_percentage_error(y_true: ndarray, y_pred: ndarray) -> float:
check_arrays(y_true, y_pred)
if np.any(y_true == 0):
raise ValueError("Found zero in true values. MAPE is undefined.")
return float(100. * np.mean(np.abs((y_true - y_pred) / y_true)))
return float(100.0 * np.mean(np.abs((y_true - y_pred) / y_true)))
16 changes: 8 additions & 8 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from enfobench.evaluation.metrics import (
mean_absolute_error,
mean_absolute_percentage_error,
mean_bias_error,
root_mean_squared_error,
mean_squared_error,
mean_absolute_percentage_error
root_mean_squared_error,
)

all_metrics = [
Expand Down Expand Up @@ -43,12 +43,12 @@ def test_metric_raises_with_empty_array(metric):


def test_mean_absolute_error():
assert mean_absolute_error(np.array([1, 2, 3]), np.array([1, 2, 3])) == 0.
assert mean_absolute_error(np.array([1, 2, 3]), np.array([2, 3, 4])) == 1.
assert mean_absolute_error(np.array([1, 2, 3]), np.array([0, 1, 2])) == 1.
assert mean_absolute_error(np.array([1, 2, 3]), np.array([1, 2, 3])) == 0.0
assert mean_absolute_error(np.array([1, 2, 3]), np.array([2, 3, 4])) == 1.0
assert mean_absolute_error(np.array([1, 2, 3]), np.array([0, 1, 2])) == 1.0


def test_mean_bias_error():
assert mean_bias_error(np.array([1, 2, 3]), np.array([1, 2, 3])) == 0.
assert mean_bias_error(np.array([1, 2, 3]), np.array([2, 3, 4])) == 1.
assert mean_bias_error(np.array([1, 2, 3]), np.array([0, 1, 2])) == -1.
assert mean_bias_error(np.array([1, 2, 3]), np.array([1, 2, 3])) == 0.0
assert mean_bias_error(np.array([1, 2, 3]), np.array([2, 3, 4])) == 1.0
assert mean_bias_error(np.array([1, 2, 3]), np.array([0, 1, 2])) == -1.0

0 comments on commit 88d7082

Please sign in to comment.