Skip to content

Commit

Permalink
added darts.NaiveMovingAverage
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Apr 5, 2024
1 parent ae4a8e6 commit 8c169de
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
13 changes: 13 additions & 0 deletions models/dt-naive-moving-avg/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM attilabalint/enfobench-models:base-u8darts-0.27.2

WORKDIR /usr/local/app

COPY ./requirements.txt /usr/local/app/requirements.txt
RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt

COPY ./src /usr/local/app/src

ENV ENFOBENCH_MODEL_HISTORY="1D"

EXPOSE 3000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"]
1 change: 1 addition & 0 deletions models/dt-naive-moving-avg/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
enfobench>=0.6.0,<0.7.0
61 changes: 61 additions & 0 deletions models/dt-naive-moving-avg/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import pandas as pd
from darts import TimeSeries
from darts.models.forecasting.baselines import NaiveMovingAverage

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import periods_in_duration


class DartsNaiveMovingAverageModel:
def __init__(self, history: str):
self.history = history.upper()

def info(self) -> ModelInfo:
return ModelInfo(
name=f"Darts.NaiveMovingAverage.{self.history}",
authors=[AuthorInfo(name="Attila Balint", email="[email protected]")],
type=ForecasterType.point,
params={
"history": self.history,
},
)

def forecast(
self,
horizon: int,
history: pd.DataFrame,
past_covariates: pd.DataFrame | None = None,
future_covariates: pd.DataFrame | None = None,
metadata: dict | None = None,
**kwargs,
) -> pd.DataFrame:
# Fill missing values
history = history.fillna(history.y.mean())

# Create model
periods = periods_in_duration(history.index, duration=self.history)
model = NaiveMovingAverage(input_chunk_length=periods)

# Fit model
series = TimeSeries.from_dataframe(history, value_cols=["y"])
model.fit(series)

# Make forecast
pred = model.predict(horizon)

# Postprocess forecast
forecast = pred.pd_dataframe().rename(columns={"y": "yhat"}).fillna(history.y.mean())
return forecast


# Load parameters
history = os.getenv("ENFOBENCH_MODEL_HISTORY")

# Instantiate your model
model = DartsNaiveMovingAverageModel(history)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit 8c169de

Please sign in to comment.