Skip to content

Commit

Permalink
removed redundant SampleableModelWrapper in favour of
Browse files Browse the repository at this point in the history
`ReturnedModelWrapper` + introduced `rand_like!!` to hide explicit
calls to `_evaluate!!`
  • Loading branch information
torfjelde committed Nov 25, 2024
1 parent 5134ff7 commit 45451f7
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 177 deletions.
194 changes: 21 additions & 173 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,152 +1,5 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)

struct SampleableModelWrapper{M}
model::M
end

"""
to_sampleable(model::Model)
Return a wrapper around `model` which indicates that this model can only be sampled from.
This is mainly meant to be used on the right-hand side of a `~` operator to indicate that
the model can be sampled from but not necessarily evaluated for its log density.
!!! warning
Note that other operations that one typically associate with expressions of the form `left ~ right`
such as [`condition`](@ref) or [`fix`](@ref), will also not work with `to_sampleable`.
!!! warning
It's generally recommended to use [`prefix(::Model, input)`](@ref) when working with submodels
to ensure that the variables in `model` are unique and do not clash with other variables in the
parent model or in other submodels.
# Examples
## Simple example
```jldoctest submodel-to-sampleable; setup=:(using Distributions)
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
end;
julia> @model function demo2(x, y)
a ~ to_sampleable(demo1(x))
return y ~ Uniform(0, a)
end;
```
When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled:
```jldoctest submodel-to-sampleable
julia> vi = VarInfo(demo2(missing, 0.4));
julia> @varname(x) in keys(vi)
true
```
Variable `a` is not tracked since it can be computed from the random variable `x` that was
tracked when running `demo1`:
```jldoctest submodel-to-sampleable
julia> @varname(a) in keys(vi)
false
```
We can check that the log joint probability of the model accumulated in `vi` is correct:
```jldoctest submodel-to-sampleable
julia> x = vi[@varname(x)];
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
true
```
## With prefixing
```jldoctest submodel-to-sampleable-prefix; setup=:(using Distributions)
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
end;
julia> @model function demo2(x, y, z)
a ~ to_sampleable(prefix(demo1(x), :sub1))
b ~ to_sampleable(prefix(demo1(y), :sub2))
return z ~ Uniform(-a, b)
end;
```
When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and
`sub2.x` will be sampled:
```jldoctest submodel-to-sampleable-prefix
julia> vi = VarInfo(demo2(missing, missing, 0.4));
julia> @varname(var"sub1.x") in keys(vi)
true
julia> @varname(var"sub2.x") in keys(vi)
true
```
Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and
`sub2.x` that were tracked when running `demo1`:
```jldoctest submodel-to-sampleable-prefix
julia> @varname(a) in keys(vi)
false
julia> @varname(b) in keys(vi)
false
```
We can check that the log joint probability of the model accumulated in `vi` is correct:
```jldoctest submodel-to-sampleable-prefix
julia> sub1_x = vi[@varname(var"sub1.x")];
julia> sub2_x = vi[@varname(var"sub2.x")];
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
julia> getlogp(vi) ≈ logprior + loglikelihood
true
```
## Different ways of setting the prefix
```jldoctest submodel-to-sampleable-prefix-alts; setup=:(using DynamicPPL, Distributions)
julia> @model inner() = x ~ Normal()
inner (generic function with 2 methods)
julia> # When `prefix` is unspecified, no prefix is used.
@model submodel_noprefix() = a ~ to_sampleable(inner())
submodel_noprefix (generic function with 2 methods)
julia> @varname(x) in keys(VarInfo(submodel_noprefix()))
true
julia> # Using a static string.
@model submodel_prefix_string() = a ~ to_sampleable(prefix(inner(), "my prefix"))
submodel_prefix_string (generic function with 2 methods)
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
true
julia> # Using string interpolation.
@model submodel_prefix_interpolation() = a = to_sampleable(prefix(inner(), "\$(nameof(inner()))"))
submodel_prefix_interpolation (generic function with 2 methods)
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
true
julia> # Or using some arbitrary expression.
@model submodel_prefix_expr() = a ~ to_sampleable(prefix(inner(), 1 + 2))
submodel_prefix_expr (generic function with 2 methods)
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
true
```
"""
to_sampleable(model::Model) = SampleableModelWrapper(model)

"""
need_concretize(expr)
Expand Down Expand Up @@ -325,6 +178,7 @@ function check_tilde_rhs(@nospecialize(x))
end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
check_tilde_rhs(x::ReturnedModelWrapper) = x

"""
unwrap_right_vn(right, vn)
Expand Down Expand Up @@ -574,34 +428,28 @@ function generate_tilde(left, right)
# more selective with our escape. Until that's the case, we remove them all.
return quote
$dist = $right

if $dist isa $(SampleableModelWrapper)
$left, __varinfo__ = $(_evaluate!!)($dist.model, __varinfo__, __context__)
$left
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($dist),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($dist),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
end
end
end
Expand Down
20 changes: 16 additions & 4 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
value, logp, vi = tilde_assume(context, right, vn, vi)
return value, acclogp_assume!!(context, vi, logp)
return if is_rhs_model(right)
rand_like!!(right, context, vi)
else
value, logp, vi = tilde_assume(context, right, vn, vi)
value, acclogp_assume!!(context, vi, logp)
end
end

# observe
Expand Down Expand Up @@ -197,6 +201,7 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(context, right, left, vname, vi)
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
return tilde_observe!!(context, right, left, vi)
end

Expand All @@ -210,6 +215,7 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the
probability of `vi` with the returned value.
"""
function tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
logp, vi = tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down Expand Up @@ -420,8 +426,12 @@ model inputs), accumulate the log probability, and return the sampled value and
Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
"""
function dot_tilde_assume!!(context, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
return value, acclogp_assume!!(context, vi, logp), vi
return if is_rhs_model(right)
rand_like!!(right, context, vi)
else
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
value, acclogp_assume!!(context, vi, logp)
end
end

# `dot_assume`
Expand Down Expand Up @@ -672,6 +682,7 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor
name and indices; if needed, these can be accessed through this function, though.
"""
function dot_tilde_observe!!(context, right, left, vn, vi)
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
return dot_tilde_observe!!(context, right, left, vi)
end

Expand All @@ -684,6 +695,7 @@ probability, and return the observed value and updated `vi`.
Falls back to `dot_tilde_observe(context, right, left, vi)`.
"""
function dot_tilde_observe!!(context, right, left, vi)
is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported"))
logp, vi = dot_tilde_observe(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end
Expand Down
Loading

0 comments on commit 45451f7

Please sign in to comment.