@@ -124,6 +124,21 @@ lightgbm_engine_args <-
124124 component_id = " engine"
125125 )
126126
127+ catboost_engine_args <-
128+ tibble :: tibble(
129+ name = c(
130+ " max_leaves" ,
131+ " l2_leaf_reg"
132+ ),
133+ call_info = list (
134+ list (pkg = " dials" , fun = " num_leaves" ),
135+ list (pkg = " dials" , fun = " penalty" , range = c(- 4 , 1 ))
136+ ),
137+ source = " model_spec" ,
138+ component = " boost_tree" ,
139+ component_id = " engine"
140+ )
141+
127142ranger_engine_args <-
128143 tibble :: tibble(
129144 name = c(
@@ -345,19 +360,27 @@ tunable.boost_tree <- function(x, ...) {
345360 if (x $ engine == " xgboost" ) {
346361 res <- add_engine_parameters(res , xgboost_engine_args )
347362 res $ call_info [res $ name == " sample_size" ] <-
348- list (list (pkg = " dials" , fun = " sample_prop" ))
363+ list (list (pkg = " dials" , fun = " sample_prop" , range = c( 0.5 , 1.0 ) ))
349364 res $ call_info [res $ name == " learn_rate" ] <-
350365 list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
351366 } else if (x $ engine == " C5.0" ) {
352367 res <- add_engine_parameters(res , c5_boost_engine_args )
353368 res $ call_info [res $ name == " trees" ] <-
354369 list (list (pkg = " dials" , fun = " trees" , range = c(1 , 100 )))
355370 res $ call_info [res $ name == " sample_size" ] <-
356- list (list (pkg = " dials" , fun = " sample_prop" ))
371+ list (list (pkg = " dials" , fun = " sample_prop" , range = c( 0.5 , 1.0 ) ))
357372 } else if (x $ engine == " lightgbm" ) {
358373 res <- add_engine_parameters(res , lightgbm_engine_args )
359374 res $ call_info [res $ name == " sample_size" ] <-
360- list (list (pkg = " dials" , fun = " sample_prop" ))
375+ list (list (pkg = " dials" , fun = " sample_prop" , range = c(0.5 , 1.0 )))
376+ res $ call_info [res $ name == " learn_rate" ] <-
377+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
378+ } else if (x $ engine == " catboost" ) {
379+ res <- add_engine_parameters(res , catboost_engine_args )
380+ res $ call_info [res $ name == " learn_rate" ] <-
381+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
382+ res $ call_info [res $ name == " sample_size" ] <-
383+ list (list (pkg = " dials" , fun = " sample_prop" , range = c(0.5 , 1.0 )))
361384 }
362385 res
363386}
0 commit comments