Skip to content

Commit

Permalink
stan_extract() written so can take either rstan or cmdstanr
Browse files Browse the repository at this point in the history
  • Loading branch information
n8thangreen committed Aug 24, 2024
1 parent 268aac5 commit e7e7fe4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
1 change: 0 additions & 1 deletion R/default_prior_cure.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
default_prior_cure <- function(formula_cure,
prior_cure = list(),
bg_model = 2) {
browser()
nTx <- formula_cure$fe_nlevels[1]
n_groups <- formula_cure$n_groups
nvars <- formula_cure$nvars
Expand Down
26 changes: 25 additions & 1 deletion R/prep_S_joint_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ prep_S_joint_data <- function(bmcm_out) {
model_names <- bmcm_out$distns
n_tx <- bmcm_out$formula$cure$fe_nlevels[1]

stan_extract <- rstan::extract(bmcm_out$output)
stan_extract <- stan_extract(bmcm_out)

CI_probs <- c(0.025, 0.5, 0.975)

Expand Down Expand Up @@ -59,3 +59,27 @@ prep_S_joint_data <- function(bmcm_out) {

plot_dat
}


#
stan_extract <- function(bmcm_out, pattern = "") {
fit <- bmcm_out$output

if (inherits(fit, "stanfit")) {

samples <- rstan::extract(fit)
param_names <- grep(pattern, names(samples), value = TRUE)
extracted_params <- samples[param_names]

} else if (inherits(fit, "CmdStanMCMC")) {

samples <- fit$draws(format = "df")
param_names <- grep(pattern, names(samples), value = TRUE)
extracted_params <- samples[, param_names]
} else {
stop("Fit object must be of class 'stanfit' (rstan) or 'CmdStanMCMC' (cmdstanr).")
}

extracted_params
}

0 comments on commit e7e7fe4

Please sign in to comment.