Skip to content

Commit

Permalink
first sktime (NaiveForecaster) model
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Dec 1, 2023
1 parent 39bcf4c commit 056e921
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
16 changes: 16 additions & 0 deletions models/st-naive/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
FROM python:3.11.6-slim-bookworm

WORKDIR /usr/local/app

COPY ./requirements.txt /usr/local/app/requirements.txt
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt

# Polars requires AVX2 CPU instructions, which are not available on the server
RUN pip uninstall -y polars
RUN pip install --no-cache-dir polars-lts-cpu

COPY ./src /usr/local/app/src

EXPOSE 3000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"]
2 changes: 2 additions & 0 deletions models/st-naive/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
enfobench>=0.3.5,<0.4.0
sktime==0.24.1
48 changes: 48 additions & 0 deletions models/st-naive/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pandas as pd
from sktime.forecasting.naive import NaiveForecaster

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import create_forecast_index


class NaiveForecasterModel:
def info(self) -> ModelInfo:
return ModelInfo(
name="Sktime.NaiveForecaster.Mean",
authors=[AuthorInfo(name="Attila Balint", email="[email protected]")],
type=ForecasterType.point,
params={
"strategy": "mean",
},
)

def forecast(
self,
horizon: int,
history: pd.DataFrame,
past_covariates: pd.DataFrame | None = None,
future_covariates: pd.DataFrame | None = None,
level: list[int] | None = None,
**kwargs,
) -> pd.DataFrame:
# Fill missing values
y = history.y.fillna(history.y.mean())

# Create model
model = NaiveForecaster(strategy="mean")

# Make forecast
index = create_forecast_index(history=history, horizon=horizon)
pred: pd.Series = model.fit_predict(y, fh=index, **kwargs)

# Postprocess forecast
forecast = pred.to_frame("yhat").fillna(y.mean())
return forecast


# Instantiate your model
model = NaiveForecasterModel()

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit 056e921

Please sign in to comment.