Skip to content

Commit

Permalink
Merge in main
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Nov 30, 2024
2 parents 99532e0 + 0a39979 commit d73668d
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 261 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30.6"
version = "0.31.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -49,7 +49,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14"
Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
Expand Down
119 changes: 10 additions & 109 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,49 +77,11 @@ function tilde_assume(
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
vn,
vi,
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
function tilde_assume(::LikelihoodContext, right, vn, vi)
return assume(NoDist(right), vn, vi)
return assume(nodist(right), vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
return assume(rng, sampler, NoDist(right), vn, vi)
return assume(rng, sampler, nodist(right), vn, vi)
end

function tilde_assume(context::PrefixContext, right, vn, vi)
Expand Down Expand Up @@ -257,7 +219,7 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, dist)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
Expand Down Expand Up @@ -328,37 +290,6 @@ function dot_tilde_assume(
end

# `LikelihoodContext`
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
Expand All @@ -368,46 +299,16 @@ function dot_tilde_assume(
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::PriorContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
end
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
end

function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
function dot_tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
)
return dot_tilde_assume(
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
)
end

Expand Down Expand Up @@ -500,7 +401,7 @@ end
# HACK: These methods are only used in the `get_and_set_val!` methods below.
# FIXME: Remove these.
function _link_broadcast_new(vi, vn, dist, r)
b = to_linked_internal_transform(vi, dist)
b = to_linked_internal_transform(vi, vn, dist)
return b(r)
end

Expand Down Expand Up @@ -591,7 +492,7 @@ function get_and_set_val!(
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
# FIXME: This is not great.
acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
Expand Down
29 changes: 8 additions & 21 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ DefaultContext()
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior
julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
PriorContext()
```
"""
setchildcontext
Expand Down Expand Up @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))
julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)
PriorContext()
julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
Expand Down Expand Up @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext <: AbstractContext
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
A leaf context resulting in the exclusion of likelihood terms when running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
struct PriorContext <: AbstractContext end
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext <: AbstractContext
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
A leaf context resulting in the exclusion of prior terms when running the model.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
struct LikelihoodContext <: AbstractContext end
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
Expand Down
6 changes: 4 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
record_post_tilde_assume!(context, vn, right, value, logp, vi)
return value, logp, vi
end
function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi)
function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
)
record_pre_tilde_assume!(context, vn, right, vi)
value, logp, vi = DynamicPPL.tilde_assume(
rng, childcontext(context), sampler, right, vn, vi
Expand Down Expand Up @@ -425,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi)
end

function DynamicPPL.dot_tilde_assume(
rng, context::DebugContext, sampler, right, left, vn, vi
rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi
)
record_pre_dot_tilde_assume!(context, vn, left, right, vi)
value, logp, vi = DynamicPPL.dot_tilde_assume(
Expand Down
8 changes: 8 additions & 0 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
# NOTE(torfjelde): Need this to avoid stack overflow.
function Distributions.rand!(
rng::Random.AbstractRNG,
d::NoDist{Distributions.ArrayLikeVariate{N}},
x::AbstractArray{<:Real,N},
) where {N}
return Distributions.rand!(rng, d.dist, x)
end
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
Expand Down
86 changes: 68 additions & 18 deletions src/test_utils/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,6 @@
#
# Utilities for testing contexts.

"""
test_context_interface(context)
Test that `context` implements the `AbstractContext` interface.
"""
function test_context_interface(context)
# Is a subtype of `AbstractContext`.
@test context isa DynamicPPL.AbstractContext
# Should implement `NodeTrait.`
@test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf}
# If it's a parent.
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent
# Should implement `childcontext` and `setchildcontext`
@test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) ==
context
end
end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
Expand Down Expand Up @@ -60,3 +42,71 @@ function DynamicPPL.dot_tilde_observe(
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
return logp * context.mod, vi
end

# Dummy context to test nested behaviors.
struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext
context::C
end
TestParentContext() = TestParentContext(DefaultContext())
DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestParentContext) = context.context
DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child)
function Base.show(io::IO, c::TestParentContext)
return print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")")
end

"""
test_context(context::AbstractContext, model::Model)
Test that `context` correctly implements the `AbstractContext` interface for `model`.
This method ensures that `context`
- Correctly implements the `AbstractContext` interface.
- Correctly implements the tilde-pipeline.
"""
function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
# `NodeTrait`.
node_trait = DynamicPPL.NodeTrait(context)
# Throw error immediately if it it's missing a `NodeTrait` implementation.
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
throw(ValueError("Invalid NodeTrait: $node_trait"))

# To see change, let's make sure we're using a different leaf context than the current.
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
PriorContext()
else
DefaultContext()
end
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
leafcontext_new

# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
# Setting the child context to a leaf should now change the leafcontext
# accordingly.
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
@test DynamicPPL.childcontext(context_with_new_leaf) ===
DynamicPPL.leafcontext(context_with_new_leaf) ===
leafcontext_new
end

# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
# NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the
# context might alter which variables are present, their names, etc., e.g. `PrefixContext`.
# TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos.
# Untyped varinfo.
varinfo_untyped = DynamicPPL.VarInfo()
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true)
# Typed varinfo.
varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true)
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
Accessors = "0.1"
Bijectors = "0.13.9, 0.14"
Bijectors = "0.13.9, 0.14, 0.15"
Combinatorics = "1"
Compat = "4.3.0"
Distributions = "0.25"
Expand Down
Loading

0 comments on commit d73668d

Please sign in to comment.