Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Improvements to `ForecastingModel`:
- Added parameter `clean: bool` to `ForecastingModel.save()` to store a cleaned version of the model (removes training data from global models, and Lightning Trainer-related parameters from torch models). [#2649](https://github.com/unit8co/darts/pull/2649) by [Jonas Blanc](https://github.com/jonasblanc).
- Added parameter `pl_trainer_kwargs` to `TorchForecastingModel.load()` to setup a new Lightning Trainer used to configure the model for downstream tasks (e.g. prediction). [#2649](https://github.com/unit8co/darts/pull/2649) by [Jonas Blanc](https://github.com/jonasblanc).
- Fixed a bug in `LightGBMModel`, `XGBModel`, and `CatBoostModel` which raised an error when calling `fit()` with `val_sample_weight`. [#2626](https://github.com/unit8co/darts/pull/2626) by [Kylin Schmidt](https://github.com/kylinschmidt).
- Improved the documentation of how `WindowedAnomalyScorer` extract the training data from the input series. [#2674](https://github.com/unit8co/darts/pull/2674) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def _add_val_set_to_kwargs(
val_weights = val_weights or None
else:
val_sets = [(val_samples, val_labels)]
val_weights = val_weight
val_weights = [val_weight]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks neat!


val_set_name, val_weight_name = self.val_set_params
return dict(kwargs, **{val_set_name: val_sets, val_weight_name: val_weights})
Expand Down
28 changes: 28 additions & 0 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,34 @@ def test_not_enough_covariates(self, config):
future_covariates=future_covariates[: -26 + req_future_offset],
)

@pytest.mark.parametrize(
"config",
product(
[(XGBModel, xgb_test_params)]
+ ([(LightGBMModel, lgbm_test_params)] if lgbm_available else [])
+ ([(CatBoostModel, cb_test_params)] if cb_available else []),
[True, False],
),
)
def test_val_set_weights_runnability_trees(self, config):
"""Tests using weights in val set for single and multi series."""
(model_cls, model_kwargs), single_series = config
model = model_cls(lags=10, **model_kwargs)

series = tg.sine_timeseries(length=20)
weights = tg.linear_timeseries(length=20)
if not single_series:
series = [series] * 2
weights = [weights] * 2

model.fit(
series=series,
val_series=series,
sample_weight=weights,
val_sample_weight=weights,
)
_ = model.predict(1, series=series)

@pytest.mark.parametrize(
"config",
product(
Expand Down