From d804ef159c9efb030ab2aab7dd9aa7f33a38bc27 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 14 Oct 2024 17:13:50 +0100 Subject: [PATCH] Allowing pushing new symbols to TypedVarInfo --- src/varinfo.jl | 28 ++++++++++++++++++++++------ test/varinfo.jl | 12 ++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4b229d828..13674555f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1832,14 +1832,31 @@ function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + end + + sym = getsym(vn) + if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + # The NamedTuple doesn't have an entry for this variable, let's add one. + val = tovec(r) + md = Metadata( + Dict(vn => 1), + [vn], + [1:length(val)], + val, + [dist], + [gidset], + [get_num_produce(vi)], + Dict{String,BitVector}("trans" => [false], "del" => [false]), + ) + vi = Accessors.@set vi.metadata[sym] = md + else + meta = getmetadata(vi, vn) + push!(meta, vn, r, dist, gidset, get_num_produce(vi)) end - meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, gidset, get_num_produce(vi)) - return vi end @@ -1870,7 +1887,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) - return meta end diff --git a/test/varinfo.jl b/test/varinfo.jl index 88439425a..308f6d5b7 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end + + @testset "push!! to TypedVarInfo" begin + vn_x = @varname x + vn_y = @varname y + untyped_vi = VarInfo() + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + typed_vi = TypedVarInfo(untyped_vi) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + @test typed_vi[vn_x] == 1.0 + @test typed_vi[vn_y] == 2.0 + end + @testset "setgid!" begin vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata