diff --git a/Makefile b/Makefile index a515646..83ea0d0 100644 --- a/Makefile +++ b/Makefile @@ -109,6 +109,20 @@ download-amazon-chronos: models/amazon-chronos/models/chronos-t5-tiny \ models/amazon-chronos/models/chronos-t5-large + +models/salesforce-moirai/models/moirai-1.0-R-small: + git clone https://huggingface.co/salesforce/moirai-1.0-R-small ./models/salesforce-moirai/models/moirai-1.0-R-small + +models/salesforce-moirai/models/moirai-1.0-R-base: + git clone https://huggingface.co/salesforce/moirai-1.0-R-base ./models/salesforce-moirai/models/moirai-1.0-R-base + +models/salesforce-moirai/models/moirai-1.0-R-large: + git clone https://huggingface.co/salesforce/moirai-1.0-R-large ./models/salesforce-moirai/models/moirai-1.0-R-large + +download-salesforce-moirai: models/salesforce-moirai/models/moirai-1.0-R-small \ + models/salesforce-moirai/models/moirai-1.0-R-base \ + models/salesforce-moirai/models/moirai-1.0-R-large + ################################################################################# # MODEL RULES # ################################################################################# @@ -148,11 +162,17 @@ base-image-amazon-chronos: docker build --build-arg CHRONOS_VERSION=$(CHRONOS_VERSION) -t $(DOCKER_HUB_REPOSITORY):base-amazon-chronos-$(CHRONOS_VERSION) ./docker/base/amazon-chronos +.PHONY: base-image-salesforce-moirai +base-image-salesforce-moirai: + docker build -t $(DOCKER_HUB_REPOSITORY):base-salesforce-moirai ./docker/base/salesforce-moirai + + ## Build base images base-images: base-image-darts \ base-image-sktime \ base-image-statsforecast \ - base-image-amazon-chronos + base-image-amazon-chronos \ + base-image-salesforce-moirai .PHONY: push-base-images @@ -162,6 +182,7 @@ push-base-images: docker push $(DOCKER_HUB_REPOSITORY):base-sktime-$(SKTIME_VERSION) docker push $(DOCKER_HUB_REPOSITORY):base-statsforecast-$(STATSFORECAST_VERSION) docker push $(DOCKER_HUB_REPOSITORY):base-amazon-chronos-$(CHRONOS_VERSION) + docker push $(DOCKER_HUB_REPOSITORY):base-salesforce-moirai .PHONY: image diff --git a/docker/base/salesforce-moirai/Dockerfile b/docker/base/salesforce-moirai/Dockerfile new file mode 100644 index 0000000..4e51b12 --- /dev/null +++ b/docker/base/salesforce-moirai/Dockerfile @@ -0,0 +1,8 @@ +FROM python:3.11-slim-bookworm + +RUN apt-get update && apt-get install -y --no-install-recommends git + +WORKDIR /usr/local/app + +RUN pip install --no-cache-dir --upgrade pip +RUN pip install --no-cache-dir git+https://github.com/SalesforceAIResearch/uni2ts.git diff --git a/models/salesforce-moirai/Dockerfile b/models/salesforce-moirai/Dockerfile new file mode 100644 index 0000000..36700b3 --- /dev/null +++ b/models/salesforce-moirai/Dockerfile @@ -0,0 +1,15 @@ +FROM attilabalint/enfobench-models:base-salesforce-moirai + +WORKDIR /usr/local/app + +COPY ./requirements.txt /usr/local/app/requirements.txt +RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt + +COPY ./models /usr/local/app/models +COPY ./src /usr/local/app/src + +ENV ENFOBENCH_MODEL_NAME="moirai-1.0-R-small" +ENV ENFOBENCH_NUM_SAMPLES="1" + +EXPOSE 3000 +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000", "--workers", "1"] diff --git a/models/salesforce-moirai/requirements.txt b/models/salesforce-moirai/requirements.txt new file mode 100644 index 0000000..7ad0e5b --- /dev/null +++ b/models/salesforce-moirai/requirements.txt @@ -0,0 +1,2 @@ + # Chronos repository +enfobench>=0.6.0,<0.7.0 \ No newline at end of file diff --git a/models/salesforce-moirai/src/main.py b/models/salesforce-moirai/src/main.py new file mode 100644 index 0000000..7d73df9 --- /dev/null +++ b/models/salesforce-moirai/src/main.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path + +import pandas as pd +import torch +from gluonts.dataset.pandas import PandasDataset + +from enfobench import AuthorInfo, ForecasterType, ModelInfo +from enfobench.evaluation.server import server_factory +from enfobench.evaluation.utils import create_forecast_index +from uni2ts.model.moirai import MoiraiForecast + + +# Check for GPU availability +device = "cuda" if torch.cuda.is_available() else "cpu" +root_dir = Path(__file__).parent.parent + + +class SalesForceMoraiModel: + def __init__(self, model_name: str, num_samples: int): + self.model_name = model_name + self.num_samples = num_samples + self.size = model_name.split("-")[-1] + + def info(self) -> ModelInfo: + return ModelInfo( + name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}', + authors=[ + AuthorInfo(name="Attila Balint", email="attila.balint@kuleuven.be"), + ], + type=ForecasterType.quantile, + params={ + "num_samples": self.num_samples, + }, + ) + + def forecast( + self, + horizon: int, + history: pd.DataFrame, + past_covariates: pd.DataFrame | None = None, + future_covariates: pd.DataFrame | None = None, + metadata: dict | None = None, + level: list[int] | None = None, + **kwargs, + ) -> pd.DataFrame: + # Fill missing values + history = history.fillna(history.y.mean()) + + # Convert into GluonTS dataset + ds = PandasDataset(dict(history)) + + model_dir = root_dir / "models" / self.model_name + if not model_dir.exists(): + raise FileNotFoundError( + f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." + ) + # Prepare pre-trained model + model = MoiraiForecast.load_from_checkpoint( + checkpoint_path=str(model_dir / 'model.ckpt'), + prediction_length=horizon, + context_length=len(history), + patch_size='auto', + num_samples=self.num_samples, + target_dim=1, + feat_dynamic_real_dim=ds.num_feat_dynamic_real, + past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real, + map_location=device, + ) + + # Make predictions + predictor = model.create_predictor(batch_size=32) + forecasts = next(predictor.predict(ds)) + data = {"yhat": forecasts.mean} # TODO: extend to quantiles + + # Postprocess forecast + index = create_forecast_index(history=history, horizon=horizon) + forecast = pd.DataFrame(index=index, data=data) + return forecast + + +model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small") +num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1)) + +# Instantiate your model +model = SalesForceMoraiModel(model_name=model_name, num_samples=num_samples) + +# Create a forecast server by passing in your model +app = server_factory(model)