From 4de2a0146191304edc7a294b86eaea53daa61716 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 7 Oct 2023 23:47:34 +0100 Subject: [PATCH 01/20] link and invlink should correctly work with Selector etc. --- src/varinfo.jl | 58 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index ddb4caffb..08b986f74 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -903,25 +903,44 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(varinfo) + return _link(varinfo, spl) end -function _link(varinfo::UntypedVarInfo) +function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata), + _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(varinfo::TypedVarInfo) +function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata) + md = _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) # TODO: Update logp, etc. return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end +@generated function _link_metadata!( + varinfo::VarInfo, + metadata::NamedTuple{names}, + ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!( + expr.args, + :(_link_metadata!(varinfo, metadata.$f)) + ) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end function _link_metadata!(varinfo::VarInfo, metadata::Metadata) vns = metadata.vns @@ -972,25 +991,44 @@ end function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) - return _invlink(varinfo) + return _invlink(varinfo, spl) end -function _invlink(varinfo::UntypedVarInfo) +function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata), + _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(varinfo::TypedVarInfo) +function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata) + md = _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) # TODO: Update logp, etc. return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end +@generated function _invlink_metadata!( + varinfo::VarInfo, + metadata::NamedTuple{names}, + ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!( + expr.args, + :(_invlink_metadata!(varinfo, metadata.$f)) + ) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) vns = metadata.vns From 1e4d9f19ee4df18af801fc50bac7e58b56f5f94a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:20:02 +0100 Subject: [PATCH 02/20] more fixes to link and invlink --- src/varinfo.jl | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 08b986f74..d56d5b22d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -909,7 +909,7 @@ end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), + _link_metadata!(varinfo, varinfo.metadata, _getvns(spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -917,22 +917,22 @@ end function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) - # TODO: Update logp, etc. + md = _link_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _link_metadata!( +@generated function _link_metadata_namedtuple!( varinfo::VarInfo, metadata::NamedTuple{names}, + vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 push!( - expr.args, - :(_link_metadata!(varinfo, metadata.$f)) + vals.args, + :(_link_metadata!(varinfo, metadata.$f, vns.$f)) ) else push!(vals.args, :(metadata.$f)) @@ -941,13 +941,13 @@ end return :(NamedTuple{$names}($vals)) end -function _link_metadata!(varinfo::VarInfo, metadata::Metadata) +function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in unconstrained space. - if istrans(varinfo, vn) + if istrans(varinfo, vn) || vn ∉ target_vns return metadata.vals[getrange(metadata, vn)] end @@ -997,7 +997,7 @@ end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns(spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1005,22 +1005,22 @@ end function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) - # TODO: Update logp, etc. + md = _invlink_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _invlink_metadata!( +@generated function _invlink_metadata_namedtuple!( varinfo::VarInfo, metadata::NamedTuple{names}, + vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 push!( - expr.args, - :(_invlink_metadata!(varinfo, metadata.$f)) + vals.args, + :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)) ) else push!(vals.args, :(metadata.$f)) @@ -1029,13 +1029,14 @@ end return :(NamedTuple{$names}($vals)) end -function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) +function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn - # Return early if we're already in constrained space. - if !istrans(varinfo, vn) + # Return early if we're already in constrained space OR if we're not + # supposed to touch this `vn`. + if !istrans(varinfo, vn) || vn ∉ target_vns return metadata.vals[getrange(metadata, vn)] end From bf4fcc66f949b1e0597cc0311d96f12332aa4f4c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:20:44 +0100 Subject: [PATCH 03/20] formatting --- src/varinfo.jl | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index d56d5b22d..dedab5f97 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -917,23 +917,19 @@ end function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _link_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) + md = _link_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @generated function _link_metadata_namedtuple!( - varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space} + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!( - vals.args, - :(_link_metadata!(varinfo, metadata.$f, vns.$f)) - ) + push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1005,23 +1001,19 @@ end function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _invlink_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) + md = _invlink_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @generated function _invlink_metadata_namedtuple!( - varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space} + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!( - vals.args, - :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)) - ) + push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end From f1fde0bbee3ae9213479cb78c0b4803a8ec17bc5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:35:17 +0100 Subject: [PATCH 04/20] added simple tests for usage of selectors --- test/varinfo.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 598ea7814..7f96c071e 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,7 @@ +# A simple "algorithm" which only has `s` variables in its space. +struct MySAlg end +DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) + @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin @@ -421,4 +425,42 @@ end end end + + @testset "VarInfo with selectors" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(model) + selector = DynamicPPL.Selector() + spl = Sampler(MySAlg(), model, selector) + + vns = DynamicPPL.TestUtils.varnames(model) + vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) + vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) + for vn in vns_s + DynamicPPL.updategid!(varinfo, vn, spl) + end + + # Should only get the variables subsumed by `@varname(s)`. + @test varinfo[spl] == + mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s) + + # `link` + varinfo_linked = DynamicPPL.link(varinfo, spl, model) + # `s` variables should be linked + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + # `m` variables should NOT be linked + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + # And `varinfo` should be unchanged + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) + + # `invlink` + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) + # `s` variables should no longer be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) + # `m` variables should still not be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) + # And `varinfo_linked` should be unchanged + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + end + end end From 67770d32a917fd066695814c1cd0bff943c1fb2b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:35:37 +0100 Subject: [PATCH 05/20] bumped patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 08ca184bf..c9805dadb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.18" +version = "0.23.19" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 4bf5f7c7b33fb6fe6fbb28a2aded5fe56b85446f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:36:21 +0100 Subject: [PATCH 06/20] fied typos --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index dedab5f97..3e9d0c204 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -909,7 +909,7 @@ end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, _getvns(spl)), + _link_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -993,7 +993,7 @@ end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, _getvns(spl)), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) From d15df29c74dba13655241102a428437f86e54756 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 00:16:15 +0100 Subject: [PATCH 07/20] added missing _getvns_link for UntypedVarInfo --- src/varinfo.jl | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3e9d0c204..600577796 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -902,6 +902,17 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) return vi end +# HACK: We need `SampleFromPrior` to result in ALL values which are in need +# of a transformation to be transformed. `_getvns` will by default return +# an empty iterable for `SampleFromPrior`, so we need to override it here. +# This is quite hacky, but seems safer than changing the behavior of `_getvns`. +_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) +_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing +_getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( + Base.Returns(nothing), + _getvns(varinfo, spl) +) + function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) return _link(varinfo, spl) end @@ -909,7 +920,7 @@ end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), + _link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -918,7 +929,7 @@ end function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) md = _link_metadata_namedtuple!( - varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @@ -943,7 +954,8 @@ function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in unconstrained space. - if istrans(varinfo, vn) || vn ∉ target_vns + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -993,7 +1005,7 @@ end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1002,7 +1014,7 @@ end function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) md = _invlink_metadata_namedtuple!( - varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @@ -1028,7 +1040,8 @@ function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not # supposed to touch this `vn`. - if !istrans(varinfo, vn) || vn ∉ target_vns + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end From 203aeb370cc3ee725214dbacb9da1e3475f9397d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 00:17:12 +0100 Subject: [PATCH 08/20] simplify `_getvns_link` for TypedVarInfo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 600577796..a00bcf0ee 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -910,7 +910,7 @@ _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) _getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( Base.Returns(nothing), - _getvns(varinfo, spl) + varinfo.metadata ) function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) From b170c0b3da309798b6aa2563665d6ff4a547db33 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 00:21:23 +0100 Subject: [PATCH 09/20] Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a00bcf0ee..0fca2e128 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -908,10 +908,9 @@ end # This is quite hacky, but seems safer than changing the behavior of `_getvns`. _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) _getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing -_getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( - Base.Returns(nothing), - varinfo.metadata -) +function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) + return map(Base.Returns(nothing), varinfo.metadata) +end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) return _link(varinfo, spl) From 1086731c5d7fbeb90b9b534351223da3fd7ba867 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 10:33:30 +0100 Subject: [PATCH 10/20] added Compat as dep so we can make use of certain features, e.g. Returns --- Project.toml | 14 ++++++++------ src/varinfo.jl | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index c9805dadb..e9c88fa9f 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -21,6 +22,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + +[extensions] +DynamicPPLMCMCChainsExt = ["MCMCChains"] + [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.6" @@ -28,6 +35,7 @@ BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1.5.4" +Compat = "4" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" LogDensityProblems = "2" @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1" ZygoteRules = "0.2" julia = "1.6" -[extensions] -DynamicPPLMCMCChainsExt = ["MCMCChains"] - [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - -[weakdeps] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" diff --git a/src/varinfo.jl b/src/varinfo.jl index 0fca2e128..dec1a96a4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -909,7 +909,7 @@ end _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) _getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) - return map(Base.Returns(nothing), varinfo.metadata) + return map(Returns(nothing), varinfo.metadata) end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) From 69055b24099a9d018b1509787efda6032c16672a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 17:38:54 +0100 Subject: [PATCH 11/20] added `subset` which can extract a subset of the varinfo --- src/DynamicPPL.jl | 1 + src/varinfo.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 042931ebb..4a326a7e8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -47,6 +47,7 @@ export AbstractVarInfo, SimpleVarInfo, push!!, empty!!, + subset, getlogp, setlogp!!, acclogp!!, diff --git a/src/varinfo.jl b/src/varinfo.jl index dec1a96a4..03c4302c3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -236,6 +236,66 @@ else _tail(nt::NamedTuple) = Base.tail(nt) end +# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert +# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which +# might result in a `Vector{Any}`. +""" + subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) + +Subset a `varinfo` to only contain the variables `vns`. +""" +function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} + # If all the variables are using the same symbol, then we can just extract that field from the metadata. + metadata = subset(getfield(varinfo.metadata, sym), vns) + return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) + syms = Tuple(unique(map(getsym, vns))) + metadatas = map(syms) do sym + subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) + end + + return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) +end + +""" + subset(metadata::Metadata, vns::AbstractVector{<:VarName}) + +Subset a `metadata` to only contain the variables `vns`. +""" +function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) + # TODO: Should we error if `vns` contains a variable that is not in `metadata`? + indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) + indices = Dict(vn => i for (i, vn) in enumerate(vns)) + # HACK: maintaining consistency between `vals` and `ranges` in scenarios where + # `vns = [@varname(x[2])]` and `metadata` contains `x[1]` and `x[2]` is difficult. + # There are two options: + # 1. Keep ranges as they are and simply `copy` the full `vals`. + # 2. Adjust the ranges to be consistent with the `vals`. + # We choose option 1 for now, though this feels quite hacky. + ranges = metadata.ranges[indices_for_vns] + # vals = mapreduce(Base.Fix1(getindex, metadata.vals), vcat, ranges) + vals = copy(metadata.vals) + + flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) + return Metadata( + indices, + vns, + ranges, + vals, + metadata.dists[indices_for_vns], + metadata.gids, + metadata.orders[indices_for_vns], + flags, + ) +end + const VarView = Union{Int,UnitRange,Vector{Int}} """ From 266c2d6876e536fd7749469fe2499833984c0845 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 17:40:18 +0100 Subject: [PATCH 12/20] added testing of `subset` for `VarInfo` --- test/varinfo.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 7f96c071e..f9037961b 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -463,4 +463,48 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) end end + + @testset "subset" begin + @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = TV(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end + model = demo_subsetting_varinfo() + + @testset "$(short_varinfo_name(varinfo))" for varinfo in [ + VarInfo(model), + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] + + # All variables. + @test isempty(setdiff(keys(varinfo), [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])])) + + @testset "$(convert(Vector{VarName}, vns))" for vns in [ + [@varname(s)], + [@varname(m)], + [@varname(x[1])], + [@varname(x[2])], + [@varname(s), @varname(m)], + [@varname(s), @varname(x[1])], + [@varname(s), @varname(x[2])], + [@varname(m), @varname(x[1])], + [@varname(m), @varname(x[2])], + [@varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1])], + [@varname(s), @varname(m), @varname(x[2])], + [@varname(s), @varname(x[1]), @varname(x[2])], + [@varname(m), @varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ] + varinfo_subset = subset(varinfo, vns) + # Should now only contain the variables in `vns`. + @test isempty(setdiff(keys(varinfo_subset), vns)) + # Values should be the same. + @test [varinfo_subset[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + end + end end From 859edfb407347c011bd6c0ae306220c974f08afa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 17:43:06 +0100 Subject: [PATCH 13/20] formatting --- test/varinfo.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index f9037961b..66a5ae4fd 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -471,16 +471,21 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) x = TV(undef, 2) x[1] ~ Normal(m, sqrt(s)) x[2] ~ Normal(m, sqrt(s)) + return nothing end model = demo_subsetting_varinfo() @testset "$(short_varinfo_name(varinfo))" for varinfo in [ - VarInfo(model), - last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) - ] + VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] # All variables. - @test isempty(setdiff(keys(varinfo), [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])])) + @test isempty( + setdiff( + keys(varinfo), + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ), + ) @testset "$(convert(Vector{VarName}, vns))" for vns in [ [@varname(s)], @@ -498,7 +503,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) [@varname(s), @varname(x[1]), @varname(x[2])], [@varname(m), @varname(x[1]), @varname(x[2])], [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], - ] + ] varinfo_subset = subset(varinfo, vns) # Should now only contain the variables in `vns`. @test isempty(setdiff(keys(varinfo_subset), vns)) From e2e1db7359eda06680d7bd3d00d196a78cc66ba3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:16:43 +0100 Subject: [PATCH 14/20] added implementation of `merge` for `VarInfo` and tests for it --- src/varinfo.jl | 173 +++++++++++++++++++++++++++++++++++++++++++++++- test/varinfo.jl | 40 +++++++++++ 2 files changed, 211 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 03c4302c3..35ca1ef72 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -269,7 +269,7 @@ end Subset a `metadata` to only contain the variables `vns`. """ -function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) +function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) # TODO: Should we error if `vns` contains a variable that is not in `metadata`? indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) indices = Dict(vn => i for (i, vn) in enumerate(vns)) @@ -279,6 +279,7 @@ function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) # 1. Keep ranges as they are and simply `copy` the full `vals`. # 2. Adjust the ranges to be consistent with the `vals`. # We choose option 1 for now, though this feels quite hacky. + # TODO: Only pick the subset of `vals` needed. ranges = metadata.ranges[indices_for_vns] # vals = mapreduce(Base.Fix1(getindex, metadata.vals), vcat, ranges) vals = copy(metadata.vals) @@ -296,6 +297,163 @@ function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) ) end +""" + merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + +Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. +""" +Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right) +Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right) + +function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) + lp = getlogp(varinfo_left) + getlogp(varinfo_right) + # TODO: Is this really the way we want to combine `num_produce`? + num_produce = varinfo_left.num_produce[] + varinfo_right.num_produce[] + return VarInfo(metadata, Ref(lp), Ref(num_produce)) +end + +function merge_metadata( + metadata_left::NamedTuple{names_left}, + metadata_right::NamedTuple{names_right} +) where {names_left, names_right} + # TODO: Improve this. Maybe make `@generated`? + metadata = map(names_left) do sym + if sym in names_right + merge_metadata(getfield(metadata_left, sym), getfield(metadata_right, sym)) + else + getfield(metadata_left, sym) + end + end + names_right_only = filter(∉(names_left), names_right) + metadata_right_only = map(Tuple(names_right_only)) do sym + if !(sym in names_left) + getfield(metadata_right, sym) + end + end + + return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...)) +end + +function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) + # Extract the varnames. + vns_left = metadata_left.vns + vns_right = metadata_right.vns + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(metadata_left.vals) + T_right = eltype(metadata_right.vals) + T = promote_type(T_left, T_right) + # TODO: Is this necessary? + if !(T <: Real) + T = Real + end + + # Determine `eltype` of `dists`. + D_left = eltype(metadata_left.dists) + D_right = eltype(metadata_right.dists) + D = promote_type(D_left, D_right) + # TODO: Is this necessary? + if !(D <: Distribution) + D = Distribution + end + + # Initialize required fields for `metadata`. + vns = VarName[] + idcs = Dict{VarName, Int}() + ranges = Vector{UnitRange{Int}}() + vals = T[] + dists = D[] + gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + orders = Int[] + flags = Dict{String, BitVector}() + # Initialize the `flags`. + for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) + flags[k] = BitVector() + end + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + # `idcs` + idcs[vn] = idx + # `vns` + push!(vns, vn) + if vn in vns_left && vn in vns_right + # `vals`: only valid if they're the length. + vals_left = getval(metadata_left, vn) + vals_right = getval(metadata_right, vn) + @assert length(vals_left) == length(vals_right) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists`: only valid if they're the same. + dists_left = getdist(metadata_left, vn) + dists_right = getdist(metadata_right, vn) + @assert dists_left == dists_right + push!(dists, dists_left) + # `orders`: giving precedence to `metadata_right` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + # Using `metadata_right`; should we? + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + elseif vn in vns_left + # Just extract the metadata from `metadata_left`. + # `vals` + vals_left = getval(metadata_left, vn) + append!(vals, vals_left) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_left = getdist(metadata_left, vn) + push!(dists, dists_left) + # `orders` + push!(orders, getorder(metadata_left, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_left, vn, k)) + end + else + # Just extract the metadata from `metadata_right`. + # `vals` + vals_right = getvals(metadata_right, vn) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_right)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_right = getdist(metadata_right, vn) + push!(dists, dists_right) + # `orders` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + end + end + + return Metadata( + idcs, + vns, + ranges, + vals, + dists, + gids, + orders, + flags, + ) +end + const VarView = Union{Int,UnitRange,Vector{Int}} """ @@ -1434,6 +1592,16 @@ function setorder!(vi::VarInfo, vn::VarName, index::Int) return vi end +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. +""" +getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) +getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] + + ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1444,8 +1612,9 @@ end Check whether `vn` has a true value for `flag` in `vi`. """ function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] + return is_flagged(getmetadata(vi, vn), vn, flag) end +is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)] """ unset_flag!(vi::VarInfo, vn::VarName, flag::String) diff --git a/test/varinfo.jl b/test/varinfo.jl index 66a5ae4fd..f05d73986 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -512,4 +512,44 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + + @testset "merge" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(short_varinfo_name(varinfo))" for varinfo in [ + VarInfo(model), + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] + + vns = DynamicPPL.TestUtils.varnames(model) + + @testset "with itself" begin + # Merging itself should be a no-op. + varinfo_merged = merge(varinfo, varinfo) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with empty" begin + # Merging with an empty `VarInfo` should be a no-op. + varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with different value" begin + x = DynamicPPL.TestUtils.rand(model) + varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns) + # After `merge`, we should have the same values as `x`. + varinfo_merged = merge(varinfo, varinfo_changed) + DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) + end + end + end + end end From a90f2b95bedcd64d63cdea46be4e87642c2fd3bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:19:53 +0100 Subject: [PATCH 15/20] more tests --- test/varinfo.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index f05d73986..a77bb6a13 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -474,6 +474,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) return nothing end model = demo_subsetting_varinfo() + vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] @testset "$(short_varinfo_name(varinfo))" for varinfo in [ VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) @@ -483,11 +484,11 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test isempty( setdiff( keys(varinfo), - [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + vns, ), ) - @testset "$(convert(Vector{VarName}, vns))" for vns in [ + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ [@varname(s)], [@varname(m)], [@varname(x[1])], @@ -504,11 +505,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) [@varname(m), @varname(x[1]), @varname(x[2])], [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - varinfo_subset = subset(varinfo, vns) - # Should now only contain the variables in `vns`. - @test isempty(setdiff(keys(varinfo_subset), vns)) + varinfo_subset = subset(varinfo, vns_subset) + # Should now only contain the variables in `vns_subset`. + @test isempty(setdiff(keys(varinfo_subset), vns_subset)) # Values should be the same. - @test [varinfo_subset[vn] for vn in vns] == [varinfo[vn] for vn in vns] + @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] + + # `merge` with the original. + varinfo_merged = merge(varinfo, varinfo_subset) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end end end @@ -519,9 +528,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) ] - vns = DynamicPPL.TestUtils.varnames(model) - @testset "with itself" begin # Merging itself should be a no-op. varinfo_merged = merge(varinfo, varinfo) From 1d8d50794e022311c0fa321092cf7766ff15d60a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:20:57 +0100 Subject: [PATCH 16/20] formatting --- src/varinfo.jl | 36 ++++++++++++++++-------------------- test/varinfo.jl | 13 +++++-------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 35ca1ef72..8b95277df 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -302,8 +302,11 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right) -Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right) +Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) = + _merge(varinfo_left, varinfo_right) +function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) + return _merge(varinfo_left, varinfo_right) +end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) @@ -314,9 +317,8 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) end function merge_metadata( - metadata_left::NamedTuple{names_left}, - metadata_right::NamedTuple{names_right} -) where {names_left, names_right} + metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} +) where {names_left,names_right} # TODO: Improve this. Maybe make `@generated`? metadata = map(names_left) do sym if sym in names_right @@ -332,7 +334,9 @@ function merge_metadata( end end - return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...)) + return NamedTuple{(names_left..., names_right_only...)}( + tuple(metadata..., metadata_right_only...) + ) end function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) @@ -361,13 +365,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # Initialize required fields for `metadata`. vns = VarName[] - idcs = Dict{VarName, Int}() + idcs = Dict{VarName,Int}() ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` orders = Int[] - flags = Dict{String, BitVector}() + flags = Dict{String,BitVector}() # Initialize the `flags`. for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) flags[k] = BitVector() @@ -442,16 +446,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) end end - return Metadata( - idcs, - vns, - ranges, - vals, - dists, - gids, - orders, - flags, - ) + return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -1601,7 +1596,6 @@ run before sampling `vn`. getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] - ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1614,7 +1608,9 @@ Check whether `vn` has a true value for `flag` in `vi`. function is_flagged(vi::VarInfo, vn::VarName, flag::String) return is_flagged(getmetadata(vi, vn), vn, flag) end -is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)] +function is_flagged(metadata::Metadata, vn::VarName, flag::String) + return metadata.flags[flag][getidx(metadata, vn)] +end """ unset_flag!(vi::VarInfo, vn::VarName, flag::String) diff --git a/test/varinfo.jl b/test/varinfo.jl index a77bb6a13..20e9b9823 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -481,12 +481,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) ] # All variables. - @test isempty( - setdiff( - keys(varinfo), - vns, - ), - ) + @test isempty(setdiff(keys(varinfo), vns)) @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ [@varname(s)], @@ -526,7 +521,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS @testset "$(short_varinfo_name(varinfo))" for varinfo in [ VarInfo(model), - last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())), ] vns = DynamicPPL.TestUtils.varnames(model) @testset "with itself" begin @@ -551,7 +546,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "with different value" begin x = DynamicPPL.TestUtils.rand(model) - varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns) + varinfo_changed = DynamicPPL.TestUtils.update_values!!( + deepcopy(varinfo), x, vns + ) # After `merge`, we should have the same values as `x`. varinfo_merged = merge(varinfo, varinfo_changed) DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) From 2882fbcbd899cd95079aaec387d5e71eaedef8f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 11:27:13 +0100 Subject: [PATCH 17/20] improved merge_metadata for NamedTuple inputs --- src/varinfo.jl | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 8b95277df..df30fb607 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -302,8 +302,9 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) = - _merge(varinfo_left, varinfo_right) +function Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) + return _merge(varinfo_left, varinfo_right) +end function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) return _merge(varinfo_left, varinfo_right) end @@ -316,27 +317,31 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) return VarInfo(metadata, Ref(lp), Ref(num_produce)) end -function merge_metadata( +@generated function merge_metadata( metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} ) where {names_left,names_right} - # TODO: Improve this. Maybe make `@generated`? - metadata = map(names_left) do sym + names = Expr(:tuple) + vals = Expr(:tuple) + # Loop over `names_left` first because we want to preserve the order of the variables. + for sym in names_left + push!(names.args, QuoteNode(sym)) if sym in names_right - merge_metadata(getfield(metadata_left, sym), getfield(metadata_right, sym)) + push!( + vals.args, + :(merge_metadata(metadata_left.$sym, metadata_right.$sym)) + ) else - getfield(metadata_left, sym) + push!(vals.args, :(metadata_left.$sym)) end end + # Loop over remaining variables in `names_right`. names_right_only = filter(∉(names_left), names_right) - metadata_right_only = map(Tuple(names_right_only)) do sym - if !(sym in names_left) - getfield(metadata_right, sym) - end + for sym in names_right_only + push!(names.args, QuoteNode(sym)) + push!(vals.args, :(metadata_right.$sym)) end - return NamedTuple{(names_left..., names_right_only...)}( - tuple(metadata..., metadata_right_only...) - ) + return :(NamedTuple{$names}($vals)) end function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) From e281e1f82859b2b42298b3aa293ed7db31a97497 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:35:28 +0100 Subject: [PATCH 18/20] added proper handling of the `vals` in `subset` --- src/varinfo.jl | 51 +++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index df30fb607..0d30caa0e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -264,25 +264,41 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) end -""" - subset(metadata::Metadata, vns::AbstractVector{<:VarName}) - -Subset a `metadata` to only contain the variables `vns`. -""" function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) # TODO: Should we error if `vns` contains a variable that is not in `metadata`? indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) indices = Dict(vn => i for (i, vn) in enumerate(vns)) - # HACK: maintaining consistency between `vals` and `ranges` in scenarios where - # `vns = [@varname(x[2])]` and `metadata` contains `x[1]` and `x[2]` is difficult. - # There are two options: - # 1. Keep ranges as they are and simply `copy` the full `vals`. - # 2. Adjust the ranges to be consistent with the `vals`. - # We choose option 1 for now, though this feels quite hacky. - # TODO: Only pick the subset of `vals` needed. - ranges = metadata.ranges[indices_for_vns] - # vals = mapreduce(Base.Fix1(getindex, metadata.vals), vcat, ranges) - vals = copy(metadata.vals) + # Construct new `vals` and `ranges`. + vals_original = metadata.vals + ranges_original = metadata.ranges + # Allocate the new `vals`. and `ranges`. + vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) + ranges = similar(ranges_original) + # The new range `r` for `vns[i]` is offset by `offset` and + # has the same length as the original range `r_original`. + # The new `indices` (from above) ensures ordering according to `vns`. + # NOTE: This means that the order of the variables in `vns` defines the order + # in the resulting `varinfo`! This can have performance implications, e.g. + # if in the model we have something like + # + # for i = 1:N + # x[i] ~ Normal() + # end + # + # and we then we do + # + # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) + # + # the resulting `varinfo` will have `vals` ordered differently from the + # original `varinfo`, which can have performance implications. + offset = 0 + for (idx, idx_original) in enumerate(indices_for_vns) + r_original = ranges_original[idx_original] + r = (offset + 1):(offset + length(r_original)) + vals[r] = vals_original[r_original] + ranges[idx] = r + offset = r[end] + end flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) return Metadata( @@ -302,10 +318,7 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -function Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) - return _merge(varinfo_left, varinfo_right) -end -function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) +function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) return _merge(varinfo_left, varinfo_right) end From f54abd290c349b5b16e9343b4a368de8e8780dbd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:35:50 +0100 Subject: [PATCH 19/20] added docs for `subset` and `merge` --- src/varinfo.jl | 110 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0d30caa0e..eb7dd081c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -243,6 +243,116 @@ end subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) Subset a `varinfo` to only contain the variables `vns`. + +!!! warning + The ordering of the variables in the resulting `varinfo` will _not_ + necessarily follow the ordering of the variables in `varinfo`. + Hence care must be taken, in particular when used in conjunction with + other methods which uses the vector-representation of `varinfo`, e.g. + `getindex(varinfo, sampler)` + +# Examples +```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) +julia> @model function demo() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = Vector{Float64}(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> varinfo = VarInfo(model); + +julia> keys(varinfo) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> for (i, vn) in enumerate(keys(varinfo)) + varinfo[vn] = i + end + +julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 + +julia> # Extract one with only `m`. + varinfo_subset1 = subset(varinfo, [@varname(m),]); + + +julia> keys(varinfo_subset1) +1-element Vector{VarName{:m, Setfield.IdentityLens}}: + m + +julia> varinfo_subset1[@varname(m)] +2.0 + +julia> # Extract one with both `s` and `x[2]`. + varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + +julia> keys(varinfo_subset2) +2-element Vector{VarName}: + s + x[2] + +julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +2-element Vector{Float64}: + 1.0 + 4.0 +``` + +`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) + +```jldoctest varinfo-subset +julia> # Merge the two. + varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + +julia> keys(varinfo_subset_merged) +3-element Vector{VarName}: + m + s + x[2] + +julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +3-element Vector{Float64}: + 1.0 + 2.0 + 4.0 + +julia> # Merge the two with the original. + varinfo_merged = merge(varinfo, varinfo_subset_merged); + +julia> keys(varinfo_merged) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 +``` + +# Notes + +## Type-stability + +!!! warning + This function is only type-stable when `vns` contains only varnames + with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will + be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. """ function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) From cb61b1c2257f7fbf791cf8a6c885a8872ce52224 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:36:05 +0100 Subject: [PATCH 20/20] added `subset` and `merge` to documentation --- docs/src/api.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index ddd119816..47a92c07b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,6 +255,8 @@ DynamicPPL.reconstruct #### Utils ```@docs +Base.merge(::VarInfo, ::VarInfo) +DynamicPPL.subset DynamicPPL.unflatten DynamicPPL.tonamedtuple DynamicPPL.varname_leaves