Skip to content

Commit

Permalink
Try #360:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jan 27, 2022
2 parents 8990bfb + 0b304db commit c5b8a7f
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 96 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.17.3"
version = "0.17.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
4 changes: 4 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,8 @@ include("test_utils.jl")
@deprecate acclogp!(vi, logp) acclogp!!(vi, logp)
@deprecate resetlogp!(vi) resetlogp!!(vi)

@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!(
vi, trans, vn
)

end # module
75 changes: 49 additions & 26 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
Expand All @@ -64,15 +64,15 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
Expand All @@ -86,7 +86,7 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
Expand Down Expand Up @@ -194,7 +194,9 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn]
# x = vi[vn]
r_raw = getindex_raw(vi, vn)
r = maybe_invlink(vi, vn, dist, r_raw)
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
end

Expand All @@ -211,16 +213,23 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r))
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
# Otherwise we just extract it.
# r = vi[vn]
r_raw = getindex_raw(vi, vn)
r = maybe_invlink(vi, vn, dist, r_raw)
end
else
r = init(rng, dist, sampler)
push!!(vi, vn, r, dist, sampler)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
Expand Down Expand Up @@ -286,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
Expand All @@ -305,7 +314,7 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
Expand All @@ -326,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
Expand All @@ -345,7 +354,7 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
Expand Down Expand Up @@ -390,7 +399,9 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = vi[vns]
# r = vi[vns]
r_raw = getindex_raw(vi, vns)
r = maybe_invlink(vi, vn, dist, r_raw)
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
Expand Down Expand Up @@ -423,7 +434,8 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = reshape(vi[vec(vns)], size(vns))
r_raw = getindex_raw(vi, vec(vns))
r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
return r, lp, vi
end
Expand Down Expand Up @@ -462,19 +474,24 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r_raw = getindex_raw(vi, vns)
r = maybe_invlink(vi, vns, dist, r_raw)
end
else
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
push!!(vi, vn, r[:, i], dist, spl)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl)
# `push!!` sets the trans-flag to `false` by default.
setttrans!!(vi, true, vn)
else
push!!(vi, vn, r[:, i], dist, spl)
end
end
end
return r
Expand All @@ -496,12 +513,13 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
# r = reshape(vi[vec(vns)], size(vns))
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -511,8 +529,13 @@ function get_and_set_val!(
# 1. Figure out the broadcast size and use a `foreach`.
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
push!!.(Ref(vi), vns, r, dists, Ref(spl))
settrans!.(Ref(vi), false, vns)
if istrans(vi)
push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.(Ref(vi), true, vns)
else
push!!.(Ref(vi), vns, r, dists, Ref(spl))
end
end
return r
end
Expand Down
Loading

0 comments on commit c5b8a7f

Please sign in to comment.