Skip to content

Commit 21a56c8

Browse files
committed
exclude non-tunable engine arguments in tunable()
closes #1104
1 parent 6e8106e commit 21a56c8

File tree

2 files changed

+26
-35
lines changed

2 files changed

+26
-35
lines changed

R/tunable.R

+10-10
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ tunable.linear_reg <- function(x, ...) {
248248
} else if (x$engine == "brulee") {
249249
res <- add_engine_parameters(res, brulee_linear_engine_args)
250250
}
251-
res
251+
res[!vapply(res$call_info, is.null, logical(1)), ]
252252
}
253253

254254
#' @export
@@ -260,7 +260,7 @@ tunable.logistic_reg <- function(x, ...) {
260260
} else if (x$engine == "brulee") {
261261
res <- add_engine_parameters(res, brulee_logistic_engine_args)
262262
}
263-
res
263+
res[!vapply(res$call_info, is.null, logical(1)), ]
264264
}
265265

266266
#' @export
@@ -272,7 +272,7 @@ tunable.multinomial_reg <- function(x, ...) {
272272
} else if (x$engine == "brulee") {
273273
res <- add_engine_parameters(res, brulee_multinomial_engine_args)
274274
}
275-
res
275+
res[!vapply(res$call_info, is.null, logical(1)), ]
276276
}
277277

278278
#' @export
@@ -295,7 +295,7 @@ tunable.boost_tree <- function(x, ...) {
295295
res$call_info[res$name == "sample_size"] <-
296296
list(list(pkg = "dials", fun = "sample_prop"))
297297
}
298-
res
298+
res[!vapply(res$call_info, is.null, logical(1)), ]
299299
}
300300

301301
#' @export
@@ -310,7 +310,7 @@ tunable.rand_forest <- function(x, ...) {
310310
} else if (x$engine == "aorsf") {
311311
res <- add_engine_parameters(res, aorsf_engine_args)
312312
}
313-
res
313+
res[!vapply(res$call_info, is.null, logical(1)), ]
314314
}
315315

316316
#' @export
@@ -319,7 +319,7 @@ tunable.mars <- function(x, ...) {
319319
if (x$engine == "earth") {
320320
res <- add_engine_parameters(res, earth_engine_args)
321321
}
322-
res
322+
res[!vapply(res$call_info, is.null, logical(1)), ]
323323
}
324324

325325
#' @export
@@ -333,7 +333,7 @@ tunable.decision_tree <- function(x, ...) {
333333
partykit_engine_args %>%
334334
dplyr::mutate(component = "decision_tree"))
335335
}
336-
res
336+
res[!vapply(res$call_info, is.null, logical(1)), ]
337337
}
338338

339339
#' @export
@@ -343,7 +343,7 @@ tunable.svm_poly <- function(x, ...) {
343343
res$call_info[res$name == "degree"] <-
344344
list(list(pkg = "dials", fun = "prod_degree", range = c(1L, 3L)))
345345
}
346-
res
346+
res[!vapply(res$call_info, is.null, logical(1)), ]
347347
}
348348

349349

@@ -357,7 +357,7 @@ tunable.mlp <- function(x, ...) {
357357
res$call_info[res$name == "epochs"] <-
358358
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
359359
}
360-
res
360+
res[!vapply(res$call_info, is.null, logical(1)), ]
361361
}
362362

363363
#' @export
@@ -366,7 +366,7 @@ tunable.survival_reg <- function(x, ...) {
366366
if (x$engine == "flexsurvspline") {
367367
res <- add_engine_parameters(res, flexsurvspline_engine_args)
368368
}
369-
res
369+
res[!vapply(res$call_info, is.null, logical(1)), ]
370370
}
371371

372372
# nocov end

tests/testthat/_snaps/tunable.md

+16-25
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@
4343
Code
4444
tunable(spec %>% set_engine("glmnet", dfmax = tune()))
4545
Output
46-
# A tibble: 3 x 5
46+
# A tibble: 2 x 5
4747
name call_info source component component_id
4848
<chr> <list> <chr> <chr> <chr>
4949
1 penalty <named list [2]> model_spec linear_reg main
5050
2 mixture <named list [3]> model_spec linear_reg main
51-
3 dfmax <NULL> model_spec linear_reg engine
5251

5352
# tunable.logistic_reg()
5453

@@ -95,12 +94,11 @@
9594
Code
9695
tunable(spec %>% set_engine("glmnet", dfmax = tune()))
9796
Output
98-
# A tibble: 3 x 5
97+
# A tibble: 2 x 5
9998
name call_info source component component_id
10099
<chr> <list> <chr> <chr> <chr>
101100
1 penalty <named list [2]> model_spec logistic_reg main
102101
2 mixture <named list [3]> model_spec logistic_reg main
103-
3 dfmax <NULL> model_spec logistic_reg engine
104102

105103
# tunable.multinom_reg()
106104

