Skip to content

Commit

Permalink
updated LR model to support recursive and direct models
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Dec 15, 2023
1 parent 7c8f9dd commit 7c187b4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 14 deletions.
12 changes: 11 additions & 1 deletion models/dt-four-theta/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,23 @@


class DartsFourThetaModel:
"""FourTheta model from Darts.
Args:
seasonality: The seasonality of the time series. E.g. "1D" for daily seasonality.
References:
https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html
"""
def __init__(self, seasonality: str):
self.seasonality = seasonality.upper()

def info(self) -> ModelInfo:
return ModelInfo(
name=f"Darts.FourTheta.{self.seasonality}.SM-A",
authors=[AuthorInfo(name="Attila Balint", email="[email protected]")],
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
type=ForecasterType.point,
params={
"seasonality": self.seasonality,
Expand Down
2 changes: 1 addition & 1 deletion models/dt-linear-regression/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt
COPY ./src /usr/local/app/src

ENV ENFOBENCH_MODEL_SEASONALITY="1D"
ENV ENFOBENCH_MODEL_DIRECT="0"
ENV ENFOBENCH_MODEL_TYPE="DirectMultiModel"

EXPOSE 3000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"]
43 changes: 31 additions & 12 deletions models/dt-linear-regression/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,25 @@
from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import periods_in_duration
from typing import Literal


class DartsLinearRegressionModel:
def __init__(self, seasonality: str, direct: bool):
def __init__(self, seasonality: str, model_type: Literal['DirectMultiModel', 'DirectMultiOutput', 'Recursive']):
self.seasonality = seasonality.upper()
self.direct = direct
self.model_type = model_type

def info(self) -> ModelInfo:
return ModelInfo(
name=f"Darts.LinearRegression.{'Direct.' if self.direct else ''}{self.seasonality}",
name=f"Darts.LinearRegression.{self.model_type}.{self.seasonality}",
authors=[
AuthorInfo(name="Mohamad Khalil", email="[email protected]"),
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
type=ForecasterType.point,
params={
"model_type": self.model_type,
"seasonality": self.seasonality,
"direct": self.direct,
},
)

Expand All @@ -42,12 +43,30 @@ def forecast(

# Create model
periods = periods_in_duration(history.index, duration=self.seasonality)
model = RegressionModel(
lags=list(range(-periods, 0)),
output_chunk_length=horizon,
model=LinearRegression(),
multi_models=not self.direct,
)
if self.model_type == 'Recursive':
model = RegressionModel(
model=LinearRegression(),
lags=list(range(-periods, 0)),
output_chunk_length=1,
multi_models=False,
)
elif self.model_type == 'DirectMultiOutput':
model = RegressionModel(
model=LinearRegression(),
lags=list(range(-periods, 0)),
output_chunk_length=horizon,
multi_models=False,
)
elif self.model_type == 'DirectMultiModel':
model = RegressionModel(
model=LinearRegression(),
lags=list(range(-periods, 0)),
output_chunk_length=horizon,
multi_models=True,
)
else:
msg = f"Unknown model type: {self.model_type}"
raise ValueError(msg)

# Fit model
series = TimeSeries.from_dataframe(history, value_cols=["y"])
Expand All @@ -63,10 +82,10 @@ def forecast(

# Load parameters
seasonality = os.getenv("ENFOBENCH_MODEL_SEASONALITY")
direct = bool(int(os.getenv("ENFOBENCH_MODEL_DIRECT")))
model_type = os.getenv("ENFOBENCH_MODEL_TYPE")

# Instantiate your model
model = DartsLinearRegressionModel(seasonality=seasonality, direct=direct)
model = DartsLinearRegressionModel(seasonality=seasonality, model_type=model_type)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit 7c187b4

Please sign in to comment.