Skip to content

Commit

Permalink
renamed returned_quantities to returned as requested
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Nov 25, 2024
1 parent 0c6bada commit 5134ff7
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr
@addlogprob!
```

Return values of the model function for a collection of samples can be obtained with [`returned_quantities`](@ref).
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).

```@docs
returned_quantities
returned(model, chain)
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down
8 changes: 4 additions & 4 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
end

"""
returned_quantities(model::Model, chain::MCMCChains.Chains)
returned(model::Model, chain::MCMCChains.Chains)
Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.
Expand All @@ -63,7 +63,7 @@ m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
returned_quantities(m, chain) # <= results in a `Vector` of returned values
returned(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
Expand All @@ -87,7 +87,7 @@ julia> model = demo(randn(10));
julia> chain = sample(model, MH(), 10);
julia> DynamicPPL.returned_quantities(model, chain)
julia> returned(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
Expand All @@ -101,7 +101,7 @@ julia> DynamicPPL.returned_quantities(model, chain)
(-0.16489786710222099,)
```
"""
function DynamicPPL.returned_quantities(
function DynamicPPL.returned(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
Expand Down
5 changes: 3 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export AbstractVarInfo,
fix,
unfix,
prefix,
returned_quantities,
returned,
# Convenience macros
@addlogprob!,
@submodel,
Expand All @@ -132,7 +132,8 @@ export AbstractVarInfo,
to_sampleable,
# Deprecated.
@logprob_str,
@prob_str
@prob_str,
generated_quantities

# Reexport
using Distributions: loglikelihood
Expand Down
2 changes: 1 addition & 1 deletion src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
@deprecate generated_quantities returned_quantities
@deprecate generated_quantities(model, params) returned(model, params)
18 changes: 9 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1206,9 +1206,9 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end

"""
returned_quantities(model::Model, parameters::NamedTuple)
returned_quantities(model::Model, values, keys)
returned_quantities(model::Model, values, keys)
returned(model::Model, parameters::NamedTuple)
returned(model::Model, values, keys)
returned(model::Model, values, keys)
Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
Expand All @@ -1218,7 +1218,7 @@ If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameter
```jldoctest
julia> using DynamicPPL, Distributions
julia> using DynamicPPL: returned_quantities
julia> using DynamicPPL: returned
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
Expand All @@ -1235,18 +1235,18 @@ julia> model = demo(randn(10));
julia> parameters = (; s = 1.0, m_shifted=10.0);
julia> returned_quantities(model, parameters)
julia> returned(model, parameters)
(0.0,)
julia> returned_quantities(model, values(parameters), keys(parameters))
julia> returned(model, values(parameters), keys(parameters))
(0.0,)
```
"""
function returned_quantities(model::Model, parameters::NamedTuple)
function returned(model::Model, parameters::NamedTuple)
fixed_model = fix(model, parameters)
return fixed_model()
end

function returned_quantities(model::Model, values, keys)
return returned_quantities(model, NamedTuple{keys}(values))
function returned(model::Model, values, keys)
return returned(model, NamedTuple{keys}(values))
end
2 changes: 1 addition & 1 deletion src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ function prefix_submodel_context(prefix::Bool, ctx)
return ctx
end

const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated, use `left ~ to_sampleable(model)` and `left ~ to_sampleable(prefix(model, ...))`, respectively, instead."
const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated, use `left ~ to_sampleable(returned(model))` and `left ~ to_sampleable(returned(prefix(model, ...)))`, respectively, instead."

function submodel(prefix_expr, expr, ctx=esc(:__context__))
prefix_left, prefix = getargs_assignment(prefix_expr)
Expand Down

0 comments on commit 5134ff7

Please sign in to comment.