Skip to content

Commit

Permalink
salesforce moirai models
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Apr 26, 2024
1 parent 47c411a commit e00eb3b
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 1 deletion.
23 changes: 22 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ download-amazon-chronos: models/amazon-chronos/models/chronos-t5-tiny \
models/amazon-chronos/models/chronos-t5-large



models/salesforce-moirai/models/moirai-1.0-R-small:
git clone https://huggingface.co/salesforce/moirai-1.0-R-small ./models/salesforce-moirai/models/moirai-1.0-R-small

models/salesforce-moirai/models/moirai-1.0-R-base:
git clone https://huggingface.co/salesforce/moirai-1.0-R-base ./models/salesforce-moirai/models/moirai-1.0-R-base

models/salesforce-moirai/models/moirai-1.0-R-large:
git clone https://huggingface.co/salesforce/moirai-1.0-R-large ./models/salesforce-moirai/models/moirai-1.0-R-large

download-salesforce-moirai: models/salesforce-moirai/models/moirai-1.0-R-small \
models/salesforce-moirai/models/moirai-1.0-R-base \
models/salesforce-moirai/models/moirai-1.0-R-large

#################################################################################
# MODEL RULES #
#################################################################################
Expand Down Expand Up @@ -148,11 +162,17 @@ base-image-amazon-chronos:
docker build --build-arg CHRONOS_VERSION=$(CHRONOS_VERSION) -t $(DOCKER_HUB_REPOSITORY):base-amazon-chronos-$(CHRONOS_VERSION) ./docker/base/amazon-chronos


.PHONY: base-image-salesforce-moirai
base-image-salesforce-moirai:
docker build -t $(DOCKER_HUB_REPOSITORY):base-salesforce-moirai ./docker/base/salesforce-moirai


## Build base images
base-images: base-image-darts \
base-image-sktime \
base-image-statsforecast \
base-image-amazon-chronos
base-image-amazon-chronos \
base-image-salesforce-moirai


.PHONY: push-base-images
Expand All @@ -162,6 +182,7 @@ push-base-images:
docker push $(DOCKER_HUB_REPOSITORY):base-sktime-$(SKTIME_VERSION)
docker push $(DOCKER_HUB_REPOSITORY):base-statsforecast-$(STATSFORECAST_VERSION)
docker push $(DOCKER_HUB_REPOSITORY):base-amazon-chronos-$(CHRONOS_VERSION)
docker push $(DOCKER_HUB_REPOSITORY):base-salesforce-moirai


.PHONY: image
Expand Down
8 changes: 8 additions & 0 deletions docker/base/salesforce-moirai/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM python:3.11-slim-bookworm

RUN apt-get update && apt-get install -y --no-install-recommends git

WORKDIR /usr/local/app

RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir git+https://github.com/SalesforceAIResearch/uni2ts.git
15 changes: 15 additions & 0 deletions models/salesforce-moirai/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM attilabalint/enfobench-models:base-salesforce-moirai

WORKDIR /usr/local/app

COPY ./requirements.txt /usr/local/app/requirements.txt
RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt

COPY ./models /usr/local/app/models
COPY ./src /usr/local/app/src

ENV ENFOBENCH_MODEL_NAME="moirai-1.0-R-small"
ENV ENFOBENCH_NUM_SAMPLES="1"

EXPOSE 3000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000", "--workers", "1"]
2 changes: 2 additions & 0 deletions models/salesforce-moirai/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Chronos repository
enfobench>=0.6.0,<0.7.0
89 changes: 89 additions & 0 deletions models/salesforce-moirai/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
from pathlib import Path

import pandas as pd
import torch
from gluonts.dataset.pandas import PandasDataset

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import create_forecast_index
from uni2ts.model.moirai import MoiraiForecast


# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
root_dir = Path(__file__).parent.parent


class SalesForceMoraiModel:
def __init__(self, model_name: str, num_samples: int):
self.model_name = model_name
self.num_samples = num_samples
self.size = model_name.split("-")[-1]

def info(self) -> ModelInfo:
return ModelInfo(
name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}',
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
type=ForecasterType.quantile,
params={
"num_samples": self.num_samples,
},
)

def forecast(
self,
horizon: int,
history: pd.DataFrame,
past_covariates: pd.DataFrame | None = None,
future_covariates: pd.DataFrame | None = None,
metadata: dict | None = None,
level: list[int] | None = None,
**kwargs,
) -> pd.DataFrame:
# Fill missing values
history = history.fillna(history.y.mean())

# Convert into GluonTS dataset
ds = PandasDataset(dict(history))

model_dir = root_dir / "models" / self.model_name
if not model_dir.exists():
raise FileNotFoundError(
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
)
# Prepare pre-trained model
model = MoiraiForecast.load_from_checkpoint(
checkpoint_path=str(model_dir / 'model.ckpt'),
prediction_length=horizon,
context_length=len(history),
patch_size='auto',
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
map_location=device,
)

# Make predictions
predictor = model.create_predictor(batch_size=32)
forecasts = next(predictor.predict(ds))
data = {"yhat": forecasts.mean} # TODO: extend to quantiles

# Postprocess forecast
index = create_forecast_index(history=history, horizon=horizon)
forecast = pd.DataFrame(index=index, data=data)
return forecast


model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small")
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1))

# Instantiate your model
model = SalesForceMoraiModel(model_name=model_name, num_samples=num_samples)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit e00eb3b

Please sign in to comment.