Skip to content

Commit

Permalink
tried to tidy by making several helper functions. Really should restr…
Browse files Browse the repository at this point in the history
…ucture the whole code but ended up with this by adding more functionality incrementally.
  • Loading branch information
n8thangreen committed Aug 12, 2024
1 parent 910d476 commit caf1940
Showing 1 changed file with 71 additions and 34 deletions.
105 changes: 71 additions & 34 deletions R/bmcm_stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,34 +163,31 @@ bmcm_stan <- function(input_data,
step_size = 0.05)
)

if (use_cmdstanr) {
dots <- cmdstanr_dots
} else {
dots <- rstan_dots
}

##############
# fit model

use_precompiled_model <- !is.na(precompiled_model_path)

if (!use_precompiled_model) {
if (read_stan_code) {
model_code <- readr::read_file(here::here("data/stan_model_code.stan"))
} else {
model_code <- create_stancode(distns)
}

if (use_cmdstanr) {
model_path <- cmdstanr::write_stan_file(model_code,
dir = ".", basename = model_name)
precompiled_model <- cmdstanr::cmdstan_model(stan_file = model_path)
} else {
precompiled_model <- rstan::stan_model(model_code = model_code,
model_name = model_name)
}

if (use_precompiled_model) {
precompiled_model <-
load_precompiled_model(
use_cmdstanr = use_cmdstanr,
path = precompiled_model_path)
} else {
if (use_cmdstanr) {
# may be automatically saved to cmdstan_path()
precompiled_model <- cmdstanr::cmdstan_model(exe_file = precompiled_model_path)
} else {
precompiled_model <- readRDS(precompiled_model_path)
}
model_code <-
get_model_code(read_stan_code, distns)

precompiled_model <-
compile_model(
use_cmdstanr = use_cmdstanr,
model_code = model_code,
model_name = model_name)
}

# for testing
Expand All @@ -200,18 +197,9 @@ bmcm_stan <- function(input_data,

res <- list()

if (use_cmdstanr) {
res$output <- do.call(
precompiled_model$sample,
args = c(stan_inputs, cmdstanr_dots))
res$stan_dots <- cmdstanr_dots
} else {
res$output <- do.call(
rstan::sampling,
args = c(list(object = precompiled_model), stan_inputs, rstan_dots))
res$stan_dots <- rstan_dots
}

res$output <- perform_sampling(use_cmdstanr, precompiled_model,
stan_inputs, dots)
res$dots <- dots
res$call <- call
res$distns <- distns
res$inputs <- stan_inputs
Expand All @@ -223,3 +211,52 @@ bmcm_stan <- function(input_data,
return(res)
}

#
get_model_code <- function(read_stan_code, distns) {
if (read_stan_code) {
return(readr::read_file(here::here("data/stan_model_code.stan")))
} else {
return(create_stancode(distns))
}
}

#
compile_model <- function(use_cmdstanr, model_code, model_name) {
if (use_cmdstanr) {
model_path <-
cmdstanr::write_stan_file(
model_code, dir = ".", basename = model_name)
return(cmdstanr::cmdstan_model(stan_file = model_path))
} else {
return(rstan::stan_model(model_code = model_code, model_name = model_name))
}
}

#
load_precompiled_model <- function(use_cmdstanr, path) {
if (use_cmdstanr) {
return(cmdstanr::cmdstan_model(exe_file = path))
} else {
return(readRDS(path))
}
}

#
perform_sampling <- function(use_cmdstanr, precompiled_model,
stan_inputs, dots) {
if (use_cmdstanr) {
output <- do.call(
precompiled_model$sample,
args = c(stan_inputs, dots)
)
} else {
output <- do.call(
rstan::sampling,
args = c(list(object = precompiled_model),
stan_inputs, dots)
)
}

output
}

0 comments on commit caf1940

Please sign in to comment.