diff --git a/Project.toml b/Project.toml index 08ca184bf..e9c88fa9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.18" +version = "0.23.19" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -8,6 +8,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -21,6 +22,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + +[extensions] +DynamicPPLMCMCChainsExt = ["MCMCChains"] + [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.6" @@ -28,6 +35,7 @@ BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1.5.4" +Compat = "4" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" LogDensityProblems = "2" @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1" ZygoteRules = "0.2" julia = "1.6" -[extensions] -DynamicPPLMCMCChainsExt = ["MCMCChains"] - [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - -[weakdeps] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" diff --git a/docs/src/api.md b/docs/src/api.md index ddd119816..47a92c07b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,6 +255,8 @@ DynamicPPL.reconstruct #### Utils ```@docs +Base.merge(::VarInfo, ::VarInfo) +DynamicPPL.subset DynamicPPL.unflatten DynamicPPL.tonamedtuple DynamicPPL.varname_leaves diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 042931ebb..4a326a7e8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -47,6 +47,7 @@ export AbstractVarInfo, SimpleVarInfo, push!!, empty!!, + subset, getlogp, setlogp!!, acclogp!!, diff --git a/src/varinfo.jl b/src/varinfo.jl index ddb4caffb..eb7dd081c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -236,6 +236,347 @@ else _tail(nt::NamedTuple) = Base.tail(nt) end +# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert +# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which +# might result in a `Vector{Any}`. +""" + subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) + +Subset a `varinfo` to only contain the variables `vns`. + +!!! warning + The ordering of the variables in the resulting `varinfo` will _not_ + necessarily follow the ordering of the variables in `varinfo`. + Hence care must be taken, in particular when used in conjunction with + other methods which uses the vector-representation of `varinfo`, e.g. + `getindex(varinfo, sampler)` + +# Examples +```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) +julia> @model function demo() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = Vector{Float64}(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> varinfo = VarInfo(model); + +julia> keys(varinfo) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> for (i, vn) in enumerate(keys(varinfo)) + varinfo[vn] = i + end + +julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 + +julia> # Extract one with only `m`. + varinfo_subset1 = subset(varinfo, [@varname(m),]); + + +julia> keys(varinfo_subset1) +1-element Vector{VarName{:m, Setfield.IdentityLens}}: + m + +julia> varinfo_subset1[@varname(m)] +2.0 + +julia> # Extract one with both `s` and `x[2]`. + varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + +julia> keys(varinfo_subset2) +2-element Vector{VarName}: + s + x[2] + +julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +2-element Vector{Float64}: + 1.0 + 4.0 +``` + +`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) + +```jldoctest varinfo-subset +julia> # Merge the two. + varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + +julia> keys(varinfo_subset_merged) +3-element Vector{VarName}: + m + s + x[2] + +julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +3-element Vector{Float64}: + 1.0 + 2.0 + 4.0 + +julia> # Merge the two with the original. + varinfo_merged = merge(varinfo, varinfo_subset_merged); + +julia> keys(varinfo_merged) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 +``` + +# Notes + +## Type-stability + +!!! warning + This function is only type-stable when `vns` contains only varnames + with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will + be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. +""" +function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} + # If all the variables are using the same symbol, then we can just extract that field from the metadata. + metadata = subset(getfield(varinfo.metadata, sym), vns) + return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) + syms = Tuple(unique(map(getsym, vns))) + metadatas = map(syms) do sym + subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) + end + + return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) +end + +function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) + # TODO: Should we error if `vns` contains a variable that is not in `metadata`? + indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) + indices = Dict(vn => i for (i, vn) in enumerate(vns)) + # Construct new `vals` and `ranges`. + vals_original = metadata.vals + ranges_original = metadata.ranges + # Allocate the new `vals`. and `ranges`. + vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) + ranges = similar(ranges_original) + # The new range `r` for `vns[i]` is offset by `offset` and + # has the same length as the original range `r_original`. + # The new `indices` (from above) ensures ordering according to `vns`. + # NOTE: This means that the order of the variables in `vns` defines the order + # in the resulting `varinfo`! This can have performance implications, e.g. + # if in the model we have something like + # + # for i = 1:N + # x[i] ~ Normal() + # end + # + # and we then we do + # + # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) + # + # the resulting `varinfo` will have `vals` ordered differently from the + # original `varinfo`, which can have performance implications. + offset = 0 + for (idx, idx_original) in enumerate(indices_for_vns) + r_original = ranges_original[idx_original] + r = (offset + 1):(offset + length(r_original)) + vals[r] = vals_original[r_original] + ranges[idx] = r + offset = r[end] + end + + flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) + return Metadata( + indices, + vns, + ranges, + vals, + metadata.dists[indices_for_vns], + metadata.gids, + metadata.orders[indices_for_vns], + flags, + ) +end + +""" + merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + +Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. +""" +function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + return _merge(varinfo_left, varinfo_right) +end + +function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) + lp = getlogp(varinfo_left) + getlogp(varinfo_right) + # TODO: Is this really the way we want to combine `num_produce`? + num_produce = varinfo_left.num_produce[] + varinfo_right.num_produce[] + return VarInfo(metadata, Ref(lp), Ref(num_produce)) +end + +@generated function merge_metadata( + metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} +) where {names_left,names_right} + names = Expr(:tuple) + vals = Expr(:tuple) + # Loop over `names_left` first because we want to preserve the order of the variables. + for sym in names_left + push!(names.args, QuoteNode(sym)) + if sym in names_right + push!( + vals.args, + :(merge_metadata(metadata_left.$sym, metadata_right.$sym)) + ) + else + push!(vals.args, :(metadata_left.$sym)) + end + end + # Loop over remaining variables in `names_right`. + names_right_only = filter(∉(names_left), names_right) + for sym in names_right_only + push!(names.args, QuoteNode(sym)) + push!(vals.args, :(metadata_right.$sym)) + end + + return :(NamedTuple{$names}($vals)) +end + +function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) + # Extract the varnames. + vns_left = metadata_left.vns + vns_right = metadata_right.vns + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(metadata_left.vals) + T_right = eltype(metadata_right.vals) + T = promote_type(T_left, T_right) + # TODO: Is this necessary? + if !(T <: Real) + T = Real + end + + # Determine `eltype` of `dists`. + D_left = eltype(metadata_left.dists) + D_right = eltype(metadata_right.dists) + D = promote_type(D_left, D_right) + # TODO: Is this necessary? + if !(D <: Distribution) + D = Distribution + end + + # Initialize required fields for `metadata`. + vns = VarName[] + idcs = Dict{VarName,Int}() + ranges = Vector{UnitRange{Int}}() + vals = T[] + dists = D[] + gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + orders = Int[] + flags = Dict{String,BitVector}() + # Initialize the `flags`. + for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) + flags[k] = BitVector() + end + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + # `idcs` + idcs[vn] = idx + # `vns` + push!(vns, vn) + if vn in vns_left && vn in vns_right + # `vals`: only valid if they're the length. + vals_left = getval(metadata_left, vn) + vals_right = getval(metadata_right, vn) + @assert length(vals_left) == length(vals_right) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists`: only valid if they're the same. + dists_left = getdist(metadata_left, vn) + dists_right = getdist(metadata_right, vn) + @assert dists_left == dists_right + push!(dists, dists_left) + # `orders`: giving precedence to `metadata_right` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + # Using `metadata_right`; should we? + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + elseif vn in vns_left + # Just extract the metadata from `metadata_left`. + # `vals` + vals_left = getval(metadata_left, vn) + append!(vals, vals_left) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_left = getdist(metadata_left, vn) + push!(dists, dists_left) + # `orders` + push!(orders, getorder(metadata_left, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_left, vn, k)) + end + else + # Just extract the metadata from `metadata_right`. + # `vals` + vals_right = getvals(metadata_right, vn) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_right)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_right = getdist(metadata_right, vn) + push!(dists, dists_right) + # `orders` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + end + end + + return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) +end + const VarView = Union{Int,UnitRange,Vector{Int}} """ @@ -902,33 +1243,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) return vi end +# HACK: We need `SampleFromPrior` to result in ALL values which are in need +# of a transformation to be transformed. `_getvns` will by default return +# an empty iterable for `SampleFromPrior`, so we need to override it here. +# This is quite hacky, but seems safer than changing the behavior of `_getvns`. +_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) +_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing +function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) + return map(Returns(nothing), varinfo.metadata) +end + function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(varinfo) + return _link(varinfo, spl) end -function _link(varinfo::UntypedVarInfo) +function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata), + _link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(varinfo::TypedVarInfo) +function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata) - # TODO: Update logp, etc. + md = _link_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -function _link_metadata!(varinfo::VarInfo, metadata::Metadata) +@generated function _link_metadata_namedtuple!( + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f))) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end +function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in unconstrained space. - if istrans(varinfo, vn) + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -972,32 +1339,49 @@ end function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) - return _invlink(varinfo) + return _invlink(varinfo, spl) end -function _invlink(varinfo::UntypedVarInfo) +function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(varinfo::TypedVarInfo) +function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata) - # TODO: Update logp, etc. + md = _invlink_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) +@generated function _invlink_metadata_namedtuple!( + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f))) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end +function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn - # Return early if we're already in constrained space. - if !istrans(varinfo, vn) + # Return early if we're already in constrained space OR if we're not + # supposed to touch this `vn`. + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -1331,6 +1715,15 @@ function setorder!(vi::VarInfo, vn::VarName, index::Int) return vi end +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. +""" +getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) +getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] + ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1341,7 +1734,10 @@ end Check whether `vn` has a true value for `flag` in `vi`. """ function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] + return is_flagged(getmetadata(vi, vn), vn, flag) +end +function is_flagged(metadata::Metadata, vn::VarName, flag::String) + return metadata.flags[flag][getidx(metadata, vn)] end """ diff --git a/test/varinfo.jl b/test/varinfo.jl index 598ea7814..20e9b9823 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,7 @@ +# A simple "algorithm" which only has `s` variables in its space. +struct MySAlg end +DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) + @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin @@ -421,4 +425,135 @@ end end end + + @testset "VarInfo with selectors" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(model) + selector = DynamicPPL.Selector() + spl = Sampler(MySAlg(), model, selector) + + vns = DynamicPPL.TestUtils.varnames(model) + vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) + vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) + for vn in vns_s + DynamicPPL.updategid!(varinfo, vn, spl) + end + + # Should only get the variables subsumed by `@varname(s)`. + @test varinfo[spl] == + mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s) + + # `link` + varinfo_linked = DynamicPPL.link(varinfo, spl, model) + # `s` variables should be linked + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + # `m` variables should NOT be linked + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + # And `varinfo` should be unchanged + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) + + # `invlink` + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) + # `s` variables should no longer be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) + # `m` variables should still not be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) + # And `varinfo_linked` should be unchanged + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + end + end + + @testset "subset" begin + @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = TV(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + return nothing + end + model = demo_subsetting_varinfo() + vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] + + @testset "$(short_varinfo_name(varinfo))" for varinfo in [ + VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] + + # All variables. + @test isempty(setdiff(keys(varinfo), vns)) + + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ + [@varname(s)], + [@varname(m)], + [@varname(x[1])], + [@varname(x[2])], + [@varname(s), @varname(m)], + [@varname(s), @varname(x[1])], + [@varname(s), @varname(x[2])], + [@varname(m), @varname(x[1])], + [@varname(m), @varname(x[2])], + [@varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1])], + [@varname(s), @varname(m), @varname(x[2])], + [@varname(s), @varname(x[1]), @varname(x[2])], + [@varname(m), @varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ] + varinfo_subset = subset(varinfo, vns_subset) + # Should now only contain the variables in `vns_subset`. + @test isempty(setdiff(keys(varinfo_subset), vns_subset)) + # Values should be the same. + @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] + + # `merge` with the original. + varinfo_merged = merge(varinfo, varinfo_subset) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + end + end + + @testset "merge" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(short_varinfo_name(varinfo))" for varinfo in [ + VarInfo(model), + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())), + ] + vns = DynamicPPL.TestUtils.varnames(model) + @testset "with itself" begin + # Merging itself should be a no-op. + varinfo_merged = merge(varinfo, varinfo) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with empty" begin + # Merging with an empty `VarInfo` should be a no-op. + varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with different value" begin + x = DynamicPPL.TestUtils.rand(model) + varinfo_changed = DynamicPPL.TestUtils.update_values!!( + deepcopy(varinfo), x, vns + ) + # After `merge`, we should have the same values as `x`. + varinfo_merged = merge(varinfo, varinfo_changed) + DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) + end + end + end + end end