diff --git a/Project.toml b/Project.toml index f995d7359..eab8c362c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30" +version = "0.30.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 88f892a72..b6a84238e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -439,22 +439,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) return Accessors.@set varinfo.values = _subset(varinfo.values, vns) end -function _subset(x::AbstractDict, vns) +function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} vns_present = collect(keys(x)) - vns_found = mapreduce(vcat, vns) do vn + vns_found = mapreduce(vcat, vns; init=VN[]) do vn return filter(Base.Fix1(subsumes, vn), vns_present) end - - # NOTE: This `vns` to be subsume varnames explicitly present in `x`. + C = ConstructionBase.constructorof(typeof(x)) if isempty(vns_found) - throw( - ArgumentError( - "Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.", - ), - ) + return C() + else + return C(vn => x[vn] for vn in vns_found) end - C = ConstructionBase.constructorof(typeof(x)) - return C(vn => x[vn] for vn in vns_found) end function _subset(x::NamedTuple, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index 8727796bc..70bb4dc76 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -368,20 +368,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) ) end -function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) +function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where VN<:VarName # TODO: Should we error if `vns` contains a variable that is not in `metadata`? # For each `vn` in `vns`, get the variables subsumed by `vn`. - vns = mapreduce(vcat, vns_given) do vn + vns = mapreduce(vcat, vns_given; init=VN[]) do vn filter(Base.Fix1(subsumes, vn), metadata.vns) end indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) - indices = Dict(vn => i for (i, vn) in enumerate(vns)) + indices = if isempty(vns) + Dict{VarName,Int}() + else + Dict(vn => i for (i, vn) in enumerate(vns)) + end # Construct new `vals` and `ranges`. vals_original = metadata.vals ranges_original = metadata.ranges # Allocate the new `vals`. and `ranges`. - vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) - ranges = similar(ranges_original) + vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0)) + ranges = similar(ranges_original, length(vns)) # The new range `r` for `vns[i]` is offset by `offset` and # has the same length as the original range `r_original`. # The new `indices` (from above) ensures ordering according to `vns`. @@ -415,7 +419,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) ranges, vals, metadata.dists[indices_for_vns], - metadata.gids, + metadata.gids[indices_for_vns], metadata.orders[indices_for_vns], flags, ) diff --git a/test/varinfo.jl b/test/varinfo.jl index 65f849dda..e45cb2e8f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -566,6 +566,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) else vns_supported_standard end + + @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported + varinfo_subset = subset(varinfo, VarName[]) + @test isempty(varinfo_subset) + end + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported varinfo_subset = subset(varinfo, vns_subset)