Skip to content

Commit

Permalink
link and invlink should correctly work with Selector and thus `…
Browse files Browse the repository at this point in the history
…Gibbs` (#542)

* link and invlink should correctly work with Selector etc.

* more fixes to link and invlink

* formatting

* added simple tests for usage of selectors

* bumped patch version

* fied typos

* added missing _getvns_link for UntypedVarInfo

* simplify `_getvns_link` for TypedVarInfo

* Update src/varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added Compat as dep so we can make use of certain features, e.g. Returns

* forgot using Compat

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2023
1 parent d204fcb commit 0289358
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 24 deletions.
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.18"
version = "0.23.19"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -21,13 +22,20 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
Expand All @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DynamicPPL
using AbstractMCMC: AbstractSampler, AbstractChains
using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict

Expand Down
77 changes: 60 additions & 17 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,33 +902,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
return vi
end

# HACK: We need `SampleFromPrior` to result in ALL values which are in need
# of a transformation to be transformed. `_getvns` will by default return
# an empty iterable for `SampleFromPrior`, so we need to override it here.
# This is quite hacky, but seems safer than changing the behavior of `_getvns`.
_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl)
_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing
function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior)
return map(Returns(nothing), varinfo.metadata)
end

function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model)
return _link(varinfo)
return _link(varinfo, spl)
end

function _link(varinfo::UntypedVarInfo)
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_link_metadata!(varinfo, varinfo.metadata),
_link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _link(varinfo::TypedVarInfo)
function _link(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata)
# TODO: Update logp, etc.
md = _link_metadata_namedtuple!(
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
)
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
@generated function _link_metadata_namedtuple!(
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f)))
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in unconstrained space.
if istrans(varinfo, vn)
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
if istrans(varinfo, vn) || (target_vns !== nothing && vn target_vns)
return metadata.vals[getrange(metadata, vn)]
end

Expand Down Expand Up @@ -972,32 +998,49 @@ end
function invlink(
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
)
return _invlink(varinfo)
return _invlink(varinfo, spl)
end

function _invlink(varinfo::UntypedVarInfo)
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_invlink_metadata!(varinfo, varinfo.metadata),
_invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _invlink(varinfo::TypedVarInfo)
function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata)
# TODO: Update logp, etc.
md = _invlink_metadata_namedtuple!(
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
)
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
@generated function _invlink_metadata_namedtuple!(
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)))
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in constrained space.
if !istrans(varinfo, vn)
# Return early if we're already in constrained space OR if we're not
# supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler.
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
if !istrans(varinfo, vn) || (target_vns !== nothing && vn target_vns)
return metadata.vals[getrange(metadata, vn)]
end

Expand Down
42 changes: 42 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# A simple "algorithm" which only has `s` variables in its space.
struct MySAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

@testset "varinfo.jl" begin
@testset "TypedVarInfo" begin
@model gdemo(x, y) = begin
Expand Down Expand Up @@ -421,4 +425,42 @@
end
end
end

@testset "VarInfo with selectors" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
varinfo = VarInfo(model)
selector = DynamicPPL.Selector()
spl = Sampler(MySAlg(), model, selector)

vns = DynamicPPL.TestUtils.varnames(model)
vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns)
vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns)
for vn in vns_s
DynamicPPL.updategid!(varinfo, vn, spl)
end

# Should only get the variables subsumed by `@varname(s)`.
@test varinfo[spl] ==
mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s)

# `link`
varinfo_linked = DynamicPPL.link(varinfo, spl, model)
# `s` variables should be linked
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
# `m` variables should NOT be linked
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
# And `varinfo` should be unchanged
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns)

# `invlink`
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model)
# `s` variables should no longer be linked
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s)
# `m` variables should still not be linked
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m)
# And `varinfo_linked` should be unchanged
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
end
end
end

0 comments on commit 0289358

Please sign in to comment.