Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Aug 24, 2023
1 parent 1924746 commit 9b7504e
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,51 @@ end
end
end

@testset "LKJCholesky" begin
@testset "uplo=$uplo" for uplo in ['L', 'U']
@model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo)
@testset "d=$d" for d in [2, 3, 5]
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
@testset "$(short_varinfo_name(vi))" for vi in vis
val = vi[@varname(x), dist]
# Ensure that `reconstruct` works as intended.
@test val isa Cholesky
@test val.uplo == uplo

@test length(vi[:]) == d^2
lp = logpdf(dist, val)
lp_model = logjoint(model, vi)
@test lp_model lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d * (d - 1) ÷ 2
# Should now include the log-absdet-jacobian correction.
@test !(getlogp(vi_linked) lp)
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == d^2
@test getlogp(vi_invlinked) lp
end
end
end
end

# Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504
@testset "dirichlet" begin
@testset "Dirichlet" begin
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
@test getlogp(vi) lp
lp_model = logjoint(model, vi)
@test lp_model lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d - 1
Expand Down

0 comments on commit 9b7504e

Please sign in to comment.