Skip to content

Commit

Permalink
Allow empty subsets of VarInfos
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Oct 16, 2024
1 parent 1d10278 commit 06fbe70
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30"
version = "0.30.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
17 changes: 6 additions & 11 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
vns_present = collect(keys(x))
vns_found = mapreduce(vcat, vns) do vn
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
return filter(Base.Fix1(subsumes, vn), vns_present)
end

# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
C = ConstructionBase.constructorof(typeof(x))
if isempty(vns_found)
throw(
ArgumentError(
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
),
)
return C()
else
return C(vn => x[vn] for vn in vns_found)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns_found)
end

function _subset(x::NamedTuple, vns)
Expand Down
16 changes: 10 additions & 6 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,20 +368,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
)
end

function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where VN<:VarName
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
# For each `vn` in `vns`, get the variables subsumed by `vn`.
vns = mapreduce(vcat, vns_given) do vn
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
filter(Base.Fix1(subsumes, vn), metadata.vns)
end
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
indices = Dict(vn => i for (i, vn) in enumerate(vns))
indices = if isempty(vns)
Dict{VarName,Int}()
else
Dict(vn => i for (i, vn) in enumerate(vns))
end
# Construct new `vals` and `ranges`.
vals_original = metadata.vals
ranges_original = metadata.ranges
# Allocate the new `vals`. and `ranges`.
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
ranges = similar(ranges_original)
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
ranges = similar(ranges_original, length(vns))
# The new range `r` for `vns[i]` is offset by `offset` and
# has the same length as the original range `r_original`.
# The new `indices` (from above) ensures ordering according to `vns`.
Expand Down Expand Up @@ -415,7 +419,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
ranges,
vals,
metadata.dists[indices_for_vns],
metadata.gids,
metadata.gids[indices_for_vns],
metadata.orders[indices_for_vns],
flags,
)
Expand Down
6 changes: 6 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
else
vns_supported_standard
end

@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported
varinfo_subset = subset(varinfo, VarName[])
@test isempty(varinfo_subset)
end

@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
vns_supported
varinfo_subset = subset(varinfo, vns_subset)
Expand Down

0 comments on commit 06fbe70

Please sign in to comment.