diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f3c5171b0..18db6c7e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -257,7 +257,7 @@ function assume( else r = init(rng, dist, sampler) if istrans(vi) - f = to_linked_internal_transform(vi, dist) + f = to_linked_internal_transform(vi, vn, dist) push!!(vi, vn, f(r), dist, sampler) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) @@ -500,7 +500,7 @@ end # HACK: These methods are only used in the `get_and_set_val!` methods below. # FIXME: Remove these. function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, dist) + b = to_linked_internal_transform(vi, vn, dist) return b(r) end @@ -591,7 +591,7 @@ function get_and_set_val!( push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. # FIXME: This is not great. - acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) + acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else