From 361c45e2430f1b7a2d46f89c10f79b3fba0e052a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Nov 2024 22:50:18 +0100 Subject: [PATCH 1/4] fixed calls to `to_linked_internal_transform` --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f3c5171b0..0434f8751 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -257,7 +257,7 @@ function assume( else r = init(rng, dist, sampler) if istrans(vi) - f = to_linked_internal_transform(vi, dist) + f = to_linked_internal_transform(vi, vn, dist) push!!(vi, vn, f(r), dist, sampler) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) @@ -500,7 +500,7 @@ end # HACK: These methods are only used in the `get_and_set_val!` methods below. # FIXME: Remove these. function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, dist) + b = to_linked_internal_transform(vi, vn, dist) return b(r) end From 545cfabd68437f07d2cf08cc38885d8113176f6e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 28 Nov 2024 10:23:01 +0100 Subject: [PATCH 2/4] fixed incorrect call to `acclogp_assume!!` --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0434f8751..18db6c7e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -591,7 +591,7 @@ function get_and_set_val!( push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. # FIXME: This is not great. - acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) + acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else From d93006bd25756ea3e37c3e1d7a12336fedf9aaa9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 28 Nov 2024 15:55:40 +0100 Subject: [PATCH 3/4] added test for the branch we were currently imssing --- test/varinfo.jl | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index a2425ebc8..4c306599c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -770,4 +770,43 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) end end + + @testset "sampling from linked varinfo" begin + # `~` + @model function demo(n=1) + x = Vector(undef, n) + for i in eachindex(x) + x[i] ~ Exponential() + end + return x + end + model1 = demo(1) + varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) + # Sampling from `model2` should hit the `istrans(vi) == true` branches + # because all the existing variables are linked. + model2 = demo(2) + varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())) + for vn in [@varname(x[1]), @varname(x[2])] + @test DynamicPPL.istrans(varinfo2, vn) + end + + # `.~` + @model function demo_dot(n=1) + x ~ Exponential() + if n > 1 + y = Vector(undef, n - 1) + y .~ Exponential() + end + return x + end + model1 = demo_dot(1) + varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) + # Sampling from `model2` should hit the `istrans(vi) == true` branches + # because all the existing variables are linked. + model2 = demo_dot(2) + varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())) + for vn in [@varname(x), @varname(y[1])] + @test DynamicPPL.istrans(varinfo2, vn) + end + end end From 64ff18af48edd4a4f666f677d904df6cc9349a17 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 28 Nov 2024 16:18:04 +0100 Subject: [PATCH 4/4] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/varinfo.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 4c306599c..c45fb47e0 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -785,7 +785,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())) + varinfo2 = last( + DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) + ) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -804,7 +806,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())) + varinfo2 = last( + DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) + ) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end