Skip to content

Commit

Permalink
add eight schools examples
Browse files Browse the repository at this point in the history
  • Loading branch information
jgabry committed Jun 15, 2020
1 parent bec67fd commit f2b484f
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 26 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(cmdstan_version)
export(cmdstanr_example)
export(install_cmdstan)
export(num_threads)
export(print_example_program)
export(read_sample_csv)
export(rebuild_cmdstan)
export(set_cmdstan_path)
Expand Down
65 changes: 49 additions & 16 deletions R/example.R
Original file line number Diff line number Diff line change
@@ -1,28 +1,52 @@
#' Fit models for use in examples
#'
#' @export
#' @param example Name of the example. Currently only `"logistic"` is available.
#' * `logistic`: logistic regression with parameters `alpha` (intercept) and
#' `beta` (vector of regression coefficients).
#' @param example Name of the example. The currently available examples are:
#' * `"logistic"`: logistic regression with intercept and 3 predictors.
#' * `"schools"`: the so-called "eight schools" model, a hierarchical
#' meta-analysis. Fitting this model will result in warnings about
#' divergences.
#' * `"schools_ncp"`: non-centered parameterization eight schools model that
#' fixes the problem with divergences.
#'
#' To print the Stan code for a given example use `print_example_program()`.
#'
#' @param method Which fitting method should be used? The default is the
#' `"sample"` method (MCMC).
#' @param ... Arguments passed to the chosen `method`.
#' @param quiet If `TRUE` (the default) then fitting the model is wrapped in
#' [utils::capture.output()].
#'
#' @return The fitted model object returned by the selected `method`.
#' @return
#' The fitted model object returned by the selected `method`.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example(chains = 2, save_warmup = TRUE)
#' fit_mcmc$summary()
#' print_example_program("logistic")
#' fit_logistic_mcmc <- cmdstanr_example("logistic", chains = 2)
#' fit_logistic_mcmc$summary()
#'
#' fit_logistic_optim <- cmdstanr_example("logistic", method = "optimize")
#' fit_logistic_optim$summary()
#'
#' fit_logistic_vb <- cmdstanr_example("logistic", method = "variational")
#' fit_logistic_vb$summary()
#'
#' print_example_program("schools")
#' fit_schools_mcmc <- cmdstanr_example("schools")
#'
#' print_example_program("schools_ncp")
#' fit_schools_mcmc <- cmdstanr_example("schools_ncp")
#'
#' fit_optim <- cmdstanr_example(method = "optimize")
#' fit_optim$summary()
#' # optimization fails for hierarchical model
#' cmdstanr_example("schools", "optimize", quiet = FALSE)
#' }
#'
cmdstanr_example <-
function(example = "logistic",
function(example = c("logistic", "schools", "schools_ncp"),
method = c("sample", "optimize", "variational"),
...) {
...,
quiet = TRUE) {

example <- match.arg(example)
method <- match.arg(method)
Expand All @@ -35,13 +59,22 @@ cmdstanr_example <-
file.copy(system.file(example_program, package = "cmdstanr"), tmp)
}
mod <- cmdstan_model(tmp)
data_file <- system.file(example_data, package = "cmdstanr")

out <- utils::capture.output(
fit <- mod[[method]](
data = system.file(example_data, package = "cmdstanr"),
...
)
)
if (quiet) {
out <- utils::capture.output(fit <- mod[[method]](data = data_file, ...))
} else {
fit <- mod[[method]](data = data_file, ...)
}
fit
}

#' @rdname cmdstanr_example
#' @export
print_example_program <-
function(example = c("logistic", "schools", "schools_ncp")) {
example <- match.arg(example)
code <- readLines(system.file(paste0(example, ".stan"), package = "cmdstanr"))
cat(code, sep = "\n")
}

5 changes: 5 additions & 0 deletions inst/schools.data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"J": 8,
"y": [28, 8, -3, 7, -1, 1, 18, 12],
"sigma": [15, 10, 16, 11, 9, 11, 10, 18]
}
16 changes: 16 additions & 0 deletions inst/schools.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
data {
int<lower=1> J;
vector<lower=0>[J] sigma;
vector[J] y;
}
parameters {
real mu;
real<lower=0> tau;
vector[J] theta;
}
model {
target += normal_lpdf(tau | 0, 10);
target += normal_lpdf(mu | 0, 10);
target += normal_lpdf(theta | mu, tau);
target += normal_lpdf(y | theta, sigma);
}
5 changes: 5 additions & 0 deletions inst/schools_ncp.data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"J": 8,
"y": [28, 8, -3, 7, -1, 1, 18, 12],
"sigma": [15, 10, 16, 11, 9, 11, 10, 18]
}
19 changes: 19 additions & 0 deletions inst/schools_ncp.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
data {
int<lower=1> J;
vector<lower=0>[J] sigma;
vector[J] y;
}
parameters {
real mu;
real<lower=0> tau;
vector[J] theta_raw;
}
transformed parameters {
vector[J] theta = mu + tau * theta_raw;
}
model {
target += normal_lpdf(tau | 0, 10);
target += normal_lpdf(mu | 0, 10);
target += normal_lpdf(theta_raw | 0, 1);
target += normal_lpdf(y | theta, sigma);
}
46 changes: 36 additions & 10 deletions man/cmdstanr_example.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tests/testthat/test-example.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ test_that("cmdstanr_example works", {

fit_vb <- cmdstanr_example("logistic", method = "variational")
checkmate::expect_r6(fit_vb, "CmdStanVB")

expect_output(print_example_program("schools"), "vector[J] theta", fixed=TRUE)
expect_output(print_example_program("schools_ncp"), "vector[J] theta_raw", fixed=TRUE)
})

0 comments on commit f2b484f

Please sign in to comment.