Skip to content

Commit

Permalink
proposed changes related to tidymodels/tune#818
Browse files Browse the repository at this point in the history
  • Loading branch information
‘topepo’ committed Jan 19, 2024
1 parent e35e756 commit fe3b16e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 36 deletions.
87 changes: 54 additions & 33 deletions R/bootci.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,41 +67,46 @@ check_tidy <- function(x, std_col = FALSE) {
if (std_col) {
std_candidates <- colnames(x) %in% std_exp
std_candidates <- colnames(x)[std_candidates]
re_name <- list(std_err = std_candidates)
if (has_id) {
x <-
dplyr::select(x, term, estimate, id, tidyselect::one_of(std_candidates)) %>%
mutate(id = (id == "Apparent")) %>%
setNames(c("term", "estimate", "orig", "std_err"))
dplyr::select(x, term, estimate, id, tidyselect::one_of(std_candidates),
dplyr::starts_with(".")) %>%
mutate(orig = (id == "Apparent")) %>%
dplyr::rename(!!!re_name)
} else {
x <-
dplyr::select(x, term, estimate, tidyselect::one_of(std_candidates)) %>%
setNames(c("term", "estimate", "std_err"))
dplyr::select(x, term, estimate, tidyselect::one_of(std_candidates),
dplyr::starts_with(".")) %>%
dplyr::rename(!!!re_name)
}
} else {
if (has_id) {
x <-
dplyr::select(x, term, estimate, id) %>%
dplyr::select(x, term, estimate, id, dplyr::starts_with(".")) %>%
mutate(orig = (id == "Apparent")) %>%
dplyr::select(-id)
} else {
x <- dplyr::select(x, term, estimate)
x <- dplyr::select(x, term, estimate, dplyr::starts_with("."))
}
}

x
}


