Skip to content

Commit

Permalink
fixed cutoff-date issue and added more detailed notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Jun 18, 2024
1 parent f75b16d commit d3f7944
Show file tree
Hide file tree
Showing 13 changed files with 48,501 additions and 481 deletions.
24,876 changes: 24,747 additions & 129 deletions notebooks/01. Univariate.ipynb

Large diffs are not rendered by default.

23,948 changes: 23,643 additions & 305 deletions notebooks/02. Multivariate.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.1"
__version__ = "0.7.0"
15 changes: 12 additions & 3 deletions src/enfobench/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,24 @@ def _check_cutoff_in_rage(self, cutoff_date: pd.Timestamp):
def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame:
"""Returns the history of the target variable up to the cutoff date.
The cutoff date is the timestamp when the forecast is made,
therefore it is included in the history.
Args:
cutoff_date: The cutoff date.
Returns:
The history of the target variable up to the cutoff date.
"""
self._check_cutoff_in_rage(cutoff_date)
return self._target[self._target.index < cutoff_date]
return self._target[self._target.index <= cutoff_date]

def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
"""Returns the past covariates for the cutoff date.
The cutoff date is the timestamp when the forecast is made,
therefore it is included in the past covariates.
Args:
cutoff_date: The cutoff date.
Expand All @@ -170,11 +176,14 @@ def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
return None

self._check_cutoff_in_rage(cutoff_date)
return self._past_covariates[self._past_covariates.index < cutoff_date]
return self._past_covariates[self._past_covariates.index <= cutoff_date]

def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None:
"""Returns the future covariates for the cutoff date.
The cutoff date is the timestamp when the forecast is made,
therefore it is not included in the future covariates.
Args:
cutoff_date: The cutoff date.
Expand All @@ -191,7 +200,7 @@ def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | Non

