Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inclusive_scan with initial value support (warp/block) #1749

Merged
merged 28 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2f93334
Add warp inclusive_scan with init
gonidelis May 15, 2024
66bcb0a
Add docs for warp scan inclusive scan
gonidelis May 16, 2024
9c33ead
Fix docs separator pitfall
gonidelis May 17, 2024
9c6da03
Warp scan docs and readability fixes
gonidelis May 20, 2024
917b496
Add block inclusive_scan with init value (block_aggregate included)
gonidelis May 20, 2024
30988ea
Add docs for block scan inclusive scan
gonidelis May 20, 2024
fba1605
Add block scan raking and warp scans implementation support for initi…
gonidelis May 22, 2024
87584bd
Input cannot be const for Block::InclusiveScan
gonidelis May 22, 2024
e41ccea
Add tests for value based block scan APIs
gonidelis May 23, 2024
fbf080b
Fix block InclusiveScan value-based implementation internals
gonidelis May 29, 2024
a29693e
Refine blockScan unit tests
gonidelis May 29, 2024
402d435
Add value based API tests
gonidelis May 31, 2024
6b8c673
Revert "Input cannot be const for Block::InclusiveScan"
gonidelis May 31, 2024
bd3e8b6
Revert "Add value based API tests"
gonidelis May 31, 2024
52b1a0a
Revert "Fix block InclusiveScan value-based implementation internals"
gonidelis May 31, 2024
76379d7
Revert "Add tests for value based block scan APIs"
gonidelis May 31, 2024
632a86a
Revert "Revert "Input cannot be const for Block::InclusiveScan""
gonidelis May 31, 2024
ef07f8f
Remove value-based kernels from off branch
gonidelis May 31, 2024
45182ff
Add array based API tests for block inclusive scan
gonidelis May 31, 2024
403f4a1
Add device inclusive scan with init value
gonidelis Jun 1, 2024
156e804
Revert "Add device inclusive scan with init value"
gonidelis Jun 1, 2024
1bd8ffa
Resolve CI issues on block_scan_api test
gonidelis Jun 3, 2024
8394a55
Remove more value based APIs
gonidelis Jun 5, 2024
f7b7651
Remove even more value based APIs, complement docs with details
gonidelis Jun 5, 2024
c9131ac
Resolve reviews, add warp_scan API test
gonidelis Jun 7, 2024
9582835
Add warp_aggregate API test
gonidelis Jun 7, 2024
e41e85a
Improve warp_scan api test
gonidelis Jun 10, 2024
690571a
Resolve final reviews
gonidelis Jun 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions cub/cub/block/block_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2242,6 +2242,63 @@ public:
}
}

//! @rst
//! Computes an inclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes an array of consecutive input elements.
//!
//! - Supports non-commutative scan operators.
//! - @blocked
//! - @granularity
//! - @smemreuse
//!
//! Snippet
//! +++++++
//!
//! The code snippet below illustrates an inclusive prefix max scan of 128 integer items that
//! are partitioned in a :ref:`blocked arrangement <flexible-data-arrangement>` across 64 threads
//! where each thread owns 2 consecutive items.
//!
//! .. literalinclude:: ../../test/catch2_test_block_scan_api.cu
//!
//! :language: c++
//! :dedent:
//! :start-after: example-begin inclusive-scan-array-init-value
//! :end-before: example-end inclusive-scan-array-init-value
//!
//!
//! @endrst
//!
//! @tparam ITEMS_PER_THREAD
//! **[inferred]** The number of consecutive items partitioned onto each thread.
//!
//! @tparam ScanOp
//! **[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)`
//!
//! @param[in] input
//! Calling thread's input items
//!
//! @param[out] output
//! Calling thread's output items (may be aliased to `input`)
//!
//! @param[in] initial_value
//! Initial value to seed the inclusive scan (uniform across block)
//!
//! @param[in] scan_op
//! Binary scan functor
template <int ITEMS_PER_THREAD, typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
InclusiveScan(T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op)
gonidelis marked this conversation as resolved.
Show resolved Hide resolved
{
// Reduce consecutive thread items in registers
T thread_prefix = internal::ThreadReduce(input, scan_op);

// Exclusive thread block-scan
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op);

