Skip to content

Commit

Permalink
extract function into get_stan_defaults()
Browse files Browse the repository at this point in the history
  • Loading branch information
n8thangreen committed Aug 12, 2024
1 parent 9e7d740 commit cc36f80
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions R/bmcm_stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,42 +132,7 @@ bmcm_stan <- function(input_data,

model_name <- paste0("bmcm_stan_", glue::glue_collapse(distns, sep = "_"))

# default sampler parameters
#TODO: cmdstanr and rstan are slightly different
# harmonize automatically
rstan_dots <-
modifyList(
dots,
list(warmup = 100,
iter = 500,
thin = 1,
chains = 1,
control = list(adapt_delta = 0.99,
max_treedepth = 100,
stepsize = 0.05),
include = TRUE,
open_progress = TRUE)
# verbose = TRUE)
)

cmdstanr_dots <-
modifyList(
dots,
list(iter_warmup = 100,
iter_sampling = 500,
save_warmup = FALSE,
thin = 1,
chains = 1,
adapt_delta = 0.99,
max_treedepth = 100,
step_size = 0.05)
)

if (use_cmdstanr) {
dots <- cmdstanr_dots
} else {
dots <- rstan_dots
}
mcmc_params <- get_stan_defaults(use_cmdstanr, dots)

##############
# fit model
Expand Down Expand Up @@ -198,8 +163,8 @@ bmcm_stan <- function(input_data,
res <- list()

res$output <- perform_sampling(use_cmdstanr, precompiled_model,
stan_inputs, dots)
res$dots <- dots
stan_inputs, mcmc_params)
res$mcmc_params <- mcmc_params
res$call <- call
res$distns <- distns
res$inputs <- stan_inputs
Expand Down Expand Up @@ -263,3 +228,38 @@ perform_sampling <- function(use_cmdstanr, precompiled_model,
output
}

# default sampler parameters
#TODO: cmdstanr and rstan are slightly different
# harmonize automatically
get_stan_defaults <- function(use_cmdstanr, dots) {
if (use_cmdstanr) {
return(
modifyList(
dots,
list(iter_warmup = 100,
iter_sampling = 500,
save_warmup = FALSE,
thin = 1,
chains = 1,
adapt_delta = 0.99,
max_treedepth = 100,
step_size = 0.05)
))
} else {
return(
modifyList(
dots,
list(warmup = 100,
iter = 500,
thin = 1,
chains = 1,
control = list(
adapt_delta = 0.99,
max_treedepth = 100,
stepsize = 0.05),
include = TRUE,
open_progress = TRUE)
# verbose = TRUE)
))
}
}

0 comments on commit cc36f80

Please sign in to comment.