Skip to content

Commit 27ba772

Browse files
authored
For VarInfo, fix merge and allow push!!ing new Symbols (#690)
* Fix treatment of gid in merge(::Metadata) * Allowing pushing new symbols to TypedVarInfo * Bump patch version to 0.30.1
1 parent 1d10278 commit 27ba772

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30"
3+
version = "0.30.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/varinfo.jl

+29-7
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
490490
ranges = Vector{UnitRange{Int}}()
491491
vals = T[]
492492
dists = D[]
493-
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
493+
gids = Set{Selector}[]
494494
orders = Int[]
495495
flags = Dict{String,BitVector}()
496496
# Initialize the `flags`.
@@ -520,6 +520,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
520520
dist_right = getdist(metadata_right, vn)
521521
# Give precedence to `metadata_right`.
522522
push!(dists, dist_right)
523+
gid = metadata_right.gids[getidx(metadata_right, vn)]
524+
push!(gids, gid)
523525
# `orders`: giving precedence to `metadata_right`
524526
push!(orders, getorder(metadata_right, vn))
525527
# `flags`
@@ -539,6 +541,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
539541
# `dists`
540542
dist_left = getdist(metadata_left, vn)
541543
push!(dists, dist_left)
544+
gid = metadata_left.gids[getidx(metadata_left, vn)]
545+
push!(gids, gid)
542546
# `orders`
543547
push!(orders, getorder(metadata_left, vn))
544548
# `flags`
@@ -557,6 +561,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
557561
# `dists`
558562
dist_right = getdist(metadata_right, vn)
559563
push!(dists, dist_right)
564+
gid = metadata_right.gids[getidx(metadata_right, vn)]
565+
push!(gids, gid)
560566
# `orders`
561567
push!(orders, getorder(metadata_right, vn))
562568
# `flags`
@@ -1826,14 +1832,31 @@ function BangBang.push!!(
18261832
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
18271833
)
18281834
if vi isa UntypedVarInfo
1829-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
1835+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
18301836
elseif vi isa TypedVarInfo
1831-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1837+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1838+
end
1839+
1840+
sym = getsym(vn)
1841+
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
1842+
# The NamedTuple doesn't have an entry for this variable, let's add one.
1843+
val = tovec(r)
1844+
md = Metadata(
1845+
Dict(vn => 1),
1846+
[vn],
1847+
[1:length(val)],
1848+
val,
1849+
[dist],
1850+
[gidset],
1851+
[get_num_produce(vi)],
1852+
Dict{String,BitVector}("trans" => [false], "del" => [false]),
1853+
)
1854+
vi = Accessors.@set vi.metadata[sym] = md
1855+
else
1856+
meta = getmetadata(vi, vn)
1857+
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
18321858
end
18331859

1834-
meta = getmetadata(vi, vn)
1835-
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
1836-
18371860
return vi
18381861
end
18391862

@@ -1864,7 +1887,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
18641887
push!(meta.orders, num_produce)
18651888
push!(meta.flags["del"], false)
18661889
push!(meta.flags["trans"], false)
1867-
18681890
return meta
18691891
end
18701892

test/varinfo.jl

+25
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
154154
test_varinfo!(vi)
155155
test_varinfo!(empty!!(TypedVarInfo(vi)))
156156
end
157+
158+
@testset "push!! to TypedVarInfo" begin
159+
vn_x = @varname x
160+
vn_y = @varname y
161+
untyped_vi = VarInfo()
162+
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
163+
typed_vi = TypedVarInfo(untyped_vi)
164+
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
165+
@test typed_vi[vn_x] == 1.0
166+
@test typed_vi[vn_y] == 2.0
167+
end
168+
157169
@testset "setgid!" begin
158170
vi = VarInfo(DynamicPPL.Metadata())
159171
meta = vi.metadata
@@ -694,6 +706,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
694706
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
695707
@test DynamicPPL.istrans(varinfo_merged, @varname(x))
696708
end
709+
710+
# The below used to error, testing to avoid regression.
711+
@testset "merge gids" begin
712+
gidset_left = Set([Selector(1)])
713+
vi_left = VarInfo()
714+
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
715+
gidset_right = Set([Selector(2)])
716+
vi_right = VarInfo()
717+
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
718+
varinfo_merged = merge(vi_left, vi_right)
719+
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
720+
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
721+
end
697722
end
698723

699724
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)