Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to quantile argument to quantile levels #1208

Merged
merged 5 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

## Breaking Change

* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`.
hfrick marked this conversation as resolved.
Show resolved Hide resolved
* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.

# parsnip 1.2.1

Expand Down
2 changes: 1 addition & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

# ----------------------------------------------------------------------------

other_args <- c("interval", "level", "std_error", "quantile",
other_args <- c("interval", "level", "std_error", "quantile_levels",
"time", "eval_time", "increasing")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
Expand Down
23 changes: 18 additions & 5 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#' @keywords internal
#' @rdname other_predict
#' @param quantile A vector of numbers between 0 and 1 for the quantile being
#' predicted.
#' @param quantile_levels A vector of values between zero and one for the
#' quantile to be predicted. If the model has a `"censored regression"` mode,
#' this value should be `NULL`. For other modes, the default is `(1:9)/10`.
#' @inheritParams predict.model_fit
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <- function(object,
new_data,
quantile = (1:9)/10,
quantile_levels = NULL,
interval = "none",
level = 0.95,
...) {
Expand All @@ -20,15 +21,27 @@ predict_quantile.model_fit <- function(object,
return(NULL)
}

if (object$spec$mode == "quantile regression") {
if (!is.null(quantile_levels)) {
cli::cli_abort("When the mode is {.val quantile regression},
{.arg quantile_levels} are specified by {.fn set_mode}.")
}
} else {
if (is.null(quantile_levels)) {
quantile_levels <- (1:9)/10
topepo marked this conversation as resolved.
Show resolved Hide resolved
}
hardhat::check_quantile_levels(quantile_levels)
# Pass some extra arguments to be used in post-processor
object$quantile_levels <- quantile_levels
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$quantile$pre)) {
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
}

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$quantile$args$p <- quantile
pred_call <- make_pred_call(object$spec$method$pred$quantile)

res <- eval_tidy(pred_call)
Expand Down
38 changes: 0 additions & 38 deletions R/surv_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,6 @@ set_pred(
)
)

set_pred(
model = "surv_reg",
eng = "flexsurv",
mode = "regression",
type = "quantile",
value = list(
pre = NULL,
post = flexsurv_quant,
func = c(fun = "summary"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
quantiles = expr(quantile)
)
)
)

# ------------------------------------------------------------------------------

set_model_engine("surv_reg", mode = "regression", eng = "survival")
Expand Down Expand Up @@ -133,22 +114,3 @@ set_pred(
)
)
)

set_pred(
model = "surv_reg",
eng = "survival",
mode = "regression",
type = "quantile",
value = list(
pre = NULL,
post = survreg_quant,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
p = expr(quantile)
)
)
)
7 changes: 4 additions & 3 deletions man/other_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/set_args.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions tests/testthat/_snaps/linear_reg_quantreg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# linear quantile regression via quantreg - multiple quantiles

Code
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:
9) / 9)
Condition
Error in `predict_quantile()`:
! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`.

5 changes: 5 additions & 0 deletions tests/testthat/test-linear_reg_quantreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', {
expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row"))
expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10)

expect_snapshot(
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:9)/9),
error = TRUE
)

###

ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])
Expand Down
14 changes: 1 addition & 13 deletions tests/testthat/test-surv_reg_survreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ complete_form <- survival::Surv(time) ~ group
# ------------------------------------------------------------------------------

test_that('survival execution', {
skip_on_travis()

rlang::local_options(lifecycle_verbosity = "quiet")
surv_basic <- surv_reg() %>% set_engine("survival")
surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival")
Expand Down Expand Up @@ -46,7 +44,7 @@ test_that('survival execution', {
})

test_that('survival prediction', {
skip_on_travis()
skip_if_not_installed("survival")

rlang::local_options(lifecycle_verbosity = "quiet")
surv_basic <- surv_reg() %>% set_engine("survival")
Expand All @@ -61,16 +59,6 @@ test_that('survival prediction', {
exp_pred <- predict(extract_fit_engine(res), head(lung))
exp_pred <- tibble(.pred = unname(exp_pred))
expect_equal(exp_pred, predict(res, head(lung)))

exp_quant <- predict(extract_fit_engine(res), head(lung), p = (2:4)/5, type = "quantile")
exp_quant <-
apply(exp_quant, 1, function(x)
tibble(.pred = x, .quantile = (2:4) / 5))
exp_quant <- tibble(.pred = exp_quant)
obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4)/5)

expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant))

})


Loading