Skip to content

Commit

Permalink
Move link!!/invlink!! tests back into DynamicPPL
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 29, 2024
1 parent 356a787 commit 68afaad
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,57 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
end
end

@testset "link!! and invlink!!" begin
@model gdemo(x, y) = begin
s ~ InverseGamma(2, 3)
m ~ Uniform(0, 2)
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
end
model = gdemo(1.0, 2.0)

# Check that instantiating the model does not perform linking
vi = VarInfo()
meta = vi.metadata
model(vi, SampleFromUniform())
@test all(x -> !istrans(vi, x), meta.vns)

# Check that linking and invlinking set the `trans` flag accordingly
v = copy(meta.vals)
link!!(vi, model)
@test all(x -> istrans(vi, x), meta.vns)
invlink!!(vi, model)
@test all(x -> !istrans(vi, x), meta.vns)
@test meta.vals v atol = 1e-10

# Check that linking and invlinking preserves the values
vi = TypedVarInfo(vi)
meta = vi.metadata
@test all(x -> !istrans(vi, x), meta.s.vns)
@test all(x -> !istrans(vi, x), meta.m.vns)
v_s = copy(meta.s.vals)
v_m = copy(meta.m.vals)
link!!(vi, model)
@test all(x -> istrans(vi, x), meta.s.vns)
@test all(x -> istrans(vi, x), meta.m.vns)
invlink!!(vi, model)
@test all(x -> !istrans(vi, x), meta.s.vns)
@test all(x -> !istrans(vi, x), meta.m.vns)
@test meta.s.vals v_s atol = 1e-10
@test meta.m.vals v_m atol = 1e-10

# Transform only one variable (`s`) but not the others (`m`)
spl = DynamicPPL.Sampler(MySAlg(), model)
link!!(vi, spl, model)
@test all(x -> istrans(vi, x), meta.s.vns)
@test all(x -> !istrans(vi, x), meta.m.vns)
invlink!!(vi, spl, model)
@test all(x -> !istrans(vi, x), meta.s.vns)
@test all(x -> !istrans(vi, x), meta.m.vns)
@test meta.s.vals v_s atol = 1e-10
@test meta.m.vals v_m atol = 1e-10
end

@testset "istrans" begin
@model demo_constrained() = x ~ truncated(Normal(), 0, Inf)
model = demo_constrained()
Expand Down

0 comments on commit 68afaad

Please sign in to comment.