-
Notifications
You must be signed in to change notification settings - Fork 920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/val_sample_weight error for models inherited from RegressionModel #2626
Fix/val_sample_weight error for models inherited from RegressionModel #2626
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you for your contribution.
Can you please add unit-tests so that we can make sure the fix works as expected for all the situations?
CHANGELOG.md
Outdated
@@ -7,6 +7,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co | |||
|
|||
[Full Changelog](https://github.com/unit8co/darts/compare/0.31.0...master) | |||
|
|||
- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models. | |
- Fix a bug in `RegressionModel` when `val_sample_weight` is used with a single timeseries. [#2626](https://github.com/unit8co/darts/pull/2626) by [Kylin Schmidt](https://github.com/kylinschmidt). |
CHANGELOG.md
Outdated
@@ -1403,7 +1405,7 @@ ts: TimeSeries = AirPassengers().load() | |||
```python | |||
# Assuming a multivariate TimeSeries named series with 3 columns or variables. | |||
# To apply fn to columns with names '0' and '2': | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you revert this change and the others below?
@@ -588,7 +588,7 @@ def _add_val_set_to_kwargs( | |||
val_weights = val_weights or None | |||
else: | |||
val_sets = [(val_samples, val_labels)] | |||
val_weights = val_weight | |||
val_weights = [val_weight] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks neat!
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2626 +/- ##
==========================================
- Coverage 94.27% 94.19% -0.08%
==========================================
Files 141 141
Lines 15552 15552
==========================================
- Hits 14661 14649 -12
- Misses 891 903 +12 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @KylinSchmidt for this fix 🚀 I took the liberty to push some changes (unit tests, and the suggestions from @madtoinou ).
Everything looks great now, ready to merge 💯
Checklist before merging this PR:
Fix the bug in #2579 that causes an error when
val_sample_weight
is set in the CatBoost and XGBoost models.Summary
The CatBoost and XGBoost models, which inherit from the RegressionModel class, encounter an error when setting val_sample_weight. This issue is resolved by modifying the _add_val_set_to_kwargs method in RegressionModel, changing val_weights = val_weight to val_weights = [val_weight].
Other Information