Skip to content

Commit

Permalink
added test for the branch we were currently imssing
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Nov 28, 2024
1 parent 545cfab commit d93006b
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,4 +770,43 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
end
end

@testset "sampling from linked varinfo" begin
# `~`
@model function demo(n=1)
x = Vector(undef, n)
for i in eachindex(x)
x[i] ~ Exponential()
end
return x
end
model1 = demo(1)
varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1)
# Sampling from `model2` should hit the `istrans(vi) == true` branches
# because all the existing variables are linked.
model2 = demo(2)
varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()))
for vn in [@varname(x[1]), @varname(x[2])]
@test DynamicPPL.istrans(varinfo2, vn)
end

# `.~`
@model function demo_dot(n=1)
x ~ Exponential()
if n > 1
y = Vector(undef, n - 1)
y .~ Exponential()
end
return x
end
model1 = demo_dot(1)
varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1)
# Sampling from `model2` should hit the `istrans(vi) == true` branches
# because all the existing variables are linked.
model2 = demo_dot(2)
varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()))
for vn in [@varname(x), @varname(y[1])]
@test DynamicPPL.istrans(varinfo2, vn)
end
end
end

0 comments on commit d93006b

Please sign in to comment.