Skip to content

Commit

Permalink
bug fixes for percentile intervals on survival metrics (#818)
Browse files Browse the repository at this point in the history
* add maybe_choose_eval_time

* re-write survival bits

* update/fix tests

* nocov

* added additional test

* version bump

* changes based on reviewer feedback

* updates to work with tidymodels/rsample#465

* update snapshots

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
  • Loading branch information
topepo and ‘topepo’ authored Jan 24, 2024
1 parent 4181d91 commit 59612cf
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 108 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9014
Version: 1.1.2.9015
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ export(is_recipe)
export(is_workflow)
export(last_fit)
export(load_pkgs)
export(maybe_choose_eval_time)
export(message_wrap)
export(metrics_info)
export(min_grid)
Expand Down
95 changes: 65 additions & 30 deletions R/int_pctl.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,24 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001,

rlang::check_dots_empty()

# check eval_time and set default when null
eval_time <- default_eval_time(eval_time, .data$.metrics[[1]])
if (is.null(metrics)) {
metrics <- .get_tune_metrics(.data)
} else {
if (!inherits(metrics, "metric_set")) {
cli::cli_abort("{.arg metrics} should be a metric set as generated by {.fun yardstick::metric_set}.")
}
}

if (is.null(eval_time)) {
eval_time <- .get_tune_eval_times(.data)
eval_time <- maybe_choose_eval_time(.data, metrics, eval_time)
} else {
eval_time <- unique(eval_time)
check_eval_time_in_tune_results(.data, eval_time)
# Are there at least a minimal number of evaluation times?
check_enough_eval_times(eval_time, metrics)
}

.data$.predictions <- filter_predictions_by_eval_time(.data$.predictions, eval_time)

y_nm <- outcome_names(.data)
Expand All @@ -90,10 +106,19 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001,
}
}

# TODO Changes in https://github.com/tidymodels/rsample/pull/465
# will effect how these computations are done since they will
# compute intervals for `terms` as well as any columns that begin
# with a period. This will simply the code considerably for
# survival and non-survival models. We will make this version
# compatible with the future rsample version (but will factor
# this code later).

res <-
purrr::map2(
config_keys, sample.int(10000, p),
~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par, event_level)
~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par,
event_level, eval_time, alpha)
) %>%
purrr::list_rbind() %>%
dplyr::arrange(.config, .metric)
Expand All @@ -110,7 +135,8 @@ get_int_p_operator <- function(allow = TRUE) {
res
}

boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par, event_level) {
boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par,
event_level, eval_time, alpha) {
y_nm <- outcome_names(x)
preds <- collect_predictions(x, summarize = TRUE, parameters = config)

Expand All @@ -124,14 +150,13 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa
.errorhandling = "pass",
.packages = c("tune", "rsample")
) %op% {
comp_metrics(rs$splits[[i]], y_nm, metrics, event_level)
comp_metrics(rs$splits[[i]], y_nm, metrics, event_level, eval_time)
}

if (any(grepl("survival", .get_tune_metric_names(x)))) {
# compute by evaluation time
res <- int_pctl_dyn_surv(rs, allow_par)
res <- int_pctl_surv(rs, allow_par, alpha)
} else {
res <- rsample::int_pctl(rs, .metrics)
res <- rsample::int_pctl(rs, .metrics, alpha = alpha)
}
res %>%
dplyr::mutate(.estimator = "bootstrap") %>%
Expand All @@ -141,29 +166,39 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa
cbind(config)
}

