diff --git a/docs/src/api.md b/docs/src/api.md index d2be26c56..6b9493985 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,6 +255,8 @@ DynamicPPL.reconstruct ```@docs DynamicPPL.unflatten DynamicPPL.tonamedtuple +DynamicPPL.varname_leaves +DynamicPPL.varname_and_value_leaves ``` #### `SimpleVarInfo` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5faced372..d1a21530b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -15,7 +15,7 @@ using Setfield: Setfield using ZygoteRules: ZygoteRules using LogDensityProblems: LogDensityProblems -using LinearAlgebra: Cholesky +using LinearAlgebra: LinearAlgebra, Cholesky using DocStringExtensions diff --git a/src/utils.jl b/src/utils.jl index d28697127..a1fb12788 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -870,3 +870,155 @@ function varname_leaves(vn::VarName, val::NamedTuple) end return Iterators.flatten(iter) end + +""" + varname_and_value_leaves(vn::VarName, val) + +Return an iterator over all varname-value pairs that are represented by `vn` on `val`. + +# Examples +```jldoctest varname-and-value-leaves +julia> using DynamicPPL: varname_and_value_leaves + +julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) +(x[1], 1) +(x[2], 2) + +julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) +(x[1:2][1], 1) +(x[1:2][2], 2) + +julia> x = (y = 1, z = [[2.0], [3.0]]); + +julia> foreach(println, varname_and_value_leaves(@varname(x), x)) +(x.y, 1) +(x.z[1][1], 2.0) +(x.z[2][1], 3.0) +``` + +There are also some special handling for certain types: + +```jldoctest varname-and-value-leaves +julia> using LinearAlgebra + +julia> x = reshape(1:4, 2, 2); + +julia> # `LowerTriangular` + foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) +(x[1,1], 1) +(x[2,1], 2) +(x[2,2], 4) + +julia> # `UpperTriangular` + foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) +(x[1,1], 1) +(x[1,2], 3) +(x[2,2], 4) + +julia> # `Cholesky` with lower-triangular + foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) +(x[1,1], 1.0) +(x[2,1], 0.0) +(x[2,2], 1.0) + +julia> # `Cholesky` with upper-triangular + foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) +(x[1,1], 1.0) +(x[1,2], 0.0) +(x[2,2], 1.0) +``` +""" +function varname_and_value_leaves(vn::VarName, x) + return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) +end + +""" + Leaf{T} + +A container that represents the leaf of a nested structure, implementing +`iterate` to return itself. + +This is particularly useful in conjunction with `Iterators.flatten` to +prevent flattening of nested structures. +""" +struct Leaf{T} + value::T +end + +Leaf(xs...) = Leaf(xs) + +# Allow us to treat `Leaf` as an iterator containing a single element. +# Something like an `[x]` would also be an iterator with a single element, +# but when we call `flatten` on this, it would also iterate over `x`, +# unflattening that too. By making `Leaf` a single-element iterator, which +# returns itself, we can call `iterate` on this as many times as we like +# without causing any change. The result is that `Iterators.flatten` +# will _not_ unflatten `Leaf`s. +# Note that this is similar to how `Base.iterate` is implemented for `Real`:: +# +# julia> iterate(1) +# (1, nothing) +# +# One immediate example where this becomes in our scenario is that we might +# have `missing` values in our data, which does _not_ have an `iterate` +# implemented. Calling `Iterators.flatten` on this would cause an error. +Base.iterate(leaf::Leaf) = leaf, nothing +Base.iterate(::Leaf, _) = nothing + +# Convenience. +value(leaf::Leaf) = leaf.value + +# Leaf-types. +varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] +function varname_and_value_leaves_inner( + vn::VarName, val::AbstractArray{<:Union{Real,Missing}} +) + return ( + Leaf( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + val[I], + ) for I in CartesianIndices(val) + ) +end +# Containers. +function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_and_value_leaves_inner( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + val[I], + ) for I in CartesianIndices(val) + ) +end +function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do sym + lens = DynamicPPL.Setfield.PropertyLens{sym}() + varname_and_value_leaves_inner(vn ∘ lens, get(val, lens)) + end + + return Iterators.flatten(iter) +end +# Special types. +function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) + # TODO: Or do we use `PDMat` here? + return varname_and_value_leaves_inner(vn, x.UL) +end +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) + return ( + Leaf( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + x[I], + ) + # Iteration over the lower-triangular indices. + for I in CartesianIndices(x) if I[1] >= I[2] + ) +end +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) + return ( + Leaf( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + x[I], + ) + # Iteration over the upper-triangular indices. + for I in CartesianIndices(x) if I[1] <= I[2] + ) +end