Skip to content

Conversation

tjburch
Copy link

@tjburch tjburch commented May 23, 2025

Closes #990

This PR implements support for variable fold weights in hyperparameter tuning. This is useful in cases where folds may have differing numbers of observations, and you want proportional contribution to hyperparameter selection.

The implementation adds two main functions: add_fold_weights() to attach custom weights to rset objects, and calculate_fold_weights() to automatically compute weights proportional to fold sizes. Weights are stored as .fold_weights attributes and should flow through the existing tuning pipeline.

Core changes are in estimate_tune_results() which now detects weights and uses weighted statistics (weighted mean, weighted standard deviation, effective sample size) when aggregating metrics. Implementation should be backwards compatible and non-breaking.

@topepo
Copy link
Member

topepo commented Jun 4, 2025

Hey @tjburch. Thanks for the PR.

We're doing a pretty invasive update the this package that will take some time. We'll look at the PR after things are settled there but it might be another 2-3 weeks.

Is this something time-sensitive for you?

@tjburch
Copy link
Author

tjburch commented Jun 4, 2025

Nope. Just had some bandwidth staying awake on paternity leave. Review at your leisure, let me know if I can assist otherwise.

@topepo
Copy link
Member

topepo commented Jun 4, 2025

Just had some bandwidth staying awake on paternity leave

I estimate that 5% of all my work has been while waiting at the bustop or for some sort of practice to end 😄

@tjburch tjburch marked this pull request as draft July 23, 2025 02:01
@tjburch
Copy link
Author

tjburch commented Jul 23, 2025

Got around to deploying this into a local project and running into some odd errors. Will circle back.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

air

[air] reported by reviewdog 🐶

tune/R/utils.R

Line 393 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 399 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 408 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 411 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 414 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 419 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 427 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 430 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 436 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 440 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 466 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 481 in 27a5ab5


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

unweighted_results <- fit_resamples(simple_wflow, folds,
control = control_resamples(save_pred = FALSE))
weighted_results_equal <- fit_resamples(simple_wflow, weighted_folds_equal,
control = control_resamples(save_pred = FALSE))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_equal(unweighted_metrics$mean, weighted_metrics_equal$mean, tolerance = 1e-10)


[air] reported by reviewdog 🐶

unequal_weights <- c(0.1, 0.3, 0.6) # Higher weight on last fold


[air] reported by reviewdog 🐶

weighted_results_unequal <- fit_resamples(simple_wflow, weighted_folds_unequal,
control = control_resamples(save_pred = FALSE))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_false(all(abs(unweighted_metrics$mean - weighted_metrics_unequal$mean) < 1e-10))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_equal(sum(calculated_weights), 1) # Should sum to 1 now


[air] reported by reviewdog 🐶

skip_if_not_installed("parsnip")


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

tune_mod <- parsnip::linear_reg(penalty = tune()) %>%


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

weighted_tune_results <- tune_grid(tune_wflow, weighted_folds,
grid = simple_grid,
control = control_grid(save_pred = FALSE))


[air] reported by reviewdog 🐶

expect_s3_class(weighted_tune_results, "tune_results")


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

unweighted_tune_results <- tune_grid(tune_wflow, folds,
grid = simple_grid,
control = control_grid(save_pred = FALSE))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_false(all(abs(weighted_metrics$mean - unweighted_metrics$mean) < 1e-10))


[air] reported by reviewdog 🐶