// Exclusive scan in registers with prefix as seed
internal::ThreadScanInclusive(input, output, scan_op, thread_prefix);
}

//! @rst
//! Computes an inclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes an array of consecutive input elements. Also provides every thread
Expand Down Expand Up @@ -2325,6 +2382,69 @@ public:
}
}

//! @rst
//! Computes an inclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes an array of consecutive input elements. Also provides every thread
//! with the block-wide ``block_aggregate`` of all inputs.
//!
//! - Supports non-commutative scan operators.
//! - @blocked
//! - @granularity
//! - @smemreuse
//!
//! Snippet
//! +++++++
//!
//! The code snippet below illustrates an inclusive prefix max scan of 128 integer items that
//! are partitioned in a :ref:`blocked arrangement <flexible-data-arrangement>` across 64 threads
//! where each thread owns 2 consecutive items.
//!
//! .. literalinclude:: ../../test/catch2_test_block_scan_api.cu
//!
//! :language: c++
//! :dedent:
//! :start-after: example-begin inclusive-scan-array-aggregate-init-value
//! :end-before: example-end inclusive-scan-array-aggregate-init-value
//!
//! The value ``126`` will be stored in ``block_aggregate`` for all threads.
//!
//! @endrst
//!
//! @tparam ITEMS_PER_THREAD
//! **[inferred]** The number of consecutive items partitioned onto each thread.
//!
//! @tparam ScanOp
//! **[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)`
//!
//! @param[in] input
//! Calling thread's input items
//!
//! @param[out] output
//! Calling thread's output items (may be aliased to `input`)
//!
//! @param[in] initial_value
//! Initial value to seed the inclusive scan (uniform across block). It is not taken
//! into account for block_aggregate.
//!
//! @param[in] scan_op
//! Binary scan functor
//!
//! @param[out] block_aggregate
//! Block-wide aggregate reduction of input items
template <int ITEMS_PER_THREAD, typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void InclusiveScan(
T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op, T& block_aggregate)
gonidelis marked this conversation as resolved.
Show resolved Hide resolved
{
// Reduce consecutive thread items in registers
T thread_prefix = internal::ThreadReduce(input, scan_op);

// Exclusive thread block-scan
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op, block_aggregate);

// Exclusive scan in registers with prefix as seed
internal::ThreadScanInclusive(input, output, scan_op, thread_prefix);
}

//! @rst
//! Computes an inclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes an array of consecutive input elements.
Expand Down
113 changes: 112 additions & 1 deletion cub/cub/warp/warp_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,65 @@ public:
//! @param[out] inclusive_output
//! Calling thread's output item. May be aliased with `input`
//!
//! @param[in] can_op
//! @param[in] scan_op
//! Binary scan operator
template <typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void InclusiveScan(T input, T& inclusive_output, ScanOp scan_op)
{
InternalWarpScan(temp_storage).InclusiveScan(input, inclusive_output, scan_op);
}

