Skip to content

Commit 01f26f9

Browse files
authored
Merge pull request #19 from andrjohns/simplify-constrains
Optimise constraint handling, centralise sources
2 parents c2ab023 + 0bfa199 commit 01f26f9

22 files changed

+421
-478
lines changed

Diff for: DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Description: Allows for the estimation of parameters for 'R' functions using the
2020
License: MIT + file LICENSE
2121
Encoding: UTF-8
2222
Roxygen: list(markdown = TRUE)
23-
RoxygenNote: 7.2.3
23+
RoxygenNote: 7.3.1
2424
NeedsCompilation: yes
2525
UseLTO: true
2626
SystemRequirements: GNU make

Diff for: R/StanEstimators-package.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
#' @importFrom stats setNames runif
1414
#' @importFrom methods new
1515
#'
16-
NULL
16+
"_PACKAGE"

Diff for: R/cpp_exports.R

+8-8
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ call_stan_impl <- function(options_vector, input_list) {
1010
invisible(NULL)
1111
}
1212

13-
parse_csv <- function(filename) {
14-
parsed <- .Call(`parse_csv_`, filename)
13+
parse_csv <- function(filename, lower = -Inf, upper = Inf) {
14+
parsed <- .Call(`parse_csv_`, filename, lower, upper)
1515
parsed$metadata <- parsed$metadata[unique(names(parsed$metadata))]
1616
parsed
1717
}
@@ -40,17 +40,17 @@ hessian_impl <- function(model_ptr, upars, jacobian = TRUE) {
4040
.Call(`hessian_`, model_ptr, upars, jacobian)
4141
}
4242

43-
unconstrain_variables_impl <- function(model_ptr, variables) {
44-
.Call(`unconstrain_variables_`, model_ptr, variables)
43+
unconstrain_variables_impl <- function(variables, lb, ub) {
44+
.Call(`unconstrain_variables_`, variables, lb, ub)
4545
}
4646

47-
unconstrain_draws_impl <- function(model_ptr, draws, match_format = TRUE) {
47+
unconstrain_draws_impl <- function(draws, lb, ub, match_format = TRUE) {
4848
draws_matrix <- posterior::as_draws_matrix(draws)
4949
par_cols <- grep("^par", colnames(draws_matrix))
5050
if (length(par_cols) == 0) {
5151
stop("No parameter columns found in draws object", call. = FALSE)
5252
}
53-
unconstrained_variables <- .Call(`unconstrain_draws_`, model_ptr, draws_matrix[, par_cols])
53+
unconstrained_variables <- .Call(`unconstrain_draws_`, draws_matrix[, par_cols], lb, ub)
5454
draws_matrix[, par_cols] <- unconstrained_variables
5555
if (match_format) {
5656
match_draws_format(draws, draws_matrix)
@@ -59,8 +59,8 @@ unconstrain_draws_impl <- function(model_ptr, draws, match_format = TRUE) {
5959
}
6060
}
6161

62-
constrain_variables_impl <- function(model_ptr, upars) {
63-
.Call(`constrain_variables_`, model_ptr, upars)
62+
constrain_variables_impl <- function(upars, lb, ub) {
63+
.Call(`constrain_variables_`, upars, lb, ub)
6464
}
6565

6666
lub_constrain <- function(x, lb, ub) {

Diff for: R/laplace.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ stan_laplace <- function(fn, par_inits = NULL, n_pars = NULL, additional_args =
121121

122122
call_stan(args, inputs, quiet)
123123

124-
parsed <- parse_csv(inputs$output_filepath)
124+
parsed <- parse_csv(inputs$output_filepath, lower=inputs$lower, upper=inputs$upper)
125125
estimates <- setNames(data.frame(parsed$samples), parsed$header)
126126

127127
methods::new("StanLaplace",

Diff for: R/model_methods.R

+5-7
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ setMethod("hessian", "StanBase",
120120
#' @aliases unconstrain_variables,StanBase,StanBase-method
121121
setMethod("unconstrain_variables", "StanBase",
122122
function(stan_object, variables) {
123-
unconstrain_variables_impl(stan_object@model_methods$model_pointer,
124-
variables)
123+
unconstrain_variables_impl(variables, stan_object@lower_bounds, stan_object@upper_bounds)
125124
}
126125
)
127126

@@ -132,7 +131,7 @@ setMethod("unconstrain_draws", "StanBase",
132131
if (is.null(draws)) {
133132
draws <- stan_object@draws
134133
}
135-
unconstrain_draws_impl(stan_object@model_methods$model_pointer, draws)
134+
unconstrain_draws_impl(draws, stan_object@lower_bounds, stan_object@upper_bounds)
136135
}
137136
)
138137

@@ -142,9 +141,9 @@ setMethod("unconstrain_draws", "StanOptimize",
142141
function(stan_object, draws) {
143142
if (is.null(draws)) {
144143
variables <- stan_object@estimates
145-
unconstrain_draws_impl(stan_object@model_methods$model_pointer, stan_object@estimates, match_format = FALSE)
144+
unconstrain_draws_impl(stan_object@estimates, stan_object@lower_bounds, stan_object@upper_bounds, match_format = FALSE)
146145
} else {
147-
unconstrain_draws_impl(stan_object@model_methods$model_pointer, draws)
146+
unconstrain_draws_impl(draws, stan_object@lower_bounds, stan_object@upper_bounds)
148147
}
149148
}
150149
)
@@ -153,8 +152,7 @@ setMethod("unconstrain_draws", "StanOptimize",
153152
#' @aliases constrain_variables,StanBase,StanBase-method
154153
setMethod("constrain_variables", "StanBase",
155154
function(stan_object, unconstrained_variables) {
156-
constrain_variables_impl(stan_object@model_methods$model_pointer,
157-
unconstrained_variables)
155+
constrain_variables_impl(unconstrained_variables, stan_object@lower_bounds, stan_object@upper_bounds)
158156
}
159157
)
160158

Diff for: R/optimize.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ stan_optimize <- function(fn, par_inits = NULL, n_pars = NULL, additional_args =
113113
output_args = output)
114114
call_stan(args, inputs, quiet)
115115

116-
parsed <- parse_csv(inputs$output_filepath)
116+
parsed <- parse_csv(inputs$output_filepath, lower=inputs$lower, upper=inputs$upper)
117117

118118
methods::new("StanOptimize",
119119
metadata = parsed$metadata,

Diff for: R/pathfinder.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ stan_pathfinder <- function(fn, par_inits = NULL, n_pars = NULL, additional_args
114114

115115
call_stan(args, inputs, quiet)
116116

117-
parsed <- parse_csv(inputs$output_filepath)
117+
parsed <- parse_csv(inputs$output_filepath, lower=inputs$lower, upper=inputs$upper)
118118

119119
methods::new("StanPathfinder",
120120
metadata = parsed$metadata,

Diff for: R/sample.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ stan_sample <- function(fn, par_inits = NULL, n_pars = NULL, additional_args = l
237237
}
238238

239239
all_samples <- lapply(output_files, function(filepath) {
240-
parse_csv(filepath)
240+
parse_csv(filepath, lower=inputs$lower, upper=inputs$upper)
241241
})
242242
draw_names <- all_samples[[1]]$header
243243
metadata <- all_samples[[1]]$metadata

Diff for: R/variational.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ stan_variational <- function(fn, par_inits = NULL, n_pars = NULL, additional_arg
110110

111111
call_stan(args, inputs, quiet)
112112

113-
parsed <- parse_csv(inputs$output_filepath)
113+
parsed <- parse_csv(inputs$output_filepath, lower=inputs$lower, upper=inputs$upper)
114114
estimates <- setNames(data.frame(parsed$samples), parsed$header)
115115
methods::new("StanVariational",
116116
metadata = parsed$metadata,

Diff for: inst/include/estimator/estimator.hpp

+31-50
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@ using namespace stan::math;
66
stan::math::profile_map profiles__;
77
static constexpr std::array<const char*, 13> locations_array__ =
88
{" (found before start of program)",
9-
" (in 'include/estimator/estimator.stan', line 15, column 2 to column 61)",
10-
" (in 'include/estimator/estimator.stan', line 19, column 2 to column 95)",
11-
" (in 'include/estimator/estimator.stan', line 6, column 2 to column 12)",
12-
" (in 'include/estimator/estimator.stan', line 7, column 2 to column 18)",
13-
" (in 'include/estimator/estimator.stan', line 8, column 2 to column 16)",
14-
" (in 'include/estimator/estimator.stan', line 9, column 8 to column 13)",
15-
" (in 'include/estimator/estimator.stan', line 9, column 2 to column 32)",
16-
" (in 'include/estimator/estimator.stan', line 10, column 9 to column 14)",
17-
" (in 'include/estimator/estimator.stan', line 10, column 2 to column 29)",
18-
" (in 'include/estimator/estimator.stan', line 11, column 9 to column 14)",
19-
" (in 'include/estimator/estimator.stan', line 11, column 2 to column 29)",
20-
" (in 'include/estimator/estimator.stan', line 15, column 49 to column 54)"};
9+
" (in 'inst/include/estimator/estimator.stan', line 15, column 2 to column 21)",
10+
" (in 'inst/include/estimator/estimator.stan', line 19, column 2 to column 95)",
11+
" (in 'inst/include/estimator/estimator.stan', line 6, column 2 to column 12)",
12+
" (in 'inst/include/estimator/estimator.stan', line 7, column 2 to column 18)",
13+
" (in 'inst/include/estimator/estimator.stan', line 8, column 2 to column 16)",
14+
" (in 'inst/include/estimator/estimator.stan', line 9, column 8 to column 13)",
15+
" (in 'inst/include/estimator/estimator.stan', line 9, column 2 to column 32)",
16+
" (in 'inst/include/estimator/estimator.stan', line 10, column 9 to column 14)",
17+
" (in 'inst/include/estimator/estimator.stan', line 10, column 2 to column 29)",
18+
" (in 'inst/include/estimator/estimator.stan', line 11, column 9 to column 14)",
19+
" (in 'inst/include/estimator/estimator.stan', line 11, column 2 to column 29)",
20+
" (in 'inst/include/estimator/estimator.stan', line 15, column 9 to column 14)"};
2121
class estimator_model final : public model_base_crtp<estimator_model> {
2222
private:
2323
int Npars;
@@ -48,7 +48,7 @@ class estimator_model final : public model_base_crtp<estimator_model> {
4848
// suppress unused var warning
4949
(void) DUMMY_VAR__;
5050
try {
51-
int pos__;
51+
int pos__ = std::numeric_limits<int>::min();
5252
pos__ = 1;
5353
current_statement__ = 3;
5454
context__.validate_dims("data initialization", "Npars", "int",
@@ -130,7 +130,7 @@ class estimator_model final : public model_base_crtp<estimator_model> {
130130
}
131131
inline std::vector<std::string> model_compile_info() const noexcept {
132132
return std::vector<std::string>{"stanc_version = stanc3 v2.35.0",
133-
"stancflags = --O1 --allow-undefined"};
133+
"stancflags = --allow-undefined"};
134134
}
135135
// Base log prob
136136
template <bool propto__, bool jacobian__, typename VecR, typename VecI,
@@ -156,20 +156,13 @@ class estimator_model final : public model_base_crtp<estimator_model> {
156156
// suppress unused var warning
157157
(void) function__;
158158
try {
159-
Eigen::Matrix<local_scalar_t__,-1,1> pars;
160-
current_statement__ = 1;
161-
pars = in__.template read_constrain_lub<
162-
Eigen::Matrix<local_scalar_t__,-1,1>,
163-
jacobian__>(lower_bounds, upper_bounds, lp__, Npars);
164-
current_statement__ = 1;
165-
stan::math::check_matching_dims("constraint", "pars", pars, "lower",
166-
lower_bounds);
159+
Eigen::Matrix<local_scalar_t__,-1,1> pars =
160+
Eigen::Matrix<local_scalar_t__,-1,1>::Constant(Npars, DUMMY_VAR__);
167161
current_statement__ = 1;
168-
stan::math::check_matching_dims("constraint", "pars", pars, "upper",
169-
upper_bounds);
162+
pars = in__.template read<Eigen::Matrix<local_scalar_t__,-1,1>>(Npars);
170163
{
171164
current_statement__ = 2;
172-
lp_accum__.add(r_function(pars, finite_diff, no_bounds, bounds_types,
165+
lp_accum__.add(r_function<jacobian__>(pars, finite_diff, no_bounds, bounds_types,
173166
lower_bounds, upper_bounds, pstream__));
174167
}
175168
} catch (const std::exception& e) {
@@ -202,20 +195,13 @@ class estimator_model final : public model_base_crtp<estimator_model> {
202195
// suppress unused var warning
203196
(void) function__;
204197
try {
205-
Eigen::Matrix<local_scalar_t__,-1,1> pars;
206-
current_statement__ = 1;
207-
pars = in__.template read_constrain_lub<
208-
Eigen::Matrix<local_scalar_t__,-1,1>,
209-
jacobian__>(lower_bounds, upper_bounds, lp__, Npars);
210-
current_statement__ = 1;
211-
stan::math::check_matching_dims("constraint", "pars", pars, "lower",
212-
lower_bounds);
198+
Eigen::Matrix<local_scalar_t__,-1,1> pars =
199+
Eigen::Matrix<local_scalar_t__,-1,1>::Constant(Npars, DUMMY_VAR__);
213200
current_statement__ = 1;
214-
stan::math::check_matching_dims("constraint", "pars", pars, "upper",
215-
upper_bounds);
201+
pars = in__.template read<Eigen::Matrix<local_scalar_t__,-1,1>>(Npars);
216202
{
217203
current_statement__ = 2;
218-
lp_accum__.add(r_function(pars, finite_diff, no_bounds, bounds_types,
204+
lp_accum__.add(r_function<jacobian__>(pars, finite_diff, no_bounds, bounds_types,
219205
lower_bounds, upper_bounds, pstream__));
220206
}
221207
} catch (const std::exception& e) {
@@ -259,17 +245,11 @@ class estimator_model final : public model_base_crtp<estimator_model> {
259245
// suppress unused var warning
260246
(void) function__;
261247
try {
262-
Eigen::Matrix<double,-1,1> pars;
263-
current_statement__ = 1;
264-
pars = in__.template read_constrain_lub<
265-
Eigen::Matrix<local_scalar_t__,-1,1>,
266-
jacobian__>(lower_bounds, upper_bounds, lp__, Npars);
248+
Eigen::Matrix<double,-1,1> pars =
249+
Eigen::Matrix<double,-1,1>::Constant(Npars,
250+
std::numeric_limits<double>::quiet_NaN());
267251
current_statement__ = 1;
268-
stan::math::check_matching_dims("constraint", "pars", pars, "lower",
269-
lower_bounds);
270-
current_statement__ = 1;
271-
stan::math::check_matching_dims("constraint", "pars", pars, "upper",
272-
upper_bounds);
252+
pars = in__.template read<Eigen::Matrix<local_scalar_t__,-1,1>>(Npars);
273253
out__.write(pars);
274254
if (stan::math::logical_negation(
275255
(stan::math::primitive_value(emit_transformed_parameters__) ||
@@ -299,12 +279,13 @@ class estimator_model final : public model_base_crtp<estimator_model> {
299279
// suppress unused var warning
300280
(void) DUMMY_VAR__;
301281
try {
302-
Eigen::Matrix<local_scalar_t__,-1,1> pars;
282+
Eigen::Matrix<local_scalar_t__,-1,1> pars =
283+
Eigen::Matrix<local_scalar_t__,-1,1>::Constant(Npars, DUMMY_VAR__);
303284
current_statement__ = 1;
304285
stan::model::assign(pars,
305286
in__.read<Eigen::Matrix<local_scalar_t__,-1,1>>(Npars),
306287
"assigning variable pars");
307-
out__.write_free_lub(lower_bounds, upper_bounds, pars);
288+
out__.write(pars);
308289
} catch (const std::exception& e) {
309290
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
310291
}
@@ -325,7 +306,7 @@ class estimator_model final : public model_base_crtp<estimator_model> {
325306
current_statement__ = 1;
326307
context__.validate_dims("parameter initialization", "pars", "double",
327308
std::vector<size_t>{static_cast<size_t>(Npars)});
328-
int pos__;
309+
int pos__ = std::numeric_limits<int>::min();
329310
pos__ = 1;
330311
Eigen::Matrix<local_scalar_t__,-1,1> pars =
331312
Eigen::Matrix<local_scalar_t__,-1,1>::Constant(Npars, DUMMY_VAR__);
@@ -340,7 +321,7 @@ class estimator_model final : public model_base_crtp<estimator_model> {
340321
pos__ = (pos__ + 1);
341322
}
342323
}
343-
out__.write_free_lub(lower_bounds, upper_bounds, pars);
324+
out__.write(pars);
344325
} catch (const std::exception& e) {
345326
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
346327
}

Diff for: inst/include/estimator/estimator.stan

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ data {
1212
}
1313

1414
parameters {
15-
vector<lower=lower_bounds, upper=upper_bounds>[Npars] pars;
15+
vector[Npars] pars;
1616
}
1717

1818
model {

0 commit comments

Comments
 (0)