diff --git a/Project.toml b/Project.toml index 08ca184bf..e9c88fa9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.18" +version = "0.23.19" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -8,6 +8,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -21,6 +22,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + +[extensions] +DynamicPPLMCMCChainsExt = ["MCMCChains"] + [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.6" @@ -28,6 +35,7 @@ BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1.5.4" +Compat = "4" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" LogDensityProblems = "2" @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1" ZygoteRules = "0.2" julia = "1.6" -[extensions] -DynamicPPLMCMCChainsExt = ["MCMCChains"] - [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - -[weakdeps] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 042931ebb..8e3a778ad 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -3,6 +3,7 @@ module DynamicPPL using AbstractMCMC: AbstractSampler, AbstractChains using AbstractPPL using Bijectors +using Compat using Distributions using OrderedCollections: OrderedDict diff --git a/src/varinfo.jl b/src/varinfo.jl index ddb4caffb..3e7dc119f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -902,33 +902,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) return vi end +# HACK: We need `SampleFromPrior` to result in ALL values which are in need +# of a transformation to be transformed. `_getvns` will by default return +# an empty iterable for `SampleFromPrior`, so we need to override it here. +# This is quite hacky, but seems safer than changing the behavior of `_getvns`. +_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) +_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing +function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) + return map(Returns(nothing), varinfo.metadata) +end + function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(varinfo) + return _link(varinfo, spl) end -function _link(varinfo::UntypedVarInfo) +function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata), + _link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(varinfo::TypedVarInfo) +function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata) - # TODO: Update logp, etc. + md = _link_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -function _link_metadata!(varinfo::VarInfo, metadata::Metadata) +@generated function _link_metadata_namedtuple!( + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f))) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end +function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in unconstrained space. - if istrans(varinfo, vn) + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -972,32 +998,49 @@ end function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) - return _invlink(varinfo) + return _invlink(varinfo, spl) end -function _invlink(varinfo::UntypedVarInfo) +function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(varinfo::TypedVarInfo) +function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata) - # TODO: Update logp, etc. + md = _invlink_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) +@generated function _invlink_metadata_namedtuple!( + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f))) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end +function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn - # Return early if we're already in constrained space. - if !istrans(varinfo, vn) + # Return early if we're already in constrained space OR if we're not + # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end diff --git a/test/varinfo.jl b/test/varinfo.jl index 598ea7814..7f96c071e 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,7 @@ +# A simple "algorithm" which only has `s` variables in its space. +struct MySAlg end +DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) + @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin @@ -421,4 +425,42 @@ end end end + + @testset "VarInfo with selectors" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(model) + selector = DynamicPPL.Selector() + spl = Sampler(MySAlg(), model, selector) + + vns = DynamicPPL.TestUtils.varnames(model) + vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) + vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) + for vn in vns_s + DynamicPPL.updategid!(varinfo, vn, spl) + end + + # Should only get the variables subsumed by `@varname(s)`. + @test varinfo[spl] == + mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s) + + # `link` + varinfo_linked = DynamicPPL.link(varinfo, spl, model) + # `s` variables should be linked + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + # `m` variables should NOT be linked + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + # And `varinfo` should be unchanged + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) + + # `invlink` + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) + # `s` variables should no longer be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) + # `m` variables should still not be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) + # And `varinfo_linked` should be unchanged + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + end + end end