Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some clean up of contexts #711

Merged
merged 20 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7488238
fixed incorrect implementation of `dot_tilde_assume` for `PrefixContext`
torfjelde Nov 1, 2024
1d211c5
removed `vars` field from `PriorContext` and `LikelihoodContext` as
torfjelde Nov 1, 2024
fcccfd2
Merge branch 'master' into torfjelde/context-cleanup
penelopeysm Nov 2, 2024
f978287
replaced `NoDist` with `nodist`
torfjelde Nov 5, 2024
ae84112
fixed method ambiguity issue
torfjelde Nov 5, 2024
881e28d
added missing `Distributions.rand!` definition for `NoDist`
torfjelde Nov 5, 2024
1f6fc62
added more elaborate testing of evaluation of contexts
torfjelde Nov 5, 2024
55e24e9
added `DynamicPPL.TestUtils.test_context` for testing contexts and re…
torfjelde Nov 5, 2024
e5b7e44
added proper testing for PrefixContext of all demo models
torfjelde Nov 5, 2024
e1c8fd1
formatting
torfjelde Nov 5, 2024
23bc877
Merge branch 'master' into torfjelde/context-cleanup
torfjelde Nov 6, 2024
decaa78
Merge branch 'master' into torfjelde/context-cleanup
torfjelde Nov 8, 2024
b67f51b
added some dropped tests
torfjelde Nov 11, 2024
1da4c6e
Update src/test_utils.jl
torfjelde Nov 11, 2024
b551356
Update src/test_utils.jl
torfjelde Nov 11, 2024
1b91a72
Merge branch 'master' into torfjelde/context-cleanup
penelopeysm Nov 11, 2024
247f754
Update test/debug_utils.jl
torfjelde Nov 12, 2024
ebbda53
bump patch version
torfjelde Nov 12, 2024
8d41d7b
Merge branch 'master' into torfjelde/context-cleanup
torfjelde Nov 25, 2024
527ef0c
Merge remote-tracking branch 'origin/master' into torfjelde/context-c…
penelopeysm Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 7 additions & 106 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,49 +77,11 @@ function tilde_assume(
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
vn,
vi,
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
function tilde_assume(::LikelihoodContext, right, vn, vi)
return assume(NoDist(right), vn, vi)
return assume(nodist(right), vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
return assume(rng, sampler, NoDist(right), vn, vi)
return assume(rng, sampler, nodist(right), vn, vi)
end

function tilde_assume(context::PrefixContext, right, vn, vi)
Expand Down Expand Up @@ -328,37 +290,6 @@ function dot_tilde_assume(
end

# `LikelihoodContext`
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
Expand All @@ -368,46 +299,16 @@ function dot_tilde_assume(
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::PriorContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
end
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
end

function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
function dot_tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
)
return dot_tilde_assume(
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
)
end

Expand Down
29 changes: 8 additions & 21 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ DefaultContext()
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior

julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
PriorContext()
```
"""
setchildcontext
Expand Down Expand Up @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))

julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)
PriorContext()

julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
Expand Down Expand Up @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext <: AbstractContext

The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
A leaf context resulting in the exclusion of likelihood terms when running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
struct PriorContext <: AbstractContext end
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext <: AbstractContext

The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
A leaf context resulting in the exclusion of prior terms when running the model.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
struct LikelihoodContext <: AbstractContext end
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
Expand Down
8 changes: 8 additions & 0 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
# NOTE(torfjelde): Need this to avoid stack overflow.
function Distributions.rand!(
rng::Random.AbstractRNG,
d::NoDist{Distributions.ArrayLikeVariate{N}},
x::AbstractArray{<:Real,N},
) where {N}
return Distributions.rand!(rng, d.dist, x)
end
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
Expand Down
87 changes: 69 additions & 18 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1039,24 +1039,7 @@ function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...;
return test_sampler_on_demo_models(sampler, args...; kwargs...)
end

"""
test_context_interface(context)

Test that `context` implements the `AbstractContext` interface.
"""
function test_context_interface(context)
# Is a subtype of `AbstractContext`.
@test context isa DynamicPPL.AbstractContext
# Should implement `NodeTrait.`
@test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf}
# If it's a parent.
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent
# Should implement `childcontext` and `setchildcontext`
@test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) ==
context
end
end

# Testing for contexts.
"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
Expand Down Expand Up @@ -1097,4 +1080,72 @@ function DynamicPPL.dot_tilde_observe(
return logp * context.mod, vi
end

# Dummy context to test nested behaviors.
struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext
context::C
end
TestParentContext() = TestParentContext(DefaultContext())
DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestParentContext) = context.context
DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child)
function Base.show(io::IO, c::TestParentContext)
return print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")")
end

"""
test_context(context::AbstractContext, model::Model)

Test that `context` correctly implements the `AbstractContext` interface for `model`.

This method ensures that `context`
- Correctly implements the `AbstractContext` interface.
- Correctly implements the tilde-pipeline.
"""
function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
# `NodeTrait`.
node_trait = DynamicPPL.NodeTrait(context)
# Throw error immediately if it it's missing a `NodeTrait` implementation.
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
throw(ValueError("Invalid NodeTrait: $node_trait"))

# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
end

# To see change, let's make sure we're using a different leaf context than the current.
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
PriorContext()
else
DefaultContext()
end
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
leafcontext_new

# Setting the child context to a leaf should now change the leafcontext accordingly.
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
@test childcontext(context_with_new_leaf) ===
leafcontext(context_with_new_leaf) ===
leafcontext_new

# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
# NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the
# context might alter which variables are present, their names, etc., e.g. `PrefixContext`.
# TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos.
# Untyped varinfo.
varinfo_untyped = DynamicPPL.VarInfo()
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true)
# Typed varinfo.
varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true)
end

end
Loading
Loading