From 1110d303ecb9d2c6cb499bf31972bf03da00d917 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 29 Nov 2024 11:40:24 +0100 Subject: [PATCH] Fixed incorrect calls to `to_linked_internal_transform` (#726) * fixed calls to `to_linked_internal_transform` * fixed incorrect call to `acclogp_assume!!` * added test for the branch we were currently imssing * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 6 ++--- test/varinfo.jl | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a0e11b65b..489c64c57 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -219,7 +219,7 @@ function assume( else r = init(rng, dist, sampler) if istrans(vi) - f = to_linked_internal_transform(vi, dist) + f = to_linked_internal_transform(vi, vn, dist) push!!(vi, vn, f(r), dist, sampler) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) @@ -401,7 +401,7 @@ end # HACK: These methods are only used in the `get_and_set_val!` methods below. # FIXME: Remove these. function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, dist) + b = to_linked_internal_transform(vi, vn, dist) return b(r) end @@ -492,7 +492,7 @@ function get_and_set_val!( push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. # FIXME: This is not great. - acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) + acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else diff --git a/test/varinfo.jl b/test/varinfo.jl index a2425ebc8..c45fb47e0 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -770,4 +770,47 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) end end + + @testset "sampling from linked varinfo" begin + # `~` + @model function demo(n=1) + x = Vector(undef, n) + for i in eachindex(x) + x[i] ~ Exponential() + end + return x + end + model1 = demo(1) + varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) + # Sampling from `model2` should hit the `istrans(vi) == true` branches + # because all the existing variables are linked. + model2 = demo(2) + varinfo2 = last( + DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) + ) + for vn in [@varname(x[1]), @varname(x[2])] + @test DynamicPPL.istrans(varinfo2, vn) + end + + # `.~` + @model function demo_dot(n=1) + x ~ Exponential() + if n > 1 + y = Vector(undef, n - 1) + y .~ Exponential() + end + return x + end + model1 = demo_dot(1) + varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) + # Sampling from `model2` should hit the `istrans(vi) == true` branches + # because all the existing variables are linked. + model2 = demo_dot(2) + varinfo2 = last( + DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) + ) + for vn in [@varname(x), @varname(y[1])] + @test DynamicPPL.istrans(varinfo2, vn) + end + end end