Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing tonamedtuple #526

Merged
merged 6 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
```@docs
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```

#### `SimpleVarInfo`
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky
using LinearAlgebra: LinearAlgebra, Cholesky

using DocStringExtensions

Expand Down
152 changes: 152 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,155 @@ function varname_leaves(vn::VarName, val::NamedTuple)
end
return Iterators.flatten(iter)
end

"""
varname_and_value_leaves(vn::VarName, val)

Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
Copy link
Member

@yebai yebai Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
Return an iterator over all (varname::VarName, value::Real) pairs represented by `vn` on `val`. Common types for `val`, including Array, NamedTuple, Real and Cholesky, are supported by default. Handling for new types can be added by overloading `varname_and_value_leaves` methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I specifically didn't add any documentation about this, as we'd then also have to add docs to varname_and_value_leaves_inner.

IMO this should be a user-facing method and thus should be "simple" in what it describes (since we are exporting it). People who would potentially be interested in overloading this can eaisly just look at the source code and immediately understand that they need to overload varnames_and_value_leaves_inner for their type.

Copy link
Member Author

@torfjelde torfjelde Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it doesn't always return (varname:::VarName, value::Real) btw; value could also be Missing.


# Examples
```jldoctest varname-and-value-leaves
julia> using DynamicPPL: varname_and_value_leaves

julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)

julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)

julia> x = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
```

There are also some special handling for certain types:

```jldoctest varname-and-value-leaves
julia> using LinearAlgebra

julia> x = reshape(1:4, 2, 2);

julia> # `LowerTriangular`
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
(x[1,1], 1)
(x[2,1], 2)
(x[2,2], 4)

julia> # `UpperTriangular`
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
(x[1,1], 1)
(x[1,2], 3)
(x[2,2], 4)

julia> # `Cholesky` with lower-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
(x[1,1], 1.0)
(x[2,1], 0.0)
(x[2,2], 1.0)

julia> # `Cholesky` with upper-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
(x[1,1], 1.0)
(x[1,2], 0.0)
(x[2,2], 1.0)
```
"""
function varname_and_value_leaves(vn::VarName, x)
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
end

"""
Leaf{T}

A container that represents the leaf of a nested structure, implementing
`iterate` to return itself.

This is particularly useful in conjunction with `Iterators.flatten` to
prevent flattening of nested structures.
"""
struct Leaf{T}
value::T
end

Leaf(xs...) = Leaf(xs)

# Allow us to treat `Leaf` as an iterator containing a single element.
# Something like an `[x]` would also be an iterator with a single element,
# but when we call `flatten` on this, it would also iterate over `x`,
# unflattening that too. By making `Leaf` a single-element iterator, which
# returns itself, we can call `iterate` on this as many times as we like
# without causing any change. The result is that `Iterators.flatten`
# will _not_ unflatten `Leaf`s.
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
#
# julia> iterate(1)
# (1, nothing)
#
# One immediate example where this becomes in our scenario is that we might
# have `missing` values in our data, which does _not_ have an `iterate`
# implemented. Calling `Iterators.flatten` on this would cause an error.
Base.iterate(leaf::Leaf) = leaf, nothing
Base.iterate(::Leaf, _) = nothing

# Convenience.
value(leaf::Leaf) = leaf.value

# Leaf-types.
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
function varname_and_value_leaves_inner(
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))),
val[I],
) for I in CartesianIndices(val)
)
end
# Containers.
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))),
val[I],
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
lens = DynamicPPL.Setfield.PropertyLens{sym}()
varname_and_value_leaves_inner(vn ∘ lens, get(val, lens))
end

return Iterators.flatten(iter)
end
# Special types.
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
# TODO: Or do we use `PDMat` here?
return varname_and_value_leaves_inner(vn, x.UL)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))),
x[I],
)
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))),
x[I],
)
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
end
Loading