//! @rst
//! Computes an inclusive prefix scan using the specified binary scan functor across the
//! calling warp.
//!
//! * @smemwarpreuse
//!
//! Snippet
//! +++++++
//!
//! The code snippet below illustrates four concurrent warp-wide inclusive prefix sum scans
//! within a block of 128 threads (one per each of the 32-thread warps).
//!
//! .. literalinclude:: ../../../cub/test/catch2_test_warp_scan_api.cu
//! :language: c++
//! :dedent:
//! :start-after: example-begin inclusive-warp-scan-init-value
//! :end-before: example-end inclusive-warp-scan-init-value
//!
//! Suppose the set of input ``thread_data`` in the first warp is
//! ``{0, 1, 2, 3, ..., 31}``, in the second warp is ``{1, 2, 3, 4, ..., 32}`` etc.
//! The corresponding output ``thread_data`` for a max operation in the first
//! warp would be ``{3, 3, 3, 3, ..., 31}``, the output for the second warp would be
//! ``{3, 3, 3, 4, ..., 32}``, etc.
//! @endrst
//!
//! @tparam ScanOp
//! **[inferred]** Binary scan operator type having member
//! `T operator()(const T &a, const T &b)`
//!
//! @param[in] input
//! Calling thread's input item
//!
//! @param[out] inclusive_output
//! Calling thread's output item. May be aliased with `input`
//!
//! @param[in] initial_value
//! Initial value to seed the inclusive scan (uniform across warp)
//!
//! @param[in] scan_op
//! Binary scan operator
template <typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void InclusiveScan(T input, T& inclusive_output, T initial_value, ScanOp scan_op)
{
InternalWarpScan internal(temp_storage);

T exclusive_output;
internal.InclusiveScan(input, inclusive_output, scan_op);

internal.Update(input, inclusive_output, exclusive_output, scan_op, initial_value, Int2Type<IS_INTEGER>());
}

//! @rst
//! Computes an inclusive prefix scan using the specified binary scan functor across the
//! calling warp. Also provides every thread with the warp-wide ``warp_aggregate`` of
Expand Down Expand Up @@ -544,6 +595,66 @@ public:
InternalWarpScan(temp_storage).InclusiveScan(input, inclusive_output, scan_op, warp_aggregate);
}

//! @rst
//! Computes an inclusive prefix scan using the specified binary scan functor across the
//! calling warp. Also provides every thread with the warp-wide ``warp_aggregate`` of
//! all inputs.
//!
//! * @smemwarpreuse
//!
//! Snippet
//! +++++++
//!
//! The code snippet below illustrates four concurrent warp-wide inclusive prefix max scans
//! within a block of 128 threads (one scan per warp).
//!
//! .. literalinclude:: ../../../cub/test/catch2_test_warp_scan_api.cu
//! :language: c++
//! :dedent:
//! :start-after: example-begin inclusive-warp-scan-init-value-aggregate
//! :end-before: example-end inclusive-warp-scan-init-value-aggregate
//!
//! Suppose the set of input ``thread_data`` across the block of threads is
//! ``{1, 1, 1, 1, ..., 1}``. For initial value equal to 3, the corresponding output
//! ``thread_data`` for a sum operation in the first warp would be
//! ``{4, 5, 6, 7, ..., 35}``, the output for the second warp would be
//! ``{4, 5, 6, 7, ..., 35}``, etc. Furthermore, ``warp_aggregate`` would be assigned
//! ``32`` for threads in each warp.
//! @endrst
//!
//! @tparam ScanOp
//! **[inferred]** Binary scan operator type having member
//! `T operator()(const T &a, const T &b)`
//! @param[in] input
//! Calling thread's input item
//!
//! @param[out] inclusive_output
//! Calling thread's output item. May be aliased with ``input``
//!
//! @param[in] initial_value
//! Initial value to seed the inclusive scan (uniform across warp). It is not taken
//! into account for warp_aggregate.
//!
//! @param[in] scan_op
//! Binary scan operator
//!
//! @param[out] warp_aggregate
//! Warp-wide aggregate reduction of input items.
template <typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
InclusiveScan(T input, T& inclusive_output, T initial_value, ScanOp scan_op, T& warp_aggregate)
{
InternalWarpScan internal(temp_storage);

// Perform the inclusive scan operation
internal.InclusiveScan(input, inclusive_output, scan_op);

// Update the inclusive_output and warp_aggregate using the Update function
T exclusive_output;
internal.Update(
input, inclusive_output, exclusive_output, warp_aggregate, scan_op, initial_value, Int2Type<IS_INTEGER>());
}

//! @} end member group
//! @name Exclusive prefix scans
//! @{
Expand Down
Loading
Loading