diff --git a/Makefile b/Makefile index 9b6376a..dd94061 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: install clean format lint tests build publish publish-test +.PHONY: install clean lint style format test build publish publish-test ################################################################################# # GLOBALS # @@ -7,45 +7,55 @@ PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) PROJECT_NAME = energy-forecat-benchmark-toolkit PACKAGE_NAME = enfobench +PYTHON_INTERPRETER = python3 ################################################################################# # COMMANDS # ################################################################################# +## Create python virtual environment +venv/bin/python: + ( \ + $(PYTHON_INTERPRETER) -m venv $(PROJECT_DIR)/venv; \ + source $(PROJECT_DIR)/venv/bin/activate; \ + pip install --upgrade pip; \ + ) + ## Install project dependencies -install: - pip install -U pip - pip install -e ."[test,dev]" - mypy --install-types +install: venv/bin/python + (\ + source $(PROJECT_DIR)/venv/bin/activate; \ + pip install -e .; \ + ) ## Delete all compiled Python files clean: find . -type f -name "*.py[co]" -delete find . -type d -name "__pycache__" -delete +## Lint using ruff, mypy, black, and isort +lint: + hatch run lint:all + + +## Check style using ruff, black, and isort +style: + hatch run lint:style + ## Format using black format: - ruff src tests --fix - black src tests - isort src tests - -## Lint using ruff, mypy, black, and isort -lint: format - mypy src - ruff src tests - black src tests --check - isort src tests --check-only + hatch run lint:fmt ## Run pytest with coverage -tests: - pytest src tests +test: + hatch run cov ################################################################################# # PROJECT RULES # ################################################################################# ## Build source distribution and wheel -build: lint tests +build: style hatch build ## Upload source distribution and wheel to PyPI diff --git a/README.md b/README.md index aaae72b..c0b384f 100644 --- a/README.md +++ b/README.md @@ -32,22 +32,65 @@ Load your own data and create a dataset. ```python import pandas as pd -from enfobench.evaluation import Dataset +from enfobench.dataset import Dataset -# Load your dataset and make sure that the timestamp column in named 'ds' and the target values named 'y' +# Load your datasets data = pd.read_csv("../path/to/your/data.csv", parse_dates=['timestamp'], index_col='timestamp') -covariates = data.drop(columns=['target_column']) + +# Create a target DataFrame that has a pd.DatetimeIndex and a column named 'y' +target = data.loc[:, ['target_column']].rename(columns={'target_column': 'y'}) + +# Add covariates that can be used as past covariates. This also has to have a pd.DatetimeIndex +past_covariates = data.loc[:, ['covariate_1', 'covariate_2']] + +# As sometimes it can be challenging to access historical forecasts to use future covariates, +# the package also has a helper function to create perfect historical forecasts from the past covariates. +from enfobench.dataset.utils import create_perfect_forecasts_from_covariates + +# The example below creates simulated perfect historical forecasts with a horizon of 24 hours and a step of 1 day. +future_covariates = create_perfect_forecasts_from_covariates( + past_covariates, + horizon=pd.Timedelta("24 hours"), + step=pd.Timedelta("1 day"), +) dataset = Dataset( target=data['target_column'], - covariates=covariates, + past_covariates=past_covariates, + future_covariates=future_covariates, +) +``` + +The package integrates with the HuggingFace Dataset ['attila-balint-kul/electricity-demand'](https://huggingface.co/datasets/attila-balint-kul/electricity-demand). +To use this, just download all the files from the data folder to your computer. + +```python +from enfobench.dataset import Dataset, DemandDataset + +# Load the dataset from the folder that you downloaded the files to. +ds = DemandDataset("/path/to/the/dataset/folder/that/contains/all/subsets") + +# List all meter ids +ds.metadata_subset.list_unique_ids() + +# Get dataset for a specific meter id +target, past_covariates, metadata = ds.get_data_by_unique_id("unique_id_of_the_meter") + +# Create a dataset +dataset = Dataset( + target=target, + past_covariates=past_covariates, + future_covariates=None, + metadata=metadata ) ``` + You can perform a cross validation on any model locally that adheres to the `enfobench.Model` protocol. ```python import MyModel +import pandas as pd from enfobench.evaluation import cross_validate # Import your model and instantiate it @@ -64,9 +107,11 @@ cv_results = cross_validate( ) ``` -You can use the same crossvalidation interface with your model served behind an API. +You can use the same crossvalidation interface with your model served behind an API. +To make this simple, both a client and a server are provided. ```python +import pandas as pd from enfobench.evaluation import cross_validate, ForecastClient # Import your model and instantiate it @@ -83,20 +128,21 @@ cv_results = cross_validate( ) ``` -The package also collects common metrics for you that you can quickly evaluate on your results. +The package also collects common metrics used in forecasting. ```python from enfobench.evaluation import evaluate_metrics_on_forecasts from enfobench.evaluation.metrics import ( - mean_bias_error, mean_absolute_error, mean_squared_error, root_mean_squared_error, + mean_bias_error, + mean_absolute_error, + mean_squared_error, + root_mean_squared_error, ) -# Merge the cross validation results with the original data -forecasts = cv_results.merge(dataset.target, on="ds", how="left") - +# Simply pass in the cross validation results and the metrics you want to evaluate. metrics = evaluate_metrics_on_forecasts( - forecasts, + cv_results, metrics={ "mean_bias_error": mean_bias_error, "mean_absolute_error": mean_absolute_error, @@ -106,6 +152,19 @@ metrics = evaluate_metrics_on_forecasts( ) ``` +In order to serve your model behind an API, you can use the built in server factory. + +```python +import uvicorn +from enfobench.evaluation.server import server_factory + +model = MyModel() + +# Create a server that serves your model +server = server_factory(model) +uvicorn.run(server, port=3000) +``` + ## Contributing Contributions and feedback are welcome! For major changes, please open an issue first to discuss @@ -121,4 +180,4 @@ Submit a pull request describing your changes. ## License -BSD 3-Clause License +BSD 2-Clause License diff --git a/pyproject.toml b/pyproject.toml index a7f8edc..16a34c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,57 +4,51 @@ build-backend = "hatchling.build" [project] name = "enfobench" +dynamic = ["version"] description = "Energy forecast benchmarking toolkit." +readme = "README.md" +requires-python = ">=3.10" +license = "BSD-2-clause" +keywords = [ + "energy", + "forecasting", + "benchmarking", +] authors = [ { author = "Attila Balint", email = "attila.balint@kuleuven.be" }, ] -readme = "README.md" packages = [{ include = "enfobench", from = "src" }] -requires-python = ">=3.7" classifiers = [ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Natural Language :: English", "Operating System :: OS Independent", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Scientific/Engineering", ] dependencies = [ + "duckdb<=1.0.0", "fastapi>=0.68.0,<1.0.0", - "pandas>=1.3.0,<2.0.0", - "pydantic>=1.0.0,<2.0.0", - "python-multipart>=0.0.0,<1.0.0", - "pyarrow>=12.0.0,<13.0.0", - "requests>=2.26.0,<3.0.0", - "tqdm>=4.60.0,<5.0.0", - "uvicorn>=0.20.0,<1.0.0", + "pandas>=1.5", + "pydantic>=2.0.0", + "python-multipart>=0.0.0", + "pyarrow>=13.0.0", + "requests>=2.26.0", + "tqdm>=4.60.0", + "uvicorn>=0.20.0", ] -dynamic = ["version"] - -[project.urls] -"Homepage" = "https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit" -[project.optional-dependencies] -test = [ - "black==23.3.0", - "isort==5.12.0", - "httpx==0.24.1", - "mypy==1.4.0", - "ruff==0.0.274", - "pytest==7.3.2", -] -dev = [ - "twine>=4.0.0,<5.0.0", - "pre-commit>=3.0.0,<4.0.0", -] +[project.urls] +Documentation = "https://github.com/attila-baline-kul/energy-forecast-benchmark-toolkit#readme" +Issues = "https://github.com/attila-baline-kul/energy-forecast-benchmark-toolkit/issues" +Source = "https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit" [tool.hatch.version] path = "src/enfobench/__version__.py" @@ -75,72 +69,116 @@ only-include = [ ] -[tool.black] -line-length = 99 -target-version = ["py310", ] +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-cov", + "httpx", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +cov = "pytest --cov=enfobench --cov-report=term-missing {args:tests}" + +[[tool.hatch.envs.all.matrix]] +python = ["3.10", "3.11"] +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.6.0", + "ruff>=0.1.5", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/enfobench tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + + +[tool.black] +target-version = ["py310"] +line-length = 120 +skip-string-normalization = true [tool.ruff] target-version = "py310" -line-length = 99 +line-length = 120 select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - # "I", # isort - "C", # flake8-comprehensions - "B", # flake8-bugbear + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", ] ignore = [ - "E501", # line too long, handled by black - "B008", # do not perform function calls in argument defaults - "C901", # too complex + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", ] - -[tool.ruff.per-file-ignores] -"__init__.py" = ["F401"] # Don't touch unused imports [tool.ruff.isort] known-first-party = ["enfobench"] -[tool.ruff.flake8-quotes] -inline-quotes = "single" - [tool.ruff.flake8-tidy-imports] ban-relative-imports = "all" -[tool.mypy] -python_version = "3.10" -files = ["src/enfobench"] -disallow_untyped_defs = false -disallow_incomplete_defs = false -follow_imports = "normal" -ignore_missing_imports = true -pretty = true -show_column_numbers = true -show_error_codes = true -warn_no_return = false -warn_unused_ignores = true - -[[tool.mypy.overrides]] -module = "tests.*" -ignore_missing_imports = true -check_untyped_defs = true - -[tool.isort] -profile = "black" - -[tool.pytest.ini_options] -testpaths = ["tests"] -filterwarnings = ["error"] +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] +source_pkgs = ["enfobench", "tests"] branch = true -source_pkgs = [ - "enfobench", - "tests", +parallel = true +omit = [ + "src/enfobench/__about__.py", ] +[tool.coverage.paths] +enfobench = ["src/enfobench"] +tests = ["tests"] + [tool.coverage.report] exclude_lines = [ "no cov", diff --git a/src/enfobench/__init__.py b/src/enfobench/__init__.py index e69de29..a82d40d 100644 --- a/src/enfobench/__init__.py +++ b/src/enfobench/__init__.py @@ -0,0 +1,10 @@ +from enfobench.dataset import Dataset +from enfobench.evaluation import AuthorInfo, ForecasterType, Model, ModelInfo + +__all__ = [ + "Model", + "ModelInfo", + "AuthorInfo", + "ForecasterType", + "Dataset", +] diff --git a/src/enfobench/__version__.py b/src/enfobench/__version__.py index b5fdc75..493f741 100644 --- a/src/enfobench/__version__.py +++ b/src/enfobench/__version__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.3.0" diff --git a/src/enfobench/dataset/__init__.py b/src/enfobench/dataset/__init__.py new file mode 100644 index 0000000..1760efd --- /dev/null +++ b/src/enfobench/dataset/__init__.py @@ -0,0 +1,9 @@ +from enfobench.dataset._dataset import Dataset, DemandDataset, DemandSubset, MetadataSubset, WeatherSubset + +__all__ = [ + "Dataset", + "DemandDataset", + "DemandSubset", + "MetadataSubset", + "WeatherSubset", +] diff --git a/src/enfobench/dataset/_dataset.py b/src/enfobench/dataset/_dataset.py new file mode 100644 index 0000000..12d33aa --- /dev/null +++ b/src/enfobench/dataset/_dataset.py @@ -0,0 +1,393 @@ +import logging +from abc import ABCMeta +from dataclasses import dataclass +from pathlib import Path + +import duckdb +import pandas as pd + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class Metadata: + unique_id: str + location_id: str + latitude: float + longitude: float + building_type: str + + +class Subset(metaclass=ABCMeta): # noqa: B024 + """Subset class representing one of the subset of the HuggingFace dataset. + + Args: + file_path: The path to the subset file. + """ + + def __init__(self, file_path: Path | str) -> None: + file_path = Path(file_path).resolve() + if not file_path.is_file() or not file_path.exists(): + msg = "Please provide the existing file where the subset is located." + raise ValueError(msg) + self.file_path = file_path + + def __repr__(self): + return f"{self.__class__.__name__}(file_path={self.file_path})" + + def read(self) -> pd.DataFrame: + """Read the subset from the file.""" + return pd.read_parquet(self.file_path) + + +class MetadataSubset(Subset): + """Metadata subset of the HuggingFace dataset containing all metadata about the meters. + + Args: + file_path: The path to the subset file. + """ + + def list_unique_ids(self) -> list[str]: + """Lists all unique ids.""" + query = """ + SELECT DISTINCT unique_id + FROM read_parquet(?) + """ + conn = duckdb.connect(":memory:") + return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().unique_id.tolist() + + def get_by_unique_id(self, unique_id: str) -> Metadata: + """Returns the metadata for the given unique id. + + Args: + unique_id: The unique id of the meter. + """ + query = """ + SELECT * + FROM read_parquet(?) + WHERE unique_id = ? + """ + conn = duckdb.connect(":memory:") + df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df() + if df.empty: + msg = f"Unique id '{unique_id}' was not found." + raise KeyError(msg) + return Metadata(**df.to_dict(orient="records")[0]) + + +class WeatherSubset(Subset): + """Weather subset of the HuggingFace dataset containing all weather data. + + Args: + file_path: The path to the subset file. + """ + + def list_location_ids(self) -> list[str]: + """Lists all location ids.""" + query = """ + SELECT DISTINCT location_id + FROM read_parquet(?) + """ + conn = duckdb.connect(":memory:") + return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().location_id.tolist() + + def get_by_location_id(self, location_id: str, columns: list[str] | None = None) -> pd.DataFrame: + """Returns the weather data for the given location id. + + Args: + location_id: The location id of the weather station. + columns: The columns to return. If None, all columns are returned. + """ + conn = duckdb.connect(":memory:") + + if columns: + query = f""" + SELECT timestamp, {", ".join(columns)} + FROM read_parquet(?) + WHERE location_id = ? + """ # noqa: S608 + else: + query = """ + SELECT * + FROM read_parquet(?) + WHERE location_id = ? + """ + df = conn.execute(query, parameters=[str(self.file_path), location_id]).fetch_df() + if df.empty: + msg = f"Location id '{location_id}' was not found." + raise KeyError(msg) + + # Remove location_id and set timestamp as index + df.drop(columns=["location_id"], inplace=True, errors="ignore") + df.set_index("timestamp", inplace=True) + return df + + +class DemandSubset(Subset): + """Demand subset of the HuggingFace dataset containing all demand data. + + Args: + file_path: The path to the subset file. + """ + + def get_by_unique_id(self, unique_id: str): + """Returns the demand data for the given unique id. + + Args: + unique_id: The unique id of the meter. + """ + query = """ + SELECT * + FROM read_parquet(?) + WHERE unique_id = ? + """ + conn = duckdb.connect(":memory:") + df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df() + if df.empty: + msg = f"Unique id '{unique_id}' was not found." + raise KeyError(msg) + + # Remove unique_id and set timestamp as index + df.drop(columns=["unique_id"], inplace=True, errors="ignore") + df.set_index("timestamp", inplace=True) + return df + + +class DemandDataset: + """DemandDataset class representing the HuggingFace dataset. + + This class is a collection of all subsets inside HuggingFace dataset. + It provides an easy way to access the different subsets. + + Args: + directory: The directory where the HuggingFace dataset is located. + This directory should contain all the subset files. + """ + + HUGGINGFACE_DATASET = "attila-balint-kul/electricity-demand" + SUBSETS = ("demand", "metadata", "weather") + + def __init__(self, directory: Path | str) -> None: + directory = Path(directory).resolve() + if not directory.is_dir() or not directory.exists(): + msg = f"Please provide the existing directory where the '{self.HUGGINGFACE_DATASET}' dataset is located." + raise ValueError(msg) + self.directory = directory.resolve() + + def __repr__(self) -> str: + return f"DemandDataset(directory={self.directory})" + + def _check_for_valid_subset(self, subset: str): + if subset not in self.SUBSETS: + msg = f"Please provide a valid subset. Available subsets: {self.SUBSETS}" + raise ValueError(msg) + + @property + def metadata_subset(self) -> MetadataSubset: + """Returns the metadata subset.""" + return MetadataSubset(self._get_subset_path("metadata")) + + @property + def weather_subset(self) -> WeatherSubset: + """Returns the weather subset.""" + return WeatherSubset(self._get_subset_path("weather")) + + @property + def demand_subset(self) -> DemandSubset: + """Returns the demand subset.""" + return DemandSubset(self._get_subset_path("demand")) + + def get_subset(self, subset: str) -> Subset: + """Returns the selected subset.""" + self._check_for_valid_subset(subset) + if subset == "metadata": + return self.metadata_subset + elif subset == "weather": + return self.weather_subset + elif subset == "demand": + return self.demand_subset + msg = f"Please provide a valid subset. Available subsets: {self.SUBSETS}" + raise ValueError(msg) + + def _get_subset_path(self, subset: str) -> Path: + filepath = self.directory / f"{subset}.parquet" + if not filepath.exists(): + msg = ( + f"There is no {subset} in the directory. " + f"Make sure to download all subsets from the HuggingFace dataset: {self.HUGGINGFACE_DATASET}." + ) + raise ValueError(msg) + return self.directory / f"{subset}.parquet" + + def read_subset(self, subset: str) -> pd.DataFrame: + """Reads the selected subset.""" + return self.get_subset(subset).read() + + def list_unique_ids(self) -> list[str]: + return self.metadata_subset.list_unique_ids() + + def list_location_ids(self) -> list[str]: + return self.weather_subset.list_location_ids() + + def get_data_by_unique_id(self, unique_id: str) -> tuple[pd.DataFrame | pd.DataFrame | Metadata]: + metadata = self.metadata_subset.get_by_unique_id(unique_id) + + demand = self.demand_subset.get_by_unique_id(unique_id) + weather = self.weather_subset.get_by_location_id(metadata.location_id) + return demand, weather, metadata + + +class Dataset: + """Dataset class representing a collection of data required for forecasting task. + + Args: + target: The target variable. + past_covariates: The past covariates. + future_covariates: The future covariates. + metadata: The metadata. + """ + + def __init__( + self, + target: pd.DataFrame, + past_covariates: pd.DataFrame | None = None, + future_covariates: pd.DataFrame | None = None, + metadata: Metadata | None = None, + ): + self._target = self._check_target(target.copy()) + self._first_available_target_date: pd.Timestamp = self._target.index[0] + self._last_available_target_date: pd.Timestamp = self._target.index[-1] + + self._past_covariates = ( + self._check_past_covariates(past_covariates.copy()) if past_covariates is not None else None + ) + + self._future_covariates = ( + self._check_external_forecasts(future_covariates.copy()) if future_covariates is not None else None + ) + self.metadata = metadata + + @property + def target_available_since(self) -> pd.Timestamp: + """Returns the first available target date.""" + return self._first_available_target_date + + @property + def target_available_until(self) -> pd.Timestamp: + """Returns the last available target date.""" + return self._last_available_target_date + + @property + def target_freq(self) -> str: + """Returns the frequency of the target.""" + return self._target.index.inferred_freq + + @staticmethod + def _check_target(y: pd.DataFrame) -> pd.DataFrame: + if isinstance(y, pd.Series): + logger.warning("Target is a Series, converting to DataFrame.") + y = y.to_frame("y") + if not isinstance(y.index, pd.DatetimeIndex): + msg = f"Target dataframe must have DatetimeIndex, have {y.index.__class__.__name__}." + raise ValueError(msg) + y.rename_axis("timestamp", inplace=True) + y.sort_index(inplace=True) + return y + + def _check_past_covariates(self, X: pd.DataFrame) -> pd.DataFrame: # noqa: N803 + if not isinstance(X.index, pd.DatetimeIndex): + msg = f"Past covariates must have DatetimeIndex, have {X.index.__class__.__name__}." + raise ValueError(msg) + + X.rename_axis("timestamp", inplace=True) + X.sort_index(inplace=True) + + if X.index[0] > self._first_available_target_date: + msg = "Covariates must be provided for the full target timeframe, covariates start after target values." + raise ValueError(msg) + + if X.index[-1] < self._last_available_target_date: + msg = "Covariates must be provided for the full target timeframe, covariates end before target values." + raise ValueError(msg) + return X + + def _check_external_forecasts(self, X: pd.DataFrame) -> pd.DataFrame: # noqa: N803 + first_forecast_date = X.cutoff_date.min() + last_forecast_date = X.cutoff_date.max() + last_forecast_end_date = X[X.cutoff_date == last_forecast_date].timestamp.max() + + if first_forecast_date > self._first_available_target_date: + msg = ( + "External forecasts must be provided for the full target timeframe, " + "forecasts start after target values." + ) + raise ValueError(msg) + + if last_forecast_end_date < self._last_available_target_date: + msg = ( + "External forecasts must be provided for the full target timeframe, " + "forecasts end before target values." + ) + raise ValueError(msg) + + return X + + def _check_cutoff_in_rage(self, cutoff_date: pd.Timestamp): + if cutoff_date < self._first_available_target_date: + msg = f"Cutoff date is before the start date: {cutoff_date} < {self._first_available_target_date}." + raise IndexError(msg) + + if cutoff_date > self._last_available_target_date: + msg = f"Cutoff date is after the end date: {cutoff_date} > {self._last_available_target_date}." + raise IndexError(msg) + + def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame: + """Returns the history of the target variable up to the cutoff date. + + Args: + cutoff_date: The cutoff date. + + Returns: + The history of the target variable up to the cutoff date. + """ + self._check_cutoff_in_rage(cutoff_date) + return self._target[self._target.index <= cutoff_date] + + def get_past_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None: + """Returns the past covariates for the cutoff date. + + Args: + cutoff_date: The cutoff date. + + Returns: + The past covariates up until the cutoff date. + """ + if self._past_covariates is None: + return None + + self._check_cutoff_in_rage(cutoff_date) + return self._past_covariates[self._past_covariates.index <= cutoff_date] + + def get_future_covariates(self, cutoff_date: pd.Timestamp) -> pd.DataFrame | None: + """Returns the future covariates for the cutoff date. + + Args: + cutoff_date: The cutoff date. + + Returns: + The last external forecasts made before the cutoff date. + """ + if self._future_covariates is None: + return None + + self._check_cutoff_in_rage(cutoff_date) + last_past_cutoff_date = self._future_covariates.cutoff_date[ + self._future_covariates.cutoff_date <= cutoff_date + ].max() + + future_covariates = self._future_covariates[ + (self._future_covariates.cutoff_date == last_past_cutoff_date) + & (self._future_covariates.timestamp > cutoff_date) + ] + future_covariates.set_index("timestamp", inplace=True) + return future_covariates diff --git a/src/enfobench/dataset/utils.py b/src/enfobench/dataset/utils.py new file mode 100644 index 0000000..d236962 --- /dev/null +++ b/src/enfobench/dataset/utils.py @@ -0,0 +1,47 @@ +import warnings + +import pandas as pd + + +def create_perfect_forecasts_from_covariates( + past_covariates: pd.DataFrame, + *, + horizon: pd.Timedelta, + step: pd.Timedelta, + **kwargs, +) -> pd.DataFrame: + """Create forecasts from covariates. + + Sometimes external forecasts are not available for the entire horizon. This function creates + external forecast dataframe from external covariates as a perfect forecast. + + Args: + past_covariates: The external covariates. + horizon: The forecast horizon. + step: The step size between forecasts. + + Returns: + The external forecast dataframe. + """ + start = kwargs.get("start", past_covariates.index[0]) + last_date = past_covariates.index[-1] + + forecasts = [] + while start + horizon <= last_date: + forecast = past_covariates.loc[(past_covariates.index > start) & (past_covariates.index <= start + horizon)] + forecast.insert(0, "cutoff_date", start) + forecast.rename_axis("timestamp", inplace=True) + forecast.reset_index(inplace=True) + + if len(forecast) == 0: + warnings.warn( + f"Covariates not found for {start} - {start + horizon}, cannot make forecast at step {start}", + UserWarning, + stacklevel=2, + ) + + forecasts.append(forecast) + start += step + + forecast_df = pd.concat(forecasts, ignore_index=True) + return forecast_df diff --git a/src/enfobench/evaluation/__init__.py b/src/enfobench/evaluation/__init__.py index f11fe91..77a8f4a 100644 --- a/src/enfobench/evaluation/__init__.py +++ b/src/enfobench/evaluation/__init__.py @@ -1,10 +1,18 @@ -from .client import ForecastClient -from .evaluate import ( +from enfobench.evaluation.client import ForecastClient +from enfobench.evaluation.evaluate import ( cross_validate, - evaluate_metric_on_forecast, - evaluate_metric_on_forecasts, - evaluate_metrics_on_forecast, evaluate_metrics_on_forecasts, - generate_cutoff_dates, ) -from .protocols import Dataset, EnvironmentInfo, ForecasterType, Model, ModelInfo +from enfobench.evaluation.model import AuthorInfo, ForecasterType, Model, ModelInfo +from enfobench.evaluation.protocols import Dataset + +__all__ = [ + "ForecastClient", + "cross_validate", + "evaluate_metrics_on_forecasts", + "Dataset", + "Model", + "ModelInfo", + "AuthorInfo", + "ForecasterType", +] diff --git a/src/enfobench/evaluation/client.py b/src/enfobench/evaluation/client.py index c21464b..321b068 100644 --- a/src/enfobench/evaluation/client.py +++ b/src/enfobench/evaluation/client.py @@ -1,36 +1,35 @@ import io -from typing import Dict, List, Optional, Union +from http import HTTPStatus import pandas as pd import requests -from enfobench.evaluation.protocols import EnvironmentInfo, ModelInfo +from enfobench.evaluation.model import ModelInfo +from enfobench.evaluation.server import EnvironmentInfo def to_buffer(df: pd.DataFrame) -> io.BytesIO: buffer = io.BytesIO() - df.to_parquet(buffer, index=False) + df.to_parquet(buffer, index=True) buffer.seek(0) return buffer class ForecastClient: - def __init__( - self, host: str = "localhost", port: int = 3000, secure: bool = False, client=None - ): - self.base_url = f"{'https' if secure else 'http'}://{host}:{port}" + def __init__(self, host: str = "localhost", port: int = 3000, *, use_https: bool = False, client=None): + self.base_url = f"{'https' if use_https else 'http'}://{host}:{port}" self._session = requests.Session() if client is None else client def info(self) -> ModelInfo: response = self._session.get(f"{self.base_url}/info") - if response.status_code != 200: + if response.status_code != HTTPStatus.OK: response.raise_for_status() return ModelInfo(**response.json()) def environment(self) -> EnvironmentInfo: response = self._session.get(f"{self.base_url}/environment") - if response.status_code != 200: + if response.status_code != HTTPStatus.OK: response.raise_for_status() return EnvironmentInfo(**response.json()) @@ -39,11 +38,11 @@ def forecast( self, horizon: int, history: pd.DataFrame, - past_covariates: Optional[pd.DataFrame] = None, - future_covariates: Optional[pd.DataFrame] = None, - level: Optional[List[int]] = None, + past_covariates: pd.DataFrame | None = None, + future_covariates: pd.DataFrame | None = None, + level: list[int] | None = None, ) -> pd.DataFrame: - params: Dict[str, Union[int, List[int]]] = { + params: dict[str, int | list[int]] = { "horizon": horizon, } if level is not None: @@ -62,9 +61,10 @@ def forecast( params=params, files=files, ) - if response.status_code != 200: + if response.status_code != HTTPStatus.OK: response.raise_for_status() df = pd.DataFrame.from_records(response.json()["forecast"]) - df["ds"] = pd.to_datetime(df["ds"]) + df["timestamp"] = pd.to_datetime(df["timestamp"]) + df.set_index("timestamp", inplace=True) return df diff --git a/src/enfobench/evaluation/evaluate.py b/src/enfobench/evaluation/evaluate.py index 92f94f8..783cbe6 100644 --- a/src/enfobench/evaluation/evaluate.py +++ b/src/enfobench/evaluation/evaluate.py @@ -1,27 +1,23 @@ import warnings -from typing import Callable, Dict, List, Optional, Union +from collections.abc import Callable import pandas as pd from tqdm import tqdm +from enfobench.dataset import Dataset from enfobench.evaluation.client import ForecastClient -from enfobench.evaluation.protocols import Dataset, Model -from enfobench.evaluation.utils import steps_in_horizon +from enfobench.evaluation.model import Model +from enfobench.evaluation.utils import generate_cutoff_dates, steps_in_horizon def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> float: """Evaluate a single metric on a single forecast. - Parameters: - ----------- - forecast: - Forecast to evaluate. - metric: - Metric to evaluate. + Args: + forecast: Forecast to evaluate. + metric: Metric to evaluate. Returns: - -------- - metric_value: Metric value. """ _nonempty_df = forecast.dropna(subset=["y"]) @@ -29,26 +25,18 @@ def evaluate_metric_on_forecast(forecast: pd.DataFrame, metric: Callable) -> flo return metric_value -def evaluate_metrics_on_forecast( - forecast: pd.DataFrame, metrics: Dict[str, Callable] -) -> Dict[str, float]: +def evaluate_metrics_on_forecast(forecast: pd.DataFrame, metrics: dict[str, Callable]) -> dict[str, float]: """Evaluate multiple metrics on a single forecast. - Parameters: - ----------- - forecast: - Forecast to evaluate. - metrics: - Metric to evaluate. + Args: + forecast: Forecast to evaluate. + metrics: Metric to evaluate. Returns: - -------- - metric_value: Metric value. """ metric_values = { - metric_name: evaluate_metric_on_forecast(forecast, metric) - for metric_name, metric in metrics.items() + metric_name: evaluate_metric_on_forecast(forecast, metric) for metric_name, metric in metrics.items() } return metric_values @@ -56,41 +44,28 @@ def evaluate_metrics_on_forecast( def evaluate_metric_on_forecasts(forecasts: pd.DataFrame, metric: Callable) -> pd.DataFrame: """Evaluate a single metric on a set of forecasts made at different cutoff points. - Parameters: - ----------- - forecasts: - Forecasts to evaluate. - metric: - Metric to evaluate. + Args: + forecasts: Forecasts to evaluate. + metric: Metric to evaluate. Returns: - -------- - metrics_df: Metric values for each cutoff with their weight. """ metrics = { - cutoff: evaluate_metric_on_forecast(group_df, metric) - for cutoff, group_df in forecasts.groupby("cutoff_date") + cutoff: evaluate_metric_on_forecast(group_df, metric) for cutoff, group_df in forecasts.groupby("cutoff_date") } metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["value"]) return metrics_df -def evaluate_metrics_on_forecasts( - forecasts: pd.DataFrame, metrics: Dict[str, Callable] -) -> pd.DataFrame: +def evaluate_metrics_on_forecasts(forecasts: pd.DataFrame, metrics: dict[str, Callable]) -> pd.DataFrame: """Evaluate multiple metrics on a set of forecasts made at different cutoff points. - Parameters: - ----------- - forecasts: - Forecasts to evaluate. - metrics: - Metric to evaluate. + Args: + forecasts: Forecasts to evaluate. + metrics: Metric to evaluate. Returns: - -------- - metrics_df: Metric values for each cutoff with their weight. """ metric_dfs = [ @@ -101,108 +76,41 @@ def evaluate_metrics_on_forecasts( return metrics_df -def generate_cutoff_dates( - start_date: pd.Timestamp, - end_date: pd.Timestamp, - horizon: pd.Timedelta, - step: pd.Timedelta, -) -> List[pd.Timestamp]: - """Generate cutoff dates for cross-validation between two dates. - - The cutoff dates are separated by a fixed step size and the last cutoff date is a horizon away from the end date. - - Parameters - ---------- - start_date: - Start date of the time series. - end_date: - End date of the time series. - horizon: - Forecast horizon. - step: - Step size between cutoff dates. - - Examples - -------- - >>> generate_cutoff_dates( - ... start_date=pd.Timestamp("2020-01-01"), - ... end_date=pd.Timestamp("2020-01-05"), - ... horizon=pd.Timedelta("2 days"), - ... step=pd.Timedelta("1 day"), - ... ) - [ - Timestamp('2020-01-01 00:00:00'), - Timestamp('2020-01-02 00:00:00'), - Timestamp('2020-01-03 00:00:00'), - ] - """ - if horizon <= pd.Timedelta(0): - raise ValueError("Horizon must be positive.") - - if step <= pd.Timedelta(0): - raise ValueError("Step must be positive.") - - if horizon > end_date - start_date: - raise ValueError( - f"Horizon is longer than the evaluation period: {horizon} > {end_date - start_date}." - ) - - if end_date <= start_date: - raise ValueError("End date must be after the starting date.") - - cutoff_dates = [] - - cutoff = start_date - while cutoff <= end_date - horizon: - cutoff_dates.append(cutoff) - cutoff += step - - if not cutoff_dates: - raise ValueError("No dates for cross-validation") - return cutoff_dates - - def cross_validate( - model: Union[Model, ForecastClient], + model: Model | ForecastClient, dataset: Dataset, + *, start_date: pd.Timestamp, end_date: pd.Timestamp, horizon: pd.Timedelta, step: pd.Timedelta, - level: Optional[List[int]] = None, + level: list[int] | None = None, ) -> pd.DataFrame: """Cross-validate a model. - Parameters - ---------- - model: - Model to cross-validate. - dataset: - Dataset to cross-validate on. - start_date: - Start date of the time series. - end_date: - End date of the time series. - horizon: - Forecast horizon. - step: - Step size between cutoff dates. - level: - Prediction intervals to compute. - (Optional, if not provided, simple point forecasts will be computed.) + Args: + model: Model to cross-validate. + dataset: Dataset to cross-validate on. + start_date: Start date of the time series. + end_date: End date of the time series. + horizon: Forecast horizon. + step: Step size between cutoff dates. + level: Prediction intervals to compute. (Optional, if not provided, simple point forecasts will be computed.) """ - if start_date <= dataset.start_date: - raise ValueError("Start date must be after the start of the target values.") + if start_date <= dataset.target_available_since: + msg = f"Start date must be after the start of the dataset: {start_date} <= {dataset.target_available_since}." + raise ValueError(msg) - initial_training_data = start_date - dataset.start_date + initial_training_data = start_date - dataset.target_available_since if initial_training_data < pd.Timedelta("7 days"): warnings.warn("Initial training data is less than 7 days.", stacklevel=2) - if end_date > dataset.end_date: - raise ValueError("End date is beyond the target values.") + if end_date > dataset.target_available_until: + msg = f"End date must be before the end of the dataset: {end_date} > {dataset.target_available_until}." + raise ValueError(msg) cutoff_dates = generate_cutoff_dates(start_date, end_date, horizon, step) - horizon_length = steps_in_horizon(horizon, dataset.freq) + horizon_length = steps_in_horizon(horizon, dataset.target_freq) forecasts = [] for cutoff_date in tqdm(cutoff_dates): @@ -217,10 +125,30 @@ def cross_validate( future_covariates=future_covariates, level=level, ) - # TODO: validate forecast df with pandera schema - forecast = forecast.fillna(0) - forecast["cutoff_date"] = cutoff_date + if not isinstance(forecast, pd.DataFrame) or not isinstance(forecast.index, pd.DatetimeIndex): + msg = ( + f"Forecast must be a DataFrame with a DatetimeIndex, " + f"got {type(forecast)} with index {type(forecast.index)}." + ) + raise ValueError(msg) + + forecast_contains_nans = forecast.isna().any(axis=None) + if forecast_contains_nans: + msg = "Forecast contains NaNs, make sure to fill in missing values." + raise ValueError(msg) + + if len(forecast) != horizon_length: + msg = f"Forecast does not match the requested horizon length {horizon_length}, got {len(forecast)}." + raise ValueError(msg) + + forecast.rename_axis("timestamp", inplace=True) + forecast.reset_index(inplace=True) + forecast["cutoff_date"] = pd.to_datetime(cutoff_date, unit="ns") + forecast.set_index(["cutoff_date", "timestamp"], inplace=True) forecasts.append(forecast) - crossval_df = pd.concat(forecasts) + crossval_df = pd.concat(forecasts).reset_index() + + # Merge the forecast with the target + crossval_df = crossval_df.merge(dataset._target, left_on="timestamp", right_index=True) return crossval_df diff --git a/src/enfobench/evaluation/metrics.py b/src/enfobench/evaluation/metrics.py index a24315d..5bd8959 100644 --- a/src/enfobench/evaluation/metrics.py +++ b/src/enfobench/evaluation/metrics.py @@ -5,13 +5,12 @@ def check_not_empty(*arrays: ndarray) -> None: """Check that none of the arrays are not empty. - Parameters - ---------- - *arrays: list or tuple of input arrays. - Objects that will be checked for emptiness. + Args: + *arrays: Objects that will be checked for emptiness. """ - if any(X.size == 0 for X in arrays): - raise ValueError("Found empty array in inputs.") + if any(array.size == 0 for array in arrays): + msg = "Found empty array in inputs." + raise ValueError(msg) def check_consistent_length(*arrays: ndarray) -> None: @@ -19,40 +18,37 @@ def check_consistent_length(*arrays: ndarray) -> None: Checks whether all input arrays have the same length. - Parameters - ---------- - *arrays : list or tuple of input arrays. - Objects that will be checked for consistent length. + Args: + *arrays: Objects that will be checked for consistent length. """ - if any(X.ndim != 1 for X in arrays): - raise ValueError("Found multi dimensional array in inputs.") + if any(array.ndim != 1 for array in arrays): + mag = "Found multi dimensional array in inputs." + raise ValueError(mag) - lengths = [len(X) for X in arrays] + lengths = [len(array) for array in arrays] uniques = np.unique(lengths) if len(uniques) > 1: - raise ValueError(f"Found input variables with inconsistent numbers of samples: {lengths}") + msg = f"Found input variables with inconsistent numbers of samples: {lengths}" + raise ValueError(msg) def check_has_no_nan(*arrays: ndarray) -> None: """Check that all arrays have no NaNs. - Parameters - ---------- - *arrays : list or tuple of input arrays. - Objects that will be checked for NaNs. + Args: + *arrays: Objects that will be checked for NaNs. """ - for X in arrays: - if np.isnan(X).any(): - raise ValueError(f"Found NaNs in input variables: {X}") + for array in arrays: + if np.isnan(array).any(): + msg = f"Found NaNs in input variables: {array.__repr__()}" + raise ValueError(msg) def check_arrays(*arrays: ndarray) -> None: """Check that all arrays are valid. - Parameters - ---------- - *arrays : list or tuple of input arrays. - Objects that will be checked for validity. + Args: + *arrays: Objects that will be checked for validity. """ check_not_empty(*arrays) check_consistent_length(*arrays) @@ -62,12 +58,9 @@ def check_arrays(*arrays: ndarray) -> None: def mean_absolute_error(y_true: ndarray, y_pred: ndarray) -> float: """Mean absolute error regression loss. - Parameters: - ----------- - y_true : array-like of shape (n_samples,) - Ground truth (correct) target values. - y_pred : array-like of shape (n_samples,) - Estimated target values. + Args: + y_true: Ground truth (correct) target values. + y_pred: Estimated target values. """ check_arrays(y_true, y_pred) return float(np.mean(np.abs(y_true - y_pred))) @@ -76,12 +69,9 @@ def mean_absolute_error(y_true: ndarray, y_pred: ndarray) -> float: def mean_bias_error(y_true: ndarray, y_pred: ndarray) -> float: """Mean bias error regression loss. - Parameters: - ----------- - y_true : array-like of shape (n_samples,) - Ground truth (correct) target values. - y_pred : array-like of shape (n_samples,) - Estimated target values. + Args: + y_true: Ground truth (correct) target values. + y_pred: Estimated target values. """ check_arrays(y_true, y_pred) return float(np.mean(y_pred - y_true)) @@ -90,12 +80,9 @@ def mean_bias_error(y_true: ndarray, y_pred: ndarray) -> float: def mean_squared_error(y_true: ndarray, y_pred: ndarray) -> float: """Mean squared error regression loss. - Parameters: - ----------- - y_true : array-like of shape (n_samples,) - Ground truth (correct) target values. - y_pred : array-like of shape (n_samples,) - Estimated target values. + Args: + y_true: Ground truth (correct) target values. + y_pred: Estimated target values. """ check_arrays(y_true, y_pred) return float(np.mean((y_true - y_pred) ** 2)) @@ -104,12 +91,9 @@ def mean_squared_error(y_true: ndarray, y_pred: ndarray) -> float: def root_mean_squared_error(y_true: ndarray, y_pred: ndarray) -> float: """Root mean squared error regression loss. - Parameters: - ----------- - y_true : array-like of shape (n_samples,) - Ground truth (correct) target values. - y_pred : array-like of shape (n_samples,) - Estimated target values. + Args: + y_true: Ground truth (correct) target values. + y_pred: Estimated target values. """ check_arrays(y_true, y_pred) return float(np.sqrt(np.mean((y_true - y_pred) ** 2))) @@ -118,14 +102,12 @@ def root_mean_squared_error(y_true: ndarray, y_pred: ndarray) -> float: def mean_absolute_percentage_error(y_true: ndarray, y_pred: ndarray) -> float: """Mean absolute percentage error regression loss. - Parameters: - ----------- - y_true : array-like of shape (n_samples,) - Ground truth (correct) target values. - y_pred : array-like of shape (n_samples,) - Estimated target values. + Args: + y_true: Ground truth (correct) target values. + y_pred: Estimated target values. """ check_arrays(y_true, y_pred) if np.any(y_true == 0): - raise ValueError("Found zero in true values. MAPE is undefined.") + msg = "Found zero in true values. MAPE is undefined." + raise ValueError(msg) return float(100.0 * np.mean(np.abs((y_true - y_pred) / y_true))) diff --git a/src/enfobench/evaluation/model.py b/src/enfobench/evaluation/model.py new file mode 100644 index 0000000..bbf90af --- /dev/null +++ b/src/enfobench/evaluation/model.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Protocol + +import pandas as pd + + +class ForecasterType(str, Enum): + point = "point" + quantile = "quantile" + density = "density" + ensemble = "ensemble" + + +@dataclass +class AuthorInfo: + """Author information. + + Attributes: + name: Name of the author. + email: Email of the author. + """ + + name: str + email: str | None = None + + +@dataclass +class ModelInfo: + """Model information. + + Attributes: + name: Name of the model. + authors: List of authors. + type: Type of the model. + params: Parameters of the model. + """ + + name: str + authors: list[AuthorInfo] + type: ForecasterType # noqa: A003 + params: dict[str, Any] = field(default_factory=dict) + + +class Model(Protocol): + def info(self) -> ModelInfo: + ... + + def forecast( + self, + horizon: int, + history: pd.DataFrame, + past_covariates: pd.DataFrame | None = None, + future_covariates: pd.DataFrame | None = None, + level: list[int] | None = None, + **kwargs, + ) -> pd.DataFrame: + ... diff --git a/src/enfobench/evaluation/protocols.py b/src/enfobench/evaluation/protocols.py index 3a67d6f..194d0af 100644 --- a/src/enfobench/evaluation/protocols.py +++ b/src/enfobench/evaluation/protocols.py @@ -1,197 +1,16 @@ -from enum import Enum -from typing import Any, Dict, List, Optional, Protocol - -import pandas as pd -from pydantic import BaseModel - - -class ForecasterType(str, Enum): - point = "point" - quantile = "quantile" - density = "density" - ensemble = "ensemble" - - -class ModelInfo(BaseModel): - """Model information. - - Args - ---- - name: - Name of the model. - type: - Type of the model. - params: - Parameters of the model. - """ - - name: str - type: ForecasterType - params: Dict[str, Any] - - -class EnvironmentInfo(BaseModel): - packages: Dict[str, str] - - -class Model(Protocol): - def info(self) -> ModelInfo: - ... - - def forecast( - self, - horizon: int, - history: pd.DataFrame, - past_covariates: Optional[pd.DataFrame] = None, - future_covariates: Optional[pd.DataFrame] = None, - level: Optional[List[int]] = None, - **kwargs, - ) -> pd.DataFrame: - ... - - class Dataset: - def __init__( - self, - target: pd.Series, - covariates: Optional[pd.DataFrame] = None, - external_forecasts: Optional[pd.DataFrame] = None, - freq: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ): - self.freq = freq or target.index.inferred_freq - self.target = self._check_target(target.copy()) - if self.freq is None: - raise ValueError("Frequency of the target time series cannot be inferred.") - - self.start_date = self.target["ds"].iloc[0] - self.end_date = self.target["ds"].iloc[-1] - - self.covariates = ( - self._check_covariates(covariates.copy()) if covariates is not None else None - ) - self.external_forecasts = ( - self._check_external_forecasts(external_forecasts.copy()) - if external_forecasts is not None - else None - ) - self.metadata = metadata or {} - - def _check_target(self, y: pd.Series) -> pd.DataFrame: - # TODO: replace manual checks with pandera schema - if not isinstance(y.index, pd.DatetimeIndex): - raise ValueError("Index of y must be a DatetimeIndex") - y.rename_axis("ds", inplace=True) - y.sort_index(inplace=True) - y = y.to_frame("y").reset_index() - return y - - def _check_covariates(self, X: pd.DataFrame) -> pd.DataFrame: - # TODO: replace manual checks with pandera schema - if not isinstance(X.index, pd.DatetimeIndex): - raise ValueError("Index of X must be a DatetimeIndex") - - X.rename_axis("ds", inplace=True) - X.sort_index(inplace=True) - - if X.index[0] > self.start_date: - raise ValueError( - "Covariates must be provided for the full target timeframe, covariates start after target values." - ) - - if X.index[-1] < self.end_date: - raise ValueError( - "Covariates must be provided for the full target timeframe, covariates end before target values." - ) - - X.reset_index(inplace=True) - return X - - def _check_external_forecasts(self, X: pd.DataFrame) -> pd.DataFrame: - # TODO: replace manual checks with pandera schema - first_forecast_date = X.cutoff_date.min() - last_forecast_date = X.cutoff_date.max() - last_forecast_end_date = X[X.cutoff_date == last_forecast_date].ds.max() - - if first_forecast_date > self.start_date: - raise ValueError( - "External forecasts must be provided for the full target timeframe, " - "forecasts start after target values." - ) - - if last_forecast_end_date < self.end_date: - raise ValueError( - "External forecasts must be provided for the full target timeframe, " - "forecasts end before target values." - ) - - return X - - def _check_cutoff_in_rage(self, cutoff_date: pd.Timestamp): - if cutoff_date < self.start_date: - raise IndexError( - f"Cutoff date is before the start date: {cutoff_date} < {self.start_date}." - ) - - if cutoff_date > self.end_date: - raise IndexError( - f"Cutoff date is after the end date: {cutoff_date} > {self.end_date}." - ) - - def get_history(self, cutoff_date: pd.Timestamp) -> pd.DataFrame: - """Returns the history of the target variable up to the cutoff date. - - Parameters - ---------- - cutoff_date : pd.Timestamp - The cutoff date. - - Returns - ------- - The history of the target variable up to the cutoff date. - """ - self._check_cutoff_in_rage(cutoff_date) - return self.target[self.target.ds <= cutoff_date] - - def get_past_covariates(self, cutoff_date: pd.Timestamp) -> Optional[pd.DataFrame]: - """Returns the past covariates for the cutoff date. - - Parameters - ---------- - cutoff_date : pd.Timestamp - The cutoff date. - - Returns - ------- - The past covariates up until the cutoff date. - + def __init_subclass__(cls, **kwargs): + """This throws a deprecation warning on subclassing.""" + msg = f""" + {cls.__name__} is deprecated. + Import it as 'from enfobench.dataset import Dataset' instead. """ - if self.covariates is None: - return None - - self._check_cutoff_in_rage(cutoff_date) - return self.covariates[self.covariates.ds <= cutoff_date] - - def get_future_covariates(self, cutoff_date: pd.Timestamp) -> Optional[pd.DataFrame]: - """Returns the future covariates for the cutoff date. + raise DeprecationWarning(msg) - Parameters - ---------- - cutoff_date : pd.Timestamp - The cutoff date. - - Returns - ------- - The last external forecasts made before the cutoff date. + def __init__(self, *args, **kwargs): # noqa + """This throws a deprecation warning on initialization.""" + msg = f""" + {self.__class__.__name__} is deprecated. + Import it as 'from enfobench.dataset import Dataset' instead. """ - if self.external_forecasts is None: - return None - - self._check_cutoff_in_rage(cutoff_date) - last_past_cutoff_date = self.external_forecasts.ds[ - self.external_forecasts.ds <= cutoff_date - ].max() - return self.external_forecasts[ - (self.external_forecasts.cutoff_date == last_past_cutoff_date) - & (self.external_forecasts.ds > cutoff_date) - ] + raise DeprecationWarning(msg) diff --git a/src/enfobench/evaluation/server.py b/src/enfobench/evaluation/server.py index 4150aeb..a39f334 100644 --- a/src/enfobench/evaluation/server.py +++ b/src/enfobench/evaluation/server.py @@ -1,21 +1,62 @@ import io -from typing import Annotated, List, Optional +import sys +from typing import Annotated, Any import pandas as pd import pkg_resources from fastapi import FastAPI, File, Query from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse +from pydantic import BaseModel +from starlette.responses import RedirectResponse -from enfobench.evaluation.protocols import EnvironmentInfo, Model, ModelInfo +from enfobench.evaluation.model import ForecasterType, Model + + +class AuthorInfo(BaseModel): + """Author information. + + Attributes: + name: Name of the author. + email: Email of the author. + """ + + name: str + email: str | None = None + + +class ModelInfo(BaseModel): + """Model information. + + Attributes: + name: Name of the model. + authors: List of authors. + type: Type of the model. + params: Parameters of the model. + """ + + name: str + authors: list[AuthorInfo] + type: ForecasterType # noqa: A003 + params: dict[str, Any] + + +class EnvironmentInfo(BaseModel): + python: str + packages: dict[str, str] def server_factory(model: Model) -> FastAPI: app = FastAPI() environment = EnvironmentInfo( - packages={package.key: package.version for package in pkg_resources.working_set} + python=sys.version, + packages={package.key: package.version for package in pkg_resources.working_set}, ) + @app.get("/", include_in_schema=False) + async def index(): + return RedirectResponse(url="/docs") + @app.get("/info", response_model=ModelInfo) async def model_info(): """Return model information.""" @@ -30,19 +71,13 @@ async def environment_info(): async def forecast( horizon: int, history: Annotated[bytes, File()], - past_covariates: Annotated[Optional[bytes], File()] = None, - future_covariates: Annotated[Optional[bytes], File()] = None, - level: Optional[List[int]] = Query(None), + past_covariates: Annotated[bytes | None, File()] = None, + future_covariates: Annotated[bytes | None, File()] = None, + level: list[int] | None = Query(None), # noqa: B008 ): history_df = pd.read_parquet(io.BytesIO(history)) - past_covariates_df = ( - pd.read_parquet(io.BytesIO(past_covariates)) if past_covariates is not None else None - ) - future_covariates_df = ( - pd.read_parquet(io.BytesIO(future_covariates)) - if future_covariates is not None - else None - ) + past_covariates_df = pd.read_parquet(io.BytesIO(past_covariates)) if past_covariates is not None else None + future_covariates_df = pd.read_parquet(io.BytesIO(future_covariates)) if future_covariates is not None else None forecast_df = model.forecast( horizon=horizon, @@ -52,6 +87,8 @@ async def forecast( level=level, ) forecast_df.fillna(0, inplace=True) + forecast_df.rename_axis("timestamp", inplace=True) + forecast_df.reset_index(inplace=True) response = { "forecast": jsonable_encoder(forecast_df.to_dict(orient="records")), diff --git a/src/enfobench/evaluation/utils.py b/src/enfobench/evaluation/utils.py index 6a8f9a4..2bad24c 100644 --- a/src/enfobench/evaluation/utils.py +++ b/src/enfobench/evaluation/utils.py @@ -1,45 +1,36 @@ -import warnings - import pandas as pd def steps_in_horizon(horizon: pd.Timedelta, freq: str) -> int: """Return the number of steps in a given horizon. - Parameters - ---------- - horizon: - The horizon to be split into steps. - freq: - The frequency of the horizon. + Args: + horizon: The horizon to be split into steps. + freq: The frequency of the horizon. - Returns - ------- + Returns: The number of steps in the horizon. """ freq = "1" + freq if not freq[0].isdigit() else freq periods = horizon / pd.Timedelta(freq) if not periods.is_integer(): - raise ValueError("Horizon is not a multiple of the frequency") + msg = f"Horizon {horizon} is not a multiple of the frequency {freq}" + raise ValueError(msg) return int(periods) def create_forecast_index(history: pd.DataFrame, horizon: int) -> pd.DatetimeIndex: """Create time index for a forecast horizon. - Parameters - ---------- - history: - The history of the time series. - horizon: - The forecast horizon. + Args: + history: The history of the time series. + horizon: The forecast horizon. - Returns - ------- + Returns: The time index for the forecast horizon. """ - last_date = history["ds"].iloc[-1] - inferred_freq = history["ds"].dt.freq + last_date = history.index[-1] + inferred_freq = history.index.inferred_freq freq = "1" + inferred_freq if not inferred_freq[0].isdigit() else inferred_freq return pd.date_range( start=last_date + pd.Timedelta(freq), @@ -48,55 +39,60 @@ def create_forecast_index(history: pd.DataFrame, horizon: int) -> pd.DatetimeInd ) -def create_perfect_forecasts_from_covariates( - covariates: pd.DataFrame, +def generate_cutoff_dates( + start_date: pd.Timestamp, + end_date: pd.Timestamp, horizon: pd.Timedelta, step: pd.Timedelta, - **kwargs, -) -> pd.DataFrame: - """Create forecasts from covariates. - - Sometimes external forecasts are not available for the entire horizon. This function creates - external forecast dataframe from external covariates as a perfect forecast. - - Parameters - ---------- - covariates: - The external covariates. - horizon: - The forecast horizon. - step: - The step size between forecasts. - - Returns - ------- - The external forecast dataframe. +) -> list[pd.Timestamp]: + """Generate cutoff dates for cross-validation between two dates. + + The cutoff dates are separated by a fixed step size and the last cutoff date is a horizon away from the end date. + + Args: + start_date: Start date of the time series. + end_date: End date of the time series. + horizon: Forecast horizon. + step: Step size between cutoff dates. + + Examples + -------- + >>> generate_cutoff_dates( + ... start_date=pd.Timestamp("2020-01-01"), + ... end_date=pd.Timestamp("2020-01-05"), + ... horizon=pd.Timedelta("2 days"), + ... step=pd.Timedelta("1 day"), + ... ) + [ + Timestamp('2020-01-01 00:00:00'), + Timestamp('2020-01-02 00:00:00'), + Timestamp('2020-01-03 00:00:00'), + ] """ - if kwargs.get("start") is not None: - start = kwargs.get("start") - else: - start = covariates.index[0] - - last_date = covariates.index[-1] - - forecasts = [] - while start + horizon <= last_date: - forecast = covariates.loc[ - (covariates.index > start) & (covariates.index <= start + horizon) - ] - forecast.insert(0, "cutoff_date", start) - forecast.rename_axis("ds", inplace=True) - forecast.reset_index(inplace=True) - - if len(forecast) == 0: - warnings.warn( - f"Covariates not found for {start} - {start + horizon}, cannot make forecast at step {start}", - UserWarning, - stacklevel=2, - ) - - forecasts.append(forecast) - start += step - - forecast_df = pd.concat(forecasts, ignore_index=True) - return forecast_df + if horizon <= pd.Timedelta(0): + msg = f"Horizon must be positive, got {horizon}." + raise ValueError(msg) + + if step <= pd.Timedelta(0): + msg = f"Step must be positive, got {step}." + raise ValueError(msg) + + if horizon > end_date - start_date: + msg = f"Horizon is longer than the evaluation period: {horizon} > {end_date - start_date}." + raise ValueError(msg) + + if end_date <= start_date: + msg = f"End date must be after the starting date, got {end_date} <= {start_date}." + raise ValueError(msg) + + cutoff_dates = [] + + cutoff = start_date + while cutoff <= end_date - horizon: + cutoff_dates.append(cutoff) + cutoff += step + + if not cutoff_dates: + msg = f"No cutoff dates between {start_date} and {end_date} with horizon {horizon} and step {step}." + raise ValueError(msg) + return cutoff_dates diff --git a/tests/conftest.py b/tests/conftest.py index 20b98ef..df07889 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from enfobench.evaluation import ForecasterType, ModelInfo +from enfobench import AuthorInfo, ForecasterType, ModelInfo from enfobench.evaluation.utils import create_forecast_index @@ -13,6 +13,9 @@ def __init__(self, param1: int): def info(self) -> ModelInfo: return ModelInfo( name="TestModel", + authors=[ + AuthorInfo("Author 1", "author-1@institution.org"), + ], type=ForecasterType.point, params={ "param1": 1, @@ -23,17 +26,18 @@ def forecast( self, horizon: int, history, - past_covariates=None, - future_covariates=None, - level=None, - **kwargs, + past_covariates=None, # noqa: ARG002 + future_covariates=None, # noqa: ARG002 + level=None, # noqa: ARG002 + **kwargs, # noqa: ARG002 ): index = create_forecast_index(history, horizon) + y_hat = np.full(horizon, fill_value=history["y"].mean()) + self.param1 return pd.DataFrame( + index=index, data={ - "ds": index, - "yhat": np.full(horizon, fill_value=history["y"].mean()) + self.param1, - } + "yhat": y_hat, + }, ) @@ -45,7 +49,12 @@ def model(): @pytest.fixture(scope="session") def target() -> pd.DataFrame: index = pd.date_range("2020-01-01", "2020-02-01", freq="30T") - y = pd.Series(np.random.random(len(index)), index=index) + y = pd.DataFrame( + index=index, + data={ + "y": np.random.random(len(index)), + }, + ) return y @@ -68,12 +77,10 @@ def external_forecasts() -> pd.DataFrame: forecasts = [] for cutoff_date in cutoff_dates: - index = pd.date_range( - cutoff_date + pd.Timedelta(hours=1), cutoff_date + pd.Timedelta(days=7), freq="H" - ) + index = pd.date_range(cutoff_date + pd.Timedelta(hours=1), cutoff_date + pd.Timedelta(days=7), freq="H") forecast = pd.DataFrame( data={ - "ds": index, + "timestamp": index, "forecast_1": np.random.random(len(index)), "forecast_2": np.random.random(len(index)), }, diff --git a/tests/test_evaluations/test_dataset.py b/tests/test_evaluations/test_dataset.py index c7dc1c3..8ee1b24 100644 --- a/tests/test_evaluations/test_dataset.py +++ b/tests/test_evaluations/test_dataset.py @@ -1,8 +1,9 @@ from random import randrange import pandas as pd +import pytest -from enfobench.evaluation.protocols import Dataset +from enfobench.dataset import Dataset def random_date(start: pd.Timestamp, end: pd.Timestamp, resolution: int = 1) -> pd.Timestamp: @@ -12,114 +13,49 @@ def random_date(start: pd.Timestamp, end: pd.Timestamp, resolution: int = 1) -> """ delta = end - start int_delta = int(delta.total_seconds()) - random_second = randrange(0, int_delta, resolution) + random_second = randrange(0, int_delta, resolution) # noqa: S311 return start + pd.Timedelta(seconds=random_second) -def test_univariate_second(target): +@pytest.mark.parametrize("resolution", [1, 60, 900, 3600]) +def test_get_history(target, resolution): ds = Dataset(target=target) + cutoff_date = random_date(ds._first_available_target_date, ds._last_available_target_date, resolution=resolution) - cutoff_date = random_date(ds.start_date, ds.end_date) - print(cutoff_date) history = ds.get_history(cutoff_date) - assert (history.ds <= cutoff_date).all() + assert history.index.name == "timestamp" + assert isinstance(history.index, pd.DatetimeIndex) + assert "y" in history.columns + assert len(history.columns) == 1 + assert (history.index <= cutoff_date).all() -def test_univariate_minute(target): - ds = Dataset(target=target) - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=60) - print(cutoff_date) - history = ds.get_history(cutoff_date) - assert (history.ds <= cutoff_date).all() - - -def test_univariate_quarter(target): - ds = Dataset(target=target) - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=900) - print(cutoff_date) - history = ds.get_history(cutoff_date) - assert (history.ds <= cutoff_date).all() - - -def test_univariate_hour(target): - ds = Dataset(target=target) - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=3600) - print(cutoff_date) - history = ds.get_history(cutoff_date) - assert (history.ds <= cutoff_date).all() - - -def test_multivariate_second(target, covariates): - ds = Dataset(target=target, covariates=covariates) - - cutoff_date = random_date(ds.start_date, ds.end_date) - print(cutoff_date) - past_cov = ds.get_past_covariates(cutoff_date) - assert (past_cov.ds <= cutoff_date).all() - - -def test_multivariate_minute(target, covariates): - ds = Dataset(target=target, covariates=covariates) - - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=60) - print(cutoff_date) - past_cov = ds.get_past_covariates(cutoff_date) - assert (past_cov.ds <= cutoff_date).all() +@pytest.mark.parametrize("resolution", [1, 60, 900, 3600]) +def test_get_past_covariates(target, covariates, resolution): + ds = Dataset(target=target, past_covariates=covariates) + cutoff_date = random_date(ds._first_available_target_date, ds._last_available_target_date, resolution=resolution) -def test_multivariate_quarter(target, covariates): - ds = Dataset(target=target, covariates=covariates) - - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=900) - print(cutoff_date) - past_cov = ds.get_past_covariates(cutoff_date) - assert (past_cov.ds <= cutoff_date).all() - - -def test_multivariate_hour(target, covariates): - ds = Dataset(target=target, covariates=covariates) - - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=3600) - print(cutoff_date) past_cov = ds.get_past_covariates(cutoff_date) - assert (past_cov.ds <= cutoff_date).all() - - -def test_external_forecasts_second(target, covariates, external_forecasts): - ds = Dataset(target=target, covariates=covariates, external_forecasts=external_forecasts) - - cutoff_date = random_date(ds.start_date, ds.end_date) - print(cutoff_date) - future_cov = ds.get_future_covariates(cutoff_date) - assert (future_cov.cutoff_date <= cutoff_date).all() - assert (future_cov.ds > cutoff_date).all() - - -def test_external_forecasts_minute(target, covariates, external_forecasts): - ds = Dataset(target=target, covariates=covariates, external_forecasts=external_forecasts) - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=60) - print(cutoff_date) - future_cov = ds.get_future_covariates(cutoff_date) - assert (future_cov.cutoff_date <= cutoff_date).all() - assert (future_cov.ds > cutoff_date).all() + assert past_cov.index.name == "timestamp" + assert isinstance(past_cov.index, pd.DatetimeIndex) + for col in covariates.columns: + assert col in past_cov.columns + assert (past_cov.index <= cutoff_date).all() -def test_external_forecasts_quarter(target, covariates, external_forecasts): - ds = Dataset(target=target, covariates=covariates, external_forecasts=external_forecasts) +@pytest.mark.parametrize("resolution", [1, 60, 900, 3600]) +def test_get_future_covariates(target, covariates, external_forecasts, resolution): + ds = Dataset(target=target, past_covariates=covariates, future_covariates=external_forecasts) + cutoff_date = random_date(ds._first_available_target_date, ds._last_available_target_date, resolution=resolution) - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=900) - print(cutoff_date) future_cov = ds.get_future_covariates(cutoff_date) - assert (future_cov.cutoff_date <= cutoff_date).all() - assert (future_cov.ds > cutoff_date).all() - -def test_external_forecasts_hour(target, covariates, external_forecasts): - ds = Dataset(target=target, covariates=covariates, external_forecasts=external_forecasts) - - cutoff_date = random_date(ds.start_date, ds.end_date, resolution=3600) - print(cutoff_date) - future_cov = ds.get_future_covariates(cutoff_date) + assert future_cov.index.name == "timestamp" + assert isinstance(future_cov.index, pd.DatetimeIndex) + for col in external_forecasts.columns: + if col not in ["timestamp", "cutoff_date"]: + assert col in future_cov.columns assert (future_cov.cutoff_date <= cutoff_date).all() - assert (future_cov.ds > cutoff_date).all() + assert (future_cov.index > cutoff_date).all() diff --git a/tests/test_evaluations/test_evaluate.py b/tests/test_evaluations/test_evaluate.py index 558e73e..bdc14bc 100644 --- a/tests/test_evaluations/test_evaluate.py +++ b/tests/test_evaluations/test_evaluate.py @@ -2,13 +2,13 @@ import pytest from starlette.testclient import TestClient +from enfobench.dataset import Dataset from enfobench.evaluation import ( - Dataset, ForecastClient, cross_validate, - generate_cutoff_dates, ) from enfobench.evaluation.server import server_factory +from enfobench.evaluation.utils import generate_cutoff_dates @pytest.mark.parametrize( @@ -89,7 +89,7 @@ def test_cross_validate_univariate_locally(model, target): ) assert isinstance(forecasts, pd.DataFrame) - assert "ds" in forecasts.columns + assert "timestamp" in forecasts.columns assert "yhat" in forecasts.columns assert "cutoff_date" in forecasts.columns @@ -109,7 +109,7 @@ def test_cross_validate_univariate_via_server(model, target): ) assert isinstance(forecasts, pd.DataFrame) - assert "ds" in forecasts.columns + assert "timestamp" in forecasts.columns assert "yhat" in forecasts.columns assert "cutoff_date" in forecasts.columns @@ -125,7 +125,7 @@ def test_cross_validate_multivariate_locally(model, target, covariates, external ) assert isinstance(forecasts, pd.DataFrame) - assert "ds" in forecasts.columns + assert "timestamp" in forecasts.columns assert "yhat" in forecasts.columns assert "cutoff_date" in forecasts.columns @@ -145,6 +145,6 @@ def test_cross_validate_multivariate_via_server(model, target, covariates, exter ) assert isinstance(forecasts, pd.DataFrame) - assert "ds" in forecasts.columns + assert "timestamp" in forecasts.columns assert "yhat" in forecasts.columns assert "cutoff_date" in forecasts.columns diff --git a/tests/test_evaluations/test_metrics.py b/tests/test_evaluations/test_metrics.py index 42d0074..d6d51ba 100644 --- a/tests/test_evaluations/test_metrics.py +++ b/tests/test_evaluations/test_metrics.py @@ -42,13 +42,25 @@ def test_metric_raises_with_empty_array(metric): assert metric(np.array([]), np.array([1, 2, 3])) == 0 -def test_mean_absolute_error(): - assert mean_absolute_error(np.array([1, 2, 3]), np.array([1, 2, 3])) == 0.0 - assert mean_absolute_error(np.array([1, 2, 3]), np.array([2, 3, 4])) == 1.0 - assert mean_absolute_error(np.array([1, 2, 3]), np.array([0, 1, 2])) == 1.0 +@pytest.mark.parametrize( + "a,b,value", + [ + (np.array([1, 2, 3]), np.array([1, 2, 3]), 0.0), + (np.array([1, 2, 3]), np.array([2, 3, 4]), 1.0), + (np.array([1, 2, 3]), np.array([0, 1, 2]), 1.0), + ], +) +def test_mean_absolute_error(a, b, value): + assert mean_absolute_error(a, b) == value -def test_mean_bias_error(): - assert mean_bias_error(np.array([1, 2, 3]), np.array([1, 2, 3])) == 0.0 - assert mean_bias_error(np.array([1, 2, 3]), np.array([2, 3, 4])) == 1.0 - assert mean_bias_error(np.array([1, 2, 3]), np.array([0, 1, 2])) == -1.0 +@pytest.mark.parametrize( + "a,b,value", + [ + (np.array([1, 2, 3]), np.array([1, 2, 3]), 0.0), + (np.array([1, 2, 3]), np.array([2, 3, 4]), 1.0), + (np.array([1, 2, 3]), np.array([0, 1, 2]), -1.0), + ], +) +def test_mean_bias_error(a, b, value): + assert mean_bias_error(a, b) == value diff --git a/tests/test_evaluations/test_server.py b/tests/test_evaluations/test_server.py index ea58e28..f5994fc 100644 --- a/tests/test_evaluations/test_server.py +++ b/tests/test_evaluations/test_server.py @@ -1,3 +1,5 @@ +from http import HTTPStatus + import numpy as np import pandas as pd import pytest @@ -17,9 +19,15 @@ def forecast_client(model): def test_info_endpoint(forecast_client): response = forecast_client.get("/info") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == { "name": "TestModel", + "authors": [ + { + "name": "Author 1", + "email": "author-1@institution.org", + } + ], "type": "point", "params": { "param1": 1, @@ -30,7 +38,7 @@ def test_info_endpoint(forecast_client): def test_environment_endpoint(forecast_client): response = forecast_client.get("/environment") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert "packages" in response.json() for package_name, package_version in response.json()["packages"].items(): assert isinstance(package_name, str) @@ -41,10 +49,10 @@ def test_forecast_endpoint(forecast_client): horizon = 24 target_index = pd.date_range(start="2020-01-01", end="2021-01-01", freq="1H") history_df = pd.DataFrame( + index=target_index, data={ "y": np.random.normal(size=len(target_index)), - "ds": target_index, - } + }, ) response = forecast_client.post( @@ -56,16 +64,13 @@ def test_forecast_endpoint(forecast_client): "history": to_buffer(history_df), }, ) - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert "forecast" in response.json() forecast = response.json()["forecast"] assert isinstance(forecast, list) and len(forecast) == horizon for forecast_item in forecast: - assert "ds" in forecast_item and pd.Timestamp(forecast_item["ds"]) + assert "timestamp" in forecast_item and pd.Timestamp(forecast_item["timestamp"]) assert "yhat" in forecast_item and isinstance( forecast_item["yhat"], - ( - float, - int, - ), + float | int, ) diff --git a/tests/test_evaluations/test_utils.py b/tests/test_evaluations/test_utils.py index c3a790b..aa3a278 100644 --- a/tests/test_evaluations/test_utils.py +++ b/tests/test_evaluations/test_utils.py @@ -1,6 +1,7 @@ import pandas as pd import pytest +import enfobench.dataset.utils from enfobench.evaluation import utils @@ -24,9 +25,9 @@ def test_steps_in_horizon_raises_with_non_multiple_horizon(): def test_create_forecast_index(target): - history = target.to_frame("y").rename_axis("ds").reset_index() + history = target horizon = 96 - last_date = history["ds"].iloc[-1] + last_date = history.index[-1] index = utils.create_forecast_index(history, horizon) @@ -37,13 +38,13 @@ def test_create_forecast_index(target): def test_create_perfect_forecasts_from_covariates(covariates): - forecasts = utils.create_perfect_forecasts_from_covariates( + forecasts = enfobench.dataset.utils.create_perfect_forecasts_from_covariates( covariates, horizon=pd.Timedelta("7 days"), step=pd.Timedelta("1D"), ) assert isinstance(forecasts, pd.DataFrame) - assert "ds" in forecasts.columns + assert "timestamp" in forecasts.columns assert "cutoff_date" in forecasts.columns assert all(col in forecasts.columns for col in covariates.columns)