Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fixes for percentile intervals on survival metrics #818

Merged
merged 11 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9014
Version: 1.1.2.9015
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ export(is_recipe)
export(is_workflow)
export(last_fit)
export(load_pkgs)
export(maybe_choose_eval_time)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we exporting this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly future-proofing; I can see us needing this a lot inside of tune but also maybe finetune or future extensions. There's no harm in exporting and we've re-exported a fair number of functions that were originally internal.

export(message_wrap)
export(metrics_info)
export(min_grid)
Expand Down
95 changes: 65 additions & 30 deletions R/int_pctl.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,24 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001,

rlang::check_dots_empty()

# check eval_time and set default when null
eval_time <- default_eval_time(eval_time, .data$.metrics[[1]])
if (is.null(metrics)) {
metrics <- .get_tune_metrics(.data)
} else {
if (!inherits(metrics, "metric_set")) {
cli::cli_abort("{.arg metrics} should be a metric set as generated by {.fun yardstick::metric_set}.")
}
}

if (is.null(eval_time)) {
eval_time <- .get_tune_eval_times(.data)
eval_time <- maybe_choose_eval_time(.data, metrics, eval_time)
} else {
eval_time <- unique(eval_time)
check_eval_time_in_tune_results(.data, eval_time)
# Are there at least a minimal number of evaluation times?
check_enough_eval_times(eval_time, metrics)
}

.data$.predictions <- filter_predictions_by_eval_time(.data$.predictions, eval_time)

y_nm <- outcome_names(.data)
Expand All @@ -90,10 +106,19 @@ int_pctl.tune_results <- function(.data, metrics = NULL, times = 1001,
}
}

# TODO Changes in https://github.com/tidymodels/rsample/pull/465
# will effect how these computations are done since they will
# compute intervals for `terms` as well as any columns that begin
# with a period. This will simply the code considerably for
# survival and non-survival models. We will make this version
# compatible with the future rsample version (but will factor
# this code later).

res <-
purrr::map2(
config_keys, sample.int(10000, p),
~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par, event_level)
~ boostrap_metrics_by_config(.x, .y, .data, metrics, times, allow_par,
event_level, eval_time, alpha)
) %>%
purrr::list_rbind() %>%
dplyr::arrange(.config, .metric)
Expand All @@ -110,7 +135,8 @@ get_int_p_operator <- function(allow = TRUE) {
res
}

boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par, event_level) {
boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_par,
event_level, eval_time, alpha) {
y_nm <- outcome_names(x)
preds <- collect_predictions(x, summarize = TRUE, parameters = config)

Expand All @@ -124,14 +150,13 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa
.errorhandling = "pass",
.packages = c("tune", "rsample")
) %op% {
comp_metrics(rs$splits[[i]], y_nm, metrics, event_level)
comp_metrics(rs$splits[[i]], y_nm, metrics, event_level, eval_time)
}

if (any(grepl("survival", .get_tune_metric_names(x)))) {
# compute by evaluation time
res <- int_pctl_dyn_surv(rs, allow_par)
res <- int_pctl_surv(rs, allow_par, alpha)
} else {
res <- rsample::int_pctl(rs, .metrics)
res <- rsample::int_pctl(rs, .metrics, alpha = alpha)
}
res %>%
dplyr::mutate(.estimator = "bootstrap") %>%
Expand All @@ -141,29 +166,39 @@ boostrap_metrics_by_config <- function(config, seed, x, metrics, times, allow_pa
cbind(config)
}

