diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b7631c293..f3d5ed68a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -183,6 +183,8 @@ include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") include("sampler.jl") include("varname.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuple include("distribution_wrappers.jl") include("submodel.jl") include("varnamedvector.jl") diff --git a/src/varinfo.jl b/src/varinfo.jl index 417766653..1f5413db4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +const TupleVarInfo = VarInfo{<:VarNamedTuple} function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -356,6 +357,28 @@ function typed_vector_varinfo( return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end +function make_leaf_metadata((r, dist), optic) + md = Metadata() + vn = VarName{:_}(optic) + push!(md, vn, r, dist) + return md +end + +function tuple_varinfo() + metadata = VarNamedTuple((;), make_leaf_metadata) + return VarInfo(metadata, copy(default_accumulators())) +end +function tuple_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return last(init!!(rng, model, tuple_varinfo(), init_strategy)) +end +function tuple_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return tuple_varinfo(Random.default_rng(), model, init_strategy) +end + """ vector_length(varinfo::VarInfo) @@ -639,6 +662,9 @@ Return the metadata in `vi` that belongs to `vn`. """ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) +function getmetadata(vi::TupleVarInfo, vn::VarName) + return getindex(vi.metadata, remove_trailing_index(vn)) +end """ getidx(vi::VarInfo, vn::VarName) @@ -744,6 +770,10 @@ end Return the distribution from which `vn` was sampled in `vi`. """ getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) +function getdist(vi::TupleVarInfo, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return getdist(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] # TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. function getdist(::VarNamedVector, ::VarName) @@ -782,6 +812,10 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) +function setval!(vi::TupleVarInfo, val, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end function setval!(md::Metadata, val::AbstractVector, vn::VarName) return md.vals[getrange(md, vn)] = val end @@ -1579,6 +1613,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end return any(md_haskey) end +Base.haskey(vi::TupleVarInfo, vn::VarName) = haskey(vi.metadata, vn) function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) lines = Tuple{String,Any}[ @@ -1673,6 +1708,25 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) return vi end +function BangBang.push!!(vi::TupleVarInfo, vn::VarName, r, dist::Distribution) + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TupleVarInfo with dist=$dist" + return VarInfo(setindex!!(vi.metadata, (r, dist), vn), vi.accs) +end + +# TODO(mhauru) Implement properly +function is_transformed(vi::TupleVarInfo, vn::VarName) + return false +end + +function getindex(vi::TupleVarInfo, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return getindex(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end +function getindex_internal(vi::TupleVarInfo, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return getindex_internal(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end + function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) push!(getmetadata(vi, vn), vn, val, args...) return vi diff --git a/src/varname.jl b/src/varname.jl index 3eb1f2460..687427f6e 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -41,3 +41,24 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} + return if Optic === typeof(identity) + vn + elseif Optic isa IndexLens + VarName{sym}() + else + prefix(remove_trailing_index(unprefix(vn, sym)), sym) + end +end + +function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} + return if Optic === typeof(identity) + (vn, identity) + elseif Optic isa IndexLens + (VarName{sym}(), Optic.index) + else + (prefix, index) = split_trailing_index(unprefix(vn, sym)) + (prefix(prefix, sym), index) + end +end