Skip to content

Commit

Permalink
added GasDemandDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Jun 12, 2024
1 parent 1854524 commit 509ae94
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 30 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ benchmark forecast models.
## Datasets

- **[Electricity demand](https://huggingface.co/datasets/EDS-lab/electricity-demand)**
- **[Gas demand](https://huggingface.co/datasets/EDS-lab/gas-demand)**
- **[PV generation](https://huggingface.co/datasets/EDS-lab/pv-generation)**


## Dashboards

- **[Electricity demand](https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand)**
- **[Gas demand](https://huggingface.co/spaces/EDS-lab/EnFoBench-GasDemand)**
- **[PV generation](https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration)**


Expand Down
13 changes: 9 additions & 4 deletions models/amazon-chronos/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = N
self.ctx_length = ctx_length

def info(self) -> ModelInfo:
name = (
"Amazon."
f'{".".join(map(str.capitalize, self.model_name.split("-")))}'
f'{".CTX" + self.ctx_length if self.ctx_length else ""}'
)
return ModelInfo(
name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}{".CTX" + self.ctx_length if self.ctx_length else ""}',
name=name,
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
Expand All @@ -48,9 +53,9 @@ def forecast(

model_dir = root_dir / "models" / self.model_name
if not model_dir.exists():
raise FileNotFoundError(
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
)
msg = f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
raise FileNotFoundError(msg)

pipeline = ChronosPipeline.from_pretrained(
model_dir,
device_map=device,
Expand Down
8 changes: 4 additions & 4 deletions models/nixtla-timegpt/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pandas as pd
from nixtla import NixtlaClient
from pyrate_limiter import Duration, Rate, Limiter, BucketFullException
from pyrate_limiter import Duration, Limiter, Rate

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
Expand Down Expand Up @@ -68,14 +68,14 @@ def forecast(
raise ValueError(msg)

# post-process forecast
forecast = timegpt_fcst_df.rename(columns={"TimeGPT": 'yhat'})
forecast['timestamp'] = pd.to_datetime(forecast.timestamp)
forecast = timegpt_fcst_df.rename(columns={"TimeGPT": "yhat"})
forecast["timestamp"] = pd.to_datetime(forecast.timestamp)
forecast = forecast.set_index("timestamp")
return forecast


api_key = os.getenv("NIXTLA_API_KEY")
long_horizon = bool(int(os.getenv("ENFOBENCH_MODEL_LONG_HORIZON", 0)))
long_horizon = bool(int(os.getenv("ENFOBENCH_MODEL_LONG_HORIZON", "0")))

# Instantiate your model
model = NixtlaTimeGPTModel(api_key=api_key, long_horizon=long_horizon)
Expand Down
21 changes: 12 additions & 9 deletions models/salesforce-moirai/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import pandas as pd
import torch
from gluonts.dataset.pandas import PandasDataset
from uni2ts.model.moirai import MoiraiForecast

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import create_forecast_index, periods_in_duration
from uni2ts.model.moirai import MoiraiForecast


# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -24,8 +23,13 @@ def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = N
self.size = model_name.split("-")[-1]

def info(self) -> ModelInfo:
name = (
"Salesforce.Moirai-1.0-R."
f'{self.size.capitalize()}'
f'{f".CTX{self.ctx_length}" if self.ctx_length else ""}'
)
return ModelInfo(
name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}{f".CTX{self.ctx_length}" if self.ctx_length else ""}',
name=name,
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
Expand Down Expand Up @@ -54,9 +58,8 @@ def forecast(

model_dir = root_dir / "models" / self.model_name
if not model_dir.exists():
raise FileNotFoundError(
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
)
msg = f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
raise FileNotFoundError(msg)

if self.ctx_length is None:
ctx_length = len(history)
Expand All @@ -65,10 +68,10 @@ def forecast(

# Prepare pre-trained model
model = MoiraiForecast.load_from_checkpoint(
checkpoint_path=str(model_dir / 'model.ckpt'),
checkpoint_path=str(model_dir / "model.ckpt"),
prediction_length=horizon,
context_length=ctx_length,
patch_size='auto',
patch_size="auto",
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=0,
Expand All @@ -88,7 +91,7 @@ def forecast(


model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small")
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1))
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", "1"))
ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH")

# Instantiate your model
Expand Down
15 changes: 9 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dynamic = ["version"]
description = "Energy forecast benchmarking toolkit."
readme = "README.md"
requires-python = ">=3.10"
license = "BSD-2-clause"
license = "BSD-3-clause"
keywords = [
"energy",
"forecasting",
Expand Down Expand Up @@ -82,6 +82,7 @@ dependencies = [
"pytest",
"pytest-cov",
"httpx",
"setuptools", # needed to import pkg_resources
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
Expand All @@ -100,12 +101,12 @@ dependencies = [
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/enfobench tests}"
style = [
"ruff {args:.}",
"ruff check {args:.}",
"black --check --diff {args:.}",
]
fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"ruff check --fix {args:.}",
"style",
]
all = [
Expand All @@ -130,6 +131,8 @@ skip-string-normalization = true
[tool.ruff]
target-version = "py310"
line-length = 120

[tool.ruff.lint]
select = [
"A",
"ARG",
Expand Down Expand Up @@ -172,13 +175,13 @@ unfixable = [
"F401",
]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["enfobench"]

[tool.ruff.flake8-tidy-imports]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]
"models/**/*" = ["ARG001", "ARG002", "FBT001"]
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.0"
__version__ = "0.6.1"
8 changes: 3 additions & 5 deletions src/enfobench/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ class ModelInfo:

name: str
authors: list[AuthorInfo]
type: ForecasterType # noqa: A003
type: ForecasterType
params: dict[str, Any] = field(default_factory=dict)


class Model(Protocol):
def info(self) -> ModelInfo:
...
def info(self) -> ModelInfo: ...

def forecast(
self,
Expand All @@ -55,5 +54,4 @@ def forecast(
metadata: dict | None = None,
level: list[int] | None = None,
**kwargs,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...
2 changes: 2 additions & 0 deletions src/enfobench/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from enfobench.datasets.electricity_demand import ElectricityDemandDataset
from enfobench.datasets.gas_demand import GasDemandDataset
from enfobench.datasets.pv_generation import PVGenerationDataset

__all__ = [
"ElectricityDemandDataset",
"PVGenerationDataset",
"GasDemandDataset",
]
166 changes: 166 additions & 0 deletions src/enfobench/datasets/gas_demand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from typing import Any

import duckdb
import pandas as pd

from enfobench.core import Subset
from enfobench.datasets.base import DatasetBase

Metadata = dict[str, Any]


class MetadataSubset(Subset):
"""Metadata subset of the HuggingFace dataset containing all metadata about the meters.
Args:
file_path: The path to the subset file.
"""

def list_unique_ids(self) -> list[str]:
"""Lists all unique ids."""
query = """
SELECT DISTINCT unique_id
FROM read_parquet(?)
"""
conn = duckdb.connect(":memory:")
return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().unique_id.tolist()

def get_by_unique_id(self, unique_id: str) -> Metadata:
"""Returns the metadata for the given unique id.
Args:
unique_id: The unique id of the meter.
"""
query = """
SELECT *
FROM read_parquet(?)
WHERE unique_id = ?
"""
conn = duckdb.connect(":memory:")
df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df()
if df.empty:
msg = f"Unique id '{unique_id}' was not found."
raise KeyError(msg)
return df.iloc[0].to_dict()


class WeatherSubset(Subset):
"""Weather subset of the HuggingFace dataset containing all weather data.
Args:
file_path: The path to the subset file.
"""

def list_location_ids(self) -> list[str]:
"""Lists all location ids."""
query = """
SELECT DISTINCT location_id
FROM read_parquet(?)
"""
conn = duckdb.connect(":memory:")
return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().location_id.tolist()

def get_by_location_id(self, location_id: str, columns: list[str] | None = None) -> pd.DataFrame:
"""Returns the weather data for the given location id.
Args:
location_id: The location id of the weather station.
columns: The columns to return. If None, all columns are returned.
"""
conn = duckdb.connect(":memory:")

if columns:
query = f"""
SELECT timestamp, {", ".join(columns)}
FROM read_parquet(?)
WHERE location_id = ?
""" # noqa: S608
else:
query = """
SELECT *
FROM read_parquet(?)
WHERE location_id = ?
"""
df = conn.execute(query, parameters=[str(self.file_path), location_id]).fetch_df()
if df.empty:
msg = f"Location id '{location_id}' was not found."
raise KeyError(msg)

# Remove location_id and set timestamp as index
df.drop(columns=["location_id"], inplace=True, errors="ignore")
df.set_index("timestamp", inplace=True)
return df


class DemandSubset(Subset):
"""Data subset of the HuggingFace dataset containing all gas demand data.
Args:
file_path: The path to the subset file.
"""

def get_by_unique_id(self, unique_id: str):
"""Returns the demand data for the given unique id.
Args:
unique_id: The unique id of the meter.
"""
query = """
SELECT *
FROM read_parquet(?)
WHERE unique_id = ?
"""
conn = duckdb.connect(":memory:")
df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df()
if df.empty:
msg = f"Unique id '{unique_id}' was not found."
raise KeyError(msg)

# Remove unique_id and set timestamp as index
df.drop(columns=["unique_id"], inplace=True, errors="ignore")
df.set_index("timestamp", inplace=True)
return df


class GasDemandDataset(DatasetBase):
"""GasDemandDataset class representing the HuggingFace dataset.
This class is a collection of all subsets inside HuggingFace dataset.
It provides an easy way to access the different subsets.
Args:
directory: The directory where the HuggingFace dataset is located.
This directory should contain all the subset files.
"""

HUGGINGFACE_DATASET = "EDS-lab/gas-demand"
SUBSETS = ("demand", "metadata", "weather")

@property
def metadata_subset(self) -> MetadataSubset:
"""Returns the metadata subset."""
return MetadataSubset(self._get_subset_path("metadata"))

@property
def weather_subset(self) -> WeatherSubset:
"""Returns the weather subset."""
return WeatherSubset(self._get_subset_path("weather"))

@property
def demand_subset(self) -> DemandSubset:
"""Returns the demand subset."""
return DemandSubset(self._get_subset_path("demand"))

def list_unique_ids(self) -> list[str]:
return self.metadata_subset.list_unique_ids()

def list_location_ids(self) -> list[str]:
return self.weather_subset.list_location_ids()

def get_data_by_unique_id(self, unique_id: str) -> tuple[pd.DataFrame, pd.DataFrame, Metadata]:
metadata = self.metadata_subset.get_by_unique_id(unique_id)
location_id = metadata["location_id"]

demand = self.demand_subset.get_by_unique_id(unique_id)
weather = self.weather_subset.get_by_location_id(location_id)
return demand, weather, metadata
Loading

0 comments on commit 509ae94

Please sign in to comment.