From 14b4b10a3ec9a0a2d61823a68cffd524a9e6dd7c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 12:29:48 +0100 Subject: [PATCH 1/9] fix for linking higher dimensional `Dirichlet` --- src/abstract_varinfo.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index bd7e8d8fb..a12de6154 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -765,6 +765,16 @@ function with_logabsdet_jacobian_and_reconstruct(f, dist, x) return with_logabsdet_jacobian(f, x_recon) end +# NOTE: Necessary to handle product distributions of `Dirichlet` and similar. +function with_logabsdet_jacobian_and_reconstruct( + f::Bijectors.Inverse{<:Bijectors.SimplexBijector}, dist, y +) + (d, ns...) = size(dist) + yreshaped = reshape(y, d - 1, ns...) + x, logjac = with_logabsdet_jacobian(f, yreshaped) + return x, logjac +end + # TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can # just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden. # NOTE: `reconstruct` is no-op if `val` is already of correct shape. From 16daff94217437d38e5ef9cb11da2fd3899c9768 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 12:30:11 +0100 Subject: [PATCH 2/9] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cbad0d688..d85df8388 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.9" +version = "0.24.10" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 7a0841a2a0bc3e32b0d5d48ac9a4ca4df79d0b04 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 12:30:32 +0100 Subject: [PATCH 3/9] bump Bijectors compat entry --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d85df8388..b172bcaee 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.7" BangBang = "0.3" -Bijectors = "0.13" +Bijectors = "0.13.9" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" From 393854cdcd180d400bb2c029ea9e2d2cf902667c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 12:52:42 +0100 Subject: [PATCH 4/9] added tests for high-dim Dirichlet --- test/linking.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/linking.jl b/test/linking.jl index 9103fff67..7b3f1cfb8 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -174,4 +174,25 @@ end end end end + + # Related: https://github.com/TuringLang/Turing.jl/issues/2190 + @testset "High-dim Dirichlet" begin + @model function demo_highdim_dirichlet(ns...) + return x ~ filldist(Dirichlet(ones(2)), ns...) + end + @testset "ns=$ns" for ns in [(3,), (3, 4), (3, 4, 5)] + model = demo_highdim_dirichlet(ns...) + example_values = rand(NamedTuple, model) + vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) + @testset "$(short_varinfo_name(vi))" for vi in vis + # Linked. + vi_linked = if mutable + DynamicPPL.link!!(deepcopy(vi), model) + else + DynamicPPL.link(vi, model) + end + @test length(vi_linked[:]) == prod(ns) + end + end + end end From 462613a17dc382faf95adc0ffdba8705360c168d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 12:58:24 +0100 Subject: [PATCH 5/9] added `reconstruct` for `ArrayLikeVariate` --- src/utils.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index b447fed53..1457677be 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -237,6 +237,11 @@ reconstruct(f, dist, val) = reconstruct(dist, val) reconstruct(::UnivariateDistribution, val::Real) = val reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) +function reconstruct( + ::Distribution{ArrayLikeVariate{N}}, val::AbstractArray{<:Real,N} +) where {N} + return copy(val) +end reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) From 38fc5fdc684450bde753a8c6f9993a25b3be0c13 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 12:59:09 +0100 Subject: [PATCH 6/9] don't test functionality that relies on https://github.com/TuringLang/DistributionsAD.jl/pull/264 yet --- test/linking.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/linking.jl b/test/linking.jl index 7b3f1cfb8..620dbd118 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -180,7 +180,10 @@ end @model function demo_highdim_dirichlet(ns...) return x ~ filldist(Dirichlet(ones(2)), ns...) end - @testset "ns=$ns" for ns in [(3,), (3, 4), (3, 4, 5)] + @testset "ns=$ns" for ns in [ + (3,), + # (3, 4), (3, 4, 5) + ] model = demo_highdim_dirichlet(ns...) example_values = rand(NamedTuple, model) vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) From 274ec7ea24df74d1ec958d3790a1e212888ae03c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 19:59:21 +0100 Subject: [PATCH 7/9] Update test/linking.jl --- test/linking.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/linking.jl b/test/linking.jl index 620dbd118..33d7ba304 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -182,7 +182,8 @@ end end @testset "ns=$ns" for ns in [ (3,), - # (3, 4), (3, 4, 5) + (3, 4), + (3, 4, 5) ] model = demo_highdim_dirichlet(ns...) example_values = rand(NamedTuple, model) From 5df12cd384b761aff64be752431aade93525d24e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Apr 2024 10:34:19 +0100 Subject: [PATCH 8/9] Update test/linking.jl --- test/linking.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/linking.jl b/test/linking.jl index 33d7ba304..620dbd118 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -182,8 +182,7 @@ end end @testset "ns=$ns" for ns in [ (3,), - (3, 4), - (3, 4, 5) + # (3, 4), (3, 4, 5) ] model = demo_highdim_dirichlet(ns...) example_values = rand(NamedTuple, model) From bf3f4c1163146dcd6ab37d6a641c40efbd83ac3b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Apr 2024 10:36:33 +0100 Subject: [PATCH 9/9] Update test/linking.jl --- test/linking.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/linking.jl b/test/linking.jl index 620dbd118..06f6fb6d6 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -182,6 +182,7 @@ end end @testset "ns=$ns" for ns in [ (3,), + # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304 # (3, 4), (3, 4, 5) ] model = demo_highdim_dirichlet(ns...)