From f21ac6346e177c93cebb6a421d1d80727fd81b0d Mon Sep 17 00:00:00 2001 From: Attila Balint Date: Fri, 15 Dec 2023 10:27:07 +0100 Subject: [PATCH] added lightgbm model from darts --- .../Dockerfile | 1 + .../requirements.txt | 2 +- .../src/main.py | 41 ++++++++++++++----- 3 files changed, 33 insertions(+), 11 deletions(-) rename models/{dt-lightgbm-direct => dt-lightgbm}/Dockerfile (92%) rename models/{dt-lightgbm-direct => dt-lightgbm}/requirements.txt (50%) rename models/{dt-lightgbm-direct => dt-lightgbm}/src/main.py (51%) diff --git a/models/dt-lightgbm-direct/Dockerfile b/models/dt-lightgbm/Dockerfile similarity index 92% rename from models/dt-lightgbm-direct/Dockerfile rename to models/dt-lightgbm/Dockerfile index 5968e15..4bcdb7a 100644 --- a/models/dt-lightgbm-direct/Dockerfile +++ b/models/dt-lightgbm/Dockerfile @@ -12,6 +12,7 @@ RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt COPY ./src /usr/local/app/src ENV ENFOBENCH_MODEL_SEASONALITY="1D" +ENV ENFOBENCH_MODEL_TYPE="Recursive" EXPOSE 3000 CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"] diff --git a/models/dt-lightgbm-direct/requirements.txt b/models/dt-lightgbm/requirements.txt similarity index 50% rename from models/dt-lightgbm-direct/requirements.txt rename to models/dt-lightgbm/requirements.txt index 9febc19..a79bf86 100644 --- a/models/dt-lightgbm-direct/requirements.txt +++ b/models/dt-lightgbm/requirements.txt @@ -1,2 +1,2 @@ enfobench>=0.4.0,<0.5.0 -u8darts[notorch]==0.27.0 \ No newline at end of file +lightgbm==4.1.0 \ No newline at end of file diff --git a/models/dt-lightgbm-direct/src/main.py b/models/dt-lightgbm/src/main.py similarity index 51% rename from models/dt-lightgbm-direct/src/main.py rename to models/dt-lightgbm/src/main.py index 1818d08..e523765 100644 --- a/models/dt-lightgbm-direct/src/main.py +++ b/models/dt-lightgbm/src/main.py @@ -1,4 +1,5 @@ import os +from typing import Literal import pandas as pd from darts import TimeSeries @@ -10,15 +11,20 @@ class DartsLightGBMModel: - def __init__(self, seasonality: str): + def __init__(self, seasonality: str, model_type: Literal["DirectMultiModel", "DirectMultiOutput", "Recursive"]): self.seasonality = seasonality.upper() + self.model_type = model_type def info(self) -> ModelInfo: return ModelInfo( - name=f"Darts.LightGBM.Direct.{self.seasonality}", - authors=[AuthorInfo(name="Mohamad Khalil", email="coo17619@newcastle.ac.uk")], + name=f"Darts.LightGBM.{self.model_type}.{self.seasonality}", + authors=[ + AuthorInfo(name="Attila Balint", email="attila.balint@kuleuven.be"), + AuthorInfo(name="Mohamad Khalil", email="coo17619@newcastle.ac.uk"), + ], type=ForecasterType.point, params={ + "model_type": self.model_type, "seasonality": self.seasonality, }, ) @@ -36,11 +42,26 @@ def forecast( # Create model periods = periods_in_duration(history.index, duration=self.seasonality) - model = LightGBMModel( - lags=list(range(-periods, 0)), - output_chunk_length=horizon, - multi_models=False, - ) + if self.model_type == "Recursive": + model = LightGBMModel( + lags=list(range(-periods, 0)), + output_chunk_length=1, + ) + elif self.model_type == "DirectMultiOutput": + model = LightGBMModel( + lags=list(range(-periods, 0)), + output_chunk_length=horizon, + multi_models=False, + ) + elif self.model_type == "DirectMultiModel": + model = LightGBMModel( + lags=list(range(-periods, 0)), + output_chunk_length=horizon, + multi_models=True, + ) + else: + msg = f"Unknown model type {self.model_type}" + raise ValueError(msg) # Fit model series = TimeSeries.from_dataframe(history, value_cols=["y"]) @@ -51,15 +72,15 @@ def forecast( # Postprocess forecast forecast = pred.pd_dataframe().rename(columns={"y": "yhat"}).fillna(history.y.mean()) - return forecast # Load parameters seasonality = os.getenv("ENFOBENCH_MODEL_SEASONALITY") +model_type = os.getenv("ENFOBENCH_MODEL_TYPE") # Instantiate your model -model = DartsLightGBMModel(seasonality=seasonality) +model = DartsLightGBMModel(seasonality=seasonality, model_type=model_type) # Create a forecast server by passing in your model app = server_factory(model)