Skip to content

Commit

Permalink
separated the getrange version which returns the range of the vecto
Browse files Browse the repository at this point in the history
representaiton rather than the internal representaiton into
`vector_getrange` to make its function explicit
  • Loading branch information
torfjelde committed Dec 5, 2024
1 parent f500c23 commit 8afe681
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
4 changes: 4 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns)
end

vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
vector_getrange(vi::ThreadSafeVarInfo) = vector_getrange(vi.varinfo)
vector_getranges(vi::ThreadSafeVarInfo) = vector_getranges(vi.varinfo)

function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
end
Expand Down
55 changes: 38 additions & 17 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)


"""
vector_length(varinfo::VarInfo)
Return the length of the vector representation of `varinfo`.
"""
vector_length(varinfo::VarInfo) = length(varinfo.metadata)
vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata)
vector_length(md::Metadata) = sum(length, md.ranges)
Expand Down Expand Up @@ -615,21 +620,6 @@ getidx(md::Metadata, vn::VarName) = md.idcs[vn]
Return the index range of `vn` in the metadata of `vi`.
"""
getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn)
# For `TypedVarInfo` it's more difficult since we need to keep track of the offset.
# TOOD: Should we unroll this using `@generated`?
function getrange(vi::TypedVarInfo, vn::VarName)
offset = 0
for md in values(vi.metadata)
# First, we need to check if `vn` is in `md`.
# In this case, we can just return the corresponding range + offset.
haskey(md, vn) && return getrange(md, vn) .+ offset
# Otherwise, we need to get the cumulative length of the ranges in `md`
# and add it to the offset.
offset += sum(length, md.ranges)
end
# If we reach this point, `vn` is not in `vi.metadata`.
throw(KeyError(vn))
end
getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)]

"""
Expand All @@ -648,8 +638,38 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
function getranges(vi::VarInfo, vns::Vector{<:VarName})
return map(Base.Fix1(getrange, vi), vns)
end
# A more efficient version for `TypedVarInfo`.
function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName})

"""
vector_getrange(varinfo::VarInfo, varname::VarName)
Return the range corresponding to `varname` in the vector representation of `varinfo`.
"""
vector_getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn)
function vector_getrange(vi::TypedVarInfo, vn::VarName)
offset = 0
for md in values(vi.metadata)
# First, we need to check if `vn` is in `md`.
# In this case, we can just return the corresponding range + offset.
haskey(md, vn) && return vector_getrange(md, vn) .+ offset
# Otherwise, we need to get the cumulative length of the ranges in `md`
# and add it to the offset.
offset += sum(length, md.ranges)
end
# If we reach this point, `vn` is not in `vi.metadata`.
throw(KeyError(vn))
end
vector_getrange(md::Metadata, vn::VarName) = getrange(md, vn)

"""
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
Return the range corresponding to `varname` in the vector representation of `varinfo`.
"""
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
return map(Base.Fix1(vector_getrange, varinfo), varname)
end
# Specialized version for `TypedVarInfo`.
function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName})
# TODO: Does it help if we _don't_ convert to a vector here?
metadatas = collect(values(varinfo.metadata))
# Extract the offsets.
Expand All @@ -672,6 +692,7 @@ function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.Va
return ranges
end


"""
getdist(vi::VarInfo, vn::VarName)
Expand Down
4 changes: 2 additions & 2 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

# NOTE: It is not yet clear if this is something we want from all varinfo types.
# Hence, we only test the `VarInfo` types here.
@testset "getranges for `VarInfo`" begin
@testset "vector_getranges for `VarInfo`" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
nt = DynamicPPL.TestUtils.rand_prior_true(model)
Expand All @@ -829,7 +829,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
# Let's just check all the subsets of `vns`.
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
combinations(vns)
ranges = DynamicPPL.getranges(varinfo, vns_subset)
ranges = DynamicPPL.vector_getranges(varinfo, vns_subset)
@test length(ranges) == length(vns_subset)
for (r, vn) in zip(ranges, vns_subset)
@test x[r] == DynamicPPL.tovec(varinfo[vn])
Expand Down

0 comments on commit 8afe681

Please sign in to comment.