if (rlang::is_installed(c("rsample", "parsnip", "yardstick", "workflows", "recipes", "kknn"))) {


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(1/6, 1/3, 1/2) # normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(0.2, 0.3, 0.5) # already normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

resamples = folds, # No weights


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expected_weights <- weights / sum(weights) # normalized


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

splits1 <- rsample::make_splits(x = mtcars[1:20,], assessment = mtcars[21:32,])
splits2 <- rsample::make_splits(x = mtcars[1:15,], assessment = mtcars[16:32,])


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

air

[air] reported by reviewdog 🐶

expect_false(all(abs(weighted_metrics$mean - unweighted_metrics$mean) < 1e-10))


[air] reported by reviewdog 🐶

if (rlang::is_installed(c("rsample", "parsnip", "yardstick", "workflows", "recipes", "kknn"))) {


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(1/6, 1/3, 1/2) # normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(0.2, 0.3, 0.5) # already normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

resamples = folds, # No weights


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expected_weights <- weights / sum(weights) # normalized


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

splits1 <- rsample::make_splits(x = mtcars[1:20,], assessment = mtcars[21:32,])
splits2 <- rsample::make_splits(x = mtcars[1:15,], assessment = mtcars[16:32,])


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

@tjburch
Copy link
Author

tjburch commented Sep 11, 2025

Sorry for the noise here - didn't realize Air was part of ci/cd now.

fwiw I've been using my fork here for a while in a production process and it seems fine. I'm going to dust it off and get it ready to go asap

@tjburch tjburch marked this pull request as ready for review September 12, 2025 12:21
@tjburch
Copy link
Author

tjburch commented Sep 12, 2025

Alright, I think this is back ready for review.

@tjburch
Copy link
Author

tjburch commented Oct 17, 2025

── Error ('test-checks.R:673:3'): fold weights with tune_grid ──────────────────
Error in `tune_grid(tune_wflow, weighted_folds, grid = simple_grid, control = control_grid(save_pred = FALSE))`: Package install is required for glmnet.

Looks like I'll need to replace the model here to something the ci tools have

@EmilHvitfeldt
Copy link
Member

Looks like I'll need to replace the model here to something the ci tools have

yes that is correct! {xgboost} is a engine you can use, it is used in other tests

Copy link
Member

@topepo topepo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty solid! Thanks for the contribution and your patience through our busy season.

I think the main thing to do is to make a standalone file for weighted statistics.

It could be helpful for you to veryify that your weighting functions get the same results as what is currently in recipes.

R/checks.R Outdated
#' @param x An rset object.
#' @return `NULL` invisibly, or error if weights are invalid.
#' @keywords internal
check_fold_weights <- function(x) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for equal weights too and return NULL if that is the case? I think that will be the case in some situations, and it would slightly increase efficiency to avoid work whenever possible.

R/utils.R Outdated

#' @export
#' @rdname fold_weights_utils
.weighted_mean <- function(x, w) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use stats ::weighted.mean()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So looking further down, we do the same with the variance calculation. We do have an unexported API in recipes wt_calcs() that does a lot of this.

We can copy it over here (to avoid duplication) and think about moving that core code to hardat so that we get it everywhere from one source.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erm actually, we should probably put those weighting functions in a "standalone file" here, and then recipes and anyone else can import that.

@EmilHvitfeldt do you have any thoughts on that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be happy to have a standalone file for weighted functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll amend my comments a third time. We don't export wt_calcs() but we do export the single statistic versions like recipes::averages() and recipes::variances() so maybe it's that simple.

})

test_that("fold weights integration test", {
skip_if_not_installed("rsample")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all imported so no need for these skips

})

test_that("fold weights with tune_grid", {
skip_if_not_installed("rsample")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as ☝️ etc

@topepo
Copy link
Member

topepo commented Oct 17, 2025

The testing issue for:

namespace 'workflowsets' is not available and has been replaced

is a false alarm. That is fixed in #1104.

@topepo
Copy link
Member

topepo commented Oct 17, 2025

@tjburch I think that I'd like to change "fold weights" to "resample weights" to be more general. Is the former common terminology for these?

@tjburch
Copy link
Author

tjburch commented Oct 17, 2025

@tjburch I think that I'd like to change "fold weights" to "resample weights" to be more general. Is the former common terminology for these?

I'm unaware if there is a common/formal terminology. This came out of a use-case, and that's just what I've been calling it :) . Resample weights is fine by me.

@topepo
Copy link
Member

topepo commented Oct 17, 2025

I moved one of the tests away from glmnet to svm. @tjburch The glmnet package is a difficult thing to have as a formal dependency. They sometimes make breaking changes that affect us so we test them in a different non-CRAN repo.

@topepo
Copy link
Member

topepo commented Oct 17, 2025

With printing in the rset objects, I think that we'll have to move this code into an rsample PR.

Most of the rsets have a superseding class (like v_fold) and the print method for that class happens first.

Using v_fold as an example, its print method removes the first two classes (v_fold and rset) so that the remaining printing is the default tibble printing. That's why there is not a rsample:::print.rset method.

I'm going to temporarily remove the print methods and make this comment an issue for rsample, and then Hannah can take a look (who is out of office next week 😿)

#' @export
print.rset <- function(x, ...) {
  fold_weights <- attr(x, ".fold_weights")

  if (!is.null(fold_weights)) {
    # Create a tibble with fold weights as a column
    x_tbl <- tibble::as_tibble(x)
    x_tbl$fold_weight <- fold_weights
    print(x_tbl, ...)
  } else {
    # Use default behavior
    NextMethod("print")
  }
}

#' @export
print.manual_rset <- function(x, ...) {
  fold_weights <- attr(x, ".fold_weights")

  if (!is.null(fold_weights)) {
    # Create a tibble with fold weights as a column
    x_tbl <- tibble::as_tibble(x)
    x_tbl$fold_weight <- fold_weights
    print(x_tbl, ...)
  } else {
    # Use default behavior for manual_rset
    NextMethod("print")
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Variable Fold Weights

3 participants