diff --git a/DESCRIPTION b/DESCRIPTION index ed9dc9a91..838722b7c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: tune Title: Tidy Tuning Tools -Version: 1.1.2.9014 +Version: 1.1.2.9015 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), diff --git a/NAMESPACE b/NAMESPACE index 360658f8b..fea2c4777 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -210,6 +210,7 @@ export(is_recipe) export(is_workflow) export(last_fit) export(load_pkgs) +export(maybe_choose_eval_time) export(message_wrap) export(metrics_info) export(min_grid) diff --git a/R/int_pctl.R b/R/int_pctl.R index fb30852f4..ad1f95458 100644 --- a/R/int_pctl.R +++ b/R/int_pctl.R @@ -73,8 +73,24 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001, rlang::check_dots_empty() - # check eval_time and set default when null - eval_time <- default_eval_time(eval_time, .data$.metrics[[1]]) + if (is.null(metrics)) { + metrics <- .get_tune_metrics(.data) + } else { + if (!inherits(metrics, "metric_set")) { + cli::cli_abort("{.arg metrics} should be a metric set as generated by {.fun yardstick::metric_set}.") + } + } + + if (is.null(eval_time)) { + eval_time <- .get_tune_eval_times(.data) + eval_time <- maybe_choose_eval_time(.data, metrics, eval_time) + } else { + eval_time <- unique(eval_time) + check_eval_time_in_tune_results(.data, eval_time) + # Are there at least a minimal number of evaluation times? + check_enough_eval_times(eval_time, metrics) + } + .data$.predictions <- filter_predictions_by_eval_time(.data$.predictions, eval_time) y_nm <- outcome_names(.data) @@ -90,10 +106,19 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001, } } + # TODO Changes in https://github.com/tidymodels/rsample/pull/465 + # will effect how these computations are done since they will + # compute intervals for `terms` as well as any columns that begin + # with a period. This will simply the code considerably for + # survival and non-survival models. We will make this version + # compatible with the future rsample version (but will factor + # this code later). + res <- purrr::map2( config_keys, sample.int(10000, p), - ~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par, event_level) + ~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par, + event_level, eval_time, alpha) ) %>% purrr::list_rbind() %>% dplyr::arrange(.config, .metric) @@ -110,7 +135,8 @@ get_int_p_operator <- function(allow = TRUE) { res } -boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par, event_level) { +boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par, + event_level, eval_time, alpha) { y_nm <- outcome_names(x) preds <- collect_predictions(x, summarize = TRUE, parameters = config) @@ -124,14 +150,13 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa .errorhandling = "pass", .packages = c("tune", "rsample") ) %op% { - comp_metrics(rs$splits[[i]], y_nm, metrics, event_level) + comp_metrics(rs$splits[[i]], y_nm, metrics, event_level, eval_time) } - if (any(grepl("survival", .get_tune_metric_names(x)))) { # compute by evaluation time - res <- int_pctl_dyn_surv(rs, allow_par) + res <- int_pctl_surv(rs, allow_par, alpha) } else { - res <- rsample::int_pctl(rs, .metrics) + res <- rsample::int_pctl(rs, .metrics, alpha = alpha) } res %>% dplyr::mutate(.estimator = "bootstrap") %>% @@ -141,29 +166,39 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa cbind(config) } -# We have to do the analysis separately for each evaluation time. -int_pctl_dyn_surv <- function(x, allow_par) { - `%op%` <- get_int_p_operator(allow_par) - times <- unique(x$.metrics[[1]]$.eval_time) - res <- - foreach::foreach( - i = seq_along(times), - .errorhandling = "pass", - .packages = c("purrr", "rsample", "dplyr") - ) %op% { - int_pctl_by_eval_time(times[i], x) - } - dplyr::bind_rows(res) +fake_term <- function(x) { + x$term <- paste(x$term, format(1:nrow(x))) + x } -int_pctl_by_eval_time <- function(time, x) { - times <- dplyr::tibble(.eval_time = time) - x$.metrics <- purrr::map(x$.metrics, ~ dplyr::inner_join(.x, times, by = ".eval_time")) - rsample::int_pctl(x, .metrics) %>% - dplyr::mutate(.eval_time = time) %>% - dplyr::relocate(.eval_time, .after = term) -} +# tests in extratests +# nocov start +int_pctl_surv <- function(x, allow_par, alpha) { + + # int_pctl() expects terms to be unique. For (many) survival models, the + # metrics are a combination of the metric name and the evaluation time. + # We'll make a phony term value, run int_pctl(), then merge the original values + # back in. + met_key <- x$.metrics[[1]] + met_key$estimate <- NULL + met_key$old_term <- met_key$term + met_key$order <- 1:nrow(met_key) + met_key <- fake_term(met_key) + x$.metrics <- purrr::map(x$.metrics, ~ fake_term(.x)) + res <- rsample::int_pctl(x, .metrics, alpha = alpha) + + merge_keys <- c("term", grep("^\\.", names(res), value = TRUE)) + merge_keys <- intersect(merge_keys, names(met_key)) + + res <- res %>% + dplyr::full_join(met_key, by = merge_keys) %>% + dplyr::arrange(order) %>% + dplyr::select(-term, -order) %>% + dplyr::rename(term = old_term) %>% + dplyr::relocate(term, dplyr::any_of(".eval_time")) +} +# nocov end # ------------------------------------------------------------------------------ @@ -184,7 +219,7 @@ get_configs <- function(x, parameters = NULL, as_list = TRUE) { } # Compute metrics for a specific configuration -comp_metrics <- function(split, y, metrics, event_level) { +comp_metrics <- function(split, y, metrics, event_level, eval_time) { dat <- rsample::analysis(split) info <- metrics_info(metrics) @@ -196,7 +231,7 @@ comp_metrics <- function(split, y, metrics, event_level) { outcome_name = y, event_level = event_level, metrics_info = info, - eval_time = NA # TODO I don't think that this is used in the function + eval_time = eval_time ) res %>% diff --git a/R/metric-selection.R b/R/metric-selection.R index 332c31628..1ad47dd62 100644 --- a/R/metric-selection.R +++ b/R/metric-selection.R @@ -27,6 +27,10 @@ #' If a time is required and none is given, the first value in the vector #' originally given in the `eval_time` argument is used (with a warning). #' +#' `maybe_choose_eval_time()` is for cases where multiple evaluation times are +#' acceptable but you need to choose a good default. The "maybe" is because +#' the function that would use `maybe_choose_eval_time()` can accept multiple +#' metrics (like [autoplot()]). #' @keywords internal #' @export choose_metric <- function(x, metric, ..., call = rlang::caller_env()) { @@ -71,22 +75,34 @@ check_metric_in_tune_results <- function(mtr_info, metric, call = rlang::caller_ invisible(NULL) } +check_enough_eval_times <- function(eval_time, metrics) { + num_times <- length(eval_time) + max_times_req <- req_eval_times(metrics) + cls <- tibble::as_tibble(metrics)$class + uni_cls <- sort(unique(cls)) + if (max_times_req > num_times) { + cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are} + required for the metric type(s) requested: {.val {uni_cls}}. + Only {num_times} unique time{?s} {?was/were} given.") + } + + invisible(NULL) +} + contains_survival_metric <- function(mtr_info) { any(grepl("_survival", mtr_info$class)) } - # choose_eval_time() is called by show_best() and select_best() #' @rdname choose_metric #' @export -choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::caller_env()) { - rlang::check_dots_empty() +choose_eval_time <- function(x, metric, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) { mtr_set <- .get_tune_metrics(x) mtr_info <- tibble::as_tibble(mtr_set) if (!contains_survival_metric(mtr_info)) { - if (!is.null(eval_time)) { + if (!is.null(eval_time) & !quietly) { cli::cli_warn("Evaluation times are only required when the model mode is {.val censored regression} (and will be ignored).", call = call) @@ -97,7 +113,7 @@ choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::cal dyn_metric <- is_dyn(mtr_set, metric) # If we don't need an eval time but one is passed: - if (!dyn_metric & !is.null(eval_time)) { + if (!dyn_metric & !is.null(eval_time) & !quietly) { cli::cli_warn("An evaluation time is only required when a dynamic metric is selected (and {.arg eval_time} will thus be ignored).", @@ -135,6 +151,24 @@ check_eval_time_in_tune_results <- function(x, eval_time, call = rlang::caller_e invisible(NULL) } +#' @rdname choose_metric +#' @export +maybe_choose_eval_time <- function(x, mtr_set, eval_time) { + mtr_info <- tibble::as_tibble(mtr_set) + if (any(grepl("integrated", mtr_info$metric))) { + return(.get_tune_eval_times(x)) + } + eval_time <- purrr::map(mtr_info$metric, ~ choose_eval_time(x, .x, eval_time = eval_time, quietly = TRUE)) + no_eval_time <- purrr::map_lgl(eval_time, is.null) + if (all(no_eval_time)) { + eval_time <- NULL + } else { + eval_time <- sort(unique(unlist(eval_time))) + } + eval_time +} + + # ------------------------------------------------------------------------------ #' @rdname choose_metric @@ -289,15 +323,10 @@ check_eval_time_arg <- function(eval_time, mtr_set, call = rlang::caller_env()) eval_time <- .filter_eval_time(eval_time) num_times <- length(eval_time) - max_times_req <- req_eval_times(mtr_set) - if (max_times_req > num_times) { - cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are} - required for the metric type(s) requested: {.val {uni_cls}}. - Only {num_times} unique time{?s} {?was/were} given.", - call = call) - } + # Are there at least a minimal number of evaluation times? + check_enough_eval_times(eval_time, mtr_set) if (max_times_req == 0 & num_times > 0) { cli::cli_warn("Evaluation times are only required when dynamic or diff --git a/man/choose_metric.Rd b/man/choose_metric.Rd index d5a9c948b..064e32231 100644 --- a/man/choose_metric.Rd +++ b/man/choose_metric.Rd @@ -4,6 +4,7 @@ \alias{choose_metric} \alias{check_metric_in_tune_results} \alias{choose_eval_time} +\alias{maybe_choose_eval_time} \alias{first_metric} \alias{first_eval_time} \alias{.filter_perf_metrics} @@ -15,7 +16,15 @@ choose_metric(x, metric, ..., call = rlang::caller_env()) check_metric_in_tune_results(mtr_info, metric, call = rlang::caller_env()) -choose_eval_time(x, metric, eval_time = NULL, ..., call = rlang::caller_env()) +choose_eval_time( + x, + metric, + eval_time = NULL, + quietly = FALSE, + call = rlang::caller_env() +) + +maybe_choose_eval_time(x, mtr_set, eval_time) first_metric(mtr_set) @@ -70,5 +79,10 @@ to the function used to produce \code{x} (such as \code{\link[=tune_grid]{tune_g If a time is required and none is given, the first value in the vector originally given in the \code{eval_time} argument is used (with a warning). + +\code{maybe_choose_eval_time()} is for cases where multiple evaluation times are +acceptable but you need to choose a good default. The "maybe" is because +the function that would use \code{maybe_choose_eval_time()} can accept multiple +metrics (like \code{\link[=autoplot]{autoplot()}}). } \keyword{internal} diff --git a/tests/testthat/_snaps/censored-reg.md b/tests/testthat/_snaps/censored-reg.md index 500b6a76f..d9a782439 100644 --- a/tests/testthat/_snaps/censored-reg.md +++ b/tests/testthat/_snaps/censored-reg.md @@ -3,7 +3,7 @@ Code spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = mtr) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. --- diff --git a/tests/testthat/_snaps/eval-time-args.md b/tests/testthat/_snaps/eval-time-args.md index 5736356b7..96d375534 100644 --- a/tests/testthat/_snaps/eval-time-args.md +++ b/tests/testthat/_snaps/eval-time-args.md @@ -112,7 +112,7 @@ Code check_eval_time_arg(NULL, met_dyn) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. --- @@ -120,7 +120,7 @@ Code check_eval_time_arg(NULL, met_int) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. --- @@ -128,7 +128,7 @@ Code check_eval_time_arg(NULL, met_stc_dyn) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -136,7 +136,7 @@ Code check_eval_time_arg(NULL, met_stc_int) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -144,7 +144,7 @@ Code check_eval_time_arg(NULL, met_dyn_stc) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -152,7 +152,7 @@ Code check_eval_time_arg(NULL, met_dyn_int) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -160,7 +160,7 @@ Code check_eval_time_arg(NULL, met_int_stc) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -168,7 +168,7 @@ Code check_eval_time_arg(NULL, met_int_dyn) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -193,7 +193,7 @@ Code check_eval_time_arg(2, met_int) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. --- @@ -208,7 +208,7 @@ Code check_eval_time_arg(2, met_stc_int) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -223,7 +223,7 @@ Code check_eval_time_arg(2, met_dyn_int) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -231,7 +231,7 @@ Code check_eval_time_arg(2, met_int_stc) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -239,7 +239,7 @@ Code check_eval_time_arg(2, met_int_dyn) Condition - Error: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -318,7 +318,7 @@ Code fit_resamples(wflow, rs, metrics = met_dyn) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. --- @@ -326,7 +326,7 @@ Code fit_resamples(wflow, rs, metrics = met_int) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. --- @@ -334,7 +334,7 @@ Code fit_resamples(wflow, rs, metrics = met_stc_dyn) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -342,7 +342,7 @@ Code fit_resamples(wflow, rs, metrics = met_stc_int) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -350,7 +350,7 @@ Code fit_resamples(wflow, rs, metrics = met_dyn_stc) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -358,7 +358,7 @@ Code fit_resamples(wflow, rs, metrics = met_dyn_int) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -366,7 +366,7 @@ Code fit_resamples(wflow, rs, metrics = met_int_stc) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -374,7 +374,7 @@ Code fit_resamples(wflow, rs, metrics = met_int_dyn) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -395,7 +395,7 @@ Code fit_resamples(wflow, rs, metrics = met_int, eval_time = 2) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. --- @@ -408,7 +408,7 @@ Code fit_resamples(wflow, rs, metrics = met_stc_int, eval_time = 2) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -421,7 +421,7 @@ Code fit_resamples(wflow, rs, metrics = met_dyn_int, eval_time = 2) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -429,7 +429,7 @@ Code fit_resamples(wflow, rs, metrics = met_int_stc, eval_time = 2) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -437,7 +437,7 @@ Code fit_resamples(wflow, rs, metrics = met_int_dyn, eval_time = 2) Condition - Error in `fit_resamples()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -498,7 +498,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_dyn) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. --- @@ -506,7 +506,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_int) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. --- @@ -514,7 +514,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_stc_dyn) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -522,7 +522,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_stc_int) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -530,7 +530,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_dyn_stc) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -538,7 +538,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_dyn_int) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -546,7 +546,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_int_stc) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -554,7 +554,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_int_dyn) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -575,7 +575,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_int, eval_time = 2) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. --- @@ -588,7 +588,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_stc_int, eval_time = 2) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -601,7 +601,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_dyn_int, eval_time = 2) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -609,7 +609,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_int_stc, eval_time = 2) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -617,7 +617,7 @@ Code tune_grid(wflow_tune, rs, metrics = met_int_dyn, eval_time = 2) Condition - Error in `tune_grid()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -673,7 +673,7 @@ Code last_fit(wflow, split, metrics = met_dyn) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. --- @@ -681,7 +681,7 @@ Code last_fit(wflow, split, metrics = met_int) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. --- @@ -689,7 +689,7 @@ Code last_fit(wflow, split, metrics = met_stc_dyn) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -697,7 +697,7 @@ Code last_fit(wflow, split, metrics = met_stc_int) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -705,7 +705,7 @@ Code last_fit(wflow, split, metrics = met_dyn_stc) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -713,7 +713,7 @@ Code last_fit(wflow, split, metrics = met_dyn_int) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -721,7 +721,7 @@ Code last_fit(wflow, split, metrics = met_int_stc) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. --- @@ -729,7 +729,7 @@ Code last_fit(wflow, split, metrics = met_int_dyn) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. --- @@ -750,7 +750,7 @@ Code last_fit(wflow, split, metrics = met_int, eval_time = 2) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. --- @@ -763,7 +763,7 @@ Code last_fit(wflow, split, metrics = met_stc_int, eval_time = 2) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -776,7 +776,7 @@ Code last_fit(wflow, split, metrics = met_dyn_int, eval_time = 2) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- @@ -784,7 +784,7 @@ Code last_fit(wflow, split, metrics = met_int_stc, eval_time = 2) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. --- @@ -792,7 +792,7 @@ Code last_fit(wflow, split, metrics = met_int_dyn, eval_time = 2) Condition - Error in `last_fit()`: + Error in `check_enough_eval_times()`: ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. --- diff --git a/tests/testthat/_snaps/int_pctl.md b/tests/testthat/_snaps/int_pctl.md index f5d4acd08..d6803c49c 100644 --- a/tests/testthat/_snaps/int_pctl.md +++ b/tests/testthat/_snaps/int_pctl.md @@ -1,7 +1,7 @@ # percentile intervals - resamples only Code - int_res_1 <- int_pctl(lm_res, times = 200) + int_res_1 <- int_pctl(lm_res, times = 500) Condition Warning: Recommend at least 1000 non-missing bootstrap resamples for terms: `rmse`, `rsq`. @@ -14,6 +14,14 @@ Error in `int_pctl()`: ! `metrics` should be a metric set as generated by `yardstick::metric_set()`. +--- + + Code + int_res_2 <- int_pctl(lm_res, times = 500, alpha = 0.25) + Condition + Warning: + Recommend at least 1000 non-missing bootstrap resamples for terms: `rmse`, `rsq`. + # percentile intervals - last fit Code @@ -30,11 +38,13 @@ Warning: Recommend at least 1000 non-missing bootstrap resamples for term `mae`. -# percentile intervals - tuning +# percentile intervals - grid + bayes tuning Code - int_res_1 <- int_pctl(c5_res, eval_time = 2) - Condition - Warning: - The 'eval_time' argument is not needed for this data set. + int_res_1 <- int_pctl(c5_res) + +# percentile intervals - grid tuning with validation set + + Code + int_res_1 <- int_pctl(c5_res) diff --git a/tests/testthat/test-int_pctl.R b/tests/testthat/test-int_pctl.R index da64e92ce..25c7b9a1f 100644 --- a/tests/testthat/test-int_pctl.R +++ b/tests/testthat/test-int_pctl.R @@ -23,12 +23,18 @@ test_that("percentile intervals - resamples only", { .upper = numeric(0), .config = character(0) ) - expect_snapshot(int_res_1 <- int_pctl(lm_res, times = 200)) + set.seed(1) + expect_snapshot(int_res_1 <- int_pctl(lm_res, times = 500)) expect_equal(int_res_1[0,], template) expect_equal(nrow(int_res_1), 2) expect_snapshot(int_pctl(lm_res, times = 2000, metrics = "rmse"), error = TRUE) + # check to make sure that alpha works + set.seed(1) + expect_snapshot(int_res_2 <- int_pctl(lm_res, times = 500, alpha = .25)) + expect_true(int_res_2$.lower[1] > int_res_1$.lower[1]) + expect_true(int_res_2$.upper[1] < int_res_1$.upper[1]) }) @@ -69,7 +75,7 @@ test_that("percentile intervals - last fit", { -test_that("percentile intervals - tuning", { +test_that("percentile intervals - grid + bayes tuning", { skip_if_not_installed("modeldata") skip_if_not_installed("C50") skip_if_not_installed("rsample", minimum_version = "1.1.1.9000") @@ -79,7 +85,7 @@ test_that("percentile intervals - tuning", { data("two_class_dat", package = "modeldata") set.seed(1) - cls_rs <- validation_split(two_class_dat) + cls_rs <- vfold_cv(two_class_dat) c5_res <- decision_tree(min_n = tune()) %>% @@ -102,7 +108,7 @@ test_that("percentile intervals - tuning", { min_n = numeric(0) ) - expect_snapshot(int_res_1 <- int_pctl(c5_res, eval_time = 2)) + expect_snapshot(int_res_1 <- int_pctl(c5_res)) expect_equal(int_res_1[0,], template) expect_equal(nrow(int_res_1), 3) @@ -137,7 +143,7 @@ test_that("percentile intervals - tuning", { expect_equal(nrow(int_res_2), 4) set.seed(1) int_res_3 <- int_pctl(c5_bo_res, event_level = "second") - expect_true(all(int_res_3$.estimate > int_res_2$.estimate)) + expect_true(all(int_res_3$.estimate < int_res_2$.estimate)) # ------------------------------------------------------------------------------ @@ -167,3 +173,46 @@ test_that("percentile intervals - tuning", { expect_equal(nrow(int_res_4), 4) }) + + + +test_that("percentile intervals - grid tuning with validation set", { + skip_if_not_installed("modeldata") + skip_if_not_installed("C50") + skip_if_not_installed("rsample", minimum_version = "1.1.1.9000") + library(rsample) + library(parsnip) + library(yardstick) + + data("two_class_dat", package = "modeldata") + set.seed(1) + cls_split <- initial_validation_split(two_class_dat, prop = c(.8, .15)) + cls_rs <- validation_set(cls_split) + + c5_res <- + decision_tree(min_n = tune()) %>% + set_engine("C5.0") %>% + set_mode("classification") %>% + tune_grid( + Class ~., + resamples = cls_rs, + grid = dplyr::tibble(min_n = c(5, 20, 40)), + metrics = metric_set(sens), + control = control_grid(save_pred = TRUE) + ) + template <- dplyr::tibble( + .metric = character(0), + .estimator = character(0), + .lower = numeric(0), + .estimate = numeric(0), + .upper = numeric(0), + .config = character(0), + min_n = numeric(0) + ) + + expect_snapshot(int_res_1 <- int_pctl(c5_res)) + expect_equal(int_res_1[0,], template) + expect_equal(nrow(int_res_1), 3) + +}) +