diff --git a/Project.toml b/Project.toml index d85df8388..b172bcaee 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.7" BangBang = "0.3" -Bijectors = "0.13" +Bijectors = "0.13.9" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index bd7e8d8fb..a12de6154 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -765,6 +765,16 @@ function with_logabsdet_jacobian_and_reconstruct(f, dist, x) return with_logabsdet_jacobian(f, x_recon) end +# NOTE: Necessary to handle product distributions of `Dirichlet` and similar. +function with_logabsdet_jacobian_and_reconstruct( + f::Bijectors.Inverse{<:Bijectors.SimplexBijector}, dist, y +) + (d, ns...) = size(dist) + yreshaped = reshape(y, d - 1, ns...) + x, logjac = with_logabsdet_jacobian(f, yreshaped) + return x, logjac +end + # TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can # just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden. # NOTE: `reconstruct` is no-op if `val` is already of correct shape. diff --git a/src/utils.jl b/src/utils.jl index b447fed53..1457677be 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -237,6 +237,11 @@ reconstruct(f, dist, val) = reconstruct(dist, val) reconstruct(::UnivariateDistribution, val::Real) = val reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) +function reconstruct( + ::Distribution{ArrayLikeVariate{N}}, val::AbstractArray{<:Real,N} +) where {N} + return copy(val) +end reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) diff --git a/test/linking.jl b/test/linking.jl index 9103fff67..06f6fb6d6 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -174,4 +174,29 @@ end end end end + + # Related: https://github.com/TuringLang/Turing.jl/issues/2190 + @testset "High-dim Dirichlet" begin + @model function demo_highdim_dirichlet(ns...) + return x ~ filldist(Dirichlet(ones(2)), ns...) + end + @testset "ns=$ns" for ns in [ + (3,), + # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304 + # (3, 4), (3, 4, 5) + ] + model = demo_highdim_dirichlet(ns...) + example_values = rand(NamedTuple, model) + vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) + @testset "$(short_varinfo_name(vi))" for vi in vis + # Linked. + vi_linked = if mutable + DynamicPPL.link!!(deepcopy(vi), model) + else + DynamicPPL.link(vi, model) + end + @test length(vi_linked[:]) == prod(ns) + end + end + end end