@@ -244,7 +242,7 @@
244242
Code
245243
tunable(spec %>% set_engine("xgboost", feval = tune()))
246244
Output
247-
# A tibble: 9 x 5
245+
# A tibble: 8 x 5
248246
name call_info source component component_id
249247
<chr> <list> <chr> <chr> <chr>
250248
1 tree_depth <named list [2]> model_spec boost_tree main
@@ -255,7 +253,6 @@
255253
6 loss_reduction <named list [2]> model_spec boost_tree main
256254
7 sample_size <named list [2]> model_spec boost_tree main
257255
8 stop_iter <named list [2]> model_spec boost_tree main
258-
9 feval <NULL> model_spec boost_tree engine
259256

260257
# tunable.rand_forest()
261258

@@ -310,13 +307,12 @@
310307
Code
311308
tunable(spec %>% set_engine("ranger", min.bucket = tune()))
312309
Output
313-
# A tibble: 4 x 5
314-
name call_info source component component_id
315-
<chr> <list> <chr> <chr> <chr>
316-
1 mtry <named list [2]> model_spec rand_forest main
317-
2 trees <named list [2]> model_spec rand_forest main
318-
3 min_n <named list [2]> model_spec rand_forest main
319-
4 min.bucket <NULL> model_spec rand_forest engine
310+
# A tibble: 3 x 5
311+
name call_info source component component_id
312+
<chr> <list> <chr> <chr> <chr>
313+
1 mtry <named list [2]> model_spec rand_forest main
314+
2 trees <named list [2]> model_spec rand_forest main
315+
3 min_n <named list [2]> model_spec rand_forest main
320316

321317
# tunable.mars()
322318

@@ -347,13 +343,12 @@
347343
Code
348344
tunable(spec %>% set_engine("earth", minspan = tune()))
349345
Output
350-
# A tibble: 4 x 5
346+
# A tibble: 3 x 5
351347
name call_info source component component_id
352348
<chr> <list> <chr> <chr> <chr>
353349
1 num_terms <named list [3]> model_spec mars main
354350
2 prod_degree <named list [2]> model_spec mars main
355351
3 prune_method <named list [2]> model_spec mars main
356-
4 minspan <NULL> model_spec mars engine
357352

358353
# tunable.decision_tree()
359354

@@ -405,13 +400,12 @@
405400
Code
406401
tunable(spec %>% set_engine("rpart", parms = tune()))
407402
Output
408-
# A tibble: 4 x 5
403+
# A tibble: 3 x 5
409404
name call_info source component component_id
410405
<chr> <list> <chr> <chr> <chr>
411406
1 tree_depth <named list [2]> model_spec decision_tree main
412407
2 min_n <named list [2]> model_spec decision_tree main
413408
3 cost_complexity <named list [2]> model_spec decision_tree main
414-
4 parms <NULL> model_spec decision_tree engine
415409

416410
# tunable.svm_poly()
417411

@@ -444,14 +438,13 @@
444438
Code
445439
tunable(spec %>% set_engine("kernlab", tol = tune()))
446440
Output
447-
# A tibble: 5 x 5
441+
# A tibble: 4 x 5
448442
name call_info source component component_id
449443
<chr> <list> <chr> <chr> <chr>
450444
1 cost <named list [3]> model_spec svm_poly main
451445
2 degree <named list [3]> model_spec svm_poly main
452446
3 scale_factor <named list [2]> model_spec svm_poly main
453447
4 margin <named list [2]> model_spec svm_poly main
454-
5 tol <NULL> model_spec svm_poly engine
455448

456449
# tunable.mlp()
457450

@@ -511,15 +504,14 @@
511504
Code
512505
tunable(spec %>% set_engine("keras", ragged = tune()))
513506
Output
514-
# A tibble: 6 x 5
507+
# A tibble: 5 x 5
515508
name call_info source component component_id
516509
<chr> <list> <chr> <chr> <chr>
517510
1 hidden_units <named list [2]> model_spec mlp main
518511
2 penalty <named list [2]> model_spec mlp main
519512
3 dropout <named list [2]> model_spec mlp main
520513
4 epochs <named list [2]> model_spec mlp main
521514
5 activation <named list [2]> model_spec mlp main
522-
6 ragged <NULL> model_spec mlp engine
523515

524516
# tunable.survival_reg()
525517

@@ -544,8 +536,7 @@
544536
Code
545537
tunable(spec %>% set_engine("survival", parms = tune()))
546538
Output
547-
# A tibble: 1 x 5
548-
name call_info source component component_id
549-
<chr> <list> <chr> <chr> <chr>
550-
1 parms <NULL> model_spec survival_reg engine
539+
# A tibble: 0 x 5
540+
# i 5 variables: name <chr>, call_info <list>, source <chr>, component <chr>,
541+
# component_id <chr>
551542

0 commit comments

Comments
 (0)