Skip to content

Commit

Permalink
Merge pull request #1174 from tidymodels/sparse-matrix-fit-error
Browse files Browse the repository at this point in the history
Make sure all sparse data errors look nice
  • Loading branch information
EmilHvitfeldt authored Sep 6, 2024
2 parents 474152f + 236a39b commit 0b09e78
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
4 changes: 4 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ fit.model_spec <-
cli::cli_abort(msg)
}

if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
}

dots <- quos(...)

if (length(possible_engines(object)) == 0) {
Expand Down
16 changes: 11 additions & 5 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
to_sparse_data_frame <- function(x, object) {
if (methods::is(x, "sparseMatrix")) {
to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) {
if (is_sparse_matrix(x)) {
if (allow_sparse(object)) {
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
} else {
Expand All @@ -8,8 +8,10 @@ to_sparse_data_frame <- function(x, object) {
}

cli::cli_abort(
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that.")
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.val {object$engine}} doesn't accept that.",
call = call
)
}
} else if (is.data.frame(x)) {
x <- materialize_sparse_tibble(x, object, "x")
Expand All @@ -21,6 +23,10 @@ is_sparse_tibble <- function(x) {
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
}

is_sparse_matrix <- function(x) {
methods::is(x, "sparseMatrix")
}

materialize_sparse_tibble <- function(x, object, input) {
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
if (inherits(object, "model_fit")) {
Expand All @@ -29,7 +35,7 @@ materialize_sparse_tibble <- function(x, object, input) {

cli::cli_warn(
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that. Converting to
engine {.val {object$engine}} doesn't accept that. Converting to
non-sparse."
)
for (i in seq_along(ncol(x))) {
Expand Down
22 changes: 15 additions & 7 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,47 @@
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
Condition
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrix can be passed to `fit()

Code
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
Condition
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse tibble can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
Error in `fit_xy()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.

# sparse tibble can be passed to `predict()

Code
preds <- predict(lm_fit, sparse_mtcars)
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `predict()

Code
predict(lm_fit, sparse_mtcars)
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
Error in `predict()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.

# to_sparse_data_frame() is used correctly

Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ test_that("sparse tibble can be passed to `fit()", {
)
})

test_that("sparse matrix can be passed to `fit()", {
skip_if_not_installed("xgboost")

hotel_data <- sparse_hotel_rates()

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_no_error(
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

expect_snapshot(
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
)
})

test_that("sparse tibble can be passed to `fit_xy()", {
skip_if_not_installed("xgboost")

Expand Down

0 comments on commit 0b09e78

Please sign in to comment.