From 52ec07d6f8d779324790377a95d07253e0a6827d Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Tue, 16 Jul 2024 08:37:18 +0200 Subject: [PATCH 1/7] Pass kwargs in FutureCovariatesLocalForecastingModels where meaningful (cherry picked from commit e80fe3fae4617033a6a4cde77afcd40c3072db33) --- darts/logging.py | 8 ++- darts/models/forecasting/auto_arima.py | 11 ++- darts/models/forecasting/croston.py | 4 +- darts/models/forecasting/forecasting_model.py | 12 +++- darts/models/forecasting/prophet_model.py | 67 ++++++++++++------- darts/models/forecasting/sf_auto_arima.py | 4 +- darts/models/forecasting/sf_auto_ets.py | 4 +- 7 files changed, 74 insertions(+), 36 deletions(-) diff --git a/darts/logging.py b/darts/logging.py index d52ea7e83c..a459f60170 100644 --- a/darts/logging.py +++ b/darts/logging.py @@ -192,7 +192,9 @@ def __exit__(self, *_): os.close(fd) -def execute_and_suppress_output(function, logger, suppression_threshold_level, *args): +def execute_and_suppress_output( + function, logger, suppression_threshold_level, *args, **kwargs +): """ This function conditionally executes the given function with the given arguments based on whether the current level of 'logger' is below, above or equal to @@ -207,9 +209,9 @@ def execute_and_suppress_output(function, logger, suppression_threshold_level, * """ if logger.level >= suppression_threshold_level: with SuppressStdoutStderr(): - return_value = function(*args) + return_value = function(*args, **kwargs) else: - return_value = function(*args) + return_value = function(*args, **kwargs) return return_value diff --git a/darts/models/forecasting/auto_arima.py b/darts/models/forecasting/auto_arima.py index a4f30e01d4..315a38793e 100644 --- a/darts/models/forecasting/auto_arima.py +++ b/darts/models/forecasting/auto_arima.py @@ -95,12 +95,19 @@ def encode_year(idx): def supports_multivariate(self) -> bool: return False - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **fit_kwargs + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series self.model.fit( - series.values(), X=future_covariates.values() if future_covariates else None + series.values(), + X=future_covariates.values() if future_covariates else None, + **fit_kwargs ) return self diff --git a/darts/models/forecasting/croston.py b/darts/models/forecasting/croston.py index 0a5f239728..df2e1758fe 100644 --- a/darts/models/forecasting/croston.py +++ b/darts/models/forecasting/croston.py @@ -123,7 +123,9 @@ def encode_year(idx): def supports_multivariate(self) -> bool: return False - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_ + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index 5a5dc7a738..321f68f51d 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -2902,7 +2902,12 @@ class FutureCovariatesLocalForecastingModel(LocalForecastingModel, ABC): All implementations must implement the :func:`_fit()` and :func:`_predict()` methods. """ - def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **fit_kwargs, + ): """Fit/train the model on the (single) provided series. Optionally, a future covariates series can be provided as well. @@ -2915,6 +2920,9 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None A time series of future-known covariates. This time series will not be forecasted, but can be used by some models as an input. It must contain at least the same time steps/indices as the target `series`. If it is longer than necessary, it will be automatically trimmed. + fit_kwargs + Optional keyword arguments that will be passed to the fit function of the underlying model if supported + by the underlying model. Returns ------- @@ -2946,7 +2954,7 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None super().fit(series) - return self._fit(series, future_covariates=future_covariates) + return self._fit(series, future_covariates=future_covariates, **fit_kwargs) @abstractmethod def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index 77790c2843..e3ff2a8a5f 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -203,7 +203,12 @@ def encode_year(idx): # Use 0 as default value self._floor = 0 - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **fit_kwargs, + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series @@ -249,10 +254,10 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non if self.suppress_stdout_stderr: self._execute_and_suppress_output( - self.model.fit, logger, logging.WARNING, fit_df + self.model.fit, logger, logging.WARNING, fit_df, **fit_kwargs ) else: - self.model.fit(fit_df) + self.model.fit(fit_df, **fit_kwargs) return self @@ -348,11 +353,13 @@ def _check_seasonality_conditions( condition_name = attributes["condition_name"] if condition_name is not None: if condition_name not in future_covariates_columns: - invalid_conditional_seasonalities.append(( - seasonality_name, - condition_name, - "column missing", - )) + invalid_conditional_seasonalities.append( + ( + seasonality_name, + condition_name, + "column missing", + ) + ) continue if ( not future_covariates[condition_name] @@ -360,11 +367,13 @@ def _check_seasonality_conditions( .isin([True, False]) .all() ): - invalid_conditional_seasonalities.append(( - seasonality_name, - condition_name, - "invalid values", - )) + invalid_conditional_seasonalities.append( + ( + seasonality_name, + condition_name, + "invalid values", + ) + ) continue conditional_seasonality_covariates.append(condition_name) @@ -597,19 +606,23 @@ def _freq_to_days(freq: str) -> float: seconds_per_day = 86400 days = 0 - if freq in ["A", "BA", "Y", "BY", "RE"] or freq.startswith(( - "A", - "BA", - "Y", - "BY", - "RE", - )): # year + if freq in ["A", "BA", "Y", "BY", "RE"] or freq.startswith( + ( + "A", + "BA", + "Y", + "BY", + "RE", + ) + ): # year days = 365.25 - elif freq in ["Q", "BQ", "REQ"] or freq.startswith(( - "Q", - "BQ", - "REQ", - )): # quarter + elif freq in ["Q", "BQ", "REQ"] or freq.startswith( + ( + "Q", + "BQ", + "REQ", + ) + ): # quarter days = 3 * 30.4375 elif freq in [ "M", @@ -618,7 +631,9 @@ def _freq_to_days(freq: str) -> float: "SM", "LWOM", "WOM", - ] or freq.startswith(("M", "BME", "BS", "CBM", "SM", "LWOM", "WOM")): # month + ] or freq.startswith( + ("M", "BME", "BS", "CBM", "SM", "LWOM", "WOM") + ): # month days = 30.4375 elif freq == "W" or freq.startswith("W-"): # week days = 7.0 diff --git a/darts/models/forecasting/sf_auto_arima.py b/darts/models/forecasting/sf_auto_arima.py index cd8569aede..b376bd2739 100644 --- a/darts/models/forecasting/sf_auto_arima.py +++ b/darts/models/forecasting/sf_auto_arima.py @@ -89,7 +89,9 @@ def encode_year(idx): super().__init__(add_encoders=add_encoders) self.model = SFAutoARIMA(*autoarima_args, **autoarima_kwargs) - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_ + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series diff --git a/darts/models/forecasting/sf_auto_ets.py b/darts/models/forecasting/sf_auto_ets.py index 95572c42fe..5a0c9a3302 100644 --- a/darts/models/forecasting/sf_auto_ets.py +++ b/darts/models/forecasting/sf_auto_ets.py @@ -95,7 +95,9 @@ def encode_year(idx): self.model = SFAutoETS(*autoets_args, **autoets_kwargs) self._linreg = None - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_ + ): super()._fit(series, future_covariates) self._assert_univariate(series) series = self.training_series From ac6283d0cd9472510fdb3fce7db72bcc7d245042 Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Tue, 16 Jul 2024 09:13:52 +0200 Subject: [PATCH 2/7] Adapt ExponentialSmoothing __init__ and fit signatures (cherry picked from commit 0a7b9fe8a1dc78fb9ef14a87cbd5e152e109eb2e) --- .../forecasting/exponential_smoothing.py | 34 ++++++++++++------- .../forecasting/test_exponential_smoothing.py | 9 ++--- .../test_local_forecasting_models.py | 14 ++++---- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/darts/models/forecasting/exponential_smoothing.py b/darts/models/forecasting/exponential_smoothing.py index 9f0e0495e7..ff39e8dcc6 100644 --- a/darts/models/forecasting/exponential_smoothing.py +++ b/darts/models/forecasting/exponential_smoothing.py @@ -3,7 +3,7 @@ --------------------- """ -from typing import Any, Dict, Optional +from typing import Optional import numpy as np import statsmodels.tsa.holtwinters as hw @@ -24,8 +24,7 @@ def __init__( seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE, seasonal_periods: Optional[int] = None, random_state: int = 0, - kwargs: Optional[Dict[str, Any]] = None, - **fit_kwargs, + **kwargs, ): """Exponential Smoothing @@ -66,11 +65,6 @@ def __init__( :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`. See `the documentation `_. - fit_kwargs - Some optional keyword arguments that will be used to call - :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. - See `the documentation - `_. Examples -------- @@ -96,12 +90,28 @@ def __init__( self.seasonal = seasonal self.infer_seasonal_periods = seasonal_periods is None self.seasonal_periods = seasonal_periods - self.constructor_kwargs = dict() if kwargs is None else kwargs - self.fit_kwargs = fit_kwargs + self.constructor_kwargs = kwargs self.model = None np.random.seed(random_state) - def fit(self, series: TimeSeries): + def fit(self, series: TimeSeries, **fit_kwargs): + """Fit/train the model on the (single) provided series. + + Parameters + ---------- + series + The model will be trained to forecast this time series. + fit_kwargs + Some optional keyword arguments that will be used to call + :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. + See `the documentation + `_. + + Returns + ------- + self + Fitted model. + """ super().fit(series) self._assert_univariate(series) series = self.training_series @@ -128,7 +138,7 @@ def fit(self, series: TimeSeries): dates=series.time_index if series.has_datetime_index else None, **self.constructor_kwargs, ) - hw_results = hw_model.fit(**self.fit_kwargs) + hw_results = hw_model.fit(**fit_kwargs) self.model = hw_results if self.infer_seasonal_periods: diff --git a/darts/tests/models/forecasting/test_exponential_smoothing.py b/darts/tests/models/forecasting/test_exponential_smoothing.py index 45903fa548..d7370ef6ba 100644 --- a/darts/tests/models/forecasting/test_exponential_smoothing.py +++ b/darts/tests/models/forecasting/test_exponential_smoothing.py @@ -53,7 +53,7 @@ def test_constructor_kwargs(self): "initial_trend": 0.2, "initial_seasonal": np.arange(1, 25), } - model = ExponentialSmoothing(kwargs=constructor_kwargs) + model = ExponentialSmoothing(**constructor_kwargs) model.fit(self.series) # must be checked separately, name is not consistent np.testing.assert_array_almost_equal( @@ -70,12 +70,10 @@ def test_fit_kwargs(self): # using default optimization method model = ExponentialSmoothing() model.fit(self.series) - assert model.fit_kwargs == {} pred = model.predict(n=2) model_bis = ExponentialSmoothing() model_bis.fit(self.series) - assert model_bis.fit_kwargs == {} pred_bis = model_bis.predict(n=2) # two methods with the same parameters should yield the same forecasts @@ -83,9 +81,8 @@ def test_fit_kwargs(self): np.testing.assert_array_almost_equal(pred.values(), pred_bis.values()) # change optimization method - model_ls = ExponentialSmoothing(method="least_squares") - model_ls.fit(self.series) - assert model_ls.fit_kwargs == {"method": "least_squares"} + model_ls = ExponentialSmoothing() + model_ls.fit(self.series, method="least_squares") pred_ls = model_ls.predict(n=2) # forecasts should be slightly different diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index b9d0bf5084..5b2315ed7d 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -164,11 +164,13 @@ def test_save_load_model(self, tmpdir_module, model): assert os.path.exists(p) assert ( - len([ - p - for p in os.listdir(tmpdir_module) - if p.startswith(type(model).__name__) - ]) + len( + [ + p + for p in os.listdir(tmpdir_module) + if p.startswith(type(model).__name__) + ] + ) == len(full_model_paths) + 1 ) @@ -647,7 +649,7 @@ def test_model_str_call(self, config): ( ExponentialSmoothing(), "ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, " - + "seasonal_periods=None, random_state=0, kwargs=None)", + + "seasonal_periods=None, random_state=0)", ), # no params changed ( ARIMA(1, 1, 1), From d148ecf235590347e15e9c10ef64485828e06137 Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Tue, 16 Jul 2024 09:39:17 +0200 Subject: [PATCH 3/7] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10eb276588..9a7c130041 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - Added `IQRDetector`, that allows to detect anomalies using the interquartile range algorithm. [#2441] by [Igor Urbanik](https://github.com/u8-igor). - Made README's forecasting model support table more colorblind-friendly. [#2433](https://github.com/unit8co/darts/pull/2433) +- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA` +- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models + - Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor + - Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor **Fixed** - Fixed a bug when using `historical_forecasts()` with a pre-trained `RegressionModel` that has no target lags `lags=None` but uses static covariates. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader). From 939498aa84e4c896aaa3fbfbd34ba3f9b0c85020 Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Tue, 16 Jul 2024 10:30:09 +0200 Subject: [PATCH 4/7] Linting --- CHANGELOG.md | 2 +- darts/models/forecasting/auto_arima.py | 4 +- darts/models/forecasting/prophet_model.py | 56 ++++++++----------- .../test_local_forecasting_models.py | 12 ++-- 4 files changed, 31 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a7c130041..cc52c495af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Made README's forecasting model support table more colorblind-friendly. [#2433](https://github.com/unit8co/darts/pull/2433) - Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA` - 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models - - Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor + - Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor - Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor **Fixed** diff --git a/darts/models/forecasting/auto_arima.py b/darts/models/forecasting/auto_arima.py index 315a38793e..1f06c2c110 100644 --- a/darts/models/forecasting/auto_arima.py +++ b/darts/models/forecasting/auto_arima.py @@ -99,7 +99,7 @@ def _fit( self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, - **fit_kwargs + **fit_kwargs, ): super()._fit(series, future_covariates) self._assert_univariate(series) @@ -107,7 +107,7 @@ def _fit( self.model.fit( series.values(), X=future_covariates.values() if future_covariates else None, - **fit_kwargs + **fit_kwargs, ) return self diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index e3ff2a8a5f..dc0db0aecd 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -353,13 +353,11 @@ def _check_seasonality_conditions( condition_name = attributes["condition_name"] if condition_name is not None: if condition_name not in future_covariates_columns: - invalid_conditional_seasonalities.append( - ( - seasonality_name, - condition_name, - "column missing", - ) - ) + invalid_conditional_seasonalities.append(( + seasonality_name, + condition_name, + "column missing", + )) continue if ( not future_covariates[condition_name] @@ -367,13 +365,11 @@ def _check_seasonality_conditions( .isin([True, False]) .all() ): - invalid_conditional_seasonalities.append( - ( - seasonality_name, - condition_name, - "invalid values", - ) - ) + invalid_conditional_seasonalities.append(( + seasonality_name, + condition_name, + "invalid values", + )) continue conditional_seasonality_covariates.append(condition_name) @@ -606,23 +602,19 @@ def _freq_to_days(freq: str) -> float: seconds_per_day = 86400 days = 0 - if freq in ["A", "BA", "Y", "BY", "RE"] or freq.startswith( - ( - "A", - "BA", - "Y", - "BY", - "RE", - ) - ): # year + if freq in ["A", "BA", "Y", "BY", "RE"] or freq.startswith(( + "A", + "BA", + "Y", + "BY", + "RE", + )): # year days = 365.25 - elif freq in ["Q", "BQ", "REQ"] or freq.startswith( - ( - "Q", - "BQ", - "REQ", - ) - ): # quarter + elif freq in ["Q", "BQ", "REQ"] or freq.startswith(( + "Q", + "BQ", + "REQ", + )): # quarter days = 3 * 30.4375 elif freq in [ "M", @@ -631,9 +623,7 @@ def _freq_to_days(freq: str) -> float: "SM", "LWOM", "WOM", - ] or freq.startswith( - ("M", "BME", "BS", "CBM", "SM", "LWOM", "WOM") - ): # month + ] or freq.startswith(("M", "BME", "BS", "CBM", "SM", "LWOM", "WOM")): # month days = 30.4375 elif freq == "W" or freq.startswith("W-"): # week days = 7.0 diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index 5b2315ed7d..58109691f6 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -164,13 +164,11 @@ def test_save_load_model(self, tmpdir_module, model): assert os.path.exists(p) assert ( - len( - [ - p - for p in os.listdir(tmpdir_module) - if p.startswith(type(model).__name__) - ] - ) + len([ + p + for p in os.listdir(tmpdir_module) + if p.startswith(type(model).__name__) + ]) == len(full_model_paths) + 1 ) From 697cf7aab754d7c1fc528006947f75434f0d32a8 Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Fri, 19 Jul 2024 14:15:31 +0200 Subject: [PATCH 5/7] Rename to kwargs and pass to underlying model everywhere --- darts/models/forecasting/auto_arima.py | 4 ++-- darts/models/forecasting/croston.py | 6 +++++- darts/models/forecasting/exponential_smoothing.py | 6 +++--- darts/models/forecasting/forecasting_model.py | 6 +++--- darts/models/forecasting/prophet_model.py | 6 +++--- darts/models/forecasting/sf_auto_arima.py | 6 +++++- darts/models/forecasting/sf_auto_ets.py | 9 +++++---- 7 files changed, 26 insertions(+), 17 deletions(-) diff --git a/darts/models/forecasting/auto_arima.py b/darts/models/forecasting/auto_arima.py index 1f06c2c110..911c309914 100644 --- a/darts/models/forecasting/auto_arima.py +++ b/darts/models/forecasting/auto_arima.py @@ -99,7 +99,7 @@ def _fit( self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, - **fit_kwargs, + **kwargs, ): super()._fit(series, future_covariates) self._assert_univariate(series) @@ -107,7 +107,7 @@ def _fit( self.model.fit( series.values(), X=future_covariates.values() if future_covariates else None, - **fit_kwargs, + **kwargs, ) return self diff --git a/darts/models/forecasting/croston.py b/darts/models/forecasting/croston.py index df2e1758fe..bddfc95923 100644 --- a/darts/models/forecasting/croston.py +++ b/darts/models/forecasting/croston.py @@ -124,7 +124,10 @@ def supports_multivariate(self) -> bool: return False def _fit( - self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_ + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, ): super()._fit(series, future_covariates) self._assert_univariate(series) @@ -137,6 +140,7 @@ def _fit( if future_covariates is not None else None ), + **kwargs, ) return self diff --git a/darts/models/forecasting/exponential_smoothing.py b/darts/models/forecasting/exponential_smoothing.py index ff39e8dcc6..75887e30b5 100644 --- a/darts/models/forecasting/exponential_smoothing.py +++ b/darts/models/forecasting/exponential_smoothing.py @@ -94,14 +94,14 @@ def __init__( self.model = None np.random.seed(random_state) - def fit(self, series: TimeSeries, **fit_kwargs): + def fit(self, series: TimeSeries, **kwargs): """Fit/train the model on the (single) provided series. Parameters ---------- series The model will be trained to forecast this time series. - fit_kwargs + kwargs Some optional keyword arguments that will be used to call :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. See `the documentation @@ -138,7 +138,7 @@ def fit(self, series: TimeSeries, **fit_kwargs): dates=series.time_index if series.has_datetime_index else None, **self.constructor_kwargs, ) - hw_results = hw_model.fit(**fit_kwargs) + hw_results = hw_model.fit(**kwargs) self.model = hw_results if self.infer_seasonal_periods: diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index 321f68f51d..4345084137 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -2906,7 +2906,7 @@ def fit( self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, - **fit_kwargs, + **kwargs, ): """Fit/train the model on the (single) provided series. @@ -2920,7 +2920,7 @@ def fit( A time series of future-known covariates. This time series will not be forecasted, but can be used by some models as an input. It must contain at least the same time steps/indices as the target `series`. If it is longer than necessary, it will be automatically trimmed. - fit_kwargs + kwargs Optional keyword arguments that will be passed to the fit function of the underlying model if supported by the underlying model. @@ -2954,7 +2954,7 @@ def fit( super().fit(series) - return self._fit(series, future_covariates=future_covariates, **fit_kwargs) + return self._fit(series, future_covariates=future_covariates, **kwargs) @abstractmethod def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index dc0db0aecd..e50ca59c5b 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -207,7 +207,7 @@ def _fit( self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, - **fit_kwargs, + **kwargs, ): super()._fit(series, future_covariates) self._assert_univariate(series) @@ -254,10 +254,10 @@ def _fit( if self.suppress_stdout_stderr: self._execute_and_suppress_output( - self.model.fit, logger, logging.WARNING, fit_df, **fit_kwargs + self.model.fit, logger, logging.WARNING, fit_df, **kwargs ) else: - self.model.fit(fit_df, **fit_kwargs) + self.model.fit(fit_df, **kwargs) return self diff --git a/darts/models/forecasting/sf_auto_arima.py b/darts/models/forecasting/sf_auto_arima.py index b376bd2739..3302a8941c 100644 --- a/darts/models/forecasting/sf_auto_arima.py +++ b/darts/models/forecasting/sf_auto_arima.py @@ -90,7 +90,10 @@ def encode_year(idx): self.model = SFAutoARIMA(*autoarima_args, **autoarima_kwargs) def _fit( - self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_ + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, ): super()._fit(series, future_covariates) self._assert_univariate(series) @@ -98,6 +101,7 @@ def _fit( self.model.fit( series.values(copy=False).flatten(), X=future_covariates.values(copy=False) if future_covariates else None, + **kwargs, ) return self diff --git a/darts/models/forecasting/sf_auto_ets.py b/darts/models/forecasting/sf_auto_ets.py index 5a0c9a3302..e51aeb60fe 100644 --- a/darts/models/forecasting/sf_auto_ets.py +++ b/darts/models/forecasting/sf_auto_ets.py @@ -96,7 +96,10 @@ def encode_year(idx): self._linreg = None def _fit( - self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_ + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, ): super()._fit(series, future_covariates) self._assert_univariate(series) @@ -118,9 +121,7 @@ def _fit( else: target = series - self.model.fit( - target.values(copy=False).flatten(), - ) + self.model.fit(target.values(copy=False).flatten(), **kwargs) return self def _predict( From 616a44e9b8808f3325919bafc2e6843905676c32 Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Fri, 19 Jul 2024 14:46:20 +0200 Subject: [PATCH 6/7] Correct docstring --- darts/models/forecasting/forecasting_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index 4345084137..db8207ee8e 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -2921,8 +2921,7 @@ def fit( some models as an input. It must contain at least the same time steps/indices as the target `series`. If it is longer than necessary, it will be automatically trimmed. kwargs - Optional keyword arguments that will be passed to the fit function of the underlying model if supported - by the underlying model. + Optional keyword arguments that will be passed to the fit function of the underlying model. Returns ------- From bc3b4fd9a354d6efaab6bf551b805b38e6ec379b Mon Sep 17 00:00:00 2001 From: David Kleindienst Date: Fri, 19 Jul 2024 14:59:55 +0200 Subject: [PATCH 7/7] Add **kwargs to FutureCovariatesLocalForecastingModel._fit function signature --- darts/models/forecasting/forecasting_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index db8207ee8e..45cc25066c 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -2956,7 +2956,12 @@ def fit( return self._fit(series, future_covariates=future_covariates, **kwargs) @abstractmethod - def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None): + def _fit( + self, + series: TimeSeries, + future_covariates: Optional[TimeSeries] = None, + **kwargs, + ): """Fits/trains the model on the provided series. DualCovariatesModels must implement the fit logic in this method. """