Skip to content

Commit

Permalink
Allow init argument for sum (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat authored Oct 10, 2024
1 parent 93c7e3b commit 381a59d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

abstract type AbstractMutable end

function Base.sum(a::AbstractArray{<:AbstractMutable})
return operate(sum, a)
function Base.sum(a::AbstractArray{<:AbstractMutable}; kwargs...)
return operate(sum, a; kwargs...)
end

# When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting
Expand Down
13 changes: 6 additions & 7 deletions src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N}
return accumulator
end

function operate(::typeof(sum), a::AbstractArray)
return mapreduce(
identity,
add!!,
a;
init = zero(promote_operation(+, eltype(a), eltype(a))),
)
function operate(
::typeof(sum),
a::AbstractArray;
init = zero(promote_operation(+, eltype(a), eltype(a))),
)
return mapreduce(identity, add!!, a; init)
end
25 changes: 25 additions & 0 deletions test/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,28 @@ end
end
end
end

function non_mutable_sum_pr306(x)
y = zero(eltype(x))
for xi in x
y += xi
end
return y
end

@testset "sum_with_init" begin
x = convert(Vector{DummyBigInt}, 1:100)
# compilation
@allocated sum(x)
@allocated sum(x; init = DummyBigInt(0))
@allocated non_mutable_sum_pr306(x)
# now test actual allocations
no_init = @allocated sum(x)
with_init = @allocated sum(x; init = DummyBigInt(0))
no_ma = @allocated non_mutable_sum_pr306(x)
# There's an additional 16 bytes for kwarg version. Upper bound by 40 to be
# safe between Julia versions
@test with_init <= no_init + 40
# MA is at least 10-times better than no MA for this example
@test 10 * with_init < no_ma
end

0 comments on commit 381a59d

Please sign in to comment.