future_covariates = self._future_covariates[
(self._future_covariates.cutoff_date == last_past_cutoff_date)
& (self._future_covariates.timestamp >= cutoff_date)
& (self._future_covariates.timestamp > cutoff_date)
]
future_covariates.set_index("timestamp", inplace=True)
return future_covariates
2 changes: 2 additions & 0 deletions src/enfobench/datasets/electricity_demand.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ def get_data_by_unique_id(self, unique_id: str) -> tuple[pd.DataFrame, pd.DataFr

demand = self.demand_subset.get_by_unique_id(unique_id)
weather = self.weather_subset.get_by_location_id(location_id)
# Filter weather data to match the period of the demand data
weather = weather.loc[demand.index[0] - pd.Timedelta("7 days") : demand.index[-1] + pd.Timedelta("7 days")]
return demand, weather, metadata
11 changes: 10 additions & 1 deletion src/enfobench/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ def create_perfect_forecasts_from_covariates(
The external forecast dataframe.
"""
start = start or past_covariates.index[0]

if start < past_covariates.index[0]:
msg = f"start={start} must be after the start of the past_covariates={past_covariates.index[0]}"
raise ValueError(msg)

if start > past_covariates.index[-1]:
msg = f"start={start} must be before the end of the past_covariates={past_covariates.index[-1]}"
raise ValueError(msg)

last_date = past_covariates.index[-1]

forecasts = []
while start + horizon <= last_date:
forecast = past_covariates.loc[(past_covariates.index >= start) & (past_covariates.index < start + horizon)]
forecast = past_covariates.loc[start : start + horizon]
forecast.rename_axis("timestamp", inplace=True)
forecast.reset_index(inplace=True)
forecast.insert(0, "cutoff_date", pd.to_datetime(start, unit="ns"))
Expand Down
12 changes: 10 additions & 2 deletions src/enfobench/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,23 @@ def cross_validate(
level: Prediction intervals to compute. (Optional, if not provided, simple point forecasts will be computed.)
"""
if start_date <= dataset.target_available_since:
msg = f"Start date must be after the start of the dataset: {start_date} <= {dataset.target_available_since}."
msg = f"start_date={start_date} must be after the start of the dataset={dataset.target_available_since}"
raise ValueError(msg)

if start_date >= dataset.target_available_until:
msg = f"start_date={start_date} must be before the end of the dataset={dataset.target_available_until}"
raise ValueError(msg)

initial_training_data = start_date - dataset.target_available_since
if initial_training_data < pd.Timedelta("7 days"):
warnings.warn("Initial training data is less than 7 days.", stacklevel=2)

if end_date < dataset.target_available_since:
msg = f"end_date={end_date} must be after the start of the dataset={dataset.target_available_since}"
raise ValueError(msg)

if end_date > dataset.target_available_until:
msg = f"End date must be before the end of the dataset: {end_date} > {dataset.target_available_until}."
msg = f"end_date={end_date} must be before the end of the dataset={dataset.target_available_until}"
raise ValueError(msg)

cutoff_dates = generate_cutoff_dates(start_date, end_date, horizon, step)
Expand Down
11 changes: 10 additions & 1 deletion src/enfobench/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def steps_in_horizon(horizon: pd.Timedelta, freq: str) -> int:


def periods_in_duration(target: pd.DatetimeIndex, duration: timedelta | pd.Timedelta | str) -> int:
"""Return the number of periods in a given duration.
Args:
target: The target variable.
duration: The duration of the season in a timedelta format.
Returns:
The period count.
"""
if isinstance(duration, timedelta):
duration = pd.Timedelta(duration)
elif isinstance(duration, str):
Expand All @@ -41,7 +50,7 @@ def periods_in_duration(target: pd.DatetimeIndex, duration: timedelta | pd.Timed


def create_forecast_index(history: pd.DataFrame, horizon: int) -> pd.DatetimeIndex:
"""Create time index for a forecast horizon.
"""Creates a DatetimeIndex for a forecast horizon.
Args:
history: The history of the time series.
Expand Down
53 changes: 42 additions & 11 deletions src/enfobench/visualization/demand.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import timedelta
from datetime import time, timedelta

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -14,7 +14,10 @@
from matplotlib.ticker import LinearLocator, MultipleLocator
from statsmodels.graphics import tsaplots
except ImportError as e:
msg = f"Missing optional dependency '{e.name}'. Use pip or conda to install it."
msg = (
f"Missing optional dependency '{e.name}'. Use pip or conda to install it. "
"Alternatively you can enfobench[visualization] to install all dependencies for plotting."
)
raise ImportError(msg) from e


Expand Down Expand Up @@ -49,6 +52,23 @@ def plot_monthly_box(
fig.suptitle("")

# Set the labels
ax.set_xlabel("Month")
ax.set_xticklabels(
[
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
]
)
ax.set_ylabel("Energy (kWh)")
ax.set_title("Demand distribution by month", fontsize="large", loc="left")
return fig, ax
Expand Down Expand Up @@ -85,6 +105,8 @@ def plot_weekly_box(
fig.suptitle("")

# Set the labels
ax.set_xlabel("Day of the week")
ax.set_xticklabels(["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"])
ax.set_ylabel("Energy (kWh)")
ax.set_title("Demand distribution by day of week", loc="left", fontsize="large")
return fig, ax
Expand Down Expand Up @@ -121,6 +143,8 @@ def plot_daily_box(
fig.suptitle("")

# Set the labels
ax.set_xlabel("Hour of the day")
ax.set_xticklabels([f"{time(hour=h):%H:%M}" for h in range(24)], rotation=90)
ax.set_ylabel("Energy (kWh)")
ax.set_title("Demand distribution by hour", loc="left", fontsize="large")
return fig, ax
Expand All @@ -134,14 +158,13 @@ def plot_histogram(
"""Plot a histogram of demand data.
Args:
data: Demand data.
figsize: Figure size.
n_bins: Number of bins for the histogram.
Returns:
fig: Figure object.
ax: Axes object.
data: Demand data.
figsize: Figure size.
n_bins: Number of bins for the histogram.
Returns:
fig: Figure object.
ax: Axes object.
"""
# define the energy intervals to use for the histogram
bins = np.linspace(0, data.y.max(), n_bins + 1)
Expand Down Expand Up @@ -198,7 +221,7 @@ def plot_heatmap(
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.set_title("Demand heatmap", loc="left", fontsize="large")

sns.heatmap(data_hm, ax=ax, cbar_kws={"label": "Energy (kWh)"}, **kwargs)
sns.heatmap(data_hm, ax=ax, cmap="YlGnBu_r", cbar_kws={"label": "Energy (kWh)"}, **kwargs)
return fig, ax


Expand Down Expand Up @@ -284,7 +307,7 @@ def plot_data_quality(data: pd.DataFrame, figsize: tuple[float, float] = (12, 5)
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.set_title("Data quality of demand data", loc="left", fontsize="large")

cmap = plt.get_cmap("Reds", 4)
cmap = plt.get_cmap("YlGnBu_r", 4)
sns.heatmap(
data_hm,
ax=ax,
Expand Down Expand Up @@ -314,6 +337,10 @@ def plot_acf(
fig: Figure object.
ax: Axes object.
"""
if data["y"].isna().any():
msg = "The data contains missing values. Make sure to handle them before plotting the ACF."
raise ValueError(msg)

fig, ax = plt.subplots(1, 1, figsize=figsize)

periods = periods_in_duration(data.index, duration=lags)
Expand Down Expand Up @@ -356,6 +383,10 @@ def plot_pacf(
fig: Figure object.
ax: Axes object.
"""
if data["y"].isna().any():
msg = "The data contains missing values. Make sure to handle them before plotting the ACF."
raise ValueError(msg)

fig, ax = plt.subplots(1, 1, figsize=figsize)

periods = periods_in_duration(data.index, duration=lags)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def test_get_history(helpers, resolution):
assert isinstance(history.index, pd.DatetimeIndex)
assert "y" in history.columns
assert len(history.columns) == 1
assert (history.index < cutoff_date).all()
assert cutoff_date in history.index
assert (history.index <= cutoff_date).all()


@pytest.mark.parametrize("resolution", ["15T", "30T", "1H"])
Expand All @@ -39,7 +40,8 @@ def test_get_past_covariates(helpers, resolution):
assert isinstance(past_cov.index, pd.DatetimeIndex)
assert "a" in past_cov.columns
assert "b" in past_cov.columns
assert (past_cov.index < cutoff_date).all()
assert cutoff_date in past_cov.index
assert (past_cov.index <= cutoff_date).all()


@pytest.mark.parametrize("resolution", ["6H", "1D"])
Expand Down
13 changes: 0 additions & 13 deletions tests/test_dataset/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,3 @@ def test_create_perfect_forecasts_from_covariates():
assert "covariate_2" in future_covariates.columns
assert "timestamp" in future_covariates.columns
assert "cutoff_date" in future_covariates.columns


# def test_create_perfect_forecasts_from_covariates(covariates):
# forecasts = enfobench.dataset.utils.create_perfect_forecasts_from_covariates(
# covariates,
# horizon=pd.Timedelta("7 days"),
# step=pd.Timedelta("1D"),
# )
#
# assert isinstance(forecasts, pd.DataFrame)
# assert "timestamp" in forecasts.columns
# assert "cutoff_date" in forecasts.columns
# assert all(col in forecasts.columns for col in covariates.columns)
30 changes: 19 additions & 11 deletions tests/test_evaluations/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,32 @@


def test_cross_validate_univariate_locally(helpers, model):
dataset = helpers.generate_univariate_dataset(start="2020-01-01", end="2020-03-01", freq="30T")
dataset = helpers.generate_univariate_dataset(start="2020-01-01", end="2021-01-01", freq="30T")

forecasts = cross_validate(
model=model,
dataset=dataset,
start_date=pd.Timestamp("2020-02-01 10:00:00"),
end_date=pd.Timestamp("2020-03-01"),
horizon=pd.Timedelta("38 hours"),
step=pd.Timedelta("1 day"),
)
start_date = pd.Timestamp("2020-02-01 10:00:00")
end_date = pd.Timestamp("2020-04-01")
horizon = pd.Timedelta("38 hours")
step = pd.Timedelta("1 day")

forecasts = cross_validate(model, dataset, start_date=start_date, end_date=end_date, horizon=horizon, step=step)

assert isinstance(forecasts, pd.DataFrame)
assert "timestamp" in forecasts.columns
assert "y" in forecasts.columns
assert "yhat" in forecasts.columns

assert "timestamp" in forecasts.columns
assert (forecasts.timestamp > start_date).all()
assert forecasts.timestamp.iloc[-1] == end_date

assert "cutoff_date" in forecasts.columns
assert (start_date.time() == forecasts.cutoff_date.dt.time).all()

for cutoff_date, forecast in forecasts.groupby("cutoff_date"):
assert len(forecast) == 38 * 2 # 38 hours with half-hour series
assert (forecast.timestamp > cutoff_date).all()

assert list(forecasts.cutoff_date.unique()) == list(
pd.date_range(start="2020-02-01 10:00:00", end="2020-02-28 10:00:00", freq="1D", inclusive="both")
pd.date_range(start="2020-02-01 10:00:00", end="2020-03-30 10:00:00", freq="1D", inclusive="both")
)


Expand Down
3 changes: 1 addition & 2 deletions tests/test_evaluations/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,4 @@ def test_create_forecast_index(helpers, freq, horizon):
assert isinstance(index, pd.DatetimeIndex)
assert index.freq == history.index.freq
assert len(index) == horizon
assert index[0] == cutoff_date
assert all(idx >= cutoff_date for idx in index)
assert (index > cutoff_date).all()

0 comments on commit d3f7944

Please sign in to comment.