Skip to content

Commit 60a8c8c

Browse files
committed
define __broadcast ourselves
1 parent b4de96b commit 60a8c8c

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

ext/StructArraysStaticArraysExt.jl

+27-2
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArray
3333

3434
# Broadcast overload
3535
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
36-
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
36+
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype
3737
using StructArrays: isnonemptystructtype
38-
using Base.Broadcast: Broadcasted
38+
using Base.Broadcast: Broadcasted, _broadcast_getindex
3939

4040
# StaticArrayStyle has no similar defined.
4141
# Overload `try_struct_copy` instead.
@@ -79,4 +79,29 @@ end
7979
end
8080
end
8181

82+
# The `__broadcast` kernal is copied from `StaticArrays.jl`.
83+
# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl
84+
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
85+
sizes = [sz.parameters[1] for sz s.parameters]
86+
87+
indices = CartesianIndices(newsize)
88+
exprs = similar(indices, Expr)
89+
for (j, current_ind) enumerate(indices)
90+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
91+
exprs[j] = :(f($(exprs_vals...)))
92+
end
93+
94+
return quote
95+
Base.@_inline_meta
96+
return tuple($(exprs...))
97+
end
98+
end
99+
100+
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
101+
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
102+
li = LinearIndices(oldsize)
103+
ind = _broadcast_getindex(li, newindex)
104+
return :(a[$i][$ind])
105+
end
106+
82107
end

0 commit comments

Comments
 (0)