Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fixes for percentile intervals on survival metrics #818

Merged
merged 11 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ export(is_recipe)
export(is_workflow)
export(last_fit)
export(load_pkgs)
export(maybe_choose_eval_time)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we exporting this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly future-proofing; I can see us needing this a lot inside of tune but also maybe finetune or future extensions. There's no harm in exporting and we've re-exported a fair number of functions that were originally internal.

export(message_wrap)
export(metrics_info)
export(min_grid)
Expand Down
87 changes: 57 additions & 30 deletions R/int_pctl.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,27 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001,

rlang::check_dots_empty()

# check eval_time and set default when null
eval_time <- default_eval_time(eval_time, .data$.metrics[[1]])
if (is.null(metrics)) {
metrics <- .get_tune_metrics(.data)
}

if (is.null(eval_time)) {
eval_time <- .get_tune_eval_times(.data)
eval_time <- maybe_choose_eval_time(.data, metrics, eval_time)
} else {
eval_time <- unique(eval_time)
check_eval_time_in_tune_results(.data, eval_time)
num_times <- length(eval_time)
max_times_req <- req_eval_times(metrics)
cls <- tibble::as_tibble(metrics)$class
uni_cls <- sort(unique(cls))
if (max_times_req > num_times) {
cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are}
required for the metric type(s) requested: {.val {uni_cls}}.
Only {num_times} unique time{?s} {?was/were} given.")
topepo marked this conversation as resolved.
Show resolved Hide resolved
}
}

.data$.predictions <- filter_predictions_by_eval_time(.data$.predictions, eval_time)

y_nm <- outcome_names(.data)
Expand All @@ -89,7 +108,8 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001,
res <-
purrr::map2(
config_keys, sample.int(10000, p),
~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par, event_level)
~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par,
event_level, eval_time, alpha)
) %>%
purrr::list_rbind() %>%
dplyr::arrange(.config, .metric)
Expand All @@ -106,7 +126,8 @@ get_int_p_operator <- function(allow = TRUE) {
res
}

boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par, event_level) {
boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par,
event_level, eval_time, alpha) {
y_nm <- outcome_names(x)
preds <- collect_predictions(x, summarize = TRUE, parameters = config)

Expand All @@ -120,14 +141,13 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa
.errorhandling = "pass",
.packages = c("tune", "rsample")
) %op% {
comp_metrics(rs$splits[[i]], y_nm, metrics, event_level)
comp_metrics(rs$splits[[i]], y_nm, metrics, event_level, eval_time)
}

if (any(grepl("survival", .get_tune_metric_names(x)))) {
# compute by evaluation time
res <- int_pctl_dyn_surv(rs, allow_par)
res <- int_pctl_surv(rs, allow_par, alpha)
} else {
res <- rsample::int_pctl(rs, .metrics)
res <- rsample::int_pctl(rs, .metrics, alpha = alpha)
}
res %>%
dplyr::mutate(.estimator = "bootstrap") %>%
Expand All @@ -137,29 +157,36 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa
cbind(config)
}

