Skip to content

Commit

Permalink
Fixing CI from #711 (#729)
Browse files Browse the repository at this point in the history
* Fix wrong function being called

* Don't test setchildcontext on leaf contexts

* Update src/test_utils/contexts.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed method ambiguities for DebugContext

* test the context interface for DebugContext on multiple models

* Update src/debug_utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 28, 2024
1 parent d635a17 commit 2344689
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
6 changes: 4 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
record_post_tilde_assume!(context, vn, right, value, logp, vi)
return value, logp, vi
end
function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi)
function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
)
record_pre_tilde_assume!(context, vn, right, vi)
value, logp, vi = DynamicPPL.tilde_assume(
rng, childcontext(context), sampler, right, vn, vi
Expand Down Expand Up @@ -425,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi)
end

function DynamicPPL.dot_tilde_assume(
rng, context::DebugContext, sampler, right, left, vn, vi
rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi
)
record_pre_dot_tilde_assume!(context, vn, left, right, vi)
value, logp, vi = DynamicPPL.dot_tilde_assume(
Expand Down
30 changes: 15 additions & 15 deletions src/test_utils/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,6 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
throw(ValueError("Invalid NodeTrait: $node_trait"))

# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
end

# To see change, let's make sure we're using a different leaf context than the current.
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
PriorContext()
Expand All @@ -90,11 +80,21 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
leafcontext_new

# Setting the child context to a leaf should now change the leafcontext accordingly.
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
@test DynamicPPL.setchildcontext(context_with_new_leaf) ===
DynamicPPL.setleafcontext(context_with_new_leaf) ===
leafcontext_new
# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
# Setting the child context to a leaf should now change the leafcontext
# accordingly.
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
@test DynamicPPL.childcontext(context_with_new_leaf) ===
DynamicPPL.leafcontext(context_with_new_leaf) ===
leafcontext_new
end

# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
Expand Down
10 changes: 5 additions & 5 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@testset "check_model" begin
@testset "context interface" begin
# HACK: Require a model to instantiate it, so let's just grab one.
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
context = DynamicPPL.DebugUtils.DebugContext(model)
DynamicPPL.TestUtils.test_context(context, model)
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
context = DynamicPPL.DebugUtils.DebugContext(model)
DynamicPPL.TestUtils.test_context(context, model)
end
end

@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
Expand All @@ -14,7 +14,7 @@
# Check that the trace contains all the variables in the model.
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
for vn in DynamicPPL.TestUtils.varnames(model)
@test vn in varnames_in_trace
@test vn in varnames_in_traces
end

# Quick checks for `show` of trace.
Expand Down

0 comments on commit 2344689

Please sign in to comment.