Skip to content

Commit f19f361

Browse files
Backports for 0.28 (#694)
* Allow empty subsets of VarInfos (#692) * Allow empty subsets of VarInfos * Run JuliaFormatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * For VarInfo, fix merge and allow push!!ing new Symbols (#690) * Fix treatment of gid in merge(::Metadata) * Allowing pushing new symbols to TypedVarInfo * Bump patch version to 0.30.1 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 0683088 commit f19f361

File tree

4 files changed

+86
-34
lines changed

4 files changed

+86
-34
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.28.5"
3+
version = "0.28.6"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/simple_varinfo.jl

+6-11
Original file line numberDiff line numberDiff line change
@@ -429,22 +429,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
429429
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
430430
end
431431

432-
function _subset(x::AbstractDict, vns)
432+
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
433433
vns_present = collect(keys(x))
434-
vns_found = mapreduce(vcat, vns) do vn
434+
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
435435
return filter(Base.Fix1(subsumes, vn), vns_present)
436436
end
437-
438-
# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
437+
C = ConstructionBase.constructorof(typeof(x))
439438
if isempty(vns_found)
440-
throw(
441-
ArgumentError(
442-
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
443-
),
444-
)
439+
return C()
440+
else
441+
return C(vn => x[vn] for vn in vns_found)
445442
end
446-
C = ConstructionBase.constructorof(typeof(x))
447-
return C(vn => x[vn] for vn in vns_found)
448443
end
449444

450445
function _subset(x::NamedTuple, vns)

src/varinfo.jl

+47-22
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
264264
return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce)
265265
end
266266

267-
function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
267+
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
268268
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
269269
# For each `vn` in `vns`, get the variables subsumed by `vn`.
270-
vns = mapreduce(vcat, vns_given) do vn
270+
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
271271
filter(Base.Fix1(subsumes, vn), metadata.vns)
272272
end
273273
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
274-
indices = Dict(vn => i for (i, vn) in enumerate(vns))
274+
indices = if isempty(vns)
275+
Dict{VarName,Int}()
276+
else
277+
Dict(vn => i for (i, vn) in enumerate(vns))
278+
end
275279
# Construct new `vals` and `ranges`.
276280
vals_original = metadata.vals
277281
ranges_original = metadata.ranges
278282
# Allocate the new `vals`. and `ranges`.
279-
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
280-
ranges = similar(ranges_original)
283+
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
284+
ranges = similar(ranges_original, length(vns))
281285
# The new range `r` for `vns[i]` is offset by `offset` and
282286
# has the same length as the original range `r_original`.
283287
# The new `indices` (from above) ensures ordering according to `vns`.
@@ -311,7 +315,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
311315
ranges,
312316
vals,
313317
metadata.dists[indices_for_vns],
314-
metadata.gids,
318+
metadata.gids[indices_for_vns],
315319
metadata.orders[indices_for_vns],
316320
flags,
317321
)
@@ -382,7 +386,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
382386
ranges = Vector{UnitRange{Int}}()
383387
vals = T[]
384388
dists = D[]
385-
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
389+
gids = Set{Selector}[]
386390
orders = Int[]
387391
flags = Dict{String,BitVector}()
388392
# Initialize the `flags`.
@@ -412,6 +416,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
412416
dist_right = getdist(metadata_right, vn)
413417
# Give precedence to `metadata_right`.
414418
push!(dists, dist_right)
419+
gid = metadata_right.gids[getidx(metadata_right, vn)]
420+
push!(gids, gid)
415421
# `orders`: giving precedence to `metadata_right`
416422
push!(orders, getorder(metadata_right, vn))
417423
# `flags`
@@ -431,6 +437,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
431437
# `dists`
432438
dist_left = getdist(metadata_left, vn)
433439
push!(dists, dist_left)
440+
gid = metadata_left.gids[getidx(metadata_left, vn)]
441+
push!(gids, gid)
434442
# `orders`
435443
push!(orders, getorder(metadata_left, vn))
436444
# `flags`
@@ -449,6 +457,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
449457
# `dists`
450458
dist_right = getdist(metadata_right, vn)
451459
push!(dists, dist_right)
460+
gid = metadata_right.gids[getidx(metadata_right, vn)]
461+
push!(gids, gid)
452462
# `orders`
453463
push!(orders, getorder(metadata_right, vn))
454464
# `flags`
@@ -1594,25 +1604,40 @@ function BangBang.push!!(
15941604
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
15951605
)
15961606
if vi isa UntypedVarInfo
1597-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
1607+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
15981608
elseif vi isa TypedVarInfo
1599-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1609+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
16001610
end
16011611