# We have to do the analysis separately for each evaluation time.
int_pctl_dyn_surv <- function(x, allow_par) {
`%op%` <- get_int_p_operator(allow_par)
times <- unique(x$.metrics[[1]]$.eval_time)
res <-
foreach::foreach(
i = seq_along(times),
.errorhandling = "pass",
.packages = c("purrr", "rsample", "dplyr")
) %op% {
int_pctl_by_eval_time(times[i], x)
}
dplyr::bind_rows(res)
fake_term <- function(x) {
x$term <- paste(x$term, format(1:nrow(x)))
x
}

int_pctl_by_eval_time <- function(time, x) {
times <- dplyr::tibble(.eval_time = time)
x$.metrics <- purrr::map(x$.metrics, ~ dplyr::inner_join(.x, times, by = ".eval_time"))
rsample::int_pctl(x, .metrics) %>%
dplyr::mutate(.eval_time = time) %>%
dplyr::relocate(.eval_time, .after = term)
}
# tests in extratests
# nocov start
int_pctl_surv <- function(x, allow_par, alpha) {
`%op%` <- get_int_p_operator(allow_par)
topepo marked this conversation as resolved.
Show resolved Hide resolved

# int_pctl() expects terms to be unique. For (many) survival models, the
# metrics are a combination of the metric name and the evaluation time.
# We'll make a phony term value, run int_pctl(), then merge the original values
# back in.
met_key <- x$.metrics[[1]]
met_key$estimate <- NULL
met_key$old_term <- met_key$term
met_key$order <- 1:nrow(met_key)
met_key <- fake_term(met_key)

x$.metrics <- purrr::map(x$.metrics, ~ fake_term(.x))
res <-
rsample::int_pctl(x, .metrics, alpha = alpha) %>%
dplyr::full_join(met_key, by = "term") %>%
dplyr::arrange(order) %>%
dplyr::select(-term, -order) %>%
dplyr::rename(term = old_term) %>%
dplyr::relocate(term, dplyr::any_of(".eval_time"))
}
# nocov end

# ------------------------------------------------------------------------------

Expand All @@ -180,7 +207,7 @@ get_configs <- function(x, parameters = NULL, as_list = TRUE) {
}

# Compute metrics for a specific configuration
comp_metrics <- function(split, y, metrics, event_level) {
comp_metrics <- function(split, y, metrics, event_level, eval_time) {
dat <- rsample::analysis(split)
info <- metrics_info(metrics)

Expand All @@ -192,7 +219,7 @@ comp_metrics <- function(split, y, metrics, event_level) {
outcome_name = y,
event_level = event_level,
metrics_info = info,
eval_time = NA # TODO I don't think that this is used in the function
eval_time = eval_time
)

res %>%
Expand Down
29 changes: 25 additions & 4 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#' If a time is required and none is given, the first value in the vector
#' originally given in the `eval_time` argument is used (with a warning).
#'
#' `maybe_choose_eval_time()` is for cases where multiple evaluation times are
#' acceptable but you need to choose a good default. The "maybe" is because
#' the function that would use `maybe_choose_eval_time()` can accept multiple
#' metrics (like [autoplot()]).
#' @keywords internal
#' @export
choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {
Expand Down Expand Up @@ -78,14 +82,13 @@ contains_survival_metric <- function(mtr_info) {
# choose_eval_time() is called by show_best() and select_best()
#' @rdname choose_metric
#' @export
choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()
choose_eval_time <- function(x, metric, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) {

mtr_set <- .get_tune_metrics(x)
mtr_info <- tibble::as_tibble(mtr_set)

if (!contains_survival_metric(mtr_info)) {
if (!is.null(eval_time)) {
if (!is.null(eval_time) & !quietly) {
cli::cli_warn("Evaluation times are only required when the model
mode is {.val censored regression} (and will be ignored).",
call = call)
Expand All @@ -96,7 +99,7 @@ choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::cal
dyn_metric <- is_dyn(mtr_set, metric)

# If we don't need an eval time but one is passed:
if (!dyn_metric & !is.null(eval_time)) {
if (!dyn_metric & !is.null(eval_time) & !quietly) {
cli::cli_warn("An evaluation time is only required when a dynamic
metric is selected (and {.arg eval_time} will thus be
ignored).",
Expand Down Expand Up @@ -134,6 +137,24 @@ check_eval_time_in_tune_results <- function(x, eval_time, call = rlang::caller_e
invisible(NULL)
}

#' @rdname choose_metric
#' @export
maybe_choose_eval_time <- function(x, mtr_set, eval_time) {
mtr_info <- tibble::as_tibble(mtr_set)
if (any(grepl("integrated", mtr_info$metric))) {
return(.get_tune_eval_times(x))
}
eval_time <- purrr::map(mtr_info$metric, ~ choose_eval_time(x, .x, eval_time = eval_time, quietly = TRUE))
no_eval_time <- purrr::map_lgl(eval_time, is.null)
if (all(no_eval_time)) {
eval_time <- NULL
} else {
eval_time <- sort(unique(unlist(eval_time)))
}
eval_time
}


# ------------------------------------------------------------------------------

#' @rdname choose_metric
Expand Down
16 changes: 15 additions & 1 deletion man/choose_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 16 additions & 6 deletions tests/testthat/_snaps/int_pctl.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# percentile intervals - resamples only

Code
int_res_1 <- int_pctl(lm_res, times = 200)
int_res_1 <- int_pctl(lm_res, times = 500)
Condition
Warning:
Recommend at least 1000 non-missing bootstrap resamples for terms: `rmse`, `rsq`.

---

Code
int_res_2 <- int_pctl(lm_res, times = 500, alpha = 0.25)
Condition
Warning:
Recommend at least 1000 non-missing bootstrap resamples for terms: `rmse`, `rsq`.
Expand All @@ -22,11 +30,13 @@
Warning:
Recommend at least 1000 non-missing bootstrap resamples for term `mae`.

# percentile intervals - tuning
# percentile intervals - grid + bayes tuning

Code
int_res_1 <- int_pctl(c5_res, eval_time = 2)
Condition
Warning:
The 'eval_time' argument is not needed for this data set.
int_res_1 <- int_pctl(c5_res)

# percentile intervals - grid tuning with validation set

Code
int_res_1 <- int_pctl(c5_res)

60 changes: 55 additions & 5 deletions tests/testthat/test-int_pctl.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@ test_that("percentile intervals - resamples only", {
.upper = numeric(0),
.config = character(0)
)
expect_snapshot(int_res_1 <- int_pctl(lm_res, times = 200))
set.seed(1)
expect_snapshot(int_res_1 <- int_pctl(lm_res, times = 500))
expect_equal(int_res_1[0,], template)
expect_equal(nrow(int_res_1), 2)

# check to make sure that alpha works
set.seed(1)
expect_snapshot(int_res_2 <- int_pctl(lm_res, times = 500, alpha = .25))
expect_true(int_res_2$.lower[1] > int_res_1$.lower[1])
expect_true(int_res_2$.upper[1] < int_res_1$.upper[1])

})


Expand Down Expand Up @@ -67,7 +74,7 @@ test_that("percentile intervals - last fit", {



test_that("percentile intervals - tuning", {
test_that("percentile intervals - grid + bayes tuning", {
skip_if_not_installed("modeldata")
skip_if_not_installed("C50")
skip_if_not_installed("rsample", minimum_version = "1.1.1.9000")
Expand All @@ -77,7 +84,7 @@ test_that("percentile intervals - tuning", {

data("two_class_dat", package = "modeldata")
set.seed(1)
cls_rs <- validation_split(two_class_dat)
cls_rs <- vfold_cv(two_class_dat)

c5_res <-
decision_tree(min_n = tune()) %>%
Expand All @@ -100,7 +107,7 @@ test_that("percentile intervals - tuning", {
min_n = numeric(0)
)

expect_snapshot(int_res_1 <- int_pctl(c5_res, eval_time = 2))
expect_snapshot(int_res_1 <- int_pctl(c5_res))
expect_equal(int_res_1[0,], template)
expect_equal(nrow(int_res_1), 3)

Expand Down Expand Up @@ -135,7 +142,7 @@ test_that("percentile intervals - tuning", {
expect_equal(nrow(int_res_2), 4)
set.seed(1)
int_res_3 <- int_pctl(c5_bo_res, event_level = "second")
expect_true(all(int_res_3$.estimate > int_res_2$.estimate))
expect_true(all(int_res_3$.estimate < int_res_2$.estimate))

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -165,3 +172,46 @@ test_that("percentile intervals - tuning", {
expect_equal(nrow(int_res_4), 4)
})




test_that("percentile intervals - grid tuning with validation set", {
skip_if_not_installed("modeldata")
skip_if_not_installed("C50")
skip_if_not_installed("rsample", minimum_version = "1.1.1.9000")
library(rsample)
library(parsnip)
library(yardstick)

data("two_class_dat", package = "modeldata")
set.seed(1)
cls_split <- initial_validation_split(two_class_dat, prop = c(.8, .15))
cls_rs <- validation_set(cls_split)

c5_res <-
decision_tree(min_n = tune()) %>%
set_engine("C5.0") %>%
set_mode("classification") %>%
tune_grid(
Class ~.,
resamples = cls_rs,
grid = dplyr::tibble(min_n = c(5, 20, 40)),
metrics = metric_set(sens),
control = control_grid(save_pred = TRUE)
)
template <- dplyr::tibble(
.metric = character(0),
.estimator = character(0),
.lower = numeric(0),
.estimate = numeric(0),
.upper = numeric(0),
.config = character(0),
min_n = numeric(0)
)

expect_snapshot(int_res_1 <- int_pctl(c5_res))
expect_equal(int_res_1[0,], template)
expect_equal(nrow(int_res_1), 3)

})

Loading