Skip to content

Commit c7b9554

Browse files
committed
Fix tests, implement copyto!
1 parent 3519be9 commit c7b9554

File tree

4 files changed

+68
-25
lines changed

4 files changed

+68
-25
lines changed

ext/LazyArraysStaticArraysExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LazyArraysStaticArraysExt
22

33
using LazyArrays
4-
using LazyArrays: LazyArrayStyle
4+
using LazyArrays: AbstractLazyArrayStyle
55
using StaticArrays
66
using StaticArrays: StaticArrayStyle
77

@@ -10,6 +10,6 @@ function LazyArrays._vcat_layout_broadcasted((Ahead,Atail)::Tuple{SVector{M},Any
1010
Vcat(op.(Ahead,Bhead), op.(Atail,Btail))
1111
end
1212

13-
Base.BroadcastStyle(L::LazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L
13+
Base.BroadcastStyle(L::AbstractLazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L
1414

1515
end

src/cache.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayo
334334
######
335335

336336
struct CachedArrayStyle{N} <: AbstractLazyArrayStyle{N} end
337+
CachedArrayStyle(::Val{N}) where N = CachedArrayStyle{N}()
338+
CachedArrayStyle{M}(::Val{N}) where {N,M} = CachedArrayStyle{N}()
337339

338340
BroadcastStyle(::Type{<:AbstractCachedArray{<:Any,N}}) where N = CachedArrayStyle{N}()
339341
BroadcastStyle(::Type{<:SubArray{<:Any,N,<:AbstractCachedArray{<:Any,M}}}) where {N,M} = CachedArrayStyle{M}()
@@ -381,7 +383,36 @@ for op in (:*, :\, :+, :-)
381383
@eval layout_broadcasted(::ZerosLayout, ::CachedLayout, ::typeof($op), a::AbstractVector, b::AbstractVector) = broadcast(DefaultArrayStyle{1}(), $op, a, b)
382384
end
383385

386+
function resize_bcargs(bc::Broadcasted{<:CachedArrayStyle}, dest)
387+
rsz_args = let len = length(dest)
388+
map(bc.args) do arg
389+
resizedata!(arg, len)
390+
iscached = arg isa AbstractCachedArray || (arg isa SubArray && parent(arg) isa AbstractCachedArray)
391+
iscached ? cacheddata(arg) : arg
392+
end
393+
end
394+
return broadcasted(bc.f, rsz_args...)
395+
end
384396

397+
function similar(bc::Broadcasted{<:CachedArrayStyle}, ::Type{T}) where T
398+
return CachedArray(zeros(T, axes(bc)))
399+
end
400+
401+
function copyto!(dest::AbstractArray, bc::Broadcasted{<:CachedArrayStyle})
402+
#=
403+
Without flatten, we were observing some stack overflows in some cases for nested broadcasts, e.g.
404+
using SemiclassicalOrthogonalPolynomials, ClassicalOrthogonalPolynomials
405+
Q = Normalized(Legendre())
406+
P = SemiclassicalOrthogonalPolynomials.RaisedOP(Q, 1)
407+
A, = ClassicalOrthogonalPolynomials.recurrencecoefficients(Q)
408+
d = -inv(A[1] * SemiclassicalOrthogonalPolynomials._p0(Q) * P.ℓ[1])
409+
κ = d * SemiclassicalOrthogonalPolynomials.normalizationconstant(1, P)
410+
κ[1:2]
411+
leads to a stack overflow.
412+
=#
413+
rsz_bc = resize_bcargs(Base.Broadcast.flatten(bc), dest)
414+
copyto!(dest, rsz_bc)
415+
end
385416

386417
###
387418
# norm

test/cachetests.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,6 @@ using Infinities
459459
@test a[1:5] == zeros(5)
460460
end
461461

462-
463462
@testset "Issue #327" begin
464463
A = cache(Zeros((1:5, OneToInf())))
465464
B = cache(Zeros((1:5, OneToInf())))
@@ -487,6 +486,40 @@ using Infinities
487486
B[5, 7] = 3.4
488487
@test A == B
489488
end
489+
490+
@testset "copyto! with CachedArrayStyle" begin
491+
a = Accumulate(*, 1:5);
492+
b = BroadcastVector(*, 2, a);
493+
dest = Vector{Int}(undef, 3)
494+
src = view(b, 1:3)
495+
bc = LazyArrays._broadcastarray2broadcasted(src);
496+
@test similar(bc, Float32) == cache(zeros(Float32, 3)) && similar(bc, Float32) isa CachedArray{Float32}
497+
@test a.datasize == (1,)
498+
@inferred LazyArrays.resize_bcargs(bc, dest);
499+
@test a.datasize == (3,)
500+
dest = Vector{Int}(undef, 1)
501+
src = view(b, 5:5);
502+
bc = LazyArrays._broadcastarray2broadcasted(src);
503+
@inferred LazyArrays.resize_bcargs(bc, dest);
504+
@test a.datasize == (5,)
505+
506+
a = Accumulate(*, 1:5); # reset to test different resizing
507+
b = BroadcastVector(*, 2, a);
508+
dest = Vector{Int}(undef, 4)
509+
src = view(b,2:5)
510+
bc = LazyArrays._broadcastarray2broadcasted(src);
511+
rbc = LazyArrays.resize_bcargs(bc, dest);
512+
@test Base.Broadcast.BroadcastStyle(typeof(rbc)) == Base.Broadcast.DefaultArrayStyle{1}()
513+
@test rbc.f === bc.f
514+
@test rbc.args == (2, a[2:5])
515+
516+
a = Accumulate(*, 1:5); # reset to ensure copyto! is working as intended
517+
b = BroadcastVector(*, 2, a);
518+
dest = Vector{Int}(undef, 3);
519+
src = view(b,2:4);
520+
copyto!(dest, src)
521+
@test dest == [4,12,48]
522+
end
490523
end
491524

492525
end # module

test/runtests.jl

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -379,27 +379,6 @@ end
379379
@test a[end] prod(1 .+ (1:10_000_000).^(-2.0))
380380
@test LazyArrays.AccumulateAbstractVector(*, 1:5) == Accumulate(*, 1:5)
381381
@test LazyArrays.AccumulateAbstractVector(*, 1:5) isa LazyArrays.AccumulateAbstractVector
382-
383-
@testset "Broadcasted Cached" begin
384-
a = Accumulate(*, 1:5)
385-
b = BroadcastVector(*, 2, a);
386-
387-
dest = Vector{Int}(undef, 3)
388-
copyto!(dest, view(b,1:3))
389-
390-
# lets step through the copyto! to reduce to MWE
391-
bc = LazyArrays._broadcastarray2broadcasted(view(b,1:3))
392-
# this is equivalent to
393-
v = view(a,1:3)
394-
bc = broadcasted(*, 2, v)
395-
396-
copyto!(dest, bc)
397-
398-
399-
# what we want:
400-
resizedata!(v, length(dest))
401-
copyto!(dest, broadcasted(*, 2, LazyArrays.cacheddata(v)))
402-
end
403382
end
404383
end
405384

@@ -465,4 +444,4 @@ end
465444

466445
include("blocktests.jl")
467446
include("bandedtests.jl")
468-
include("blockbandedtests.jl")
447+
include("blockbandedtests.jl")

0 commit comments

Comments
 (0)