Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
Browse files Browse the repository at this point in the history
…cro' into torfjelde/returned-quantities-macro
  • Loading branch information
torfjelde committed Nov 29, 2024
2 parents ecb4737 + b95e7d5 commit 1e238ca
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, dist)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
Expand Down Expand Up @@ -425,7 +425,7 @@ end
# HACK: These methods are only used in the `get_and_set_val!` methods below.
# FIXME: Remove these.
function _link_broadcast_new(vi, vn, dist, r)
b = to_linked_internal_transform(vi, dist)
b = to_linked_internal_transform(vi, vn, dist)
return b(r)
end

Expand Down Expand Up @@ -516,7 +516,7 @@ function get_and_set_val!(
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
# FIXME: This is not great.
acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
Expand Down
43 changes: 43 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,4 +770,47 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
end
end

@testset "sampling from linked varinfo" begin
# `~`
@model function demo(n=1)
x = Vector(undef, n)
for i in eachindex(x)
x[i] ~ Exponential()
end
return x
end
model1 = demo(1)
varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1)
# Sampling from `model2` should hit the `istrans(vi) == true` branches
# because all the existing variables are linked.
model2 = demo(2)
varinfo2 = last(
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
)
for vn in [@varname(x[1]), @varname(x[2])]
@test DynamicPPL.istrans(varinfo2, vn)
end

# `.~`
@model function demo_dot(n=1)
x ~ Exponential()
if n > 1
y = Vector(undef, n - 1)
y .~ Exponential()
end
return x
end
model1 = demo_dot(1)
varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1)
# Sampling from `model2` should hit the `istrans(vi) == true` branches
# because all the existing variables are linked.
model2 = demo_dot(2)
varinfo2 = last(
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
)
for vn in [@varname(x), @varname(y[1])]
@test DynamicPPL.istrans(varinfo2, vn)
end
end
end

0 comments on commit 1e238ca

Please sign in to comment.