From 509ae946b1e118105445d5ea6607468663602411 Mon Sep 17 00:00:00 2001 From: attilabalint Date: Wed, 12 Jun 2024 13:04:39 +0200 Subject: [PATCH] added GasDemandDataset --- README.md | 2 + models/amazon-chronos/src/main.py | 13 ++- models/nixtla-timegpt/src/main.py | 8 +- models/salesforce-moirai/src/main.py | 21 ++-- pyproject.toml | 15 ++- src/enfobench/__version__.py | 2 +- src/enfobench/core/model.py | 8 +- src/enfobench/datasets/__init__.py | 2 + src/enfobench/datasets/gas_demand.py | 166 +++++++++++++++++++++++++++ src/enfobench/evaluation/server.py | 2 +- 10 files changed, 209 insertions(+), 30 deletions(-) create mode 100644 src/enfobench/datasets/gas_demand.py diff --git a/README.md b/README.md index 08653ee..f83e47a 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,14 @@ benchmark forecast models. ## Datasets - **[Electricity demand](https://huggingface.co/datasets/EDS-lab/electricity-demand)** +- **[Gas demand](https://huggingface.co/datasets/EDS-lab/gas-demand)** - **[PV generation](https://huggingface.co/datasets/EDS-lab/pv-generation)** ## Dashboards - **[Electricity demand](https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand)** +- **[Gas demand](https://huggingface.co/spaces/EDS-lab/EnFoBench-GasDemand)** - **[PV generation](https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration)** diff --git a/models/amazon-chronos/src/main.py b/models/amazon-chronos/src/main.py index 4215334..298c911 100644 --- a/models/amazon-chronos/src/main.py +++ b/models/amazon-chronos/src/main.py @@ -21,8 +21,13 @@ def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = N self.ctx_length = ctx_length def info(self) -> ModelInfo: + name = ( + "Amazon." + f'{".".join(map(str.capitalize, self.model_name.split("-")))}' + f'{".CTX" + self.ctx_length if self.ctx_length else ""}' + ) return ModelInfo( - name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}{".CTX" + self.ctx_length if self.ctx_length else ""}', + name=name, authors=[ AuthorInfo(name="Attila Balint", email="attila.balint@kuleuven.be"), ], @@ -48,9 +53,9 @@ def forecast( model_dir = root_dir / "models" / self.model_name if not model_dir.exists(): - raise FileNotFoundError( - f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." - ) + msg = f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." + raise FileNotFoundError(msg) + pipeline = ChronosPipeline.from_pretrained( model_dir, device_map=device, diff --git a/models/nixtla-timegpt/src/main.py b/models/nixtla-timegpt/src/main.py index f352414..5cdf501 100644 --- a/models/nixtla-timegpt/src/main.py +++ b/models/nixtla-timegpt/src/main.py @@ -3,7 +3,7 @@ import pandas as pd from nixtla import NixtlaClient -from pyrate_limiter import Duration, Rate, Limiter, BucketFullException +from pyrate_limiter import Duration, Limiter, Rate from enfobench import AuthorInfo, ForecasterType, ModelInfo from enfobench.evaluation.server import server_factory @@ -68,14 +68,14 @@ def forecast( raise ValueError(msg) # post-process forecast - forecast = timegpt_fcst_df.rename(columns={"TimeGPT": 'yhat'}) - forecast['timestamp'] = pd.to_datetime(forecast.timestamp) + forecast = timegpt_fcst_df.rename(columns={"TimeGPT": "yhat"}) + forecast["timestamp"] = pd.to_datetime(forecast.timestamp) forecast = forecast.set_index("timestamp") return forecast api_key = os.getenv("NIXTLA_API_KEY") -long_horizon = bool(int(os.getenv("ENFOBENCH_MODEL_LONG_HORIZON", 0))) +long_horizon = bool(int(os.getenv("ENFOBENCH_MODEL_LONG_HORIZON", "0"))) # Instantiate your model model = NixtlaTimeGPTModel(api_key=api_key, long_horizon=long_horizon) diff --git a/models/salesforce-moirai/src/main.py b/models/salesforce-moirai/src/main.py index b9c5cc4..f4b0083 100644 --- a/models/salesforce-moirai/src/main.py +++ b/models/salesforce-moirai/src/main.py @@ -4,12 +4,11 @@ import pandas as pd import torch from gluonts.dataset.pandas import PandasDataset +from uni2ts.model.moirai import MoiraiForecast from enfobench import AuthorInfo, ForecasterType, ModelInfo from enfobench.evaluation.server import server_factory from enfobench.evaluation.utils import create_forecast_index, periods_in_duration -from uni2ts.model.moirai import MoiraiForecast - # Check for GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" @@ -24,8 +23,13 @@ def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = N self.size = model_name.split("-")[-1] def info(self) -> ModelInfo: + name = ( + "Salesforce.Moirai-1.0-R." + f'{self.size.capitalize()}' + f'{f".CTX{self.ctx_length}" if self.ctx_length else ""}' + ) return ModelInfo( - name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}{f".CTX{self.ctx_length}" if self.ctx_length else ""}', + name=name, authors=[ AuthorInfo(name="Attila Balint", email="attila.balint@kuleuven.be"), ], @@ -54,9 +58,8 @@ def forecast( model_dir = root_dir / "models" / self.model_name if not model_dir.exists(): - raise FileNotFoundError( - f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." - ) + msg = f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." + raise FileNotFoundError(msg) if self.ctx_length is None: ctx_length = len(history) @@ -65,10 +68,10 @@ def forecast( # Prepare pre-trained model model = MoiraiForecast.load_from_checkpoint( - checkpoint_path=str(model_dir / 'model.ckpt'), + checkpoint_path=str(model_dir / "model.ckpt"), prediction_length=horizon, context_length=ctx_length, - patch_size='auto', + patch_size="auto", num_samples=self.num_samples, target_dim=1, feat_dynamic_real_dim=0, @@ -88,7 +91,7 @@ def forecast( model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small") -num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1)) +num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", "1")) ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH") # Instantiate your model diff --git a/pyproject.toml b/pyproject.toml index 0704c12..61c15b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Energy forecast benchmarking toolkit." readme = "README.md" requires-python = ">=3.10" -license = "BSD-2-clause" +license = "BSD-3-clause" keywords = [ "energy", "forecasting", @@ -82,6 +82,7 @@ dependencies = [ "pytest", "pytest-cov", "httpx", + "setuptools", # needed to import pkg_resources ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -100,12 +101,12 @@ dependencies = [ [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/enfobench tests}" style = [ - "ruff {args:.}", + "ruff check {args:.}", "black --check --diff {args:.}", ] fmt = [ "black {args:.}", - "ruff --fix {args:.}", + "ruff check --fix {args:.}", "style", ] all = [ @@ -130,6 +131,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py310" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -172,13 +175,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["enfobench"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] "models/**/*" = ["ARG001", "ARG002", "FBT001"] diff --git a/src/enfobench/__version__.py b/src/enfobench/__version__.py index 906d362..43c4ab0 100644 --- a/src/enfobench/__version__.py +++ b/src/enfobench/__version__.py @@ -1 +1 @@ -__version__ = "0.6.0" +__version__ = "0.6.1" diff --git a/src/enfobench/core/model.py b/src/enfobench/core/model.py index 72dfa73..94e385b 100644 --- a/src/enfobench/core/model.py +++ b/src/enfobench/core/model.py @@ -38,13 +38,12 @@ class ModelInfo: name: str authors: list[AuthorInfo] - type: ForecasterType # noqa: A003 + type: ForecasterType params: dict[str, Any] = field(default_factory=dict) class Model(Protocol): - def info(self) -> ModelInfo: - ... + def info(self) -> ModelInfo: ... def forecast( self, @@ -55,5 +54,4 @@ def forecast( metadata: dict | None = None, level: list[int] | None = None, **kwargs, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... diff --git a/src/enfobench/datasets/__init__.py b/src/enfobench/datasets/__init__.py index a168ee3..99a4439 100644 --- a/src/enfobench/datasets/__init__.py +++ b/src/enfobench/datasets/__init__.py @@ -1,7 +1,9 @@ from enfobench.datasets.electricity_demand import ElectricityDemandDataset +from enfobench.datasets.gas_demand import GasDemandDataset from enfobench.datasets.pv_generation import PVGenerationDataset __all__ = [ "ElectricityDemandDataset", "PVGenerationDataset", + "GasDemandDataset", ] diff --git a/src/enfobench/datasets/gas_demand.py b/src/enfobench/datasets/gas_demand.py new file mode 100644 index 0000000..9016aca --- /dev/null +++ b/src/enfobench/datasets/gas_demand.py @@ -0,0 +1,166 @@ +from typing import Any + +import duckdb +import pandas as pd + +from enfobench.core import Subset +from enfobench.datasets.base import DatasetBase + +Metadata = dict[str, Any] + + +class MetadataSubset(Subset): + """Metadata subset of the HuggingFace dataset containing all metadata about the meters. + + Args: + file_path: The path to the subset file. + """ + + def list_unique_ids(self) -> list[str]: + """Lists all unique ids.""" + query = """ + SELECT DISTINCT unique_id + FROM read_parquet(?) + """ + conn = duckdb.connect(":memory:") + return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().unique_id.tolist() + + def get_by_unique_id(self, unique_id: str) -> Metadata: + """Returns the metadata for the given unique id. + + Args: + unique_id: The unique id of the meter. + """ + query = """ + SELECT * + FROM read_parquet(?) + WHERE unique_id = ? + """ + conn = duckdb.connect(":memory:") + df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df() + if df.empty: + msg = f"Unique id '{unique_id}' was not found." + raise KeyError(msg) + return df.iloc[0].to_dict() + + +class WeatherSubset(Subset): + """Weather subset of the HuggingFace dataset containing all weather data. + + Args: + file_path: The path to the subset file. + """ + + def list_location_ids(self) -> list[str]: + """Lists all location ids.""" + query = """ + SELECT DISTINCT location_id + FROM read_parquet(?) + """ + conn = duckdb.connect(":memory:") + return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().location_id.tolist() + + def get_by_location_id(self, location_id: str, columns: list[str] | None = None) -> pd.DataFrame: + """Returns the weather data for the given location id. + + Args: + location_id: The location id of the weather station. + columns: The columns to return. If None, all columns are returned. + """ + conn = duckdb.connect(":memory:") + + if columns: + query = f""" + SELECT timestamp, {", ".join(columns)} + FROM read_parquet(?) + WHERE location_id = ? + """ # noqa: S608 + else: + query = """ + SELECT * + FROM read_parquet(?) + WHERE location_id = ? + """ + df = conn.execute(query, parameters=[str(self.file_path), location_id]).fetch_df() + if df.empty: + msg = f"Location id '{location_id}' was not found." + raise KeyError(msg) + + # Remove location_id and set timestamp as index + df.drop(columns=["location_id"], inplace=True, errors="ignore") + df.set_index("timestamp", inplace=True) + return df + + +class DemandSubset(Subset): + """Data subset of the HuggingFace dataset containing all gas demand data. + + Args: + file_path: The path to the subset file. + """ + + def get_by_unique_id(self, unique_id: str): + """Returns the demand data for the given unique id. + + Args: + unique_id: The unique id of the meter. + """ + query = """ + SELECT * + FROM read_parquet(?) + WHERE unique_id = ? + """ + conn = duckdb.connect(":memory:") + df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df() + if df.empty: + msg = f"Unique id '{unique_id}' was not found." + raise KeyError(msg) + + # Remove unique_id and set timestamp as index + df.drop(columns=["unique_id"], inplace=True, errors="ignore") + df.set_index("timestamp", inplace=True) + return df + + +class GasDemandDataset(DatasetBase): + """GasDemandDataset class representing the HuggingFace dataset. + + This class is a collection of all subsets inside HuggingFace dataset. + It provides an easy way to access the different subsets. + + Args: + directory: The directory where the HuggingFace dataset is located. + This directory should contain all the subset files. + """ + + HUGGINGFACE_DATASET = "EDS-lab/gas-demand" + SUBSETS = ("demand", "metadata", "weather") + + @property + def metadata_subset(self) -> MetadataSubset: + """Returns the metadata subset.""" + return MetadataSubset(self._get_subset_path("metadata")) + + @property + def weather_subset(self) -> WeatherSubset: + """Returns the weather subset.""" + return WeatherSubset(self._get_subset_path("weather")) + + @property + def demand_subset(self) -> DemandSubset: + """Returns the demand subset.""" + return DemandSubset(self._get_subset_path("demand")) + + def list_unique_ids(self) -> list[str]: + return self.metadata_subset.list_unique_ids() + + def list_location_ids(self) -> list[str]: + return self.weather_subset.list_location_ids() + + def get_data_by_unique_id(self, unique_id: str) -> tuple[pd.DataFrame, pd.DataFrame, Metadata]: + metadata = self.metadata_subset.get_by_unique_id(unique_id) + location_id = metadata["location_id"] + + demand = self.demand_subset.get_by_unique_id(unique_id) + weather = self.weather_subset.get_by_location_id(location_id) + return demand, weather, metadata diff --git a/src/enfobench/evaluation/server.py b/src/enfobench/evaluation/server.py index fcc5d73..e3fa74c 100644 --- a/src/enfobench/evaluation/server.py +++ b/src/enfobench/evaluation/server.py @@ -38,7 +38,7 @@ class ModelInfo(BaseModel): name: str authors: list[AuthorInfo] - type: ForecasterType # noqa: A003 + type: ForecasterType params: dict[str, Any]