diff --git a/src/debug_utils.jl b/src/debug_utils.jl index dcd3fcc37..f486482a9 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -332,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) record_post_tilde_assume!(context, vn, right, value, logp, vi) return value, logp, vi end -function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi) +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi +) record_pre_tilde_assume!(context, vn, right, vi) value, logp, vi = DynamicPPL.tilde_assume( rng, childcontext(context), sampler, right, vn, vi @@ -425,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi) end function DynamicPPL.dot_tilde_assume( - rng, context::DebugContext, sampler, right, left, vn, vi + rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi ) record_pre_dot_tilde_assume!(context, vn, left, right, vi) value, logp, vi = DynamicPPL.dot_tilde_assume( diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index ec663f5cc..93bb02d3b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -71,16 +71,6 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || throw(ValueError("Invalid NodeTrait: $node_trait")) - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context - childcontext_new = TestParentContext() - @test DynamicPPL.childcontext( - DynamicPPL.setchildcontext(context, childcontext_new) - ) == childcontext_new - end - # To see change, let's make sure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext PriorContext() @@ -90,11 +80,21 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == leafcontext_new - # Setting the child context to a leaf should now change the leafcontext accordingly. - context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new) - @test DynamicPPL.setchildcontext(context_with_new_leaf) === - DynamicPPL.setleafcontext(context_with_new_leaf) === - leafcontext_new + # The interface methods. + if node_trait isa DynamicPPL.IsParent + # `childcontext` and `setchildcontext` + # With new child context + childcontext_new = TestParentContext() + @test DynamicPPL.childcontext( + DynamicPPL.setchildcontext(context, childcontext_new) + ) == childcontext_new + # Setting the child context to a leaf should now change the leafcontext + # accordingly. + context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new) + @test DynamicPPL.childcontext(context_with_new_leaf) === + DynamicPPL.leafcontext(context_with_new_leaf) === + leafcontext_new + end # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. diff --git a/test/debug_utils.jl b/test/debug_utils.jl index a7a9f7b71..d0ef27348 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,9 +1,9 @@ @testset "check_model" begin @testset "context interface" begin - # HACK: Require a model to instantiate it, so let's just grab one. - model = first(DynamicPPL.TestUtils.DEMO_MODELS) - context = DynamicPPL.DebugUtils.DebugContext(model) - DynamicPPL.TestUtils.test_context(context, model) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + context = DynamicPPL.DebugUtils.DebugContext(model) + DynamicPPL.TestUtils.test_context(context, model) + end end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS @@ -14,7 +14,7 @@ # Check that the trace contains all the variables in the model. varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace) for vn in DynamicPPL.TestUtils.varnames(model) - @test vn in varnames_in_trace + @test vn in varnames_in_traces end # Quick checks for `show` of trace.