From 5a43012383c448ad321f9251fbdabfef293c4429 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 31 Aug 2023 20:39:07 +0100 Subject: [PATCH 1/6] added impl of varname_and_value_leaves --- docs/src/api.md | 2 + src/DynamicPPL.jl | 2 +- src/utils.jl | 119 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index d2be26c56..6b9493985 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,6 +255,8 @@ DynamicPPL.reconstruct ```@docs DynamicPPL.unflatten DynamicPPL.tonamedtuple +DynamicPPL.varname_leaves +DynamicPPL.varname_and_value_leaves ``` #### `SimpleVarInfo` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5faced372..d1a21530b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -15,7 +15,7 @@ using Setfield: Setfield using ZygoteRules: ZygoteRules using LogDensityProblems: LogDensityProblems -using LinearAlgebra: Cholesky +using LinearAlgebra: LinearAlgebra, Cholesky using DocStringExtensions diff --git a/src/utils.jl b/src/utils.jl index d28697127..f5a73e311 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -870,3 +870,122 @@ function varname_leaves(vn::VarName, val::NamedTuple) end return Iterators.flatten(iter) end + +""" + varname_and_value_leaves(vn::VarName, val) + +Return an iterator over all varname-value pairs that are represented by `vn` on `val`. + +# Examples +```jldoctest varname-and-value-leaves +julia> using DynamicPPL: varname_and_value_leaves + +julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) +(x[1], 1) +(x[2], 2) + +julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) +(x[1:2][1], 1) +(x[1:2][2], 2) + +julia> x = (y = 1, z = [[2.0], [3.0]]); + +julia> foreach(println, varname_and_value_leaves(@varname(x), x)) +(x.y, 1) +(x.z[1][1], 2.0) +(x.z[2][1], 3.0) +``` + +There are also some special handling for certain types: + +```jldoctest varname-and-value-leaves +julia> using LinearAlgebra + +julia> x = reshape(1:4, 2, 2); + +julia> # `LowerTriangular` + foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) +(x[1,1], 1) +(x[2,1], 2) +(x[2,2], 4) + +julia> # `UpperTriangular` + foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) +(x[1,1], 1) +(x[1,2], 3) +(x[2,2], 4) +``` +""" +function varname_and_value_leaves(vn::VarName, x) + return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) +end + +# Simple struct used to represent a varname-value pair even if we use +# something like `Iterators.flatten`. +struct Leaf{T} + value::T +end + +Leaf(xs...) = Leaf(xs) + +# Allows us to just use `Leaf` to "terminate" recursion in `Iterators.flatten`. +Base.iterate(leaf::Leaf) = leaf, leaf +Base.iterate(::Leaf, _) = nothing + +# Convenience. +value(leaf::Leaf) = leaf.value + +# Leaf-types. +varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] +function varname_and_value_leaves_inner( + vn::VarName, val::AbstractArray{<:Union{Real,Missing}} +) + return ( + Leaf( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + val[I], + ) for I in CartesianIndices(val) + ) +end +# Containers. +function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_and_value_leaves_inner( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + val[I], + ) for I in CartesianIndices(val) + ) +end +function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do sym + lens = DynamicPPL.Setfield.PropertyLens{sym}() + varname_and_value_leaves_inner(vn ∘ lens, get(val, lens)) + end + + return Iterators.flatten(iter) +end +# Special types. +function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) + # TODO: Or do we use `PDMat` here? + return varname_and_value_leaves_inner(vn, x.UL) +end +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) + return ( + Leaf( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + x[I], + ) + # Iteration over the lower-triangular indices. + for I in CartesianIndices(x) if I[1] >= I[2] + ) +end +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) + return ( + Leaf( + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + x[I], + ) + # Iteration over the upper-triangular indices. + for I in CartesianIndices(x) if I[1] <= I[2] + ) +end From 4367882d0bce9d74c13be1b0e42feaa3365acc1d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Sep 2023 11:09:10 +0100 Subject: [PATCH 2/6] added examples with cholesky to varname_and_value_leaves doctests --- src/utils.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index f5a73e311..98b3522b0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -914,6 +914,18 @@ julia> # `UpperTriangular` (x[1,1], 1) (x[1,2], 3) (x[2,2], 4) + +julia> # `Cholesky` with lower-triangular + foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) +(x[1,1], 1.0) +(x[2,1], 0.0) +(x[2,2], 1.0) + +julia> # `Cholesky` with upper-triangular + foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) +(x[1,1], 1.0) +(x[1,2], 0.0) +(x[2,2], 1.0) ``` """ function varname_and_value_leaves(vn::VarName, x) From 0909d3514969173d4addf44630040aaa5bb5b662 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Sep 2023 11:15:37 +0100 Subject: [PATCH 3/6] added more descriptive docstring of iterate for Leaf --- src/utils.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 98b3522b0..2b2711668 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -940,8 +940,18 @@ end Leaf(xs...) = Leaf(xs) -# Allows us to just use `Leaf` to "terminate" recursion in `Iterators.flatten`. -Base.iterate(leaf::Leaf) = leaf, leaf +# Allow us to treat `Leaf` as an iterator containing a single element. +# Something like an `[x]` would also be an iterator with a single element, +# but when we call `flatten` on this, it would also iterate over `x`, +# unflattening that too. By making `Leaf` a single-element iterator, which +# returns itself, we can call `iterate` on this as many times as we like +# without causing any change. The result is that `Iterators.flatten` +# will _not_ unflatten `Leaf`s. +# Note that this is similar to how `Base.iterate` is implemented for `Real`:: +# +# julia> Base.iterate(1) +# (1, nothing) +Base.iterate(leaf::Leaf) = leaf, nothing Base.iterate(::Leaf, _) = nothing # Convenience. From 49bb2c55536832f06aa62c21044a2d5ca3305c66 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Sep 2023 11:18:40 +0100 Subject: [PATCH 4/6] added concrete example in comment of iterate for Leaf --- src/utils.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 2b2711668..39899607e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -949,8 +949,12 @@ Leaf(xs...) = Leaf(xs) # will _not_ unflatten `Leaf`s. # Note that this is similar to how `Base.iterate` is implemented for `Real`:: # -# julia> Base.iterate(1) +# julia> iterate(1) # (1, nothing) +# +# One immediate example where this becomes in our scenario is that we might +# have `missing` values in our data, which does _not_ have an `iterate` +# implemented. Calling `Iterators.flatten` on this would cause an error. Base.iterate(leaf::Leaf) = leaf, nothing Base.iterate(::Leaf, _) = nothing From b96fa2274a84d2bb5b98ad083b5d01c0414aa47a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Sep 2023 11:25:58 +0100 Subject: [PATCH 5/6] added small docstring to Leaf --- src/utils.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 39899607e..951989a45 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -932,8 +932,16 @@ function varname_and_value_leaves(vn::VarName, x) return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) end -# Simple struct used to represent a varname-value pair even if we use -# something like `Iterators.flatten`. + +""" + Leaf{T} + +A container that represents the leaf of a nested structure, implementing +`iterate` to return itself. + +This is particularly useful in conjunction with `Iterators.flatten` to +prevent flattening of nested structures. +""" struct Leaf{T} value::T end From ba667dce2462bceb60dd1bcdc4579b00568623f3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Sep 2023 12:00:44 +0100 Subject: [PATCH 6/6] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 951989a45..a1fb12788 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -932,7 +932,6 @@ function varname_and_value_leaves(vn::VarName, x) return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) end - """ Leaf{T}