From 1854524c224d5e1620c6fda966c37bebfc06a488 Mon Sep 17 00:00:00 2001 From: attilabalint Date: Thu, 6 Jun 2024 19:31:13 +0200 Subject: [PATCH] added context length to amazon chronos to speed up inference --- docker/base/amazon-chronos/Dockerfile | 2 +- models/amazon-chronos/src/main.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/docker/base/amazon-chronos/Dockerfile b/docker/base/amazon-chronos/Dockerfile index 716fe69..b950873 100644 --- a/docker/base/amazon-chronos/Dockerfile +++ b/docker/base/amazon-chronos/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11.6-slim-bookworm +FROM python:3.11-slim-bookworm RUN apt-get update && apt-get install -y --no-install-recommends git diff --git a/models/amazon-chronos/src/main.py b/models/amazon-chronos/src/main.py index 5c259cd..4215334 100644 --- a/models/amazon-chronos/src/main.py +++ b/models/amazon-chronos/src/main.py @@ -7,7 +7,7 @@ from enfobench import AuthorInfo, ForecasterType, ModelInfo from enfobench.evaluation.server import server_factory -from enfobench.evaluation.utils import create_forecast_index +from enfobench.evaluation.utils import create_forecast_index, periods_in_duration # Check for GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" @@ -15,20 +15,21 @@ class AmazonChronosModel: - - def __init__(self, model_name: str, num_samples: int): + def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = None): self.model_name = model_name self.num_samples = num_samples + self.ctx_length = ctx_length def info(self) -> ModelInfo: return ModelInfo( - name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}', + name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}{".CTX" + self.ctx_length if self.ctx_length else ""}', authors=[ AuthorInfo(name="Attila Balint", email="attila.balint@kuleuven.be"), ], type=ForecasterType.quantile, params={ "num_samples": self.num_samples, + "ctx_length": self.ctx_length, }, ) @@ -58,7 +59,12 @@ def forecast( # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension - context = torch.tensor(history.y) + if self.ctx_length is None: + context = torch.tensor(history.y) + else: + ctx_length = min(periods_in_duration(history.index, duration=self.ctx_length), len(history)) + context = torch.tensor(history.y[-ctx_length:]) + prediction_length = horizon forecasts = pipeline.predict( context, @@ -78,9 +84,10 @@ def forecast( model_name = os.getenv("ENFOBENCH_MODEL_NAME") num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES")) +ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH") # Instantiate your model -model = AmazonChronosModel(model_name=model_name, num_samples=num_samples) +model = AmazonChronosModel(model_name=model_name, num_samples=num_samples, ctx_length=ctx_length) # Create a forecast server by passing in your model app = server_factory(model)