Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing CI from #711 #729

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
record_post_tilde_assume!(context, vn, right, value, logp, vi)
return value, logp, vi
end
function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi)
function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
)
record_pre_tilde_assume!(context, vn, right, vi)
value, logp, vi = DynamicPPL.tilde_assume(
rng, childcontext(context), sampler, right, vn, vi
Expand Down Expand Up @@ -425,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi)
end

function DynamicPPL.dot_tilde_assume(
rng, context::DebugContext, sampler, right, left, vn, vi
rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi
)
record_pre_dot_tilde_assume!(context, vn, left, right, vi)
value, logp, vi = DynamicPPL.dot_tilde_assume(
Expand Down
30 changes: 15 additions & 15 deletions src/test_utils/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,6 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
throw(ValueError("Invalid NodeTrait: $node_trait"))

# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
end

# To see change, let's make sure we're using a different leaf context than the current.
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
PriorContext()
Expand All @@ -90,11 +80,21 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
leafcontext_new

# Setting the child context to a leaf should now change the leafcontext accordingly.
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
@test DynamicPPL.setchildcontext(context_with_new_leaf) ===
DynamicPPL.setleafcontext(context_with_new_leaf) ===
leafcontext_new
# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
# Setting the child context to a leaf should now change the leafcontext
# accordingly.
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
@test DynamicPPL.childcontext(context_with_new_leaf) ===
DynamicPPL.leafcontext(context_with_new_leaf) ===
leafcontext_new
end

# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
Expand Down
10 changes: 5 additions & 5 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@testset "check_model" begin
@testset "context interface" begin
# HACK: Require a model to instantiate it, so let's just grab one.
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
context = DynamicPPL.DebugUtils.DebugContext(model)
DynamicPPL.TestUtils.test_context(context, model)
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
context = DynamicPPL.DebugUtils.DebugContext(model)
DynamicPPL.TestUtils.test_context(context, model)
end
end

@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
Expand All @@ -14,7 +14,7 @@
# Check that the trace contains all the variables in the model.
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
for vn in DynamicPPL.TestUtils.varnames(model)
@test vn in varnames_in_trace
@test vn in varnames_in_traces
end

# Quick checks for `show` of trace.
Expand Down
Loading