Skip to content

Commit

Permalink
added context length to moirai to speed up inference
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Jun 5, 2024
1 parent e65e739 commit 8f25953
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
31 changes: 22 additions & 9 deletions models/nixtla-timegpt/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,28 @@ def forecast(
history = history.fillna(history.y.mean())

# Rate limit forecast requests
self.limiter.try_acquire(name=f"{history.index[-1]}")
# Make request
timegpt_fcst_df = self.client.forecast(
df=history,
h=horizon,
level=level,
model=self.model,
target_col="y",
)

# Retry a function call max 5 times
max_retries = 5
n_tries = 0
while n_tries < max_retries:
try:
self.limiter.try_acquire(name=f"{history.index[-1]}")
timegpt_fcst_df = self.client.forecast(
df=history,
h=horizon,
level=level,
model=self.model,
target_col="y",
)
break
except Exception as e:
logger.exception(e)
n_tries += 1

if n_tries == max_retries:
msg = f"Could not make forecast after {max_retries} retries."
raise ValueError(msg)

# post-process forecast
forecast = timegpt_fcst_df.rename(columns={"TimeGPT": 'yhat'})
Expand Down
23 changes: 16 additions & 7 deletions models/salesforce-moirai/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from enfobench import AuthorInfo, ForecasterType, ModelInfo
from enfobench.evaluation.server import server_factory
from enfobench.evaluation.utils import create_forecast_index
from enfobench.evaluation.utils import create_forecast_index, periods_in_duration
from uni2ts.model.moirai import MoiraiForecast


Expand All @@ -17,20 +17,22 @@


class SalesForceMoraiModel:
def __init__(self, model_name: str, num_samples: int):
def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = None):
self.model_name = model_name
self.num_samples = num_samples
self.ctx_length = ctx_length
self.size = model_name.split("-")[-1]

def info(self) -> ModelInfo:
return ModelInfo(
name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}',
name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}{f".CTX{self.ctx_length}" if self.ctx_length else ""}',
authors=[
AuthorInfo(name="Attila Balint", email="[email protected]"),
],
type=ForecasterType.quantile,
params={
"num_samples": self.num_samples,
"ctx_length": self.ctx_length,
},
)

Expand All @@ -55,16 +57,22 @@ def forecast(
raise FileNotFoundError(
f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded."
)

if self.ctx_length is None:
ctx_length = len(history)
else:
ctx_length = min(periods_in_duration(history.index, duration=self.ctx_length), len(history))

# Prepare pre-trained model
model = MoiraiForecast.load_from_checkpoint(
checkpoint_path=str(model_dir / 'model.ckpt'),
prediction_length=horizon,
context_length=len(history),
context_length=ctx_length,
patch_size='auto',
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
map_location=device,
)

Expand All @@ -81,9 +89,10 @@ def forecast(

model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small")
num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1))
ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH")

# Instantiate your model
model = SalesForceMoraiModel(model_name=model_name, num_samples=num_samples)
model = SalesForceMoraiModel(model_name=model_name, num_samples=num_samples, ctx_length=ctx_length)

# Create a forecast server by passing in your model
app = server_factory(model)

0 comments on commit 8f25953

Please sign in to comment.