Skip to content

Commit

Permalink
Depreciate@submodel l ~ m in favour of l ~ to_submodel(m); rename…
Browse files Browse the repository at this point in the history
… `generated_quantities` to `returned` (#696)

* Added `@returned_quantities` macro

* Added `@returned_quantities` to the docs

* Fixed names of doctests for `@returned_quantities`

* Update src/submodel_macro.jl

Co-authored-by: Xianda Sun <[email protected]>

* Added `@prefix` macro which calls `prefix` with a `Val` argument to
make things easier to basic users

* Convert the result of `prefix_expr` in `@prefix` into a `Sybmol`
before wrapping in `Val`

* Export `prefix` and `@prefix`

* Updated docstring for `@returned_quantities`

* Fixed bug in `rand` for `Model` where it would duplicate the non-leaf
contexts in `model.context`

* Update src/contexts.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Added `prefix` and `@prefix` to docs

* removed the prefix=... syntax for `@returned_quantities`

* added deprecation.jl + deprecated `generated_quantities` in favour of `returned_quantities`

* removed export of `prefix` and `generated_quantities` (the latter is
exported by the deprecation macro)

* updated `DynamicPPLMCMCChainsExt` to define `returned_quantities`

* updated docs

* Update docs/src/api.md

Co-authored-by: Hong Ge <[email protected]>

* improved docstring for `prefix` and `@prefix`

* added `@returned_quantities` macro taking two arguments + removed
`returned_quantities` from exported functions

* updated docs to reflect the new two-argument `@returned_quantities`

* added depwarn to `@submodel` macro

* fixed reference

* fixed reference to `@prefix` in `@returned_quantities` macro

* actually fixed doc references

* updated doctests for `@submodel` to include the depwarn + added
warning regarding deprecation of `@submodel`

* added `to_sampleable` and limited `~` handling for submodels

* added docs to `to_sampleable` + removed the unnecessary macro exports
that we no longer need

* updated more docstrings

* added testing of deprecation warning of `@submodel` + replaced some
usages in tests (though we don't support some of these so we cant' do
that yet)

* Update test/compiler.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* renamed `returned_quantities` to `returned` as requested

* removed redundant `SampleableModelWrapper` in favour of
`ReturnedModelWrapper` + introduced `rand_like!!` to hide explicit
calls to `_evaluate!!`

* updated tests + docstrings + warnings to use `returned`

* updated docs

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix docs

* export `to_sampleable` and add to docs

* fixed typo in warning

* removed unnecessary import in docstring

* added docstring to `rand_like!!`

* fixed docstring for `returned(model)`

* improvements to docstrings thanks to @penelopesym

Co-authored-by: Penelope Yong <[email protected]>

* added abstract type `Distributional` and concrete type `Sampleable`,
in addition to method `to_submodel`

* replaced usages of `returned` with `to_submodel`

* formatting

* Update docs/src/api.md

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* removed export of `to_sampleable` since it currently has no purpose +
fixed docs for `returned`

* formatting

* updated docstring for `condition` and `fix` to not use `@submdoel`

* added `check_tilde_rhs` for `Sampleable`

* let the field of sampleable determine whether it works or not

* add automatic prefixing of submodels + remove support for dot-tilde
since this is ambigious in this case

* added automatic prefixing for sub-models involved in `~` statements

* updated depwarn for `@submodel` and tests

* formatting

* updated docstrings

* updated docs

* added more depwarns to the doctests to see if that helps (though I
don't understand why this is needed for Documenter.jl)

* forgot one

* replaced usage of `generated_quantities` with `returned`

* foxed docstring for `to_submodel`

* patch version bump

---------

Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
5 people authored Dec 4, 2024
1 parent 82842bc commit 2252a9b
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 120 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.0"
version = "0.31.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
38 changes: 30 additions & 8 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
@model
```

One can nest models and call another model inside the model function with [`@submodel`](@ref).

```@docs
@submodel
```

### Type

A [`Model`](@ref) can be created by calling the model function, as defined by [`@model`](@ref).
Expand Down Expand Up @@ -110,6 +104,34 @@ Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original
unfix
```

## Models within models

One can include models and call another model inside the model function with `left ~ to_submodel(model)`.

```@docs
to_submodel
```

Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations.

In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref)

```@docs
@submodel
```

In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:

```@docs
prefix
```

Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else

```@docs
returned(::Model)
```

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
Expand All @@ -118,10 +140,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr
@addlogprob!
```

Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref).
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).

```@docs
generated_quantities
returned(::DynamicPPL.Model, ::NamedTuple)
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down
12 changes: 5 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)
returned(model::Model, chain::MCMCChains.Chains)
Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.
Expand All @@ -63,12 +63,12 @@ m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
returned(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing
julia> using Turing
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
Expand All @@ -87,7 +87,7 @@ julia> model = demo(randn(10));
julia> chain = sample(model, MH(), 10);
julia> generated_quantities(model, chain)
julia> returned(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
Expand All @@ -101,9 +101,7 @@ julia> generated_quantities(model, chain)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
Expand Down
9 changes: 7 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
generated_quantities,
extract_priors,
values_as_in_model,
# Samplers
Expand Down Expand Up @@ -122,6 +121,9 @@ export AbstractVarInfo,
decondition,
fix,
unfix,
prefix,
returned,
to_submodel,
# Convenience macros
@addlogprob!,
@submodel,
Expand All @@ -130,7 +132,8 @@ export AbstractVarInfo,
check_model_and_trace,
# Deprecated.
@logprob_str,
@prob_str
@prob_str,
generated_quantities

# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -196,6 +199,8 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end
Expand Down
5 changes: 5 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ function check_tilde_rhs(@nospecialize(x))
end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
check_tilde_rhs(x::ReturnedModelWrapper) = x
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
model = check_tilde_rhs(x.model)
return Sampleable{typeof(model),AutoPrefix}(model)
end

"""
unwrap_right_vn(right, vn)
Expand Down
40 changes: 37 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,17 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
value, logp, vi = tilde_assume(context, right, vn, vi)
return value, acclogp_assume!!(context, vi, logp)
return if is_rhs_model(right)
# Prefix the variables using the `vn`.
rand_like!!(
right,
should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context,
vi,
)
else
value, logp, vi = tilde_assume(context, right, vn, vi)
value, acclogp_assume!!(context, vi, logp)
end
end

# observe
Expand Down Expand Up @@ -159,6 +168,11 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(context, right, left, vname, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
return tilde_observe!!(context, right, left, vi)
end

Expand All @@ -172,6 +186,11 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the
probability of `vi` with the returned value.
"""
function tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
logp, vi = tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down Expand Up @@ -321,8 +340,13 @@ model inputs), accumulate the log probability, and return the sampled value and
Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
"""
function dot_tilde_assume!!(context, right, left, vn, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`.~` with a model on the right-hand side is not supported; please use `~`"
),
)
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
return value, acclogp_assume!!(context, vi, logp), vi
return value, acclogp_assume!!(context, vi, logp)
end

# `dot_assume`
Expand Down Expand Up @@ -573,6 +597,11 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor
name and indices; if needed, these can be accessed through this function, though.
"""
function dot_tilde_observe!!(context, right, left, vn, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
return dot_tilde_observe!!(context, right, left, vi)
end

Expand All @@ -585,6 +614,11 @@ probability, and return the observed value and updated `vi`.
Falls back to `dot_tilde_observe(context, right, left, vi)`.
"""
function dot_tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
logp, vi = dot_tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down
28 changes: 28 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,34 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
end
end

"""
prefix(model::Model, x)
Return `model` but with all random variables prefixed by `x`.
If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
# Examples
```jldoctest
julia> using DynamicPPL: prefix
julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)
julia> rand(prefix(demo(), :my_prefix))
(var"my_prefix.x" = 1,)
julia> # One can also use `Val` to avoid runtime overheads.
rand(prefix(demo(), Val(:my_prefix)))
(var"my_prefix.x" = 1,)
```
"""
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate generated_quantities(model, params) returned(model, params)
Loading

2 comments on commit 2252a9b

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120668

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.31.1 -m "<description of version>" 2252a9b6012da8e2ac56353770a0f848f6874357
git push origin v0.31.1

Please sign in to comment.