# We have to do the analysis separately for each evaluation time.
int_pctl_dyn_surv <- function(x, allow_par) {
`%op%` <- get_int_p_operator(allow_par)
times <- unique(x$.metrics[[1]]$.eval_time)
res <-
foreach::foreach(
i = seq_along(times),
.errorhandling = "pass",
.packages = c("purrr", "rsample", "dplyr")
) %op% {
int_pctl_by_eval_time(times[i], x)
}
dplyr::bind_rows(res)
fake_term <- function(x) {
x$term <- paste(x$term, format(1:nrow(x)))
x
}

int_pctl_by_eval_time <- function(time, x) {
times <- dplyr::tibble(.eval_time = time)
x$.metrics <- purrr::map(x$.metrics, ~ dplyr::inner_join(.x, times, by = ".eval_time"))
rsample::int_pctl(x, .metrics) %>%
dplyr::mutate(.eval_time = time) %>%
dplyr::relocate(.eval_time, .after = term)
}
# tests in extratests
# nocov start
int_pctl_surv <- function(x, allow_par, alpha) {

# int_pctl() expects terms to be unique. For (many) survival models, the
# metrics are a combination of the metric name and the evaluation time.
# We'll make a phony term value, run int_pctl(), then merge the original values
# back in.
met_key <- x$.metrics[[1]]
met_key$estimate <- NULL
met_key$old_term <- met_key$term
met_key$order <- 1:nrow(met_key)
met_key <- fake_term(met_key)

x$.metrics <- purrr::map(x$.metrics, ~ fake_term(.x))
res <- rsample::int_pctl(x, .metrics, alpha = alpha)

merge_keys <- c("term", grep("^\\.", names(res), value = TRUE))
merge_keys <- intersect(merge_keys, names(met_key))

res <- res %>%
dplyr::full_join(met_key, by = merge_keys) %>%
dplyr::arrange(order) %>%
dplyr::select(-term, -order) %>%
dplyr::rename(term = old_term) %>%
dplyr::relocate(term, dplyr::any_of(".eval_time"))
}
# nocov end

# ------------------------------------------------------------------------------

Expand All @@ -184,7 +219,7 @@ get_configs <- function(x, parameters = NULL, as_list = TRUE) {
}

# Compute metrics for a specific configuration
comp_metrics <- function(split, y, metrics, event_level) {
comp_metrics <- function(split, y, metrics, event_level, eval_time) {
dat <- rsample::analysis(split)
info <- metrics_info(metrics)

Expand All @@ -196,7 +231,7 @@ comp_metrics <- function(split, y, metrics, event_level) {
outcome_name = y,
event_level = event_level,
metrics_info = info,
eval_time = NA # TODO I don't think that this is used in the function
eval_time = eval_time
)

res %>%
Expand Down
53 changes: 41 additions & 12 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#' If a time is required and none is given, the first value in the vector
#' originally given in the `eval_time` argument is used (with a warning).
#'
#' `maybe_choose_eval_time()` is for cases where multiple evaluation times are
#' acceptable but you need to choose a good default. The "maybe" is because
#' the function that would use `maybe_choose_eval_time()` can accept multiple
#' metrics (like [autoplot()]).
#' @keywords internal
#' @export
choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {
Expand Down Expand Up @@ -71,22 +75,34 @@ check_metric_in_tune_results <- function(mtr_info, metric, call = rlang::caller_
invisible(NULL)
}

check_enough_eval_times <- function(eval_time, metrics) {
num_times <- length(eval_time)
max_times_req <- req_eval_times(metrics)
cls <- tibble::as_tibble(metrics)$class
uni_cls <- sort(unique(cls))
if (max_times_req > num_times) {
cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are}
required for the metric type(s) requested: {.val {uni_cls}}.
Only {num_times} unique time{?s} {?was/were} given.")
}

invisible(NULL)
}

contains_survival_metric <- function(mtr_info) {
any(grepl("_survival", mtr_info$class))
}


# choose_eval_time() is called by show_best() and select_best()
#' @rdname choose_metric
#' @export
choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()
choose_eval_time <- function(x, metric, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) {

mtr_set <- .get_tune_metrics(x)
mtr_info <- tibble::as_tibble(mtr_set)

if (!contains_survival_metric(mtr_info)) {
if (!is.null(eval_time)) {
if (!is.null(eval_time) & !quietly) {
cli::cli_warn("Evaluation times are only required when the model
mode is {.val censored regression} (and will be ignored).",
call = call)
Expand All @@ -97,7 +113,7 @@ choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::cal
dyn_metric <- is_dyn(mtr_set, metric)

# If we don't need an eval time but one is passed:
if (!dyn_metric & !is.null(eval_time)) {
if (!dyn_metric & !is.null(eval_time) & !quietly) {
cli::cli_warn("An evaluation time is only required when a dynamic
metric is selected (and {.arg eval_time} will thus be
ignored).",
Expand Down Expand Up @@ -135,6 +151,24 @@ check_eval_time_in_tune_results <- function(x, eval_time, call = rlang::caller_e
invisible(NULL)
}

#' @rdname choose_metric
#' @export
maybe_choose_eval_time <- function(x, mtr_set, eval_time) {
mtr_info <- tibble::as_tibble(mtr_set)
if (any(grepl("integrated", mtr_info$metric))) {
return(.get_tune_eval_times(x))
}
eval_time <- purrr::map(mtr_info$metric, ~ choose_eval_time(x, .x, eval_time = eval_time, quietly = TRUE))
no_eval_time <- purrr::map_lgl(eval_time, is.null)
if (all(no_eval_time)) {
eval_time <- NULL
} else {
eval_time <- sort(unique(unlist(eval_time)))
}
eval_time
}


# ------------------------------------------------------------------------------

#' @rdname choose_metric
Expand Down Expand Up @@ -289,15 +323,10 @@ check_eval_time_arg <- function(eval_time, mtr_set, call = rlang::caller_env())
eval_time <- .filter_eval_time(eval_time)

num_times <- length(eval_time)

max_times_req <- req_eval_times(mtr_set)

if (max_times_req > num_times) {
cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are}
required for the metric type(s) requested: {.val {uni_cls}}.
Only {num_times} unique time{?s} {?was/were} given.",
call = call)
}
# Are there at least a minimal number of evaluation times?
check_enough_eval_times(eval_time, mtr_set)

if (max_times_req == 0 & num_times > 0) {
cli::cli_warn("Evaluation times are only required when dynamic or
Expand Down
16 changes: 15 additions & 1 deletion man/choose_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/censored-reg.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Code
spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = mtr)
Condition
Error in `tune_grid()`:
Error in `check_enough_eval_times()`:
! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given.

---
Expand Down
Loading

0 comments on commit 59612cf

Please sign in to comment.