Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible Improvements to FixedContext #710

Closed
wants to merge 10 commits into from
11 changes: 11 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ function unwrap_right_left_vns(
left::AbstractArray,
vn::VarName,
)
# Need to check that we don't end up double-counting log-probabilities.
combined_axes = Broadcast.combine_axes(left, right)
if prod(length, combined_axes) > length(left)
throw(
ArgumentError(
"a `.~` statement cannot result in a broadcasted expression with more elements than the left-hand side",
),
)
end

# Extract the sub-varnames.
vns = map(CartesianIndices(left)) do i
return Accessors.IndexLens(Tuple(i)) ∘ vn
end
Expand Down
173 changes: 70 additions & 103 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,6 @@ function tilde_assume(
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
vn,
vi,
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
function tilde_assume(::LikelihoodContext, right, vn, vi)
return assume(NoDist(right), vn, vi)
end
Expand Down Expand Up @@ -328,37 +290,6 @@ function dot_tilde_assume(
end

# `LikelihoodContext`
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
Expand All @@ -368,47 +299,83 @@ function dot_tilde_assume(
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
end
# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
end

function dot_tilde_assume(
rng::Random.AbstractRNG,
context::PriorContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
end
return dot_tilde_assume(
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
)
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
# `FixedContext`
function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
if !has_fixed_symbol(context, first(vns))
# Defer to `childcontext`.
return tilde_assume(childcontext(context), right, left, vns, vi)
end

# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
# We _might_ also have some of the variables fixed, but not all.
logp = 0
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
# then be compiled away in cases where the `Symbol` is not present.
left_bc = Broadcast.broadcastable(left)
right_bc = Broadcast.broadcastable(right)
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
vn = vns[I_left...]
if hasfixed(context, vn)
left[I_left...] = getfixed(context, vn)
else
# Defer to `tilde_assume`.
left[I_left...], logp_inner, vi = tilde_assume(
childcontext(context), right_bc[I_right...], vn, vi
)
logp += logp_inner
end
end
end

return left, logp, vi
end

function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
return dot_tilde_assume(
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
)
function dot_tilde_assume(
rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi
)
if !has_fixed_symbol(context, first(vns))
# Defer to `childcontext`.
return tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi)
end
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
# So we need to check each of the vns.
logp = 0
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
# then be compiled away in cases where the `Symbol` is not present.
left_bc = Broadcast.broadcastable(left)
right_bc = Broadcast.broadcastable(right)
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
vn = vns[I_left...]
if hasfixed(context, vn)
left[I_left...] = getfixed(context, vn)
else
# Defer to `tilde_assume`.
left[I_left...], logp_inner, vi = tilde_assume(
rng, childcontext(context), sampler, right_bc[I_right...], vn, vi
)
logp += logp_inner
end
end
end

return left, logp, vi
end

"""
Expand Down
36 changes: 15 additions & 21 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ DefaultContext()
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior

julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
PriorContext()
```
"""
setchildcontext
Expand Down Expand Up @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))

julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)
PriorContext()

julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
Expand Down Expand Up @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext <: AbstractContext

The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
A leaf context resulting in the exclusion of likelihood terms when running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
struct PriorContext <: AbstractContext end
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext <: AbstractContext

The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
A leaf context resulting in the exclusion of prior terms when running the model.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
struct LikelihoodContext <: AbstractContext end
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
Expand Down Expand Up @@ -514,6 +501,13 @@ NodeTrait(::FixedContext) = IsParent()
childcontext(context::FixedContext) = context.context
setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child)

has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn)

has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn)
@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names,sym}
return sym in names
end

"""
hasfixed(context::AbstractContext, vn::VarName)

Expand Down
9 changes: 9 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,4 +729,13 @@ module Issue537 end
res = model()
@test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}())
end

@testset "invalid .~ expressions" begin
@model function demo_with_invalid_dot_tilde()
m = Matrix{Float64}(undef, 1, 2)
return m .~ [Normal(); Normal()]
end

@test_throws ArgumentError demo_with_invalid_dot_tilde()()
end
end
Loading