get_p0 <- function(x, alpha = 0.05) {
get_p0 <- function(x, alpha = 0.05, groups) {
group_sym <- rlang::syms(groups)

orig <- x %>%
group_by(term) %>%
group_by(!!!group_sym) %>%
dplyr::filter(orig) %>%
dplyr::select(term, theta_0 = estimate) %>%
dplyr::select(!!!group_sym, theta_0 = estimate) %>%
ungroup()
x %>%
dplyr::filter(!orig) %>%
inner_join(orig, by = "term") %>%
group_by(term) %>%
inner_join(orig, by = groups) %>%
group_by(!!!group_sym) %>%
summarize(p0 = mean(estimate <= theta_0, na.rm = TRUE)) %>%
mutate(
Z0 = stats::qnorm(p0),
Expand Down Expand Up @@ -181,9 +186,10 @@ pctl_single <- function(stats, alpha = 0.05) {
#' @param statistics An unquoted column name or `dplyr` selector that identifies
#' a single column in the data set containing the individual bootstrap
#' estimates. This must be a list column of tidy tibbles (with columns
#' `term` and `estimate`). For t-intervals, a
#' standard tidy column (usually called `std.err`) is required.
#' See the examples below.
#' `term` and `estimate`). Optionally, users can include columns whose names
#' begin with a period and the intervals will be created for each combination
#' of these variables and the `term` column. For t-intervals, a standard tidy
#' column (usually called `std.err`) is required. See the examples below.
#' @param alpha Level of significance.
#' @param .fn A function to calculate statistic of interest. The
#' function should take an `rsplit` as the first argument and the `...` are
Expand Down Expand Up @@ -216,6 +222,8 @@ pctl_single <- function(stats, alpha = 0.05) {
#' library(purrr)
#' library(tibble)
#'
#' # ------------------------------------------------------------------------------
#'
#' lm_est <- function(split, ...) {
#' lm(mpg ~ disp + hp, data = analysis(split)) %>%
#' tidy()
Expand All @@ -230,6 +238,8 @@ pctl_single <- function(stats, alpha = 0.05) {
#' int_t(car_rs, results)
#' int_bca(car_rs, results, .fn = lm_est)
#'
#' # ------------------------------------------------------------------------------
#'
#' # putting results into a tidy format
#' rank_corr <- function(split) {
#' dat <- analysis(split)
Expand Down Expand Up @@ -272,8 +282,11 @@ int_pctl.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {

check_num_resamples(stats, B = 1000)

stat_groups <- c("term", grep("^\\.", names(stats), value = TRUE))
stat_groups <- rlang::syms(stat_groups)

vals <- stats %>%
dplyr::group_by(term) %>%
dplyr::group_by(!!!stat_groups) %>%
dplyr::do(pctl_single(.$estimate, alpha = alpha)) %>%
dplyr::ungroup()
vals
Expand Down Expand Up @@ -351,9 +364,10 @@ int_t.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {

check_num_resamples(stats, B = 500)

vals <-
stats %>%
dplyr::group_by(term) %>%
stat_groups <- c("term", grep("^\\.", names(stats), value = TRUE))
stat_groups <- rlang::syms(stat_groups)
vals <- stats %>%
dplyr::group_by(!!!stat_groups) %>%
dplyr::do(t_single(.$estimate, .$std_err, .$orig, alpha = alpha)) %>%
dplyr::ungroup()
vals
Expand All @@ -369,8 +383,11 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
rlang::abort("All statistics have missing values.")
}

stat_groups_chr <- c("term", grep("^\\.", names(stats), value = TRUE))
stat_groups_sym <- rlang::syms(stat_groups_chr)

### Estimating Z0 bias-correction
bias_corr_stats <- get_p0(stats, alpha = alpha)
bias_corr_stats <- get_p0(stats, alpha = alpha, groups = stat_groups_chr)

# need the original data frame here
loo_rs <- loo_cv(orig_data)
Expand All @@ -388,16 +405,16 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {

loo_estimate <-
loo_res %>%
dplyr::group_by(term) %>%
dplyr::group_by(!!!stat_groups_sym) %>%
dplyr::summarize(loo = mean(estimate, na.rm = TRUE)) %>%
dplyr::inner_join(loo_res, by = "term", multiple = "all") %>%
dplyr::group_by(term) %>%
dplyr::inner_join(loo_res, by = stat_groups_chr, multiple = "all") %>%
dplyr::group_by(!!!stat_groups_sym) %>%
dplyr::summarize(
cubed = sum((loo - estimate)^3),
squared = sum((loo - estimate)^2)
) %>%
dplyr::ungroup() %>%
dplyr::inner_join(bias_corr_stats, by = "term") %>%
dplyr::inner_join(bias_corr_stats, by = stat_groups_chr) %>%
dplyr::mutate(
a = cubed / (6 * (squared^(3 / 2))),
Zu = (Z0 + Za) / (1 - a * (Z0 + Za)) + Z0,
Expand All @@ -408,21 +425,25 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {

terms <- loo_estimate$term
stats <- stats %>% dplyr::filter(!orig)
for (i in seq_along(terms)) {
tmp <- new_stats(stats$estimate[stats$term == terms[i]],
lo = loo_estimate$lo[i],
hi = loo_estimate$hi[i]
)
tmp$term <- terms[i]

keys <- stats %>% dplyr::distinct(!!!stat_groups_sym)
for (i in 1:nrow(keys)) {
tmp_stats <- dplyr::inner_join(stats, keys[i,], by = stat_groups_chr)
tmp_loo <- dplyr::inner_join(loo_estimate, keys[i,], by = stat_groups_chr)

tmp <- new_stats(tmp_stats$estimate,
lo = tmp_loo$lo,
hi = tmp_loo$hi)
tmp <- dplyr::bind_cols(tmp, keys[i,])
if (i == 1) {
ci_bca <- tmp
} else {
ci_bca <- bind_rows(ci_bca, tmp)
ci_bca <- dplyr::bind_rows(ci_bca, tmp)
}
}
ci_bca <-
ci_bca %>%
dplyr::select(term, .lower, .estimate, .upper) %>%
dplyr::select(!!!stat_groups_sym, .lower, .estimate, .upper) %>%
dplyr::mutate(
.alpha = alpha,
.method = "BCa"
Expand All @@ -449,7 +470,7 @@ int_bca.bootstraps <- function(.data, statistics, alpha = 0.05, .fn, ...) {
if (length(column_name) != 1) {
rlang::abort(stat_fmt_err)
}
stats <- .data %>% dplyr::select(!!column_name, id)
stats <- .data %>% dplyr::select(!!column_name, id, dplyr::starts_with("."))
stats <- check_tidy(stats)

check_num_resamples(stats, B = 1000)
Expand Down
11 changes: 8 additions & 3 deletions man/int_pctl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fe3b16e

Please sign in to comment.