Skip to content

Commit

Permalink
Clarification in MLE and LOO doc
Browse files Browse the repository at this point in the history
closes #1052
closes #1051
  • Loading branch information
jgabry committed Dec 13, 2024
1 parent 3ae9080 commit f01d2af
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 27 deletions.
42 changes: 29 additions & 13 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1497,10 +1497,10 @@ CmdStanMCMC <- R6::R6Class(
#' and the \pkg{loo} package [vignettes](https://mc-stan.org/loo/articles/)
#' for details.
#'
#' @param variables (character vector) The name(s) of the variable(s) in the
#' Stan program containing the pointwise log-likelihood. The default is to
#' look for `"log_lik"`. This argument is passed to the
#' [`$draws()`][fit-method-draws] method.
#' @param variables (string) The name of the variable in the Stan program
#' containing the pointwise log-likelihood. The default is to look for
#' `"log_lik"`. This argument is passed to the [`$draws()`][fit-method-draws]
#' method.
#' @param r_eff (multiple options) How to handle the `r_eff` argument for `loo()`:
#' * `TRUE` (the default) will automatically call [loo::relative_eff.array()]
#' to compute the `r_eff` argument to pass to [loo::loo.array()].
Expand Down Expand Up @@ -1539,6 +1539,9 @@ CmdStanMCMC <- R6::R6Class(
#'
loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) {
require_suggested_package("loo")
if (length(variables) != 1) {
stop("Only a single variable name is allowed for the 'variables' argument.", call. = FALSE)
}
LLarray <- self$draws(variables, format = "draws_array")
if (is.logical(r_eff)) {
if (isTRUE(r_eff)) {
Expand Down Expand Up @@ -1805,6 +1808,12 @@ CmdStanMCMC$set("public", name = "num_chains", value = num_chains)
#'
#' @description A `CmdStanMLE` object is the fitted model object returned by the
#' [`$optimize()`][model-method-optimize] method of a [`CmdStanModel`] object.
#' The name "MLE" may be somewhat misleading because the `$optimize()` method
#' can compute either a penalized maximum likelihood estimate or a maximum a
#' posteriori estimate, depending on the value of the `jacobian` argument when
#' the model is fit. Additionally, for models without constrained parameters,
#' the penalized MLE and the posterior mode are equivalent, as the Jacobian
#' adjustment has no effect.
#'
#' @section Methods: `CmdStanMLE` objects have the following associated methods,
#' all of which have their own (linked) documentation pages.
Expand All @@ -1814,7 +1823,7 @@ CmdStanMCMC$set("public", name = "num_chains", value = num_chains)
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`draws()`][fit-method-draws] | Return the point estimate as a 1-row [`draws_matrix`][posterior::draws_matrix]. |
#' [`$mle()`][fit-method-mle] | Return the point estimate as a numeric vector. |
#' [`$mode()`][fit-method-mode] | Return the point estimate as a numeric vector. |
#' [`$lp()`][fit-method-lp] | Return the total log probability density (`target`). |
#' [`$init()`][fit-method-init] | Return user-specified initial values. |
#' [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
Expand Down Expand Up @@ -1874,17 +1883,23 @@ CmdStanMLE <- R6::R6Class(
)
)

#' Extract (penalized) maximum likelihood estimate after optimization
#' Extract point estimate after optimization
#'
#' @name fit-method-mle
#' @aliases mle
#' @description The `$mle()` method is only available for [`CmdStanMLE`] objects.
#' It returns the penalized maximum likelihood estimate (posterior mode) as a
#' numeric vector with one element per variable. The returned vector does *not*
#' include `lp__`, the total log probability (`target`) accumulated in the
#' model block of the Stan program, which is available via the
#' [`$lp()`][fit-method-lp] method and also included in the
#' [`$draws()`][fit-method-draws] method.
#' @description The `$mle()` method is only available for [`CmdStanMLE`]
#' objects. It returns the point estimate as a numeric vector with one element
#' per variable. The returned vector does *not* include `lp__`, the total log
#' probability (`target`) accumulated in the model block of the Stan program,
#' which is available via the [`$lp()`][fit-method-lp] method and also
#' included in the [`$draws()`][fit-method-draws] method.
#'
#' The name `mle` may be somewhat misleading because the `$optimize()` method
#' can compute either a penalized maximum likelihood estimate or a maximum a
#' posteriori estimate, depending on the value of the `jacobian` argument when
#' the model is fit. Additionally, for models without constrained parameters,
#' the penalized MLE and the posterior mode are equivalent, as the Jacobian
#' adjustment has no effect.
#'
#' @param variables (character vector) The variables (parameters, transformed
#' parameters, and generated quantities) to include. If NULL (the default)
Expand All @@ -1909,6 +1924,7 @@ mle <- function(variables = NULL) {
}
CmdStanMLE$set("public", name = "mle", value = mle)


# CmdStanLaplace ---------------------------------------------------------------
#' CmdStanLaplace objects
#'
Expand Down
8 changes: 7 additions & 1 deletion man/CmdStanMLE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/fit-method-loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 14 additions & 8 deletions man/fit-method-mle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/fit-method-save_object.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions tests/testthat/test-fit-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ test_that("loo method works if log_lik is available", {
fit_bernoulli <- testing_fit("bernoulli_log_lik")
expect_s3_class(suppressWarnings(fit_bernoulli$loo(cores = 1, save_psis = TRUE)), "loo")
expect_s3_class(suppressWarnings(fit_bernoulli$loo(r_eff = FALSE)), "loo")

expect_error(
fit_bernoulli$loo(variables = c("log_lik", "beta")),
"Only a single variable name is allowed"
)
})

test_that("loo method works with moment-matching", {
Expand Down

0 comments on commit f01d2af

Please sign in to comment.