Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

link and invlink should correctly work with Selector and thus Gibbs #542

Merged
merged 12 commits into from
Oct 10, 2023
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.18"
version = "0.23.19"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -21,13 +22,20 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
Expand All @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DynamicPPL
using AbstractMCMC: AbstractSampler, AbstractChains
using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict

Expand Down
77 changes: 60 additions & 17 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,33 +902,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
return vi
end

# HACK: We need `SampleFromPrior` to result in ALL values which are in need
# of a transformation to be transformed. `_getvns` will by default return
# an empty iterable for `SampleFromPrior`, so we need to override it here.
# This is quite hacky, but seems safer than changing the behavior of `_getvns`.
_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl)
yebai marked this conversation as resolved.
Show resolved Hide resolved
_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing
function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior)
return map(Returns(nothing), varinfo.metadata)
end

function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model)
return _link(varinfo)
return _link(varinfo, spl)
end

function _link(varinfo::UntypedVarInfo)
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_link_metadata!(varinfo, varinfo.metadata),
_link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _link(varinfo::TypedVarInfo)
function _link(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata)
# TODO: Update logp, etc.
md = _link_metadata_namedtuple!(
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
)
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
@generated function _link_metadata_namedtuple!(
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f)))
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in unconstrained space.
if istrans(varinfo, vn)
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not entirely clear to me why this is a HACK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe HACK isn't quite the right word, but it's somewhat ugly design IMO. And the overall thing of "defining a new _getvns which does something slightly different only for a particular sampler" is hack.

if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns)
return metadata.vals[getrange(metadata, vn)]
end

Expand Down Expand Up @@ -972,32 +998,49 @@ end
function invlink(
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
)
return _invlink(varinfo)
return _invlink(varinfo, spl)
end

function _invlink(varinfo::UntypedVarInfo)
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_invlink_metadata!(varinfo, varinfo.metadata),
_invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _invlink(varinfo::TypedVarInfo)
function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata)
# TODO: Update logp, etc.
md = _invlink_metadata_namedtuple!(
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
)
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
@generated function _invlink_metadata_namedtuple!(
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
Copy link
Member

@yebai yebai Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if inspace(f, space) || length(space) == 0
# we select all variables in `varinfo` if `space = nothing`,
if inspace(f, space) || length(space) == 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space is never nothing though. At "best" it's an empty tuple. Remember, space !== vns. The scenario with vns === nothing only comes into play in the next call.

push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)))
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in constrained space.
if !istrans(varinfo, vn)
# Return early if we're already in constrained space OR if we're not
# supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler.
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns)
return metadata.vals[getrange(metadata, vn)]
end

Expand Down
42 changes: 42 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# A simple "algorithm" which only has `s` variables in its space.
struct MySAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

@testset "varinfo.jl" begin
@testset "TypedVarInfo" begin
@model gdemo(x, y) = begin
Expand Down Expand Up @@ -421,4 +425,42 @@
end
end
end

@testset "VarInfo with selectors" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
varinfo = VarInfo(model)
selector = DynamicPPL.Selector()
spl = Sampler(MySAlg(), model, selector)

vns = DynamicPPL.TestUtils.varnames(model)
vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns)
vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns)
for vn in vns_s
DynamicPPL.updategid!(varinfo, vn, spl)
end

# Should only get the variables subsumed by `@varname(s)`.
@test varinfo[spl] ==
mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s)

# `link`
varinfo_linked = DynamicPPL.link(varinfo, spl, model)
# `s` variables should be linked
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
# `m` variables should NOT be linked
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
# And `varinfo` should be unchanged
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns)

# `invlink`
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model)
# `s` variables should no longer be linked
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s)
# `m` variables should still not be linked
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m)
# And `varinfo_linked` should be unchanged
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
end
end
end