diff --git a/test/varinfo.jl b/test/varinfo.jl index 7a67a6383..792fb7d10 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -387,6 +387,57 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end + @testset "link!! and invlink!!" begin + @model gdemo(x, y) = begin + s ~ InverseGamma(2, 3) + m ~ Uniform(0, 2) + x ~ Normal(m, sqrt(s)) + y ~ Normal(m, sqrt(s)) + end + model = gdemo(1.0, 2.0) + + # Check that instantiating the model does not perform linking + vi = VarInfo() + meta = vi.metadata + model(vi, SampleFromUniform()) + @test all(x -> !istrans(vi, x), meta.vns) + + # Check that linking and invlinking set the `trans` flag accordingly + v = copy(meta.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.vns) + @test meta.vals ≈ v atol = 1e-10 + + # Check that linking and invlinking preserves the values + vi = TypedVarInfo(vi) + meta = vi.metadata + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + v_s = copy(meta.s.vals) + v_m = copy(meta.m.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable (`s`) but not the others (`m`) + spl = DynamicPPL.Sampler(MySAlg(), model) + link!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + invlink!!(vi, spl, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + end + @testset "istrans" begin @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) model = demo_constrained()