Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("varnamedtuple.jl")
using .VarNamedTuples: VarNamedTuple
include("distribution_wrappers.jl")
include("submodel.jl")
include("varnamedvector.jl")
Expand Down
54 changes: 54 additions & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple}
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
}
const TupleVarInfo = VarInfo{<:VarNamedTuple}

function Base.:(==)(vi1::VarInfo, vi2::VarInfo)
return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs)
Expand Down Expand Up @@ -356,6 +357,28 @@ function typed_vector_varinfo(
return typed_vector_varinfo(Random.default_rng(), model, init_strategy)
end

function make_leaf_metadata((r, dist), optic)
md = Metadata()
vn = VarName{:_}(optic)
push!(md, vn, r, dist)
return md
end

function tuple_varinfo()
metadata = VarNamedTuple((;), make_leaf_metadata)
return VarInfo(metadata, copy(default_accumulators()))
end
function tuple_varinfo(
rng::Random.AbstractRNG,
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return last(init!!(rng, model, tuple_varinfo(), init_strategy))
end
function tuple_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return tuple_varinfo(Random.default_rng(), model, init_strategy)
end

"""
vector_length(varinfo::VarInfo)

Expand Down Expand Up @@ -639,6 +662,9 @@ Return the metadata in `vi` that belongs to `vn`.
"""
getmetadata(vi::VarInfo, vn::VarName) = vi.metadata
getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn))
function getmetadata(vi::TupleVarInfo, vn::VarName)
return getindex(vi.metadata, remove_trailing_index(vn))
end

"""
getidx(vi::VarInfo, vn::VarName)
Expand Down Expand Up @@ -744,6 +770,10 @@ end
Return the distribution from which `vn` was sampled in `vi`.
"""
getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn)
function getdist(vi::TupleVarInfo, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return getdist(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end
getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)]
# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone.
function getdist(::VarNamedVector, ::VarName)
Expand Down Expand Up @@ -782,6 +812,10 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
The values may or may not be transformed to Euclidean space.
"""
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
function setval!(vi::TupleVarInfo, val, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
return md.vals[getrange(md, vn)] = val
end
Expand Down Expand Up @@ -1579,6 +1613,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
end
return any(md_haskey)
end
Base.haskey(vi::TupleVarInfo, vn::VarName) = haskey(vi.metadata, vn)

function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
lines = Tuple{String,Any}[
Expand Down Expand Up @@ -1673,6 +1708,25 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
return vi
end

function BangBang.push!!(vi::TupleVarInfo, vn::VarName, r, dist::Distribution)
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TupleVarInfo with dist=$dist"
return VarInfo(setindex!!(vi.metadata, (r, dist), vn), vi.accs)
end

# TODO(mhauru) Implement properly
function is_transformed(vi::TupleVarInfo, vn::VarName)
return false
end

function getindex(vi::TupleVarInfo, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return getindex(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end
function getindex_internal(vi::TupleVarInfo, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return getindex_internal(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end

function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...)
push!(getmetadata(vi, vn), vn, val, args...)
return vi
Expand Down
21 changes: 21 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,24 @@ Possibly existing indices of `varname` are neglected.
) where {s,missings,_F,_a,_T}
return s in missings
end

function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
return if Optic === typeof(identity)
vn
elseif Optic isa IndexLens
VarName{sym}()
else
prefix(remove_trailing_index(unprefix(vn, sym)), sym)
end
end

function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
return if Optic === typeof(identity)
(vn, identity)
elseif Optic isa IndexLens
(VarName{sym}(), Optic.index)
else
(prefix, index) = split_trailing_index(unprefix(vn, sym))
(prefix(prefix, sym), index)
end
end
Loading