diff --git a/test/varinfo.jl b/test/varinfo.jl index a2425ebc8..4c306599c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -770,4 +770,43 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) end end + + @testset "sampling from linked varinfo" begin + # `~` + @model function demo(n=1) + x = Vector(undef, n) + for i in eachindex(x) + x[i] ~ Exponential() + end + return x + end + model1 = demo(1) + varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) + # Sampling from `model2` should hit the `istrans(vi) == true` branches + # because all the existing variables are linked. + model2 = demo(2) + varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())) + for vn in [@varname(x[1]), @varname(x[2])] + @test DynamicPPL.istrans(varinfo2, vn) + end + + # `.~` + @model function demo_dot(n=1) + x ~ Exponential() + if n > 1 + y = Vector(undef, n - 1) + y .~ Exponential() + end + return x + end + model1 = demo_dot(1) + varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) + # Sampling from `model2` should hit the `istrans(vi) == true` branches + # because all the existing variables are linked. + model2 = demo_dot(2) + varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())) + for vn in [@varname(x), @varname(y[1])] + @test DynamicPPL.istrans(varinfo2, vn) + end + end end