Skip to content

Commit

Permalink
Merge branch 'master' into wct/mooncake-perf
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt authored Dec 4, 2024
2 parents 73fbf34 + 2252a9b commit b21af0b
Show file tree
Hide file tree
Showing 15 changed files with 459 additions and 123 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ jobs:

- uses: julia-actions/julia-processcoverage@v1

- uses: codecov/codecov-action@v4
- uses: codecov/codecov-action@v5
with:
file: lcov.info
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.0"
version = "0.31.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
38 changes: 30 additions & 8 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
@model
```

One can nest models and call another model inside the model function with [`@submodel`](@ref).

```@docs
@submodel
```

### Type

A [`Model`](@ref) can be created by calling the model function, as defined by [`@model`](@ref).
Expand Down Expand Up @@ -110,6 +104,34 @@ Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original
unfix
```

## Models within models

One can include models and call another model inside the model function with `left ~ to_submodel(model)`.

```@docs
to_submodel
```

Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations.

In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref)

```@docs
@submodel
```

In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:

```@docs
prefix
```

Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else

```@docs
returned(::Model)
```

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
Expand All @@ -118,10 +140,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr
@addlogprob!
```

Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref).
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).

```@docs
generated_quantities
returned(::DynamicPPL.Model, ::NamedTuple)
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down
12 changes: 5 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)
returned(model::Model, chain::MCMCChains.Chains)
Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.
Expand All @@ -63,12 +63,12 @@ m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
returned(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing
julia> using Turing
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
Expand All @@ -87,7 +87,7 @@ julia> model = demo(randn(10));
julia> chain = sample(model, MH(), 10);
julia> generated_quantities(model, chain)
julia> returned(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
Expand All @@ -101,9 +101,7 @@ julia> generated_quantities(model, chain)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
Expand Down
9 changes: 7 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
generated_quantities,
extract_priors,
values_as_in_model,
# Samplers
Expand Down Expand Up @@ -122,6 +121,9 @@ export AbstractVarInfo,
decondition,
fix,
unfix,
prefix,
returned,
to_submodel,
# Convenience macros
@addlogprob!,
@submodel,
Expand All @@ -130,7 +132,8 @@ export AbstractVarInfo,
check_model_and_trace,
# Deprecated.
@logprob_str,
@prob_str
@prob_str,
generated_quantities

# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -196,6 +199,8 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end
Expand Down
5 changes: 5 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ function check_tilde_rhs(@nospecialize(x))
end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
check_tilde_rhs(x::ReturnedModelWrapper) = x
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
model = check_tilde_rhs(x.model)
return Sampleable{typeof(model),AutoPrefix}(model)
end

"""
unwrap_right_vn(right, vn)
Expand Down
40 changes: 37 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,17 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
value, logp, vi = tilde_assume(context, right, vn, vi)
return value, acclogp_assume!!(context, vi, logp)
return if is_rhs_model(right)
# Prefix the variables using the `vn`.
rand_like!!(
right,
should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context,
vi,
)
else
value, logp, vi = tilde_assume(context, right, vn, vi)
value, acclogp_assume!!(context, vi, logp)
end
end

# observe
Expand Down Expand Up @@ -159,6 +168,11 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(context, right, left, vname, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
return tilde_observe!!(context, right, left, vi)
end

Expand All @@ -172,6 +186,11 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the
probability of `vi` with the returned value.
"""
function tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
logp, vi = tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down Expand Up @@ -321,8 +340,13 @@ model inputs), accumulate the log probability, and return the sampled value and
Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
"""
function dot_tilde_assume!!(context, right, left, vn, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`.~` with a model on the right-hand side is not supported; please use `~`"
),
)
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
return value, acclogp_assume!!(context, vi, logp), vi
return value, acclogp_assume!!(context, vi, logp)
end

# `dot_assume`
Expand Down Expand Up @@ -573,6 +597,11 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor
name and indices; if needed, these can be accessed through this function, though.
"""
function dot_tilde_observe!!(context, right, left, vn, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
return dot_tilde_observe!!(context, right, left, vi)
end

Expand All @@ -585,6 +614,11 @@ probability, and return the observed value and updated `vi`.
Falls back to `dot_tilde_observe(context, right, left, vi)`.
"""
function dot_tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
logp, vi = dot_tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down
28 changes: 28 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,34 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
end
end

"""
prefix(model::Model, x)
Return `model` but with all random variables prefixed by `x`.
If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
# Examples
```jldoctest
julia> using DynamicPPL: prefix
julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)
julia> rand(prefix(demo(), :my_prefix))
(var"my_prefix.x" = 1,)
julia> # One can also use `Val` to avoid runtime overheads.
rand(prefix(demo(), Val(:my_prefix)))
(var"my_prefix.x" = 1,)
```
"""
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate generated_quantities(model, params) returned(model, params)
Loading

0 comments on commit b21af0b

Please sign in to comment.