Skip to content

Commit

Permalink
Merge branch 'master' into compathelper/new_version/2024-11-29-00-10-…
Browse files Browse the repository at this point in the history
…54-349-01708615278
  • Loading branch information
mhauru authored Nov 29, 2024
2 parents 25d9ef5 + 1110d30 commit 3d10037
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, dist)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
Expand Down Expand Up @@ -401,7 +401,7 @@ end
# HACK: These methods are only used in the `get_and_set_val!` methods below.
# FIXME: Remove these.
function _link_broadcast_new(vi, vn, dist, r)
b = to_linked_internal_transform(vi, dist)
b = to_linked_internal_transform(vi, vn, dist)
return b(r)
end

Expand Down Expand Up @@ -492,7 +492,7 @@ function get_and_set_val!(
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
# FIXME: This is not great.
acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
Expand Down
43 changes: 43 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,4 +770,47 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
end
end

@testset "sampling from linked varinfo" begin
# `~`
@model function demo(n=1)
x = Vector(undef, n)
for i in eachindex(x)
x[i] ~ Exponential()
end
return x
end
model1 = demo(1)
varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1)
# Sampling from `model2` should hit the `istrans(vi) == true` branches
# because all the existing variables are linked.
model2 = demo(2)
varinfo2 = last(
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
)
for vn in [@varname(x[1]), @varname(x[2])]
@test DynamicPPL.istrans(varinfo2, vn)
end

# `.~`
@model function demo_dot(n=1)
x ~ Exponential()
if n > 1
y = Vector(undef, n - 1)
y .~ Exponential()
end
return x
end
model1 = demo_dot(1)
varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1)
# Sampling from `model2` should hit the `istrans(vi) == true` branches
# because all the existing variables are linked.
model2 = demo_dot(2)
varinfo2 = last(
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
)
for vn in [@varname(x), @varname(y[1])]
@test DynamicPPL.istrans(varinfo2, vn)
end
end
end

0 comments on commit 3d10037

Please sign in to comment.