Skip to content

Commit 86589b1

Browse files
updated tunables for boosting (#1306)
Co-authored-by: Emil Hvitfeldt <[email protected]>
1 parent b6112b2 commit 86589b1

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# parsnip (development version)
22

3+
* Updates to some boosting tuning parameter information: (#1306)
4+
- lightgbm and catboost have smaller default ranges for the learning rate: -3 to -1 / 2 in log10 units.
5+
- lightgbm, xgboost, catboost, and C5.0 have smaller default ranges for the sampling proportion: 0.5 to 1.0.
6+
- catboost engine arguments were added for `max_leaves` and `l2_leaf_reg`.
7+
38
* Enable generalized random forest (`grf`) models for classification, regression, and quantile regression modes. (#1288)
49

510
* `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206).

R/tunable.R

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,21 @@ lightgbm_engine_args <-
124124
component_id = "engine"
125125
)
126126

127+
catboost_engine_args <-
128+
tibble::tibble(
129+
name = c(
130+
"max_leaves",
131+
"l2_leaf_reg"
132+
),
133+
call_info = list(
134+
list(pkg = "dials", fun = "num_leaves"),
135+
list(pkg = "dials", fun = "penalty", range = c(-4, 1))
136+
),
137+
source = "model_spec",
138+
component = "boost_tree",
139+
component_id = "engine"
140+
)
141+
127142
ranger_engine_args <-
128143
tibble::tibble(
129144
name = c(
@@ -345,19 +360,27 @@ tunable.boost_tree <- function(x, ...) {
345360
if (x$engine == "xgboost") {
346361
res <- add_engine_parameters(res, xgboost_engine_args)
347362
res$call_info[res$name == "sample_size"] <-
348-
list(list(pkg = "dials", fun = "sample_prop"))
363+
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
349364
res$call_info[res$name == "learn_rate"] <-
350365
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1 / 2)))
351366
} else if (x$engine == "C5.0") {
352367
res <- add_engine_parameters(res, c5_boost_engine_args)
353368
res$call_info[res$name == "trees"] <-
354369
list(list(pkg = "dials", fun = "trees", range = c(1, 100)))
355370
res$call_info[res$name == "sample_size"] <-
356-
list(list(pkg = "dials", fun = "sample_prop"))
371+
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
357372
} else if (x$engine == "lightgbm") {
358373
res <- add_engine_parameters(res, lightgbm_engine_args)
359374
res$call_info[res$name == "sample_size"] <-
360-
list(list(pkg = "dials", fun = "sample_prop"))
375+
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
376+
res$call_info[res$name == "learn_rate"] <-
377+
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1 / 2)))
378+
} else if (x$engine == "catboost") {
379+
res <- add_engine_parameters(res, catboost_engine_args)
380+
res$call_info[res$name == "learn_rate"] <-
381+
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1 / 2)))
382+
res$call_info[res$name == "sample_size"] <-
383+
list(list(pkg = "dials", fun = "sample_prop", range = c(0.5, 1.0)))
361384
}
362385
res
363386
}

0 commit comments

Comments
 (0)