Skip to content

Commit

Permalink
fix type stability of kwarg handling
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Jan 28, 2024
1 parent b839853 commit 9fa0e32
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@ using ThreadsBasics: chunks, @spawn
using Base: @propagate_inbounds
using Base.Threads: nthreads, @threads


struct NoInit end

function tmapreduce(f, op, A;
init=NoInit(),
nchunks::Int = 2 * nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic,
outputtype::Type = Any,)
outputtype::Type = Any,
kwargs...)
if schedule === :dynamic
_tmapreduce(f, op, A, outputtype, init, nchunks, split)
_tmapreduce(f, op, A, outputtype, nchunks, split; kwargs...)
elseif schedule === :static
_tmapreduce_static(f, op, outputtype, A, init, nchunks, split)
_tmapreduce_static(f, op, outputtype, A, nchunks, split; kwargs...)
else
schedule_err(schedule)
end
Expand All @@ -27,24 +24,14 @@ end

treducemap(op, f, A; kwargs...) = tmapreduce(f, op, A; kwargs...)

function _tmapreduce(f, op, A, ::Type{OutputType}, init, nchunks, split=:batch)::OutputType where {OutputType}
if init isa NoInit
kwargs = (;)
else
kwargs = (;init)
end
function _tmapreduce(f, op, A, ::Type{OutputType}, nchunks, split=:batch; kwargs...)::OutputType where {OutputType}
tasks = map(chunks(A; n=nchunks, split)) do inds
@spawn mapreduce(f, op, @view(A[inds]); kwargs...)
end
mapreduce(fetch, op, tasks)
end

function _tmapreduce_static(f, op, ::Type{OutputType}, A, init, nchunks, split) where {OutputType}
if init isa NoInit
kwargs = (;)
else
kwargs = (;init)
end
function _tmapreduce_static(f, op, ::Type{OutputType}, A, nchunks, split; kwargs...) where {OutputType}
results = Vector{OutputType}(undef, min(nchunks, length(A)))
@threads :static for (ichunk, inds) enumerate(chunks(A; n=nchunks, split))
results[ichunk] = mapreduce(f, op, @view(A[inds]); kwargs...)
Expand Down

0 comments on commit 9fa0e32

Please sign in to comment.