From f3fbce0932030372d0a7bca1325f78792583395d Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 28 Mar 2024 15:25:57 +0000 Subject: [PATCH] Move to `Accessors.jl` from `Setfield.jl` (#91) * move to Accessors * fix tests * fix test error * add compat * version bump * Update src/varname.jl Co-authored-by: Tor Erlend Fjelde * Update src/varname.jl * remove type piracy in `show` function * fix print behavior * removed composition of a varname to a lens * update doc * remove `Setfield` * add some type stability tests and additional doctests * fix test error * Update src/varname.jl Co-authored-by: Tor Erlend Fjelde * copy functions from Setfield and recover the interpolation abilities * fix some comments --------- Co-authored-by: Tor Erlend Fjelde --- Project.toml | 7 +- src/AbstractPPL.jl | 2 +- src/varname.jl | 379 +++++++++++++++++++++++++++------------------ test/Project.toml | 3 +- test/varname.jl | 39 +++-- 5 files changed, 262 insertions(+), 168 deletions(-) diff --git a/Project.toml b/Project.toml index ba73c43..becfdbe 100644 --- a/Project.toml +++ b/Project.toml @@ -3,17 +3,18 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.7.1" +version = "0.8.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] AbstractMCMC = "2, 3, 4, 5" +Accessors = "0.1" DensityInterface = "0.4" Random = "1.6" -Setfield = "0.8.2, 1" julia = "~1.6.6, 1.7.3" diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 1be28de..775576a 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -3,7 +3,7 @@ module AbstractPPL # VarName export VarName, getsym, - getlens, + getoptic, inspace, subsumes, subsumedby, diff --git a/src/varname.jl b/src/varname.jl index a0267e3..a3ca1dc 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,70 +1,70 @@ -using Setfield -using Setfield: PropertyLens, ComposedLens, IdentityLens, IndexLens, DynamicIndexLens +using Accessors +using Accessors: ComposedOptic, PropertyLens, IndexLens, DynamicIndexLens +using MacroTools + +const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedOptic} """ - VarName{sym}(lens::Lens=IdentityLens()) + VarName{sym}(optic=identity) -A variable identifier for a symbol `sym` and lens `lens`. +A variable identifier for a symbol `sym` and optic `optic`. The Julia variable in the model corresponding to `sym` can refer to a single value or to a hierarchical array structure of univariate, multivariate or matrix variables. The field `lens` stores the indices requires to access the random variable from the Julia variable indicated by `sym` -as a tuple of tuples. Each element of the tuple thereby contains the indices of one lens +as a tuple of tuples. Each element of the tuple thereby contains the indices of one optic operation. -`VarName`s can be manually constructed using the `VarName{sym}(lens)` constructor, or from an -lens expression through the [`@varname`](@ref) convenience macro. +`VarName`s can be manually constructed using the `VarName{sym}(optic)` constructor, or from an +optic expression through the [`@varname`](@ref) convenience macro. # Examples -```jldoctest; setup=:(using Setfield) -julia> vn = VarName{:x}(Setfield.IndexLens((Colon(), 1)) ∘ Setfield.IndexLens((2, ))) -x[:,1][2] +```jldoctest; setup=:(using Accessors) +julia> vn = VarName{:x}(Accessors.IndexLens((Colon(), 1)) ⨟ Accessors.IndexLens((2, ))) +x[:, 1][2] -julia> getlens(vn) -(@lens _[Colon(), 1][2]) +julia> getoptic(vn) +(@o _[Colon(), 1][2]) julia> @varname x[:, 1][1+1] -x[:,1][2] +x[:, 1][2] ``` """ -struct VarName{sym,T<:Lens} - lens::T +struct VarName{sym,T} + optic::T - function VarName{sym}(lens=IdentityLens()) where {sym} - # TODO: Should we completely disallow or just `@warn` of limited support? - if !is_static_lens(lens) - error("attempted to construct `VarName` with dynamic lens of type $(nameof(typeof(lens)))") + function VarName{sym}(optic=identity) where {sym} + if !is_static_optic(typeof(optic)) + throw(ArgumentError("attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))")) end - return new{sym,typeof(lens)}(lens) + return new{sym,typeof(optic)}(optic) end end """ - is_static_lens(l::Lens) - -Return `true` if `l` does not require runtime information to be resolved. + is_static_optic(l) -In particular it returns `false` for `Setfield.DynamicLens` and `Setfield.FunctionLens`. +Return `true` if `l` is one or a composition of `identity`, `PropertyLens`, and `IndexLens`; `false` if `l` is +one or a composition of `DynamicIndexLens`; and undefined otherwise. """ -is_static_lens(l::Lens) = is_static_lens(typeof(l)) -is_static_lens(::Type{<:Lens}) = false -is_static_lens(::Type{<:Union{PropertyLens, IndexLens, IdentityLens}}) = true -function is_static_lens(::Type{ComposedLens{LO, LI}}) where {LO, LI} - return is_static_lens(LO) && is_static_lens(LI) +is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true +function is_static_optic(::Type{ComposedOptic{LO,LI}}) where {LO,LI} + return is_static_optic(LO) && is_static_optic(LI) end +is_static_optic(::Type{<:DynamicIndexLens}) = false # A bit of backwards compatibility. -VarName{sym}(indexing::Tuple) where {sym} = VarName{sym}(tupleindex2lens(indexing)) +VarName{sym}(indexing::Tuple) where {sym} = VarName{sym}(tupleindex2optic(indexing)) """ - VarName(vn::VarName, lens::Lens) + VarName(vn::VarName, optic) VarName(vn::VarName, indexing::Tuple) -Return a copy of `vn` with a new index `lens`/`indexing`. +Return a copy of `vn` with a new index `optic`/`indexing`. -```jldoctest; setup=:(using Setfield) -julia> VarName(@varname(x[1][2:3]), Setfield.IndexLens((2,))) +```jldoctest; setup=:(using Accessors) +julia> VarName(@varname(x[1][2:3]), Accessors.IndexLens((2,))) x[2] julia> VarName(@varname(x[1][2:3]), ((2,),)) @@ -74,16 +74,16 @@ julia> VarName(@varname(x[1][2:3])) x ``` """ -VarName(vn::VarName, lens::Lens = IdentityLens()) = VarName{getsym(vn)}(lens) +VarName(vn::VarName, optic=identity) = VarName{getsym(vn)}(optic) function VarName(vn::VarName, indexing::Tuple) - return VarName{getsym(vn)}(tupleindex2lens(indexing)) + return VarName{getsym(vn)}(tupleindex2optic(indexing)) end -tupleindex2lens(indexing::Tuple{}) = IdentityLens() -tupleindex2lens(indexing::Tuple{<:Tuple}) = IndexLens(first(indexing)) -function tupleindex2lens(indexing::Tuple) - return IndexLens(first(indexing)) ∘ tupleindex2lens(indexing[2:end]) +tupleindex2optic(indexing::Tuple{}) = identity +tupleindex2optic(indexing::Tuple{<:Tuple}) = IndexLens(first(indexing)) # TODO: rest? +function tupleindex2optic(indexing::Tuple) + return IndexLens(first(indexing)) ∘ tupleindex2optic(indexing[2:end]) end """ @@ -104,70 +104,90 @@ julia> getsym(@varname(y)) getsym(vn::VarName{sym}) where {sym} = sym """ - getlens(vn::VarName) + getoptic(vn::VarName) -Return the lens of the Julia variable used to generate `vn`. +Return the optic of the Julia variable used to generate `vn`. ## Examples ```jldoctest -julia> getlens(@varname(x[1][2:3])) -(@lens _[1][2:3]) +julia> getoptic(@varname(x[1][2:3])) +(@o _[1][2:3]) -julia> getlens(@varname(y)) -(@lens _) +julia> getoptic(@varname(y)) +identity (generic function with 1 method) ``` """ -getlens(vn::VarName) = vn.lens - +getoptic(vn::VarName) = vn.optic """ get(obj, vn::VarName{sym}) -Alias for `get(obj, PropertyLens{sym}() ∘ getlens(vn))`. +Alias for `getoptic(vn)(obj)`. + +# Example + +```jldoctest; setup = :(nt = (a = 1, b = (c = [1, 2, 3],)); name = :nt) +julia> get(nt, @varname(nt.a)) +1 + +julia> get(nt, @varname(nt.b.c[1])) +1 + +julia> get(nt, @varname(\$name.b.c[1])) +1 +``` """ -function Setfield.get(obj, vn::VarName{sym}) where {sym} - return Setfield.get(obj, PropertyLens{sym}() ∘ getlens(vn)) +function Base.get(obj, vn::VarName{sym}) where {sym} + return getoptic(vn)(obj) end """ set(obj, vn::VarName{sym}, value) -Alias for `set(obj, PropertyLens{sym}() ∘ getlens(vn), value)`. +Alias for `set(obj, PropertyLens{sym}() ⨟ getoptic(vn), value)`. + +# Example + +```jldoctest; setup = :(using AbstractPPL: Accessors; nt = (a = 1, b = (c = [1, 2, 3],)); name = :nt) +julia> Accessors.set(nt, @varname(a), 10) +(a = 10, b = (c = [1, 2, 3],)) + +julia> Accessors.set(nt, @varname(b.c[1]), 10) +(a = 1, b = (c = [10, 2, 3],)) +``` """ -function Setfield.set(obj, vn::VarName{sym}, value) where {sym} - return Setfield.set(obj, PropertyLens{sym}() ∘ getlens(vn), value) +function Accessors.set(obj, vn::VarName{sym}, value) where {sym} + return Accessors.set(obj, PropertyLens{sym}() ⨟ getoptic(vn), value) end -Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getlens(vn)), h) +Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) function Base.:(==)(x::VarName, y::VarName) - return getsym(x) == getsym(y) && getlens(x) == getlens(y) + return getsym(x) == getsym(y) && getoptic(x) == getoptic(y) end -# Allow compositions with lenses. -function Base.:∘(vn::VarName{sym,<:Lens}, lens::Lens) where {sym} - return VarName{sym}(getlens(vn) ∘ lens) -end - -function Base.show(io::IO, vn::VarName{<:Any,<:Lens}) - # No need to check `Setfield.has_atlens_support` since - # `VarName` does not allow dynamic lenses. +function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T} print(io, getsym(vn)) - _print_application(io, getlens(vn)) + _show_optic(io, getoptic(vn)) end -# This is all just to allow to convert `Colon()` into `:`. -_print_application(io::IO, l::Lens) = Setfield.print_application(io, l) -function _print_application(io::IO, l::ComposedLens) - _print_application(io, l.outer) - _print_application(io, l.inner) +# modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502 +function _show_optic(io::IO, optic) + opts = Accessors.deopcompose(optic) + inner = Iterators.takewhile(x -> applicable(_shortstring, "", x), opts) + outer = Iterators.dropwhile(x -> applicable(_shortstring, "", x), opts) + if !isempty(outer) + show(io, opcompose(outer...)) + print(io, " ∘ ") + end + shortstr = reduce(_shortstring, inner; init="") + print(io, shortstr) end -_print_application(io::IO, l::IndexLens) = - print(io, "[", join(map(prettify_index, l.indices), ","), "]") -# This is a bit weird but whatever. We're almost always going to -# `concretize` anyways. -_print_application(io::IO, l::DynamicIndexLens) = print(io, l, "(_)") + +_shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]" +_shortstring(prev, ::typeof(identity)) = "$prev" +_shortstring(prev, o) = Accessors._shortstring(prev, o) prettify_index(x) = repr(x) prettify_index(::Colon) = ":" @@ -175,7 +195,7 @@ prettify_index(::Colon) = ":" """ Symbol(vn::VarName) -Return a `Symbol` represenation of the variable identifier `VarName`. +Return a `Symbol` representation of the variable identifier `VarName`. # Examples ```jldoctest @@ -266,7 +286,7 @@ Currently _not_ supported are: - Trailing ones: `x[2, 1]` does not subsume `x[2]` for a vector `x` """ function subsumes(u::VarName, v::VarName) - return getsym(u) == getsym(v) && subsumes(u.lens, v.lens) + return getsym(u) == getsym(v) && subsumes(getoptic(u), getoptic(v)) end # Idea behind `subsumes` for `Lens` is that we traverse the two lenses in parallel, @@ -274,20 +294,20 @@ end # `PropertyLens{:a}` and `PropertyLens{:b}` we immediately know that they do not subsume # each other since at the same level/depth they access different properties. # E.g. `x`, `x[1]`, i.e. `u` is always subsumed by `t` -subsumes(::IdentityLens, ::IdentityLens) = true -subsumes(::IdentityLens, ::Lens) = true -subsumes(::Lens, ::IdentityLens) = false +subsumes(::typeof(identity), ::typeof(identity)) = true +subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true +subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false -subsumes(t::ComposedLens, u::ComposedLens) = +subsumes(t::ComposedOptic, u::ComposedOptic) = subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) # If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a # leaf of the "lens-tree". -subsumes(t::ComposedLens, u::PropertyLens) = false +subsumes(t::ComposedOptic, u::PropertyLens) = false # Here we need to check if `u.outer` (i.e. the next lens to be applied from `u`) is # subsumed by `t`, since this would mean that the rest of the composition is also subsumed # by `t`. -subsumes(t::PropertyLens, u::ComposedLens) = subsumes(t, u.outer) +subsumes(t::PropertyLens, u::ComposedOptic) = subsumes(t, u.inner) # For `PropertyLens` either they have the same `name` and thus they are indeed the same. subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true @@ -299,8 +319,8 @@ subsumes(t::PropertyLens, u::PropertyLens) = false # FIXME: Does not correctly handle cases such as `subsumes(x, x[:])` # (but neither did old implementation). subsumes( - t::Union{IndexLens,ComposedLens{<:IndexLens}}, - u::Union{IndexLens,ComposedLens{<:IndexLens}} + t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, + u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}} ) = subsumes_indices(t, u) @@ -317,7 +337,7 @@ const ≍ = uncomparable # Therefore we must recurse until we reach something that is NOT # indexing, and then consider the sequence of indices leading up to this. """ - subsumes_indices(t::Lens, u::Lens) + subsumes_indices(t, u) Return `true` if the indexing represented by `t` subsumes `u`. @@ -326,12 +346,12 @@ e.g. `_[1][2].a[2]` and `_[1][2].a`. In such a scenario we do the following: 1. Combine `[1][2]` into a `Tuple` of indices using [`combine_indices`](@ref). 2. Do the same for `[1][2]`. 3. Compare the two tuples from (1) and (2) using `subsumes_indices`. -4. Since we're still undecided, we call `subsume(@lens(_.a[2]), @lens(_.a))` +4. Since we're still undecided, we call `subsume(@o(_.a[2]), @o(_.a))` which then returns `false`. # Example -```jldoctest; setup=:(using Setfield; using AbstractPPL: subsumes_indices) -julia> t = @lens(_[1].a); u = @lens(_[1]); +```jldoctest; setup=:(using Accessors; using AbstractPPL: subsumes_indices) +julia> t = @o(_[1].a); u = @o(_[1]); julia> subsumes_indices(t, u) false @@ -339,22 +359,22 @@ false julia> subsumes_indices(u, t) true -julia> # `IdentityLens` subsumes all. - subsumes_indices(@lens(_), t) +julia> # `identity` subsumes all. + subsumes_indices(identity, t) true -julia> # None subsumes `IdentityLens`. - subsumes_indices(t, @lens(_)) +julia> # None subsumes `identity`. + subsumes_indices(t, identity) false -julia> AbstractPPL.subsumes(@lens(_[1][2].a[2]), @lens(_[1][2].a)) +julia> AbstractPPL.subsumes(@o(_[1][2].a[2]), @o(_[1][2].a)) false -julia> AbstractPPL.subsumes(@lens(_[1][2].a), @lens(_[1][2].a[2])) +julia> AbstractPPL.subsumes(@o(_[1][2].a), @o(_[1][2].a[2])) true ``` """ -function subsumes_indices(t::Lens, u::Lens) +function subsumes_indices(t::ALLOWED_OPTICS, u::ALLOWED_OPTICS) t_indices, t_next = combine_indices(t) u_indices, u_next = combine_indices(u) @@ -378,18 +398,18 @@ function subsumes_indices(t::Lens, u::Lens) end """ - combine_indices(lens) + combine_indices(optic) Return sequential indexing into a single `Tuple` of indices, e.g. `x[:][1][2]` becomes `((Colon(), ), (1, ), (2, ))`. The result is compatible with [`subsumes_indices`](@ref) for `Tuple` input. """ -combine_indices(lens::Lens) = (), lens -combine_indices(lens::IndexLens) = (lens.indices,), nothing -function combine_indices(lens::ComposedLens{<:IndexLens}) - indices, next = combine_indices(lens.inner) - return (lens.outer.indices, indices...), next +combine_indices(optic::ALLOWED_OPTICS) = (), optic +combine_indices(optic::IndexLens) = (optic.indices,), nothing +function combine_indices(optic::ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}) + indices, next = combine_indices(optic.outer) + return (optic.inner.indices, indices...), next end """ @@ -427,11 +447,11 @@ subsumes_index(i, j) = i == j An indexing object wrapping the range of a `Base.Slice` object representing the concrete indices a `:` indicates. Behaves the same, but prints differently, namely, still as `:`. """ -struct ConcretizedSlice{T, R} <: AbstractVector{T} +struct ConcretizedSlice{T,R} <: AbstractVector{T} range::R end -ConcretizedSlice(s::Base.Slice{R}) where {R} = ConcretizedSlice{eltype(s.indices), R}(s.indices) +ConcretizedSlice(s::Base.Slice{R}) where {R} = ConcretizedSlice{eltype(s.indices),R}(s.indices) Base.show(io::IO, s::ConcretizedSlice) = print(io, ":") Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) = print(io, "ConcretizedSlice(", s.range, ")") @@ -459,9 +479,8 @@ reconcretize_index(original_index, lowered_index) = lowered_index reconcretize_index(original_index::Colon, lowered_index::Base.Slice) = ConcretizedSlice(lowered_index) - """ - concretize(l::Lens, x) + concretize(l, x) Return `l` instantiated on `x`, i.e. any information related to the runtime shape of `x` is evaluated. This concerns `begin`, `end`, and `:` slices. @@ -470,12 +489,12 @@ Basically, every index is converted to a concrete value using `Base.to_index` on slices are only converted to `ConcretizedSlice` (as opposed to `Base.Slice{Base.OneTo}`), to keep the result close to the original indexing. """ -concretize(I::Lens, x) = I +concretize(I::ALLOWED_OPTICS, x) = I concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x) concretize(I::IndexLens, x) = IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) -function concretize(I::ComposedLens, x) - x_inner = get(x, I.outer) # TODO: get view here - return ComposedLens(concretize(I.outer, x), concretize(I.inner, x_inner)) +function concretize(I::ComposedOptic, x) + x_inner = I.inner(x) # TODO: get view here + return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x)) end """ @@ -485,11 +504,11 @@ Return `vn` concretized on `x`, i.e. any information related to the runtime shap evaluated. This concerns `begin`, `end`, and `:` slices. # Examples -```jldoctest; setup=:(using Setfield) +```jldoctest; setup=:(using Accessors) julia> x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); -julia> getlens(@varname(x.a[1:end, end][:], true)) # concrete=true required for @varname -(@lens _.a[1:3, 2][:]) +julia> getoptic(@varname(x.a[1:end, end][:], true)) # concrete=true required for @varname +(@o _.a[1:3, 2][:]) julia> y = zeros(10, 10); @@ -497,11 +516,11 @@ julia> @varname(y[:], true) y[:] julia> # The underlying value is conretized, though: - AbstractPPL.getlens(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] + AbstractPPL.getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] ConcretizedSlice(Base.OneTo(100)) ``` """ -concretize(vn::VarName, x) = VarName(vn, concretize(getlens(vn), x)) +concretize(vn::VarName, x) = VarName(vn, concretize(getoptic(vn), x)) """ @varname(expr, concretize=false) @@ -521,7 +540,7 @@ concretized as `VarName` only supports non-dynamic indexing as determined by julia> x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); julia> @varname(x.a[1:end, end][:], true) -x.a[1:3,2][:] +x.a[1:3, 2][:] julia> @varname(x.a[end], false) # disable concretization ERROR: LoadError: Variable name `x.a[end]` is dynamic and requires concretization! @@ -542,46 +561,46 @@ julia> # Potentially surprising behaviour, but this is equivalent to what Base d ### General indexing -Under the hood Setfield.jl's `Lens` are used for the indexing: +Under the hood `optic`s are used for the indexing: ```jldoctest -julia> getlens(@varname(x)) -(@lens _) +julia> getoptic(@varname(x)) +identity (generic function with 1 method) -julia> getlens(@varname(x[1])) -(@lens _[1]) +julia> getoptic(@varname(x[1])) +(@o _[1]) -julia> getlens(@varname(x[:, 1])) -(@lens _[Colon(), 1]) +julia> getoptic(@varname(x[:, 1])) +(@o _[Colon(), 1]) -julia> getlens(@varname(x[:, 1][2])) -(@lens _[Colon(), 1][2]) +julia> getoptic(@varname(x[:, 1][2])) +(@o _[Colon(), 1][2]) -julia> getlens(@varname(x[1,2][1+5][45][3])) -(@lens _[1, 2][6][45][3]) +julia> getoptic(@varname(x[1,2][1+5][45][3])) +(@o _[1, 2][6][45][3]) ``` This also means that we support property access: ```jldoctest -julia> getlens(@varname(x.a)) -(@lens _.a) +julia> getoptic(@varname(x.a)) +(@o _.a) -julia> getlens(@varname(x.a[1])) -(@lens _.a[1]) +julia> getoptic(@varname(x.a[1])) +(@o _.a[1]) -julia> x = (a = [(b = rand(2), )], ); getlens(@varname(x.a[1].b[end], true)) -(@lens _.a[1].b[2]) +julia> x = (a = [(b = rand(2), )], ); getoptic(@varname(x.a[1].b[end], true)) +(@o _.a[1].b[2]) ``` -Interpolation can be used for names (the base name as well as property names). Variables within -indices are always evaluated in the calling scope, in the same manner as `Setfield` does: +Interpolation can be used for variable names, or array name, but not the lhs of a `.` expression. +Variables within indices are always evaluated in the calling scope. ```jldoctest julia> name, i = :a, 10; julia> @varname(x.\$name[i, i+1]) -x.a[10,11] +x.a[10, 11] julia> @varname(\$name) a @@ -595,47 +614,43 @@ a.x[1] julia> @varname(b.\$name.x[1]) b.a.x[1] ``` - -!!! compat "Julia 1.5" - Using `begin` in an indexing expression to refer to the first index requires at least - Julia 1.5. """ -macro varname(expr::Union{Expr,Symbol}, concretize::Bool=Setfield.need_dynamic_lens(expr)) +macro varname(expr::Union{Expr,Symbol}, concretize::Bool=Accessors.need_dynamic_optic(expr)) return varname(expr, concretize) end varname(sym::Symbol) = :($(AbstractPPL.VarName){$(QuoteNode(sym))}()) varname(sym::Symbol, _) = varname(sym) -function varname(expr::Expr, concretize=Setfield.need_dynamic_lens(expr)) +function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr)) if Meta.isexpr(expr, :ref) || Meta.isexpr(expr, :.) # Split into object/base symbol and lens. - sym_escaped, lens = Setfield.parse_obj_lens(expr) + sym_escaped, optics = _parse_obj_optic(expr) # Setfield.jl escapes the return symbol, so we need to unescape # to call `QuoteNode` on it. sym = drop_escape(sym_escaped) # This is to handle interpolated heads -- Setfield treats them differently: - # julia> Setfield.parse_obj_lens(@q $name.a) - # (:($(Expr(:escape, :_))), :((Setfield.compose)($(Expr(:escape, :name)), (Setfield.PropertyLens){:a}()))) - # julia> Setfield.parse_obj_lens(@q x.a) - # (:($(Expr(:escape, :x))), :((Setfield.compose)((Setfield.PropertyLens){:a}()))) + # julia> AbstractPPL._parse_obj_optics(Meta.parse("\$name.a")) + # (:($(Expr(:escape, :_))), (:($(Expr(:escape, :name))), :((PropertyLens){:a}()))) + # julia> AbstractPPL._parse_obj_optic(:(x.a)) + # (:($(Expr(:escape, :x))), :(Accessors.opticcompose((PropertyLens){:a}()))) if sym != :_ sym = QuoteNode(sym) else - sym = lens.args[2] - lens = Expr(:call, lens.args[1], lens.args[3:end]...) + sym = optics.args[2] + optics = Expr(:call, optics.args[1], optics.args[3:end]...) end if concretize return :( $(AbstractPPL.VarName){$sym}( - $(AbstractPPL.concretize)($lens, $sym_escaped) + $(AbstractPPL.concretize)($optics, $sym_escaped) ) ) - elseif Setfield.need_dynamic_lens(expr) + elseif Accessors.need_dynamic_optic(expr) error("Variable name `$(expr)` is dynamic and requires concretization!") else - :($(AbstractPPL.VarName){$sym}($lens)) + return :($(AbstractPPL.VarName){$sym}($optics)) end elseif Meta.isexpr(expr, :$, 1) return :($(AbstractPPL.VarName){$(esc(expr.args[1]))}()) @@ -650,6 +665,66 @@ function drop_escape(expr::Expr) return Expr(expr.head, map(x -> drop_escape(x), expr.args)...) end +function _parse_obj_optic(ex) + obj, optics = _parse_obj_optics(ex) + optic = Expr(:call, :(Accessors.opticcompose), optics...) + obj, optic +end + +# Accessors doesn't have the same support for interpolation, so copy and modify Setfield's parsing functions +is_interpolation(x) = x isa Expr && x.head == :$ + +function _parse_obj_optics_composite(lensexprs::Vector) + if isempty(lensexprs) + return esc(:_), () + else + obj, outermostlens = _parse_obj_optics(lensexprs[1]) + innerlenses = map(lensexprs[2:end]) do innerex + o, lens = _parse_obj_optics(innerex) + @assert o == esc(:_) + lens + end + return obj, (outermostlens, innerlenses...) + end +end + +function _parse_obj_optics(ex) + if @capture(ex, ∘(opticsexprs__)) + return _parse_obj_optics_composite(opticsexprs) + elseif is_interpolation(ex) + @assert length(ex.args) == 1 + return esc(:_), (esc(ex.args[1]),) + elseif @capture(ex, front_[indices__]) + obj, frontoptics = _parse_obj_optics(front) + if any(Accessors.need_dynamic_optic, indices) + @gensym collection + indices = Accessors.replace_underscore.(indices, collection) + dims = length(indices) == 1 ? nothing : 1:length(indices) + lindices = esc.(Accessors.lower_index.(collection, indices, dims)) + optics = :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),))) + else + index = esc(Expr(:tuple, indices...)) + optics = :($(Accessors.IndexLens)($index)) + end + elseif @capture(ex, front_.property_) + obj, frontoptics = _parse_obj_optics(front) + if property isa Union{Symbol,String} + optics = :($(Accessors.PropertyLens){$(QuoteNode(property))}()) + elseif is_interpolation(property) + optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}()) + else + throw(ArgumentError( + string("Error while parsing :($ex). Second argument to `getproperty` can only be", + "a `Symbol` or `String` literal, received `$property` instead.") + )) + end + else + obj = esc(ex) + return obj, () + end + obj, tuple(frontoptics..., optics) +end + """ @vsym(expr) diff --git a/test/Project.toml b/test/Project.toml index 21903a1..adeb617 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,14 +1,13 @@ [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Documenter = "0.26.3, 0.27" InvertedIndices = "1" OffsetArrays = "1" -Setfield = "0.7.1, 0.8, 1" julia = "1" diff --git a/test/varname.jl b/test/varname.jl index d9b52db..d5cb4f7 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -1,9 +1,11 @@ +using Accessors using InvertedIndices using OffsetArrays -using Setfield using AbstractPPL: ⊑, ⊒, ⋢, ⋣, ≍ +using AbstractPPL: Accessors +using AbstractPPL.Accessors: IndexLens, PropertyLens macro test_strict_subsumption(x, y) quote @@ -21,20 +23,18 @@ end @test @varname(A[:, 1][1+1]) == @varname(A[:, 1][2]) @test(@varname(A[:, 1][2]) == - VarName{:A}(@lens(_[:, 1]) ∘ @lens(_[2])) == - VarName{:A}(@lens(_[:, 1])) ∘ @lens(_[2]) == - VarName{:A}() ∘ @lens(_[:, 1]) ∘ @lens(_[2])) + VarName{:A}(@o(_[:, 1]) ⨟ @o(_[2]))) # concretization y = zeros(10, 10) x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); @test @varname(y[begin, i], true) == @varname(y[1, 1:10]) - @test @varname(y[:], true) == @varname(y[1:100]) - @test @varname(y[:, begin], true) == @varname(y[1:10, 1]) - @test getlens(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === + @test get(y, @varname(y[:], true)) == get(y, @varname(y[1:100])) + @test get(y, @varname(y[:, begin], true)) == get(y, @varname(y[1:10, 1])) + @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) - @test @varname(x.a[1:end, end][:], true) == @varname(x.a[1:3,2][1:3]) + @test get(x, @varname(x.a[1:end, end][:], true)) == get(x, @varname(x.a[1:3,2][1:3])) end @testset "subsumption with standard indexing" begin @@ -83,10 +83,29 @@ end @testset "non-standard indexing" begin A = rand(10, 10) - @test @varname(A[1, Not(3)], true) == @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]) + @test get(A, @varname(A[1, Not(3)], true)) == get(A, @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]])) B = OffsetArray(A, -5, -5) # indices -4:5×-4:5 - @test @varname(B[1, :], true) == @varname(B[1, -4:5]) + @test collect(get(B, @varname(B[1, :], true))) == collect(get(B, @varname(B[1, -4:5]))) end + + @testset "type stability" begin + @inferred VarName{:a}() + @inferred VarName{:a}(IndexLens(1)) + @inferred VarName{:a}(IndexLens(1, 2)) + @inferred VarName{:a}(PropertyLens(:b)) + @inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b))) + + a = [1, 2, 3] + @inferred get(a, @varname(a[1])) + + b = (a=[1, 2, 3],) + @inferred get(b, @varname(b.a[1])) + @inferred Accessors.set(b, @varname(a[1]), 10) + + c = (b=(a=[1, 2, 3],),) + @inferred get(c, @varname(c.b.a[1])) + @inferred Accessors.set(c, @varname(b.a[1]), 10) + end end