Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds @returned_quantities macro #696

Open
wants to merge 57 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
5c746c4
Added `@returned_quantities` macro
torfjelde Oct 23, 2024
0b081b7
Added `@returned_quantities` to the docs
torfjelde Oct 23, 2024
dc699a5
Fixed names of doctests for `@returned_quantities`
torfjelde Oct 23, 2024
7067695
Update src/submodel_macro.jl
torfjelde Oct 24, 2024
8cb0796
Added `@prefix` macro which calls `prefix` with a `Val` argument to
torfjelde Oct 29, 2024
2d887c9
Convert the result of `prefix_expr` in `@prefix` into a `Sybmol`
torfjelde Oct 29, 2024
692cfff
Export `prefix` and `@prefix`
torfjelde Oct 29, 2024
32fd6b9
Updated docstring for `@returned_quantities`
torfjelde Oct 29, 2024
5478fb3
Fixed bug in `rand` for `Model` where it would duplicate the non-leaf
torfjelde Oct 29, 2024
5fe65b3
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Oct 29, 2024
9e0730f
Update src/contexts.jl
torfjelde Oct 29, 2024
cc3af46
Added `prefix` and `@prefix` to docs
torfjelde Oct 29, 2024
720053a
removed the prefix=... syntax for `@returned_quantities`
torfjelde Oct 31, 2024
fe0403f
added deprecation.jl + deprecated `generated_quantities` in favour of…
torfjelde Oct 31, 2024
55b95a1
removed export of `prefix` and `generated_quantities` (the latter is
torfjelde Oct 31, 2024
34fb6bd
updated `DynamicPPLMCMCChainsExt` to define `returned_quantities`
torfjelde Oct 31, 2024
9a7e18f
updated docs
torfjelde Oct 31, 2024
7aef65b
Update docs/src/api.md
torfjelde Nov 1, 2024
5ee727b
improved docstring for `prefix` and `@prefix`
torfjelde Nov 6, 2024
d92141c
added `@returned_quantities` macro taking two arguments + removed
torfjelde Nov 6, 2024
64b519d
updated docs to reflect the new two-argument `@returned_quantities`
torfjelde Nov 6, 2024
1b48f65
added depwarn to `@submodel` macro
torfjelde Nov 6, 2024
db2102c
fixed reference
torfjelde Nov 6, 2024
da95aba
fixed reference to `@prefix` in `@returned_quantities` macro
torfjelde Nov 6, 2024
c8d567f
actually fixed doc references
torfjelde Nov 6, 2024
d477137
updated doctests for `@submodel` to include the depwarn + added
torfjelde Nov 8, 2024
4896793
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 8, 2024
946fa6d
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 15, 2024
bf35de4
added `to_sampleable` and limited `~` handling for submodels
torfjelde Nov 15, 2024
0f20624
added docs to `to_sampleable` + removed the unnecessary macro exports
torfjelde Nov 15, 2024
99d99b3
updated more docstrings
torfjelde Nov 15, 2024
0597b2a
added testing of deprecation warning of `@submodel` + replaced some
torfjelde Nov 15, 2024
0c6bada
Update test/compiler.jl
torfjelde Nov 15, 2024
5134ff7
renamed `returned_quantities` to `returned` as requested
torfjelde Nov 25, 2024
45451f7
removed redundant `SampleableModelWrapper` in favour of
torfjelde Nov 25, 2024
c00a9ae
updated tests + docstrings + warnings to use `returned`
torfjelde Nov 25, 2024
f0af1d5
updated docs
torfjelde Nov 25, 2024
1b231a9
formatting
torfjelde Nov 25, 2024
1faa627
Update src/model.jl
torfjelde Nov 25, 2024
92ac6b9
fix docs
torfjelde Nov 25, 2024
f73d1b0
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 25, 2024
b7b2e1d
export `to_sampleable` and add to docs
torfjelde Nov 25, 2024
ed4bb76
fixed typo in warning
torfjelde Nov 25, 2024
36f02f6
removed unnecessary import in docstring
torfjelde Nov 25, 2024
98538c5
added docstring to `rand_like!!`
torfjelde Nov 25, 2024
d316306
fixed docstring for `returned(model)`
torfjelde Nov 25, 2024
0e05901
improvements to docstrings thanks to @penelopesym
torfjelde Nov 25, 2024
f073b25
added abstract type `Distributional` and concrete type `Sampleable`,
torfjelde Nov 26, 2024
2ec03c1
replaced usages of `returned` with `to_submodel`
torfjelde Nov 26, 2024
1f70dfc
formatting
torfjelde Nov 26, 2024
f645259
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Nov 26, 2024
23355ea
Update docs/src/api.md
torfjelde Nov 27, 2024
0e82a60
removed export of `to_sampleable` since it currently has no purpose +
torfjelde Nov 27, 2024
b9017c4
formatting
torfjelde Nov 27, 2024
6e149a3
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Nov 27, 2024
933e4ed
updated docstring for `condition` and `fix` to not use `@submdoel`
torfjelde Nov 27, 2024
4fc7b76
added `check_tilde_rhs` for `Sampleable`
torfjelde Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,31 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
@model
```

One can nest models and call another model inside the model function with [`@submodel`](@ref).
One can nest models and call another model inside the model function with `left ~ returned(model)`.

```@docs
returned(::DynamicPPL.Model)
```

Note that a `[returned(::DynamicPPL.Model)](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations.
This can be indicated using [`to_sampleable`](@ref) if the user wants to be explicit.

```@docs
to_sampleable
```

In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`returned(model)`](@ref)

```@docs
@submodel
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
```

In the context of nesting models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:

```@docs
prefix
```

### Type

A [`Model`](@ref) can be created by calling the model function, as defined by [`@model`](@ref).
Expand Down Expand Up @@ -118,10 +137,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr
@addlogprob!
```

Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref).
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).

```@docs
generated_quantities
returned(::DynamicPPL.Model, ::NamedTuple)
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down
12 changes: 5 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)
returned(model::Model, chain::MCMCChains.Chains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.
Expand All @@ -63,12 +63,12 @@ m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
returned(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing
julia> using Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
Expand All @@ -87,7 +87,7 @@ julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
julia> returned(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
Expand All @@ -101,9 +101,7 @@ julia> generated_quantities(model, chain)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
Expand Down
9 changes: 7 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
generated_quantities,
extract_priors,
values_as_in_model,
# Samplers
Expand Down Expand Up @@ -122,6 +121,9 @@ export AbstractVarInfo,
decondition,
fix,
unfix,
prefix,
returned,
to_sampleable,
# Convenience macros
@addlogprob!,
@submodel,
Expand All @@ -130,7 +132,8 @@ export AbstractVarInfo,
check_model_and_trace,
# Deprecated.
@logprob_str,
@prob_str
@prob_str,
generated_quantities

# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -196,6 +199,8 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end
Expand Down
1 change: 1 addition & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ function check_tilde_rhs(@nospecialize(x))
end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
check_tilde_rhs(x::ReturnedModelWrapper) = x

"""
unwrap_right_vn(right, vn)
Expand Down
36 changes: 32 additions & 4 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
value, logp, vi = tilde_assume(context, right, vn, vi)
return value, acclogp_assume!!(context, vi, logp)
return if is_rhs_model(right)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be generalized as we desire, e.g. if want to do something special with lantent(model), we can overload this to be true and then overload rand_like!!

rand_like!!(right, context, vi)
else
value, logp, vi = tilde_assume(context, right, vn, vi)
value, acclogp_assume!!(context, vi, logp)
end
end

# observe
Expand Down Expand Up @@ -197,6 +201,11 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(context, right, left, vname, vi)
is_rhs_model(right) && throw(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we want "more" things to be allowed on right, we can easily deal with this by just generalizing is_rhs_model.

ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
return tilde_observe!!(context, right, left, vi)
end

Expand All @@ -210,6 +219,11 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the
probability of `vi` with the returned value.
"""
function tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
logp, vi = tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down Expand Up @@ -420,8 +434,12 @@ model inputs), accumulate the log probability, and return the sampled value and
Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
"""
function dot_tilde_assume!!(context, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
return value, acclogp_assume!!(context, vi, logp), vi
return if is_rhs_model(right)
rand_like!!(right, context, vi)
else
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
value, acclogp_assume!!(context, vi, logp)
end
end

# `dot_assume`
Expand Down Expand Up @@ -672,6 +690,11 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor
name and indices; if needed, these can be accessed through this function, though.
"""
function dot_tilde_observe!!(context, right, left, vn, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
return dot_tilde_observe!!(context, right, left, vi)
end

Expand All @@ -684,6 +707,11 @@ probability, and return the observed value and updated `vi`.
Falls back to `dot_tilde_observe(context, right, left, vi)`.
"""
function dot_tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
logp, vi = dot_tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down
28 changes: 28 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,34 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
end
end

"""
prefix(model::Model, x)

Return `model` but with all random variables prefixed by `x`.

If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.

# Examples

```jldoctest
julia> using DynamicPPL: prefix

julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)

julia> rand(prefix(demo(), :my_prefix))
(var"my_prefix.x" = 1,)

julia> # One can also use `Val` to avoid runtime overheads.
rand(prefix(demo(), Val(:my_prefix)))
(var"my_prefix.x" = 1,)
```
"""
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate generated_quantities(model, params) returned(model, params)
Loading
Loading