diff --git a/src/implementation.jl b/src/implementation.jl index d98185e9..7402060a 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -6,19 +6,16 @@ using ThreadsBasics: chunks, @spawn using Base: @propagate_inbounds using Base.Threads: nthreads, @threads - -struct NoInit end - function tmapreduce(f, op, A; - init=NoInit(), nchunks::Int = 2 * nthreads(), split::Symbol = :batch, schedule::Symbol =:dynamic, - outputtype::Type = Any,) + outputtype::Type = Any, + kwargs...) if schedule === :dynamic - _tmapreduce(f, op, A, outputtype, init, nchunks, split) + _tmapreduce(f, op, A, outputtype, nchunks, split; kwargs...) elseif schedule === :static - _tmapreduce_static(f, op, outputtype, A, init, nchunks, split) + _tmapreduce_static(f, op, outputtype, A, nchunks, split; kwargs...) else schedule_err(schedule) end @@ -27,24 +24,14 @@ end treducemap(op, f, A; kwargs...) = tmapreduce(f, op, A; kwargs...) -function _tmapreduce(f, op, A, ::Type{OutputType}, init, nchunks, split=:batch)::OutputType where {OutputType} - if init isa NoInit - kwargs = (;) - else - kwargs = (;init) - end +function _tmapreduce(f, op, A, ::Type{OutputType}, nchunks, split=:batch; kwargs...)::OutputType where {OutputType} tasks = map(chunks(A; n=nchunks, split)) do inds @spawn mapreduce(f, op, @view(A[inds]); kwargs...) end mapreduce(fetch, op, tasks) end -function _tmapreduce_static(f, op, ::Type{OutputType}, A, init, nchunks, split) where {OutputType} - if init isa NoInit - kwargs = (;) - else - kwargs = (;init) - end +function _tmapreduce_static(f, op, ::Type{OutputType}, A, nchunks, split; kwargs...) where {OutputType} results = Vector{OutputType}(undef, min(nchunks, length(A))) @threads :static for (ichunk, inds) ∈ enumerate(chunks(A; n=nchunks, split)) results[ichunk] = mapreduce(f, op, @view(A[inds]); kwargs...)