diff --git a/CHANGELOG.md b/CHANGELOG.md index d5021e30d0..7c90850d04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Added hyperparameters controlling the hidden layer sizes for the feature encoders in `TiDEModel`. [#2408](https://github.com/unit8co/darts/issues/2408) by [eschibli](https://github.com/eschibli). - Made README's forecasting model support table more colorblind-friendly. [#2433](https://github.com/unit8co/darts/pull/2433) - Updated the Ray Tune Hyperparameter Optimization example in the [user guide](https://unit8co.github.io/darts/userguide/hyperparameter_optimization.html) to work with the latest `ray` versions (`>=2.31.0`). [#2459](https://github.com/unit8co/darts/pull/2459) by [He Weilin](https://github.com/cnhwl). +- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA` +- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models + - Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor + - Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor **Fixed** diff --git a/darts/logging.py b/darts/logging.py index d52ea7e83c..a459f60170 100644 --- a/darts/logging.py +++ b/darts/logging.py @@ -192,7 +192,9 @@ def __exit__(self, *_): os.close(fd) -def execute_and_suppress_output(function, logger, suppression_threshold_level, *args): +def execute_and_suppress_output( + function, logger, suppression_threshold_level, *args, **kwargs +): """ This function conditionally executes the given function with the given arguments based on whether the current level of 'logger' is below, above or equal to @@ -207,9 +209,9 @@ def execute_and_suppress_output(function, logger, suppression_threshold_level, * """ if logger.level >= suppression_threshold_level: with SuppressStdoutStderr(): - return_value = function(*args) + return_value = function(*args, **kwargs) else: - return_value = function(*args) + return_value = function(*args, **kwargs) return return_value diff --git a/darts/models/forecasting/auto_arima.py b/darts/models/forecasting/auto_arima.py index a4f30e01d4..911c309914 100644 --- a/darts/models/forecasting/auto_arima.py +++ b/darts/models/forecasting/auto_arima.py @@ -95,12 +95,19 @@ def encode_year(idx): def supports_multivariate(self) -> bool: return False - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series self.model.fit( - series.values(), X=future_covariates.values() if future_covariates else None + series.values(), + X=future_covariates.values() if future_covariates else None, + **kwargs, ) return self diff --git a/darts/models/forecasting/croston.py b/darts/models/forecasting/croston.py index 0a5f239728..bddfc95923 100644 --- a/darts/models/forecasting/croston.py +++ b/darts/models/forecasting/croston.py @@ -123,7 +123,12 @@ def encode_year(idx): def supports_multivariate(self) -> bool: return False - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series @@ -135,6 +140,7 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non if future_covariates is not None else None ), + **kwargs, ) return self diff --git a/darts/models/forecasting/exponential_smoothing.py b/darts/models/forecasting/exponential_smoothing.py index 9f0e0495e7..75887e30b5 100644 --- a/darts/models/forecasting/exponential_smoothing.py +++ b/darts/models/forecasting/exponential_smoothing.py @@ -3,7 +3,7 @@ --------------------- """ -from typing import Any, Dict, Optional +from typing import Optional import numpy as np import statsmodels.tsa.holtwinters as hw @@ -24,8 +24,7 @@ def __init__( seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE, seasonal_periods: Optional[int] = None, random_state: int = 0, - kwargs: Optional[Dict[str, Any]] = None, - **fit_kwargs, + **kwargs, ): """Exponential Smoothing @@ -66,11 +65,6 @@ def __init__( :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`. See `the documentation `_. - fit_kwargs - Some optional keyword arguments that will be used to call - :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. - See `the documentation - `_. Examples -------- @@ -96,12 +90,28 @@ def __init__( self.seasonal = seasonal self.infer_seasonal_periods = seasonal_periods is None self.seasonal_periods = seasonal_periods - self.constructor_kwargs = dict() if kwargs is None else kwargs - self.fit_kwargs = fit_kwargs + self.constructor_kwargs = kwargs self.model = None np.random.seed(random_state) - def fit(self, series: TimeSeries): + def fit(self, series: TimeSeries, **kwargs): + """Fit/train the model on the (single) provided series. + + Parameters + ---------- + series + The model will be trained to forecast this time series. + kwargs + Some optional keyword arguments that will be used to call + :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. + See `the documentation + `_. + + Returns + ------- + self + Fitted model. + """ super().fit(series) self._assert_univariate(series) series = self.training_series @@ -128,7 +138,7 @@ def fit(self, series: TimeSeries): dates=series.time_index if series.has_datetime_index else None, **self.constructor_kwargs, ) - hw_results = hw_model.fit(**self.fit_kwargs) + hw_results = hw_model.fit(**kwargs) self.model = hw_results if self.infer_seasonal_periods: diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index 5a5dc7a738..45cc25066c 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -2902,7 +2902,12 @@ class FutureCovariatesLocalForecastingModel(LocalForecastingModel, ABC): All implementations must implement the :func:`_fit()` and :func:`_predict()` methods. """ - def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): """Fit/train the model on the (single) provided series. Optionally, a future covariates series can be provided as well. @@ -2915,6 +2920,8 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None A time series of future-known covariates. This time series will not be forecasted, but can be used by some models as an input. It must contain at least the same time steps/indices as the target `series`. If it is longer than necessary, it will be automatically trimmed. + kwargs + Optional keyword arguments that will be passed to the fit function of the underlying model. Returns ------- @@ -2946,10 +2953,15 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None super().fit(series) - return self._fit(series, future_covariates=future_covariates) + return self._fit(series, future_covariates=future_covariates, **kwargs) @abstractmethod - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): """Fits/trains the model on the provided series. DualCovariatesModels must implement the fit logic in this method. """ diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index 77790c2843..e50ca59c5b 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -203,7 +203,12 @@ def encode_year(idx): # Use 0 as default value self._floor = 0 - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series @@ -249,10 +254,10 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non if self.suppress_stdout_stderr: self._execute_and_suppress_output( - self.model.fit, logger, logging.WARNING, fit_df + self.model.fit, logger, logging.WARNING, fit_df, **kwargs ) else: - self.model.fit(fit_df) + self.model.fit(fit_df, **kwargs) return self diff --git a/darts/models/forecasting/sf_auto_arima.py b/darts/models/forecasting/sf_auto_arima.py index cd8569aede..3302a8941c 100644 --- a/darts/models/forecasting/sf_auto_arima.py +++ b/darts/models/forecasting/sf_auto_arima.py @@ -89,13 +89,19 @@ def encode_year(idx): super().__init__(add_encoders=add_encoders) self.model = SFAutoARIMA(*autoarima_args, **autoarima_kwargs) - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series self.model.fit( series.values(copy=False).flatten(), X=future_covariates.values(copy=False) if future_covariates else None, + **kwargs, ) return self diff --git a/darts/models/forecasting/sf_auto_ets.py b/darts/models/forecasting/sf_auto_ets.py index 95572c42fe..e51aeb60fe 100644 --- a/darts/models/forecasting/sf_auto_ets.py +++ b/darts/models/forecasting/sf_auto_ets.py @@ -95,7 +95,12 @@ def encode_year(idx): self.model = SFAutoETS(*autoets_args, **autoets_kwargs) self._linreg = None - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series @@ -116,9 +121,7 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non else: target = series - self.model.fit( - target.values(copy=False).flatten(), - ) + self.model.fit(target.values(copy=False).flatten(), **kwargs) return self def _predict( diff --git a/darts/tests/models/forecasting/test_exponential_smoothing.py b/darts/tests/models/forecasting/test_exponential_smoothing.py index 45903fa548..d7370ef6ba 100644 --- a/darts/tests/models/forecasting/test_exponential_smoothing.py +++ b/darts/tests/models/forecasting/test_exponential_smoothing.py @@ -53,7 +53,7 @@ def test_constructor_kwargs(self): "initial_trend": 0.2, "initial_seasonal": np.arange(1, 25), } - model = ExponentialSmoothing(kwargs=constructor_kwargs) + model = ExponentialSmoothing(**constructor_kwargs) model.fit(self.series) # must be checked separately, name is not consistent np.testing.assert_array_almost_equal( @@ -70,12 +70,10 @@ def test_fit_kwargs(self): # using default optimization method model = ExponentialSmoothing() model.fit(self.series) - assert model.fit_kwargs == {} pred = model.predict(n=2) model_bis = ExponentialSmoothing() model_bis.fit(self.series) - assert model_bis.fit_kwargs == {} pred_bis = model_bis.predict(n=2) # two methods with the same parameters should yield the same forecasts @@ -83,9 +81,8 @@ def test_fit_kwargs(self): np.testing.assert_array_almost_equal(pred.values(), pred_bis.values()) # change optimization method - model_ls = ExponentialSmoothing(method="least_squares") - model_ls.fit(self.series) - assert model_ls.fit_kwargs == {"method": "least_squares"} + model_ls = ExponentialSmoothing() + model_ls.fit(self.series, method="least_squares") pred_ls = model_ls.predict(n=2) # forecasts should be slightly different diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index b9d0bf5084..58109691f6 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -647,7 +647,7 @@ def test_model_str_call(self, config): ( ExponentialSmoothing(), "ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, " - + "seasonal_periods=None, random_state=0, kwargs=None)", + + "seasonal_periods=None, random_state=0)", ), # no params changed ( ARIMA(1, 1, 1),