diff --git a/NEWS.md b/NEWS.md index 3963a8c..a69dcb9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # finetune (development version) +* Improved error message from `tune_sim_anneal()` when values in the supplied `param_info` do not encompass all values evaluated in the `initial` grid. This most often happens when a user mistakenly supplies different parameter sets to the function that generated the initial results and `tune_sim_anneal()`. + * Fixed bug where `tune_sim_anneal()` would fail when supplied parameters needing finalization. The function will now finalize needed parameter ranges internally (#39). * Fixed bug where packages specified in `control_race(pkgs)` were not actually loaded in `tune_race_anova()` (#74). diff --git a/R/sim_anneal_helpers.R b/R/sim_anneal_helpers.R index 494804b..fe5b00a 100644 --- a/R/sim_anneal_helpers.R +++ b/R/sim_anneal_helpers.R @@ -144,11 +144,30 @@ random_real_neighbor <- function(current, hist_values, pset, retain = 1, encode_set_backwards <- function(x, pset, ...) { pset <- pset[pset$id %in% names(x), ] + mapply(check_backwards_encode, pset$object, x, pset$id, + SIMPLIFY = FALSE, USE.NAMES = FALSE) new_vals <- purrr::map2(pset$object, x, dials::encode_unit, direction = "backward") names(new_vals) <- names(x) tibble::as_tibble(new_vals) } +check_backwards_encode <- function(x, value, id) { + if (!dials::has_unknowns(x)) { + compl <- value[!is.na(value)] + if (any(compl < 0) | any(compl > 1)) { + cli::cli_abort(c( + "!" = "The range for parameter {.val {noquote(id)}} used when \\ + generating initial results isn't compatible with the range \\ + supplied in {.arg param_info}.", + "i" = "Possible values of parameters in {.arg param_info} should \\ + encompass all values evaluated in the initial grid." + ), + call = rlang::call2("tune_sim_anneal()") + ) + } + } +} + sample_by_distance <- function(candidates, existing, retain, pset) { if (nrow(existing) > 0) { existing <- tune::encode_set(existing, pset, as_matrix = TRUE) diff --git a/tests/testthat/_snaps/sa-overall.md b/tests/testthat/_snaps/sa-overall.md index 0149e8d..86bc3f0 100644 --- a/tests/testthat/_snaps/sa-overall.md +++ b/tests/testthat/_snaps/sa-overall.md @@ -239,3 +239,18 @@ 9 - discard suboptimal roc_auc=0.84525 (+/-0.007793) 10 ( ) accept suboptimal roc_auc=0.84383 (+/-0.00773) +# incompatible parameter objects + + Code + res <- tune_sim_anneal(car_wflow, param_info = parameter_set_with_smaller_range, + resamples = car_folds, initial = tune_res_with_bigger_range, iter = 2) + Message + Optimizing rmse + + Condition + Error in `tune_sim_anneal()`: + ! The range for parameter mtry used when generating initial results isn't compatible with the range supplied in `param_info`. + i Possible values of parameters in `param_info` should encompass all values evaluated in the initial grid. + Message + x Optimization stopped prematurely; returning current results. + diff --git a/tests/testthat/test-sa-overall.R b/tests/testthat/test-sa-overall.R index 83af21f..d6751b0 100644 --- a/tests/testthat/test-sa-overall.R +++ b/tests/testthat/test-sa-overall.R @@ -127,6 +127,56 @@ test_that("unfinalized parameters", { }) }) +test_that("incompatible parameter objects", { + skip_on_cran() + + skip_if_not_installed("ranger") + skip_if_not_installed("modeldata") + skip_if_not_installed("rsample") + + rf_spec <- parsnip::rand_forest(mode = "regression", mtry = tune::tune()) + + set.seed(1) + grid_with_bigger_range <- + dials::grid_latin_hypercube(dials::mtry(range = c(1, 16))) + + set.seed(1) + car_folds <- rsample::vfold_cv(car_prices, v = 2) + + car_wflow <- workflows::workflow() %>% + workflows::add_formula(Price ~ .) %>% + workflows::add_model(rf_spec) + + set.seed(1) + tune_res_with_bigger_range <- tune::tune_grid( + car_wflow, + resamples = car_folds, + grid = grid_with_bigger_range + ) + + set.seed(1) + parameter_set_with_smaller_range <- + dials::parameters(dials::mtry(range = c(1, 5))) + + scrub_best <- function(lines) { + has_best <- grepl("Initial best", lines) + lines[has_best] <- "" + lines + } + + set.seed(1) + expect_snapshot(error = TRUE, transform = scrub_best, { + res <- + tune_sim_anneal( + car_wflow, + param_info = parameter_set_with_smaller_range, + resamples = car_folds, + initial = tune_res_with_bigger_range, + iter = 2 + ) + }) +}) + test_that("set event-level", { # See issue 40 skip_if_not_installed("rpart")