Skip to content

Commit

Permalink
added lightgbm model from darts
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Dec 15, 2023
1 parent 7c187b4 commit f21ac63
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt
COPY ./src /usr/local/app/src

ENV ENFOBENCH_MODEL_SEASONALITY="1D"
ENV ENFOBENCH_MODEL_TYPE="Recursive"

EXPOSE 3000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
enfobench>=0.4.0,<0.5.0
u8darts[notorch]==0.27.0
lightgbm==4.1.0
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Literal

import pandas as pd
from darts import TimeSeries
Expand All @@ -10,15 +11,20 @@


class DartsLightGBMModel:
def __init__(self, seasonality: str):
def __init__(self, seasonality: str, model_type: Literal["DirectMultiModel", "DirectMultiOutput", "Recursive"]):
self.seasonality = seasonality.upper()
self.model_type = model_type

def info(self) -> ModelInfo:
return ModelInfo(
name=f"Darts.LightGBM.Direct.{self.seasonality}",
authors=[AuthorInfo(name="Mohamad Khalil", email="[email protected]")],
name=f"Darts.LightGBM.{self.model_type}.{self.seasonality}",
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
AuthorInfo(name="Mohamad Khalil", email="[email protected]"),
],
type=ForecasterType.point,
params={
"model_type": self.model_type,
"seasonality": self.seasonality,
},
)
Expand All @@ -36,11 +42,26 @@ def forecast(

# Create model
periods = periods_in_duration(history.index, duration=self.seasonality)
model = LightGBMModel(
lags=list(range(-periods, 0)),
output_chunk_length=horizon,
multi_models=False,
)
if self.model_type == "Recursive":
model = LightGBMModel(
lags=list(range(-periods, 0)),
output_chunk_length=1,
)
elif self.model_type == "DirectMultiOutput":
model = LightGBMModel(
lags=list(range(-periods, 0)),
output_chunk_length=horizon,
multi_models=False,
)
elif self.model_type == "DirectMultiModel":
model = LightGBMModel(
lags=list(range(-periods, 0)),
output_chunk_length=horizon,
multi_models=True,
)
else:
msg = f"Unknown model type {self.model_type}"
raise ValueError(msg)

# Fit model
series = TimeSeries.from_dataframe(history, value_cols=["y"])
Expand All @@ -51,15 +72,15 @@ def forecast(

# Postprocess forecast
forecast = pred.pd_dataframe().rename(columns={"y": "yhat"}).fillna(history.y.mean())

return forecast


# Load parameters
seasonality = os.getenv("ENFOBENCH_MODEL_SEASONALITY")
model_type = os.getenv("ENFOBENCH_MODEL_TYPE")

# Instantiate your model
model = DartsLightGBMModel(seasonality=seasonality)
model = DartsLightGBMModel(seasonality=seasonality, model_type=model_type)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit f21ac63

Please sign in to comment.