Skip to content

Commit

Permalink
amazon-chronos v0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Apr 25, 2024
1 parent 8c169de commit 36b594f
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 5 deletions.
27 changes: 22 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# GLOBALS #
#################################################################################

PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
PROJECT_NAME = energy-forecast-benchmark-toolkit
PACKAGE_NAME = enfobench
PYTHON_INTERPRETER ?= python3
Expand All @@ -14,16 +13,16 @@ PYTHON_INTERPRETER ?= python3
## Create python virtual environment
venv/bin/python:
( \
$(PYTHON_INTERPRETER) -m venv $(PROJECT_DIR)/venv; \
source $(PROJECT_DIR)/venv/bin/activate; \
$(PYTHON_INTERPRETER) -m venv ./venv; \
source ./venv/bin/activate; \
pip install --upgrade pip; \
)

.PHONY: install
## Install project dependencies
install: venv/bin/python
(\
source $(PROJECT_DIR)/venv/bin/activate; \
source ./venv/bin/activate; \
pip install -e .; \
)

Expand Down Expand Up @@ -83,6 +82,16 @@ publish-test: build
hatch publish --repo test


#################################################################################
# CLONING RULES #
#################################################################################

.PHONY: download-amazon-chronos
download-amazon-chronos:
git clone https://huggingface.co/amazon/chronos-t5-tiny ./models/amazon-chronos/models/chronos-t5-tiny
git clone https://huggingface.co/amazon/chronos-t5-small ./models/amazon-chronos/models/chronos-t5-small


#################################################################################
# MODEL RULES #
#################################################################################
Expand All @@ -96,6 +105,7 @@ DEFAULT_PORT := 3000
DARTS_VERSION := 0.27.2
SKTIME_VERSION := 0.26.1
STATSFORECAST_VERSION := 1.5.0
CHRONOS_VERSION := 1.1.0


.PHONY: base-image-darts
Expand All @@ -116,10 +126,16 @@ base-image-statsforecast:
docker build --build-arg STATSFORECAST_VERSION=$(STATSFORECAST_VERSION) -t $(DOCKER_HUB_REPOSITORY):base-statsforecast-$(STATSFORECAST_VERSION) ./docker/base/statsforecast


.PHONY: base-image-amazon-chronos
base-image-amazon-chronos:
docker build --build-arg CHRONOS_VERSION=$(CHRONOS_VERSION) -t $(DOCKER_HUB_REPOSITORY):base-amazon-chronos-$(CHRONOS_VERSION) ./docker/base/amazon-chronos


## Build base images
base-images: base-image-darts \
base-image-sktime \
base-image-statsforecast
base-image-statsforecast \
base-image-amazon-chronos


.PHONY: push-base-images
Expand All @@ -128,6 +144,7 @@ push-base-images:
docker push $(DOCKER_HUB_REPOSITORY):base-u8darts-$(DARTS_VERSION)
docker push $(DOCKER_HUB_REPOSITORY):base-sktime-$(SKTIME_VERSION)
docker push $(DOCKER_HUB_REPOSITORY):base-statsforecast-$(STATSFORECAST_VERSION)
docker push $(DOCKER_HUB_REPOSITORY):base-amazon-chronos-$(CHRONOS_VERSION)


.PHONY: image
Expand Down
10 changes: 10 additions & 0 deletions docker/base/amazon-chronos/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.11.6-slim-bookworm

RUN apt-get update && apt-get install -y --no-install-recommends git

WORKDIR /usr/local/app

ARG CHRONOS_VERSION

RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir git+https://github.com/amazon-science/chronos-forecasting.git@v$CHRONOS_VERSION
15 changes: 15 additions & 0 deletions models/amazon-chronos/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM attilabalint/enfobench-models:base-amazon-chronos-1.1.0

WORKDIR /usr/local/app

COPY ./requirements.txt /usr/local/app/requirements.txt
RUN pip install --no-cache-dir -r /usr/local/app/requirements.txt

COPY ./models /usr/local/app/models
COPY ./src /usr/local/app/src

ENV MODEL_NAME="chronos-t5-tiny"
ENV NUM_SAMPLES="20"

EXPOSE 3000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3000"]
2 changes: 2 additions & 0 deletions models/amazon-chronos/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Chronos repository
enfobench>=0.6.0,<0.7.0
Empty file.
86 changes: 86 additions & 0 deletions models/amazon-chronos/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from pathlib import Path

import pandas as pd
import torch
from chronos import ChronosPipeline

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import create_forecast_index

# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
root_dir = Path(__file__).parent.parent


class AmazonChronosModel:

def __init__(self, model_name: str, num_samples: int):
self.model_name = model_name
self.num_samples = num_samples

def info(self) -> ModelInfo:
return ModelInfo(
name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}',
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
type=ForecasterType.quantile,
params={
"num_samples": self.num_samples,
},
)

def forecast(
self,
horizon: int,
history: pd.DataFrame,
past_covariates: pd.DataFrame | None = None,
future_covariates: pd.DataFrame | None = None,
metadata: dict | None = None,
level: list[int] | None = None,
**kwargs,
) -> pd.DataFrame:
# Fill missing values
history = history.fillna(history.y.mean())

model_dir = root_dir / "models" / self.model_name
if not model_dir.exists():
raise FileNotFoundError(
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
)
pipeline = ChronosPipeline.from_pretrained(
model_dir,
device_map=device,
torch_dtype=torch.bfloat16,
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(history.y)
prediction_length = horizon
forecasts = pipeline.predict(
context,
prediction_length,
num_samples=self.num_samples,
limit_prediction_length=False,
) # forecast shape: [num_series, num_samples, prediction_length]
data = {"yhat": forecasts.mean(dim=1)[0]}
# for lvl in level:
# data[f"q{lvl}"] = forecasts.quantile(lvl / 100, dim=1)[0] # TODO: extend to quantiles

# Postprocess forecast
index = create_forecast_index(history=history, horizon=horizon)
forecast = pd.DataFrame(index=index, data=data)
return forecast


model_name = os.getenv("MODEL_NAME")
num_samples = int(os.getenv("NUM_SAMPLES"))

# Instantiate your model
model = AmazonChronosModel(model_name=model_name, num_samples=num_samples)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit 36b594f

Please sign in to comment.