-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1854524
commit 509ae94
Showing
10 changed files
with
209 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,8 +21,13 @@ def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = N | |
self.ctx_length = ctx_length | ||
|
||
def info(self) -> ModelInfo: | ||
name = ( | ||
"Amazon." | ||
f'{".".join(map(str.capitalize, self.model_name.split("-")))}' | ||
f'{".CTX" + self.ctx_length if self.ctx_length else ""}' | ||
) | ||
return ModelInfo( | ||
name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}{".CTX" + self.ctx_length if self.ctx_length else ""}', | ||
name=name, | ||
authors=[ | ||
AuthorInfo(name="Attila Balint", email="[email protected]"), | ||
], | ||
|
@@ -48,9 +53,9 @@ def forecast( | |
|
||
model_dir = root_dir / "models" / self.model_name | ||
if not model_dir.exists(): | ||
raise FileNotFoundError( | ||
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." | ||
) | ||
msg = f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." | ||
raise FileNotFoundError(msg) | ||
|
||
pipeline = ChronosPipeline.from_pretrained( | ||
model_dir, | ||
device_map=device, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,11 @@ | |
import pandas as pd | ||
import torch | ||
from gluonts.dataset.pandas import PandasDataset | ||
from uni2ts.model.moirai import MoiraiForecast | ||
|
||
from enfobench import AuthorInfo, ForecasterType, ModelInfo | ||
from enfobench.evaluation.server import server_factory | ||
from enfobench.evaluation.utils import create_forecast_index, periods_in_duration | ||
from uni2ts.model.moirai import MoiraiForecast | ||
|
||
|
||
# Check for GPU availability | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
@@ -24,8 +23,13 @@ def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = N | |
self.size = model_name.split("-")[-1] | ||
|
||
def info(self) -> ModelInfo: | ||
name = ( | ||
"Salesforce.Moirai-1.0-R." | ||
f'{self.size.capitalize()}' | ||
f'{f".CTX{self.ctx_length}" if self.ctx_length else ""}' | ||
) | ||
return ModelInfo( | ||
name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}{f".CTX{self.ctx_length}" if self.ctx_length else ""}', | ||
name=name, | ||
authors=[ | ||
AuthorInfo(name="Attila Balint", email="[email protected]"), | ||
], | ||
|
@@ -54,9 +58,8 @@ def forecast( | |
|
||
model_dir = root_dir / "models" / self.model_name | ||
if not model_dir.exists(): | ||
raise FileNotFoundError( | ||
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." | ||
) | ||
msg = f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." | ||
raise FileNotFoundError(msg) | ||
|
||
if self.ctx_length is None: | ||
ctx_length = len(history) | ||
|
@@ -65,10 +68,10 @@ def forecast( | |
|
||
# Prepare pre-trained model | ||
model = MoiraiForecast.load_from_checkpoint( | ||
checkpoint_path=str(model_dir / 'model.ckpt'), | ||
checkpoint_path=str(model_dir / "model.ckpt"), | ||
prediction_length=horizon, | ||
context_length=ctx_length, | ||
patch_size='auto', | ||
patch_size="auto", | ||
num_samples=self.num_samples, | ||
target_dim=1, | ||
feat_dynamic_real_dim=0, | ||
|
@@ -88,7 +91,7 @@ def forecast( | |
|
||
|
||
model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small") | ||
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1)) | ||
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", "1")) | ||
ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH") | ||
|
||
# Instantiate your model | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.6.0" | ||
__version__ = "0.6.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from enfobench.datasets.electricity_demand import ElectricityDemandDataset | ||
from enfobench.datasets.gas_demand import GasDemandDataset | ||
from enfobench.datasets.pv_generation import PVGenerationDataset | ||
|
||
__all__ = [ | ||
"ElectricityDemandDataset", | ||
"PVGenerationDataset", | ||
"GasDemandDataset", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from typing import Any | ||
|
||
import duckdb | ||
import pandas as pd | ||
|
||
from enfobench.core import Subset | ||
from enfobench.datasets.base import DatasetBase | ||
|
||
Metadata = dict[str, Any] | ||
|
||
|
||
class MetadataSubset(Subset): | ||
"""Metadata subset of the HuggingFace dataset containing all metadata about the meters. | ||
Args: | ||
file_path: The path to the subset file. | ||
""" | ||
|
||
def list_unique_ids(self) -> list[str]: | ||
"""Lists all unique ids.""" | ||
query = """ | ||
SELECT DISTINCT unique_id | ||
FROM read_parquet(?) | ||
""" | ||
conn = duckdb.connect(":memory:") | ||
return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().unique_id.tolist() | ||
|
||
def get_by_unique_id(self, unique_id: str) -> Metadata: | ||
"""Returns the metadata for the given unique id. | ||
Args: | ||
unique_id: The unique id of the meter. | ||
""" | ||
query = """ | ||
SELECT * | ||
FROM read_parquet(?) | ||
WHERE unique_id = ? | ||
""" | ||
conn = duckdb.connect(":memory:") | ||
df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df() | ||
if df.empty: | ||
msg = f"Unique id '{unique_id}' was not found." | ||
raise KeyError(msg) | ||
return df.iloc[0].to_dict() | ||
|
||
|
||
class WeatherSubset(Subset): | ||
"""Weather subset of the HuggingFace dataset containing all weather data. | ||
Args: | ||
file_path: The path to the subset file. | ||
""" | ||
|
||
def list_location_ids(self) -> list[str]: | ||
"""Lists all location ids.""" | ||
query = """ | ||
SELECT DISTINCT location_id | ||
FROM read_parquet(?) | ||
""" | ||
conn = duckdb.connect(":memory:") | ||
return conn.execute(query, parameters=[str(self.file_path)]).fetch_df().location_id.tolist() | ||
|
||
def get_by_location_id(self, location_id: str, columns: list[str] | None = None) -> pd.DataFrame: | ||
"""Returns the weather data for the given location id. | ||
Args: | ||
location_id: The location id of the weather station. | ||
columns: The columns to return. If None, all columns are returned. | ||
""" | ||
conn = duckdb.connect(":memory:") | ||
|
||
if columns: | ||
query = f""" | ||
SELECT timestamp, {", ".join(columns)} | ||
FROM read_parquet(?) | ||
WHERE location_id = ? | ||
""" # noqa: S608 | ||
else: | ||
query = """ | ||
SELECT * | ||
FROM read_parquet(?) | ||
WHERE location_id = ? | ||
""" | ||
df = conn.execute(query, parameters=[str(self.file_path), location_id]).fetch_df() | ||
if df.empty: | ||
msg = f"Location id '{location_id}' was not found." | ||
raise KeyError(msg) | ||
|
||
# Remove location_id and set timestamp as index | ||
df.drop(columns=["location_id"], inplace=True, errors="ignore") | ||
df.set_index("timestamp", inplace=True) | ||
return df | ||
|
||
|
||
class DemandSubset(Subset): | ||
"""Data subset of the HuggingFace dataset containing all gas demand data. | ||
Args: | ||
file_path: The path to the subset file. | ||
""" | ||
|
||
def get_by_unique_id(self, unique_id: str): | ||
"""Returns the demand data for the given unique id. | ||
Args: | ||
unique_id: The unique id of the meter. | ||
""" | ||
query = """ | ||
SELECT * | ||
FROM read_parquet(?) | ||
WHERE unique_id = ? | ||
""" | ||
conn = duckdb.connect(":memory:") | ||
df = conn.execute(query, parameters=[str(self.file_path), unique_id]).fetch_df() | ||
if df.empty: | ||
msg = f"Unique id '{unique_id}' was not found." | ||
raise KeyError(msg) | ||
|
||
# Remove unique_id and set timestamp as index | ||
df.drop(columns=["unique_id"], inplace=True, errors="ignore") | ||
df.set_index("timestamp", inplace=True) | ||
return df | ||
|
||
|
||
class GasDemandDataset(DatasetBase): | ||
"""GasDemandDataset class representing the HuggingFace dataset. | ||
This class is a collection of all subsets inside HuggingFace dataset. | ||
It provides an easy way to access the different subsets. | ||
Args: | ||
directory: The directory where the HuggingFace dataset is located. | ||
This directory should contain all the subset files. | ||
""" | ||
|
||
HUGGINGFACE_DATASET = "EDS-lab/gas-demand" | ||
SUBSETS = ("demand", "metadata", "weather") | ||
|
||
@property | ||
def metadata_subset(self) -> MetadataSubset: | ||
"""Returns the metadata subset.""" | ||
return MetadataSubset(self._get_subset_path("metadata")) | ||
|
||
@property | ||
def weather_subset(self) -> WeatherSubset: | ||
"""Returns the weather subset.""" | ||
return WeatherSubset(self._get_subset_path("weather")) | ||
|
||
@property | ||
def demand_subset(self) -> DemandSubset: | ||
"""Returns the demand subset.""" | ||
return DemandSubset(self._get_subset_path("demand")) | ||
|
||
def list_unique_ids(self) -> list[str]: | ||
return self.metadata_subset.list_unique_ids() | ||
|
||
def list_location_ids(self) -> list[str]: | ||
return self.weather_subset.list_location_ids() | ||
|
||
def get_data_by_unique_id(self, unique_id: str) -> tuple[pd.DataFrame, pd.DataFrame, Metadata]: | ||
metadata = self.metadata_subset.get_by_unique_id(unique_id) | ||
location_id = metadata["location_id"] | ||
|
||
demand = self.demand_subset.get_by_unique_id(unique_id) | ||
weather = self.weather_subset.get_by_location_id(location_id) | ||
return demand, weather, metadata |
Oops, something went wrong.