Skip to content

Commit

Permalink
Use _copy_impl! instead of custom function
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Oct 6, 2018
1 parent e811fa6 commit c7b6896
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 44 deletions.
13 changes: 11 additions & 2 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,9 @@ emptymutable(itr, ::Type{U}) where {U} = Vector{U}()

## from general iterable to any array

function copyto!(dest::AbstractArray, src)
copyto!(dest::AbstractArray, src) = _copyto_impl!(dest, src, true)

function _copyto_impl!(dest::AbstractArray, src, allowshorter::Bool)
destiter = eachindex(dest)
y = iterate(destiter)
for x in src
Expand All @@ -649,6 +651,9 @@ function copyto!(dest::AbstractArray, src)
dest[y[1]] = x
y = iterate(destiter, y[2])
end
if !allowshorter && y !== nothing
throw(ArgumentError(string("source has fewer elements than destination")))
end
return dest
end

Expand Down Expand Up @@ -720,8 +725,12 @@ end
## copy between abstract arrays - generally more efficient
## since a single index variable can be used.

copyto!(dest::AbstractArray, src::AbstractArray) =
function _copyto_impl!(dest::AbstractArray, src::AbstractArray, allowshorter::Bool)
if !allowshorter && length(src) < length(dest)
throw(ArgumentError("source has fewer elements than destination"))
end
copyto!(IndexStyle(dest), dest, IndexStyle(src), src)
end

function copyto!(::IndexStyle, dest::AbstractArray, ::IndexStyle, src::AbstractArray)
destinds, srcinds = LinearIndices(dest), LinearIndices(src)
Expand Down
34 changes: 12 additions & 22 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,12 @@ function copyto!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer,
unsafe_copyto!(dest, doffs, src, soffs, n)
end

copyto!(dest::Array{T}, src::Array{T}) where {T} = copyto!(dest, 1, src, 1, length(src))
function _copyto_impl!(dest::Array{T}, src::Array{T}, allowshorter::Bool) where {T}
if !allowshorter && length(src) < length(dest)
throw(ArgumentError("source has fewer elements than destination"))
end
copyto!(dest, 1, src, 1, length(src))
end

# N.B: The generic definition in multidimensional.jl covers, this, this is just here
# for bootstrapping purposes.
Expand Down Expand Up @@ -517,25 +522,10 @@ julia> collect(Float64, 1:2:5)
"""
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))

function copyto_check_length!(dest::Array, src)
len = length(dest)
i = 0
for x in src
i == len &&
throw(ErrorException("iterator returned more elements than its declared length"))
i += 1
@inbounds dest[i] = x
end
if i < len
throw(ErrorException("iterator returned fewer elements than its declared length"))
end
return dest
end

_collect(::Type{T}, itr, isz::HasLength) where {T} =
copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_copyto_impl!(Vector{T}(undef, Int(length(itr)::Integer)), itr, false)
_collect(::Type{T}, itr, isz::HasShape) where {T} =
copyto_check_length!(similar(Array{T}, axes(itr)), itr)
_copyto_impl!(similar(Array{T}, axes(itr)), itr, false)
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
a = Vector{T}()
for x in itr
Expand Down Expand Up @@ -578,7 +568,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr)
_copyto_impl!(_similar_for(cont, eltype(itr), itr, isz), itr, false)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
a = _similar_for(cont, eltype(itr), itr, isz)
Expand Down Expand Up @@ -636,7 +626,7 @@ function collect(itr::Generator)
y = iterate(itr)
if y === nothing
if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0
throw(ErrorException("iterator returned fewer elements than its declared length"))
throw(ArgumentError("iterator returned fewer elements than its declared length"))
end
return _array_for(et, itr.iter, isz)
end
Expand Down Expand Up @@ -689,9 +679,9 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T
end
lastidx = lastindex(dest)
i-1 < lastidx &&
throw(ErrorException("iterator returned fewer elements than its declared length"))
throw(ArgumentError("iterator returned fewer elements than its declared length"))
i-1 > lastidx &&
throw(ErrorException("iterator returned more elements than its declared length"))
throw(ArgumentError("iterator returned more elements than its declared length"))
return dest
end

Expand Down
6 changes: 5 additions & 1 deletion base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,12 @@ julia> y
"""
copyto!(dest, src)

function copyto!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N}
function _copyto_impl!(dest::AbstractArray{T,N}, src::AbstractArray{T,N},
allowshorter::Bool) where {T,N}
checkbounds(dest, axes(src)...)
if !allowshorter && length(src) < length(dest)
throw(ArgumentError("source has fewer elements than destination"))
end
src′ = unalias(dest, src)
for I in eachindex(IndexStyle(src′,dest), src′)
@inbounds dest[I] = src′[I]
Expand Down
36 changes: 17 additions & 19 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2525,29 +2525,27 @@ Base.view(::T25958, args...) = args
@test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end]
end

# Iterator with declared length too large
struct InvalidIter1 end
Base.length(::InvalidIter1) = 2
Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1))
# Iterator with declared length too small
struct InvalidIter2 end
Base.length(::InvalidIter2) = 2
Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1))
@testset "collect on iterator with incorrect length" begin
# Iterator with declared length too large
struct InvalidIter1 end
Base.length(::InvalidIter1) = 2
Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1))

@test_throws ErrorException collect(InvalidIter1())
@test_throws ErrorException collect(Any, InvalidIter1())
@test_throws ErrorException collect(Int, InvalidIter1())
@test_throws ErrorException [x for x in InvalidIter1()]
# Should also throw ErrorException
@test_throws ArgumentError collect(InvalidIter1())
@test_throws ArgumentError collect(Any, InvalidIter1())
@test_throws ArgumentError collect(Int, InvalidIter1())
@test_throws ArgumentError [x for x in InvalidIter1()]
# Should also throw ArgumentError
@test_broken length(Int[x for x in InvalidIter1()]) != 2

# Iterator with declared length too small
struct InvalidIter2 end
Base.length(::InvalidIter2) = 2
Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1))

@test_throws ErrorException collect(InvalidIter2())
@test_throws ErrorException collect(Any, InvalidIter2())
@test_throws ErrorException collect(Int, InvalidIter2())
@test_throws ArgumentError collect(InvalidIter2())
@test_throws ArgumentError collect(Any, InvalidIter2())
@test_throws ArgumentError collect(Int, InvalidIter2())
# These cases cannot be tested without writing to invalid memory
# unless the function checked bounds on each iteration (#29458)
# @test_throws ErrorException [x for x in InvalidIter2()]
# @test_throws ErrorException Int[x for x in InvalidIter2()]
# @test_broken length(Int[x for x in InvalidIter2()]) != 2
end

0 comments on commit c7b6896

Please sign in to comment.