# We have to do the analysis separately for each evaluation time.
int_pctl_dyn_surv <- function(x, allow_par) {
`%op%` <- get_int_p_operator(allow_par)
times <- unique(x$.metrics[[1]]$.eval_time)
res <-
foreach::foreach(
i = seq_along(times),
.errorhandling = "pass",
.packages = c("purrr", "rsample", "dplyr")
) %op% {
int_pctl_by_eval_time(times[i], x)
}
dplyr::bind_rows(res)
fake_term <- function(x) {
x$term <- paste(x$term, format(1:nrow(x)))
x
}

int_pctl_by_eval_time <- function(time, x) {
times <- dplyr::tibble(.eval_time = time)
x$.metrics <- purrr::map(x$.metrics, ~ dplyr::inner_join(.x, times, by = ".eval_time"))
rsample::int_pctl(x, .metrics) %>%
dplyr::mutate(.eval_time = time) %>%
dplyr::relocate(.eval_time, .after = term)
}
# tests in extratests
# nocov start
int_pctl_surv <- function(x, allow_par, alpha) {

# int_pctl() expects terms to be unique. For (many) survival models, the
# metrics are a combination of the metric name and the evaluation time.
# We'll make a phony term value, run int_pctl(), then merge the original values
# back in.
met_key <- x$.metrics[[1]]
met_key$estimate <- NULL
met_key$old_term <- met_key$term
met_key$order <- 1:nrow(met_key)
met_key <- fake_term(met_key)

x$.metrics <- purrr::map(x$.metrics, ~ fake_term(.x))
res <- rsample::int_pctl(x, .metrics, alpha = alpha)

merge_keys <- c("term", grep("^\\.", names(res), value = TRUE))
merge_keys <- intersect(merge_keys, names(met_key))

res <- res %>%
dplyr::full_join(met_key, by = merge_keys) %>%
dplyr::arrange(order) %>%
dplyr::select(-term, -order) %>%
dplyr::rename(term = old_term) %>%
dplyr::relocate(term, dplyr::any_of(".eval_time"))
}
# nocov end

# ------------------------------------------------------------------------------

Expand All @@ -184,7 +219,7 @@ get_configs <- function(x, parameters = NULL, as_list = TRUE) {
}

# Compute metrics for a specific configuration
comp_metrics <- function(split, y, metrics, event_level) {
comp_metrics <- function(split, y, metrics, event_level, eval_time) {
dat <- rsample::analysis(split)
info <- metrics_info(metrics)

Expand All @@ -196,7 +231,7 @@ comp_metrics <- function(split, y, metrics, event_level) {
outcome_name = y,
event_level = event_level,
metrics_info = info,
eval_time = NA # TODO I don't think that this is used in the function
eval_time = eval_time
)

res %>%
Expand Down
53 changes: 41 additions & 12 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#' If a time is required and none is given, the first value in the vector
#' originally given in the `eval_time` argument is used (with a warning).
#'
#' `maybe_choose_eval_time()` is for cases where multiple evaluation times are
#' acceptable but you need to choose a good default. The "maybe" is because
#' the function that would use `maybe_choose_eval_time()` can accept multiple
#' metrics (like [autoplot()]).
#' @keywords internal
#' @export
choose_metric <- function(x, metric, ..., call = rlang::caller_env()) {
Expand Down Expand Up @@ -71,22 +75,34 @@ check_metric_in_tune_results <- function(mtr_info, metric, call = rlang::caller_
invisible(NULL)
}

check_enough_eval_times <- function(eval_time, metrics) {
num_times <- length(eval_time)
max_times_req <- req_eval_times(metrics)
cls <- tibble::as_tibble(metrics)$class
uni_cls <- sort(unique(cls))
if (max_times_req > num_times) {
cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are}
required for the metric type(s) requested: {.val {uni_cls}}.
Only {num_times} unique time{?s} {?was/were} given.")
}

invisible(NULL)
}

contains_survival_metric <- function(mtr_info) {
any(grepl("_survival", mtr_info$class))
}


# choose_eval_time() is called by show_best() and select_best()
#' @rdname choose_metric
#' @export
choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::caller_env()) {
rlang::check_dots_empty()
choose_eval_time <- function(x, metric, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) {

mtr_set <- .get_tune_metrics(x)
mtr_info <- tibble::as_tibble(mtr_set)

if (!contains_survival_metric(mtr_info)) {
if (!is.null(eval_time)) {
if (!is.null(eval_time) & !quietly) {
cli::cli_warn("Evaluation times are only required when the model
mode is {.val censored regression} (and will be ignored).",
call = call)
Expand All @@ -97,7 +113,7 @@ choose_eval_time <- function(x, metric, eval_time = NULL, ..., call = rlang::cal
dyn_metric <- is_dyn(mtr_set, metric)

# If we don't need an eval time but one is passed:
if (!dyn_metric & !is.null(eval_time)) {
if (!dyn_metric & !is.null(eval_time) & !quietly) {
cli::cli_warn("An evaluation time is only required when a dynamic
metric is selected (and {.arg eval_time} will thus be
ignored).",
Expand Down Expand Up @@ -135,6 +151,24 @@ check_eval_time_in_tune_results <- function(x, eval_time, call = rlang::caller_e
invisible(NULL)
}

#' @rdname choose_metric
#' @export
maybe_choose_eval_time <- function(x, mtr_set, eval_time) {
mtr_info <- tibble::as_tibble(mtr_set)
if (any(grepl("integrated", mtr_info$metric))) {
return(.get_tune_eval_times(x))
}
eval_time <- purrr::map(mtr_info$metric, ~ choose_eval_time(x, .x, eval_time = eval_time, quietly = TRUE))
no_eval_time <- purrr::map_lgl(eval_time, is.null)
if (all(no_eval_time)) {
eval_time <- NULL
} else {
eval_time <- sort(unique(unlist(eval_time)))
}
eval_time
}


# ------------------------------------------------------------------------------

#' @rdname choose_metric
Expand Down Expand Up @@ -289,15 +323,10 @@ check_eval_time_arg <- function(eval_time, mtr_set, call = rlang::caller_env())
eval_time <- .filter_eval_time(eval_time)

num_times <- length(eval_time)

max_times_req <- req_eval_times(mtr_set)

if (max_times_req > num_times) {
cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are}
required for the metric type(s) requested: {.val {uni_cls}}.
Only {num_times} unique time{?s} {?was/were} given.",
call = call)
}
# Are there at least a minimal number of evaluation times?
check_enough_eval_times(eval_time, mtr_set)

if (max_times_req == 0 & num_times > 0) {
cli::cli_warn("Evaluation times are only required when dynamic or
Expand Down
16 changes: 15 additions & 1 deletion man/choose_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/censored-reg.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Code
spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = mtr)
Condition
Error in `tune_grid()`:
Error in `check_enough_eval_times()`:
! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given.

---
Expand Down
Loading
Loading