-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ae4a8e6
commit 8c169de
Showing
3 changed files
with
75 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
FROM attilabalint/enfobench-models:base-u8darts-0.27.2 | ||
|
||
WORKDIR /usr/local/app | ||
|
||
COPY ./requirements.txt /usr/local/app/requirements.txt | ||
RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt | ||
|
||
COPY ./src /usr/local/app/src | ||
|
||
ENV ENFOBENCH_MODEL_HISTORY="1D" | ||
|
||
EXPOSE 3000 | ||
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
enfobench>=0.6.0,<0.7.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
|
||
import pandas as pd | ||
from darts import TimeSeries | ||
from darts.models.forecasting.baselines import NaiveMovingAverage | ||
|
||
from enfobench import AuthorInfo, ForecasterType, ModelInfo | ||
from enfobench.evaluation.server import server_factory | ||
from enfobench.evaluation.utils import periods_in_duration | ||
|
||
|
||
class DartsNaiveMovingAverageModel: | ||
def __init__(self, history: str): | ||
self.history = history.upper() | ||
|
||
def info(self) -> ModelInfo: | ||
return ModelInfo( | ||
name=f"Darts.NaiveMovingAverage.{self.history}", | ||
authors=[AuthorInfo(name="Attila Balint", email="[email protected]")], | ||
type=ForecasterType.point, | ||
params={ | ||
"history": self.history, | ||
}, | ||
) | ||
|
||
def forecast( | ||
self, | ||
horizon: int, | ||
history: pd.DataFrame, | ||
past_covariates: pd.DataFrame | None = None, | ||
future_covariates: pd.DataFrame | None = None, | ||
metadata: dict | None = None, | ||
**kwargs, | ||
) -> pd.DataFrame: | ||
# Fill missing values | ||
history = history.fillna(history.y.mean()) | ||
|
||
# Create model | ||
periods = periods_in_duration(history.index, duration=self.history) | ||
model = NaiveMovingAverage(input_chunk_length=periods) | ||
|
||
# Fit model | ||
series = TimeSeries.from_dataframe(history, value_cols=["y"]) | ||
model.fit(series) | ||
|
||
# Make forecast | ||
pred = model.predict(horizon) | ||
|
||
# Postprocess forecast | ||
forecast = pred.pd_dataframe().rename(columns={"y": "yhat"}).fillna(history.y.mean()) | ||
return forecast | ||
|
||
|
||
# Load parameters | ||
history = os.getenv("ENFOBENCH_MODEL_HISTORY") | ||
|
||
# Instantiate your model | ||
model = DartsNaiveMovingAverageModel(history) | ||
|
||
# Create a forecast server by passing in your model | ||
app = server_factory(model) |