16021612
val = vectorize(dist, r)
1603-
1604-
meta = getmetadata(vi, vn)
1605-
meta.idcs[vn] = length(meta.idcs) + 1
1606-
push!(meta.vns, vn)
1607-
l = length(meta.vals)
1608-
n = length(val)
1609-
push!(meta.ranges, (l + 1):(l + n))
1610-
append!(meta.vals, val)
1611-
push!(meta.dists, dist)
1612-
push!(meta.gids, gidset)
1613-
push!(meta.orders, get_num_produce(vi))
1614-
push!(meta.flags["del"], false)
1615-
push!(meta.flags["trans"], false)
1613+
sym = getsym(vn)
1614+
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
1615+
# The NamedTuple doesn't have an entry for this variable, let's add one.
1616+
md = Metadata(
1617+
Dict(vn => 1),
1618+
[vn],
1619+
[1:length(val)],
1620+
val,
1621+
[dist],
1622+
[gidset],
1623+
[get_num_produce(vi)],
1624+
Dict{String,BitVector}("trans" => [false], "del" => [false]),
1625+
)
1626+
vi = Accessors.@set vi.metadata[sym] = md
1627+
else
1628+
meta = getmetadata(vi, vn)
1629+
meta.idcs[vn] = length(meta.idcs) + 1
1630+
push!(meta.vns, vn)
1631+
l = length(meta.vals)
1632+
n = length(val)
1633+
push!(meta.ranges, (l + 1):(l + n))
1634+
append!(meta.vals, val)
1635+
push!(meta.dists, dist)
1636+
push!(meta.gids, gidset)
1637+
push!(meta.orders, get_num_produce(vi))
1638+
push!(meta.flags["del"], false)
1639+
push!(meta.flags["trans"], false)
1640+
end
16161641

16171642
return vi
16181643
end

test/varinfo.jl

+32
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
145145
test_varinfo!(vi)
146146
test_varinfo!(empty!!(TypedVarInfo(vi)))
147147
end
148+
149+
@testset "push!! to TypedVarInfo" begin
150+
vn_x = @varname x
151+
vn_y = @varname y
152+
untyped_vi = VarInfo()
153+
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
154+
typed_vi = TypedVarInfo(untyped_vi)
155+
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
156+
@test typed_vi[vn_x] == 1.0
157+
@test typed_vi[vn_y] == 2.0
158+
end
159+
148160
@testset "setgid!" begin
149161
vi = VarInfo()
150162
meta = vi.metadata
@@ -511,6 +523,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
511523
else
512524
vns_supported_standard
513525
end
526+
527+
@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in
528+
vns_supported
529+
varinfo_subset = subset(varinfo, VarName[])
530+
@test isempty(varinfo_subset)
531+
end
532+
514533
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
515534
vns_supported
516535
varinfo_subset = subset(varinfo, vns_subset)
@@ -638,6 +657,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
638657
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
639658
@test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal
640659
end
660+
661+
# The below used to error, testing to avoid regression.
662+
@testset "merge gids" begin
663+
gidset_left = Set([Selector(1)])
664+
vi_left = VarInfo()
665+
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
666+
gidset_right = Set([Selector(2)])
667+
vi_right = VarInfo()
668+
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
669+
varinfo_merged = merge(vi_left, vi_right)
670+
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
671+
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
672+
end
641673
end
642674

643675
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)