From da505eb1d3700d1d2000164863fb016610eba20c Mon Sep 17 00:00:00 2001 From: "Simon P. Couch" Date: Mon, 15 Jul 2024 12:19:58 -0600 Subject: [PATCH] add method for `dbarts::bart()` (#65) --- DESCRIPTION | 1 + NAMESPACE | 1 + NEWS.md | 3 + R/bundle_bart.R | 60 +++++++++++++ man/bundle.Rd | 1 + man/bundle_bart.Rd | 110 ++++++++++++++++++++++++ man/bundle_caret.Rd | 1 + man/bundle_embed.Rd | 1 + man/bundle_h2o.Rd | 1 + man/bundle_keras.Rd | 1 + man/bundle_parsnip.Rd | 1 + man/bundle_recipe.Rd | 1 + man/bundle_stacks.Rd | 1 + man/bundle_torch.Rd | 1 + man/bundle_workflows.Rd | 1 + man/bundle_xgboost.Rd | 1 + tests/testthat/_snaps/bundle_bart.md | 18 ++++ tests/testthat/test_bundle_bart.R | 121 +++++++++++++++++++++++++++ 18 files changed, 325 insertions(+) create mode 100644 R/bundle_bart.R create mode 100644 man/bundle_bart.Rd create mode 100644 tests/testthat/_snaps/bundle_bart.md create mode 100644 tests/testthat/test_bundle_bart.R diff --git a/DESCRIPTION b/DESCRIPTION index f4e0975..a95e871 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,6 +32,7 @@ Suggests: callr, caret, covr, + dbarts, embed, h2o, keras, diff --git a/NAMESPACE b/NAMESPACE index eaf1256..11d19cf 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ S3method(bundle,H2OAutoML) S3method(bundle,H2OBinomialModel) S3method(bundle,H2OMultinomialModel) S3method(bundle,H2ORegressionModel) +S3method(bundle,bart) S3method(bundle,default) S3method(bundle,keras.engine.training.Model) S3method(bundle,luz_module_fitted) diff --git a/NEWS.md b/NEWS.md index d29100f..90300c7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # bundle (development version) +* Added bundle method for objects from `dbarts::bart()` and, by extension, + `parsnip::bart(engine = "dbarts")` (#64). + # bundle 0.1.1 * Fixed bundling of recipes steps situated inside of workflows. diff --git a/R/bundle_bart.R b/R/bundle_bart.R new file mode 100644 index 0000000..ad28861 --- /dev/null +++ b/R/bundle_bart.R @@ -0,0 +1,60 @@ +#' @templateVar class a `bart` +#' @template title_desc +#' +#' @templateVar outclass `bundled_bart` +#' @templateVar default . +#' @template return_bundle +#' @family bundlers +#' +#' @param x A `bart` object returned from [dbarts::bart()]. Notably, this ought +#' not to be the output of [parsnip::bart()]. +#' @template param_unused_dots +#' @rdname bundle_bart +#' @template butcher_details +#' @examplesIf rlang::is_installed(c("dbarts")) +#' # fit model and bundle ------------------------------------------------ +#' library(dbarts) +#' +#' mtcars$vs <- as.factor(mtcars$vs) +#' +#' set.seed(1) +#' fit <- dbarts::bart(mtcars[c("disp", "hp")], mtcars$vs, keeptrees = TRUE) +#' +#' fit_bundle <- bundle(fit) +#' +#' # then, after saveRDS + readRDS or passing to a new session ---------- +#' fit_unbundled <- unbundle(fit_bundle) +#' +#' fit_unbundled_preds <- predict(fit_unbundled, mtcars) +#' @aliases bundle.bart +#' @method bundle bart +#' @export +bundle.bart <- function(x, ...) { + rlang::check_installed("dbarts") + rlang::check_dots_empty() + + # `parsnip::bart()` and `dbarts::bart()` unfortunately both inherit from `bart` + if (inherits(x, "model_spec")) { + rlang::abort(c( + paste0("`x` should be the output of `dbarts::bart()`, not a model ", + "specification from `parsnip::bart()`."), + "To bundle `parsnip::bart()` output, train it with `parsnip::fit()` first." + )) + } + + if (is.null(x$fit)) { + rlang::abort(c( + "`x` can't be bundled.", + "`x` must have been fitted with argument `keeptrees = TRUE`." + )) + } + + # "touch" the object's state (#64) + invisible(x$fit$state) + + bundle_constr( + object = x, + situate = situate_constr(identity), + desc_class = class(x)[1] + ) +} diff --git a/man/bundle.Rd b/man/bundle.Rd index b65be5e..18e304c 100644 --- a/man/bundle.Rd +++ b/man/bundle.Rd @@ -61,6 +61,7 @@ then re-loaded and \code{unbundle()}d in a new R session for use in prediction. \seealso{ Other bundlers: \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_bart.Rd b/man/bundle_bart.Rd new file mode 100644 index 0000000..70a8cb4 --- /dev/null +++ b/man/bundle_bart.Rd @@ -0,0 +1,110 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bundle_bart.R +\name{bundle.bart} +\alias{bundle.bart} +\title{Bundle a \code{bart} object} +\usage{ +\method{bundle}{bart}(x, ...) +} +\arguments{ +\item{x}{A \code{bart} object returned from \code{\link[dbarts:bart]{dbarts::bart()}}. Notably, this ought +not to be the output of \code{\link[parsnip:bart]{parsnip::bart()}}.} + +\item{...}{Not used in this bundler and included for compatibility with +the generic only. Additional arguments passed to this method will return +an error.} +} +\value{ +A bundle object with subclass \code{bundled_bart}. + +Bundles are a list subclass with two components: + +\item{object}{An R object. Gives the output of native serialization +methods from the model-supplying package, sometimes with additional +classes or attributes that aid portability. This is often +a \link[base:raw]{raw} object.} +\item{situate}{A function. The \code{situate()} function is defined when +\code{\link[=bundle]{bundle()}} is called, though is a loose analogue of an \code{\link[=unbundle]{unbundle()}} S3 +method for that object. Since the function is defined on \code{\link[=bundle]{bundle()}}, it +has access to references and dependency information that can +be saved alongside the \code{object} component. Calling \code{\link[=unbundle]{unbundle()}} on a +bundled object \code{x} calls \code{x$situate(x$object)}, returning the +unserialized version of \code{object}. \code{situate()} will also restore needed +references, such as server instances and environmental variables.} + +Bundles are R objects that represent a "standalone" version of their +analogous model object. Thus, bundles are ready for saving to a file; saving +with \code{\link[base:readRDS]{base::saveRDS()}} is our recommended serialization strategy for bundles, +unless documented otherwise for a specific method. + +To restore the original model object \code{x} in a new environment, load its +bundle with \code{\link[base:readRDS]{base::readRDS()}} and run \code{\link[=unbundle]{unbundle()}} on it. The output +of \code{\link[=unbundle]{unbundle()}} is a model object that is ready to \code{\link[=predict]{predict()}} on new data, +and other restored functionality (like plotting or summarizing) is supported +as a side effect only. + +The bundle package wraps native serialization methods from model-supplying +packages. Between versions, those model-supplying packages may change their +native serialization methods, possibly introducing problems with re-loading +objects serialized with previous package versions. The bundle package does +not provide checks for these sorts of changes, and ought to be used in +conjunction with tooling for managing and monitoring model environments +like \link[vetiver:vetiver-package]{vetiver} or \link[renv:renv-package]{renv}. + +See \code{vignette("bundle")} for more information on bundling and its motivation. +} +\description{ +Bundling a model prepares it to be saved to a file and later +restored for prediction in a new R session. See the 'Value' section for +more information on bundles and their usage. +} +\section{bundle and butcher}{ + +The \href{https://butcher.tidymodels.org/}{butcher} package allows you to remove +parts of a fitted model object that are not needed for prediction. + +This bundle method is compatible with pre-butchering. That is, for a +fitted model \code{x}, you can safely call: + +\if{html}{\out{
}}\preformatted{res <- + x \%>\% + butcher() \%>\% + bundle() +}\if{html}{\out{
}} + +and predict with the output of \code{unbundle(res)} in a new R session. +} + +\examples{ +\dontshow{if (rlang::is_installed(c("dbarts"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +# fit model and bundle ------------------------------------------------ +library(dbarts) + +mtcars$vs <- as.factor(mtcars$vs) + +set.seed(1) +fit <- dbarts::bart(mtcars[c("disp", "hp")], mtcars$vs, keeptrees = TRUE) + +fit_bundle <- bundle(fit) + +# then, after saveRDS + readRDS or passing to a new session ---------- +fit_unbundled <- unbundle(fit_bundle) + +fit_unbundled_preds <- predict(fit_unbundled, mtcars) +\dontshow{\}) # examplesIf} +} +\seealso{ +Other bundlers: +\code{\link{bundle}()}, +\code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.keras.engine.training.Model}()}, +\code{\link{bundle.luz_module_fitted}()}, +\code{\link{bundle.model_fit}()}, +\code{\link{bundle.model_stack}()}, +\code{\link{bundle.recipe}()}, +\code{\link{bundle.step_umap}()}, +\code{\link{bundle.train}()}, +\code{\link{bundle.workflow}()}, +\code{\link{bundle.xgb.Booster}()} +} +\concept{bundlers} diff --git a/man/bundle_caret.Rd b/man/bundle_caret.Rd index f12f70d..7b4f6f1 100644 --- a/man/bundle_caret.Rd +++ b/man/bundle_caret.Rd @@ -109,6 +109,7 @@ mod_unbundled_preds <- predict(mod_unbundled, new_data = mtcars) Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_embed.Rd b/man/bundle_embed.Rd index 3bac855..f198aa9 100644 --- a/man/bundle_embed.Rd +++ b/man/bundle_embed.Rd @@ -103,6 +103,7 @@ This method wraps \code{\link[uwot:save_uwot]{uwot::save_uwot()}} and \code{\lin Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_h2o.Rd b/man/bundle_h2o.Rd index 150b186..850ae56 100644 --- a/man/bundle_h2o.Rd +++ b/man/bundle_h2o.Rd @@ -109,6 +109,7 @@ These methods wrap \code{\link[h2o:h2o.save_mojo]{h2o::h2o.save_mojo()}} and Other bundlers: \code{\link{bundle}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_keras.Rd b/man/bundle_keras.Rd index ed55875..0f86927 100644 --- a/man/bundle_keras.Rd +++ b/man/bundle_keras.Rd @@ -122,6 +122,7 @@ This method wraps \code{\link[keras:save_model_tf]{keras::save_model_tf()}} and Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, \code{\link{bundle.model_stack}()}, diff --git a/man/bundle_parsnip.Rd b/man/bundle_parsnip.Rd index a3c5a14..25d3655 100644 --- a/man/bundle_parsnip.Rd +++ b/man/bundle_parsnip.Rd @@ -107,6 +107,7 @@ mod_unbundled_preds <- predict(mod_unbundled, new_data = mtcars) Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_stack}()}, diff --git a/man/bundle_recipe.Rd b/man/bundle_recipe.Rd index 6a6a490..480f879 100644 --- a/man/bundle_recipe.Rd +++ b/man/bundle_recipe.Rd @@ -68,6 +68,7 @@ for more details on the bundling method for that object. Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_stacks.Rd b/man/bundle_stacks.Rd index 6ed6653..eb81e61 100644 --- a/man/bundle_stacks.Rd +++ b/man/bundle_stacks.Rd @@ -87,6 +87,7 @@ mod_unbundled <- unbundle(mod_bundle) Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_torch.Rd b/man/bundle_torch.Rd index 86c5603..e6c05a2 100644 --- a/man/bundle_torch.Rd +++ b/man/bundle_torch.Rd @@ -150,6 +150,7 @@ This method wraps \code{\link[luz:luz_save]{luz::luz_save()}} and \code{\link[lu Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.model_fit}()}, \code{\link{bundle.model_stack}()}, diff --git a/man/bundle_workflows.Rd b/man/bundle_workflows.Rd index 8209230..f34cd7d 100644 --- a/man/bundle_workflows.Rd +++ b/man/bundle_workflows.Rd @@ -114,6 +114,7 @@ mod_unbundled <- unbundle(mod_bundle) Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_xgboost.Rd b/man/bundle_xgboost.Rd index 02fa36e..3ebb759 100644 --- a/man/bundle_xgboost.Rd +++ b/man/bundle_xgboost.Rd @@ -104,6 +104,7 @@ and \code{\link[xgboost:xgb.load.raw]{xgboost::xgb.load.raw()}}. Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/tests/testthat/_snaps/bundle_bart.md b/tests/testthat/_snaps/bundle_bart.md new file mode 100644 index 0000000..2a68719 --- /dev/null +++ b/tests/testthat/_snaps/bundle_bart.md @@ -0,0 +1,18 @@ +# bundle.bart errors informatively with model_spec input (#64) + + Code + bundle(parsnip::bart()) + Condition + Error in `bundle()`: + ! `x` should be the output of `dbarts::bart()`, not a model specification from `parsnip::bart()`. + * To bundle `parsnip::bart()` output, train it with `parsnip::fit()` first. + +# bundle.bart errors informatively when `keeptrees = FALSE` (#64) + + Code + bundle(fit) + Condition + Error in `bundle()`: + ! `x` can't be bundled. + * `x` must have been fitted with argument `keeptrees = TRUE`. + diff --git a/tests/testthat/test_bundle_bart.R b/tests/testthat/test_bundle_bart.R new file mode 100644 index 0000000..ba9e5b0 --- /dev/null +++ b/tests/testthat/test_bundle_bart.R @@ -0,0 +1,121 @@ +test_that("bundling + unbundling bart fits", { + skip_if_not_installed("dbarts") + skip_if_not_installed("butcher") + + library(dbarts) + + # define a function to fit a model ------------------------------------------- + fit_model <- function() { + mtcars$vs <- as.factor(mtcars$vs) + + set.seed(1) + dbarts::bart( + mtcars[c("disp", "hp")], + mtcars$vs, + keeptrees = TRUE, + verbose = FALSE + ) + } + + # pass fit fn to a new session, fit, bundle, return bundle ------------------- + mod_bundle <- + callr::r( + function(fit_model) { + library(dbarts) + + mod <- fit_model() + + bundle::bundle(mod) + }, + args = list(fit_model = fit_model) + ) + + # pass the bundle to a new session, unbundle it, return predictions ---------- + mod_unbundled_preds <- + callr::r( + function(mod_bundle, test_data) { + library(dbarts) + + mod_unbundled <- bundle::unbundle(mod_bundle) + + set.seed(1) + predict(mod_unbundled, test_data) + }, + args = list( + mod_bundle = mod_bundle, + test_data = mtcars + ) + ) + + # pass fit fn to a new session, fit, butcher, bundle, return bundle ---------- + mod_butchered_bundle <- + callr::r( + function(fit_model) { + library(dbarts) + + mod <- fit_model() + + bundle::bundle(butcher::butcher(mod)) + }, + args = list(fit_model = fit_model) + ) + + # pass the bundle to a new session, unbundle it, return predictions ---------- + mod_butchered_unbundled_preds <- + callr::r( + function(mod_butchered_bundle, test_data) { + library(bundle) + + mod_butchered_unbundled <- unbundle(mod_butchered_bundle) + + set.seed(1) + predict(mod_butchered_unbundled, test_data) + }, + args = list( + mod_butchered_bundle = mod_butchered_bundle, + test_data = mtcars + ) + ) + + # run expectations ----------------------------------------------------------- + mod_fit <- fit_model() + set.seed(1) + mod_preds <- predict(mod_fit, mtcars) + + # check classes + expect_s3_class(mod_bundle, "bundled_bart") + expect_s3_class(unbundle(mod_bundle), "bart") + + # ensure that the situater function didn't bring along the whole model + expect_false("x" %in% names(environment(mod_bundle$situate))) + + # pass silly dots + expect_error(bundle(mod_fit, boop = "bop"), class = "rlib_error_dots") + + # compare predictions + expect_equal(mod_preds, mod_unbundled_preds) + expect_equal(mod_preds, mod_butchered_unbundled_preds) +}) + +test_that("bundle.bart errors informatively with model_spec input (#64)", { + skip_if_not_installed("parsnip") + + expect_snapshot(error = TRUE, bundle(parsnip::bart())) +}) + +test_that("bundle.bart errors informatively when `keeptrees = FALSE` (#64)", { + skip_if_not_installed("dbarts") + + mtcars$vs <- as.factor(mtcars$vs) + + set.seed(1) + fit <- + dbarts::bart( + mtcars[c("disp", "hp")], + mtcars$vs, + keeptrees = FALSE, + verbose = FALSE + ) + + expect_snapshot(error = TRUE, bundle(fit)) +})