From 549d9b150078eaffaf91f324d439c133e4314303 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 29 Aug 2023 22:35:00 +0100 Subject: [PATCH] Fix for `LKJCholesky` (#521) * simplification of vectorize and make use of non-dist version in SimpleVarInfo * added special reconstruct for LKJCholeksy * make use of vectorize in setval! for VarInfo * added tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed test_setval! not working when the true value is not a vector * okay now we actually fixed the test_setval! * Update test/test_util.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml (#522) --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- src/simple_varinfo.jl | 2 +- src/utils.jl | 18 +++++++++++++----- src/varinfo.jl | 7 ++++++- test/linking.jl | 39 +++++++++++++++++++++++++++++++++++++-- test/test_util.jl | 13 +++++++++---- test/turing/runtests.jl | 1 + 7 files changed, 68 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index c5b7a0241..a20e3546a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.13" +version = "0.23.14" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b3ffcec8d..025b4aad7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -544,7 +544,7 @@ values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} isempty(vi) && return T[] - return mapreduce(v -> vec([v;]), vcat, values(vi.values)) + return mapreduce(vectorize, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) diff --git a/src/utils.jl b/src/utils.jl index 9a0c9c2b2..d28697127 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -209,11 +209,10 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -vectorize(d, r) = vec(r) -vectorize(d::UnivariateDistribution, r::Real) = [r] -vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r) -vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) -vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL)) +vectorize(d, r) = vectorize(r) +vectorize(r::Real) = [r] +vectorize(r::AbstractArray{<:Real}) = copy(vec(r)) +vectorize(r::Cholesky) = copy(vec(r.UL)) # NOTE: # We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real. @@ -237,6 +236,15 @@ reconstruct(::UnivariateDistribution, val::Real) = val reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) + +function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) + return reconstruct(dist, reshape(val, size(dist))) +end +function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real}) + return Cholesky(val, dist.uplo, 0) +end +reconstruct(::LKJCholesky, val::Cholesky) = val + function reconstruct( ::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector ) diff --git a/src/varinfo.jl b/src/varinfo.jl index c1ccc34b9..60b6e93c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -325,7 +325,12 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) -setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;] +function setval!(md::Metadata, val::AbstractVector, vn::VarName) + return md.vals[getrange(md, vn)] = val +end +function setval!(md::Metadata, val, vn::VarName) + return md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val) +end """ getval(vi::VarInfo, vns::Vector{<:VarName}) diff --git a/test/linking.jl b/test/linking.jl index bb0081780..c9c0c318f 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -91,8 +91,42 @@ end end end + @testset "LKJCholesky" begin + @testset "uplo=$uplo" for uplo in ['L', 'U'] + @model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo) + @testset "d=$d" for d in [2, 3, 5] + model = demo_lkj(d) + dist = LKJCholesky(d, 1.0, uplo) + values_original = rand(model) + vis = DynamicPPL.TestUtils.setup_varinfos( + model, values_original, (@varname(x),) + ) + @testset "$(short_varinfo_name(vi))" for vi in vis + val = vi[@varname(x), dist] + # Ensure that `reconstruct` works as intended. + @test val isa Cholesky + @test val.uplo == uplo + + @test length(vi[:]) == d^2 + lp = logpdf(dist, val) + lp_model = logjoint(model, vi) + @test lp_model ≈ lp + # Linked. + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + @test length(vi_linked[:]) == d * (d - 1) ÷ 2 + # Should now include the log-absdet-jacobian correction. + @test !(getlogp(vi_linked) ≈ lp) + # Invlinked. + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + @test length(vi_invlinked[:]) == d^2 + @test getlogp(vi_invlinked) ≈ lp + end + end + end + end + # Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504 - @testset "dirichlet" begin + @testset "Dirichlet" begin @model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0) @testset "d=$d" for d in [2, 3, 5] model = demo_dirichlet(d) @@ -100,7 +134,8 @@ end @testset "$(short_varinfo_name(vi))" for vi in vis lp = logpdf(Dirichlet(d, 1.0), vi[:]) @test length(vi[:]) == d - @test getlogp(vi) ≈ lp + lp_model = logjoint(model, vi) + @test lp_model ≈ lp # Linked. vi_linked = DynamicPPL.link!!(deepcopy(vi), model) @test length(vi_linked[:]) == d - 1 diff --git a/test/test_util.jl b/test/test_util.jl index 892f7221a..31296f79a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -61,13 +61,18 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1) nt = DynamicPPL.tonamedtuple(var_info) for (k, (vals, names)) in pairs(nt) for (n, v) in zip(names, vals) - chain_val = if Symbol(n) ∉ keys(chain) + if Symbol(n) ∉ keys(chain) # Assume it's a group - vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) else - chain[sample_idx, n, chain_idx] + chain_val = chain[sample_idx, n, chain_idx] + v_true = v end - @test v == chain_val + + @test v_true == chain_val end end end diff --git a/test/turing/runtests.jl b/test/turing/runtests.jl index 7d53cb4db..2c1d5085d 100644 --- a/test/turing/runtests.jl +++ b/test/turing/runtests.jl @@ -10,6 +10,7 @@ setprogress!(false) Random.seed!(100) # load test utilities +include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl")) include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl")) @testset "Turing" begin