Skip to content

Commit

Permalink
added context length to amazon chronos to speed up inference
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Jun 6, 2024
1 parent 8f25953 commit 1854524
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docker/base/amazon-chronos/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11.6-slim-bookworm
FROM python:3.11-slim-bookworm

RUN apt-get update && apt-get install -y --no-install-recommends git

Expand Down
19 changes: 13 additions & 6 deletions models/amazon-chronos/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,29 @@

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import create_forecast_index
from enfobench.evaluation.utils import create_forecast_index, periods_in_duration

# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
root_dir = Path(__file__).parent.parent


class AmazonChronosModel:

def __init__(self, model_name: str, num_samples: int):
def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = None):
self.model_name = model_name
self.num_samples = num_samples
self.ctx_length = ctx_length

def info(self) -> ModelInfo:
return ModelInfo(
name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}',
name=f'Amazon.{".".join(map(str.capitalize, self.model_name.split("-")))}{".CTX" + self.ctx_length if self.ctx_length else ""}',
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
type=ForecasterType.quantile,
params={
"num_samples": self.num_samples,
"ctx_length": self.ctx_length,
},
)

Expand Down Expand Up @@ -58,7 +59,12 @@ def forecast(

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(history.y)
if self.ctx_length is None:
context = torch.tensor(history.y)
else:
ctx_length = min(periods_in_duration(history.index, duration=self.ctx_length), len(history))
context = torch.tensor(history.y[-ctx_length:])

prediction_length = horizon
forecasts = pipeline.predict(
context,
Expand All @@ -78,9 +84,10 @@ def forecast(

model_name = os.getenv("ENFOBENCH_MODEL_NAME")
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES"))
ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH")

# Instantiate your model
model = AmazonChronosModel(model_name=model_name, num_samples=num_samples)
model = AmazonChronosModel(model_name=model_name, num_samples=num_samples, ctx_length=ctx_length)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit 1854524

Please sign in to comment.