Skip to content

Commit

Permalink
Implement tcollect functionality. (#1)
Browse files Browse the repository at this point in the history
* Implement `tcollect`, which essentially just calls `tmap` with a Generator input.

* Add test for `tcollect`.

* also add array method for `tcollect`.

* Update runtests.jl

* fix typo

---------

Co-authored-by: Mason Protter <[email protected]>
  • Loading branch information
RomeoV and MasonProtter authored Jan 29, 2024
1 parent 9c20ea2 commit 3c8f4c1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
18 changes: 17 additions & 1 deletion src/ThreadsBasics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module ThreadsBasics
using StableTasks: @spawn
using ChunkSplitters: chunks

export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach
export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect

"""
tmapreduce(f, op, A::AbstractArray;
Expand Down Expand Up @@ -163,6 +163,22 @@ of `out[i] = f(A[i])` for each index `i` of `A` and `out`.
"""
function tmap! end

"""
tcollect(::Type{OutputType}, gen::Base.Generator{AbstractArray, F};
nchunks::Int = 2 * nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
A multithreaded function like `Base.collect`. Essentially just calls `tmap` on the generator function and inputs.
## Keyword arguments:
- `nchunks::Int` (default 2 * nthreads()) is passed to `ChunkSplitters.chunks` to inform it how many pieces of data should be worked on in parallel. Greater `nchunks` typically helps with [load balancing](https://en.wikipedia.org/wiki/Load_balancing_(computing)), but at the expense of creating more overhead.
- `split::Symbol` (default `:batch`) is passed to `ChunkSplitters.chunks` to inform it if the data chunks to be worked on should be contiguous (:batch) or shuffled (:scatter). If `scatter` is chosen, then your reducing operator `op` **must** be [commutative](https://en.wikipedia.org/wiki/Commutative_property) in addition to being associative, or you could get incorrect results!
- `schedule::Symbol` either `:dynamic` or `:static` (default `:dynamic`), determines how the parallel portions of the calculation are scheduled. `:dynamic` scheduling is generally preferred since it is more flexible and better at load balancing, but `:static` scheduling can sometimes be more performant when the time it takes to complete a step of the calculation is highly uniform, and no other parallel functions are running at the same time.
"""
function tcollect end


include("implementation.jl")

Expand Down
9 changes: 8 additions & 1 deletion src/implementation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Implementation

import ThreadsBasics: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!
import ThreadsBasics: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!, tcollect

using ThreadsBasics: chunks, @spawn
using Base: @propagate_inbounds
Expand Down Expand Up @@ -67,5 +67,12 @@ end
out
end

#-------------------------------------------------------------

function tcollect(::Type{T}, gen::Base.Generator{<:AbstractArray, F}; kwargs...) where {T, F}
tmap(gen.f, T, gen.iter; kwargs...)
end
tcollect(::Type{T}, A; kwargs...) where {T} = tmap(identity, T, A; kwargs...)


end # module Implementation
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ using Test, ThreadsBasics

map_f_itr = map(f, itr)
@test all(tmap(f, Any, itr; kwargs...) .~ map_f_itr)
@test all(tcollect(Any, (f(x) for x in itr); kwargs...) .~ map_f_itr)
@test all(tcollect(Any, f.(itr); kwargs...) .~ map_f_itr)

RT = Core.Compiler.return_type(f, Tuple{eltype(itr)})

@test tmap(f, RT, itr; kwargs...) ~ map_f_itr
@test tcollect(RT, (f(x) for x in itr); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itr); kwargs...) ~ map_f_itr
end
end
end
Expand Down

0 comments on commit 3c8f4c1

Please sign in to comment.