diff --git a/models/nixtla-timegpt/src/main.py b/models/nixtla-timegpt/src/main.py index 905f995..f352414 100644 --- a/models/nixtla-timegpt/src/main.py +++ b/models/nixtla-timegpt/src/main.py @@ -44,15 +44,28 @@ def forecast( history = history.fillna(history.y.mean()) # Rate limit forecast requests - self.limiter.try_acquire(name=f"{history.index[-1]}") - # Make request - timegpt_fcst_df = self.client.forecast( - df=history, - h=horizon, - level=level, - model=self.model, - target_col="y", - ) + + # Retry a function call max 5 times + max_retries = 5 + n_tries = 0 + while n_tries < max_retries: + try: + self.limiter.try_acquire(name=f"{history.index[-1]}") + timegpt_fcst_df = self.client.forecast( + df=history, + h=horizon, + level=level, + model=self.model, + target_col="y", + ) + break + except Exception as e: + logger.exception(e) + n_tries += 1 + + if n_tries == max_retries: + msg = f"Could not make forecast after {max_retries} retries." + raise ValueError(msg) # post-process forecast forecast = timegpt_fcst_df.rename(columns={"TimeGPT": 'yhat'}) diff --git a/models/salesforce-moirai/src/main.py b/models/salesforce-moirai/src/main.py index 7d73df9..b9c5cc4 100644 --- a/models/salesforce-moirai/src/main.py +++ b/models/salesforce-moirai/src/main.py @@ -7,7 +7,7 @@ from enfobench import AuthorInfo, ForecasterType, ModelInfo from enfobench.evaluation.server import server_factory -from enfobench.evaluation.utils import create_forecast_index +from enfobench.evaluation.utils import create_forecast_index, periods_in_duration from uni2ts.model.moirai import MoiraiForecast @@ -17,20 +17,22 @@ class SalesForceMoraiModel: - def __init__(self, model_name: str, num_samples: int): + def __init__(self, model_name: str, num_samples: int, ctx_length: str | None = None): self.model_name = model_name self.num_samples = num_samples + self.ctx_length = ctx_length self.size = model_name.split("-")[-1] def info(self) -> ModelInfo: return ModelInfo( - name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}', + name=f'Salesforce.Moirai-1.0-R.{self.size.capitalize()}{f".CTX{self.ctx_length}" if self.ctx_length else ""}', authors=[ AuthorInfo(name="Attila Balint", email="attila.balint@kuleuven.be"), ], type=ForecasterType.quantile, params={ "num_samples": self.num_samples, + "ctx_length": self.ctx_length, }, ) @@ -55,16 +57,22 @@ def forecast( raise FileNotFoundError( f"Model directory for {self.model_name} was not found at {model_dir}, make sure it is downloaded." ) + + if self.ctx_length is None: + ctx_length = len(history) + else: + ctx_length = min(periods_in_duration(history.index, duration=self.ctx_length), len(history)) + # Prepare pre-trained model model = MoiraiForecast.load_from_checkpoint( checkpoint_path=str(model_dir / 'model.ckpt'), prediction_length=horizon, - context_length=len(history), + context_length=ctx_length, patch_size='auto', num_samples=self.num_samples, target_dim=1, - feat_dynamic_real_dim=ds.num_feat_dynamic_real, - past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real, + feat_dynamic_real_dim=0, + past_feat_dynamic_real_dim=0, map_location=device, ) @@ -81,9 +89,10 @@ def forecast( model_name = os.getenv("ENFOBENCH_MODEL_NAME", "small") num_samples = int(os.getenv("ENFOBENCH_NUM_SAMPLES", 1)) +ctx_length = os.getenv("ENFOBENCH_CTX_LENGTH") # Instantiate your model -model = SalesForceMoraiModel(model_name=model_name, num_samples=num_samples) +model = SalesForceMoraiModel(model_name=model_name, num_samples=num_samples, ctx_length=ctx_length) # Create a forecast server by passing in your model app = server_factory(model)