diff --git a/R/boost_tree.R b/R/boost_tree.R index cce9d1ccd..248fba8a0 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -432,11 +432,19 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir # Split data m <- floor(n * (1 - validation)) + 1 trn_index <- sample(1:n, size = max(m, 2)) - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) + val_info_list <- list(label = y[-trn_index]) + if (!is.null(weights) && inherits(weights, "hardhat_frequency_weights")) { + #Only pass weights to internal validation set if they are frequency weights + weights <- as.integer(weights) + val_info_list$weight <- weights[-trn_index] + } + + val_data <- xgboost::xgb.DMatrix(x[-trn_index,], info = val_info_list, missing = NA) watch_list <- list(validation = val_data) info_list <- list(label = y[trn_index]) if (!is.null(weights)) { + weights <- weights_to_numeric(weights, spec = list(engine = "xgboost")) info_list$weight <- weights[trn_index] } dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) @@ -445,6 +453,7 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir } else { info_list <- list(label = y) if (!is.null(weights)) { + weights <- weights_to_numeric(weights, spec = list(engine = "xgboost")) info_list$weight <- weights } dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) diff --git a/R/fit.R b/R/fit.R index 6cda2e2c0..a43660cb2 100644 --- a/R/fit.R +++ b/R/fit.R @@ -266,7 +266,13 @@ fit_xy.model_spec <- eval_env <- rlang::env() eval_env$x <- x eval_env$y <- y - eval_env$weights <- weights_to_numeric(case_weights, object) + + if(object$engine == "xgboost" && !is.null(case_weights)){ + # Pass as raw to preserve weight type e.g. frequency, importance + eval_env$weights <- case_weights + } else { + eval_env$weights <- weights_to_numeric(case_weights, object) + } # TODO case weights: pass in eval_env not individual elements fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object) diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 5adde2957..f63d09a24 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -360,7 +360,8 @@ test_that('xgboost data conversion', { mtcar_x <- mtcars[, -1] mtcar_mat <- as.matrix(mtcar_x) mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE) - wts <- 1:32 + wts <- hardhat::importance_weights(1:32) + freq_wts <- hardhat::frequency_weights(1:32) expect_error(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg), regexp = NA) expect_true(inherits(from_df$data, "xgb.DMatrix")) @@ -403,11 +404,16 @@ test_that('xgboost data conversion', { # case weights added expect_error(wted <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, weights = wts), regexp = NA) - expect_equal(wts, xgboost::getinfo(wted$data, "weight")) + expect_equal(as.numeric(wts), xgboost::getinfo(wted$data, "weight")) expect_error(wted_val <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, weights = wts, validation = 1/4), regexp = NA) expect_true(all(xgboost::getinfo(wted_val$data, "weight") %in% wts)) expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight")) + # check that freq weights are passed to internal validation set + set.seed(1) + expect_error(val_freq_wts<-parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = freq_wts, validation = 1/10), regexp = NA) + expect_true(all(xgboost::getinfo(val_freq_wts$watchlist$validation, "weight") %in% c(3,17,26))) + }) @@ -419,7 +425,8 @@ test_that('xgboost data and sparse matrices', { mtcar_x <- mtcars[, -1] mtcar_mat <- as.matrix(mtcar_x) mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE) - wts <- 1:32 + wts <- hardhat::importance_weights(1:32) + freq_wts <- hardhat::frequency_weights(1:32) xgb_spec <- boost_tree(trees = 10) %>% @@ -443,11 +450,16 @@ test_that('xgboost data and sparse matrices', { # case weights added expect_error(wted <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = wts), regexp = NA) - expect_equal(wts, xgboost::getinfo(wted$data, "weight")) + expect_equal(as.numeric(wts), xgboost::getinfo(wted$data, "weight")) expect_error(wted_val <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = wts, validation = 1/4), regexp = NA) expect_true(all(xgboost::getinfo(wted_val$data, "weight") %in% wts)) expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight")) + # check that freq weights are passed to internal validation set + set.seed(1) + expect_error(val_freq_wts<-parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = freq_wts, validation = 1/10), regexp = NA) + expect_true(all(xgboost::getinfo(val_freq_wts$watchlist$validation, "weight") %in% c(3,17,26))) + })