Skip to content

Commit

Permalink
Add block inclusive_scan with init value (block_aggregate included)
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed May 20, 2024
1 parent 54cb32e commit f3142fa
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
41 changes: 41 additions & 0 deletions cub/cub/block/block_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,12 @@ public:
InternalBlockScan(temp_storage).ExclusiveScan(input, output, initial_value, scan_op);
}

template <typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void InclusiveScan(T input, T& output, T initial_value, ScanOp scan_op)
{
InternalBlockScan(temp_storage).InclusiveScan(input, output, initial_value, scan_op);
}

//! @rst
//! Computes an exclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes one input element.
Expand Down Expand Up @@ -886,6 +892,13 @@ public:
InternalBlockScan(temp_storage).ExclusiveScan(input, output, initial_value, scan_op, block_aggregate);
}

template <typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
InclusiveScan(T input, T& output, T initial_value, ScanOp scan_op, T& block_aggregate)
{
InternalBlockScan(temp_storage).InclusiveScan(input, output, initial_value, scan_op, block_aggregate);
}

//! @rst
//! Computes an exclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes one input element. The call-back functor ``block_prefix_callback_op`` is invoked by
Expand Down Expand Up @@ -1072,6 +1085,20 @@ public:
internal::ThreadScanExclusive(input, output, scan_op, thread_prefix);
}

template <int ITEMS_PER_THREAD, typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
InclusiveScan(T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op)
{
// Reduce consecutive thread items in registers
T thread_prefix = internal::ThreadReduce(input, scan_op);

// Exclusive thread block-scan
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op);

// Exclusive scan in registers with prefix as seed
internal::ThreadScanInclusive(input, output, scan_op, thread_prefix);
}

//! @rst
//! Computes an exclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes an array of consecutive input elements.
Expand Down Expand Up @@ -1153,6 +1180,20 @@ public:
internal::ThreadScanExclusive(input, output, scan_op, thread_prefix);
}

template <int ITEMS_PER_THREAD, typename ScanOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void InclusiveScan(
T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op, T& block_aggregate)
{
// Reduce consecutive thread items in registers
T thread_prefix = internal::ThreadReduce(input, scan_op);

// Exclusive thread block-scan
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op, block_aggregate);

// Exclusive scan in registers with prefix as seed
internal::ThreadScanInclusive(input, output, scan_op, thread_prefix);
}

//! @rst
//! Computes an exclusive block-wide prefix scan using the specified binary ``scan_op`` functor.
//! Each thread contributes an array of consecutive input elements.
Expand Down
132 changes: 132 additions & 0 deletions cub/test/catch2_test_block_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ struct sum_op_t
}
};

template <class T, scan_mode Mode>
struct min_init_value_op_t
{
T initial_value;
template <int ItemsPerThread, class BlockScanT>
__device__ void operator()(BlockScanT& scan, T (&thread_data)[ItemsPerThread]) const
{
_CCCL_IF_CONSTEXPR (Mode == scan_mode::exclusive)
{
scan.ExclusiveScan(thread_data, thread_data, initial_value, cub::Min{});
}
else
{
scan.InclusiveScan(thread_data, thread_data, initial_value, cub::Min{});
}
}
};

template <scan_mode Mode>
struct min_op_t
{
Expand All @@ -124,6 +142,36 @@ struct min_op_t
}
};

template <class T, scan_mode Mode>
struct min_init_value_aggregate_op_t
{
int m_target_thread_id;
T initial_value;
T* m_d_block_aggregate;

template <int ItemsPerThread, class BlockScanT>
__device__ void operator()(BlockScanT& scan, T (&thread_data)[ItemsPerThread]) const
{
T block_aggregate{};

_CCCL_IF_CONSTEXPR (Mode == scan_mode::exclusive)
{
scan.ExclusiveScan(thread_data, thread_data, initial_value, cub::Min{}, block_aggregate);
}
else
{
scan.InclusiveScan(thread_data, thread_data, initial_value, cub::Min{}, block_aggregate);
}

const int tid = cub::RowMajorTid(blockDim.x, blockDim.y, blockDim.z);

if (tid == m_target_thread_id)
{
*m_d_block_aggregate = block_aggregate;
}
}
};

template <class T, scan_mode Mode>
struct sum_aggregate_op_t
{
Expand Down Expand Up @@ -465,6 +513,90 @@ CUB_TEST("Block scan supports custom scan op", "[scan][block]", algorithm, modes
REQUIRE(h_out == d_out);
}

CUB_TEST("Block custom op scan works with initial value", "[scan][block]", algorithm, modes, block_dim_yz)
{
constexpr int items_per_thread = 3;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = c2h::get<2, TestType>::value;
constexpr int block_dim_z = block_dim_y;
constexpr int threads_in_block = block_dim_x * block_dim_y * block_dim_z;
constexpr int tile_size = items_per_thread * threads_in_block;
constexpr cub::BlockScanAlgorithm algorithm = c2h::get<0, TestType>::value;
constexpr scan_mode mode = c2h::get<1, TestType>::value;

using type = int;

c2h::device_vector<type> d_out(tile_size);
c2h::device_vector<type> d_in(tile_size);
c2h::gen(CUB_SEED(10), d_in);
d_in[0] = INT_MIN;

const type initial_value = static_cast<type>(GENERATE_COPY(take(2, random(0, tile_size))));

const int target_thread_id = GENERATE_COPY(take(2, random(0, threads_in_block - 1)));

block_scan<algorithm, items_per_thread, block_dim_x, block_dim_y, block_dim_z>(
d_in, d_out, min_init_value_op_t<type, mode>{initial_value});

c2h::host_vector<type> h_out = d_in;
host_scan(
mode,
h_out,
[](type l, type r) {
return std::min(l, r);
},
initial_value);

REQUIRE(h_out == d_out);
}

CUB_TEST("Block custom op scan with initial value returns valid block aggregate",
"[scan][block]",
algorithm,
modes,
block_dim_yz)
{
constexpr int items_per_thread = 3;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = c2h::get<2, TestType>::value;
constexpr int block_dim_z = block_dim_y;
constexpr int threads_in_block = block_dim_x * block_dim_y * block_dim_z;
constexpr int tile_size = items_per_thread * threads_in_block;
constexpr cub::BlockScanAlgorithm algorithm = c2h::get<0, TestType>::value;
constexpr scan_mode mode = c2h::get<1, TestType>::value;

using type = int;

c2h::device_vector<type> d_out(tile_size);
c2h::device_vector<type> d_in(tile_size);
c2h::gen(CUB_SEED(10), d_in);
d_in[0] = INT_MIN;

const type initial_value = static_cast<type>(GENERATE_COPY(take(2, random(0, tile_size))));

const int target_thread_id = GENERATE_COPY(take(2, random(0, threads_in_block - 1)));

c2h::device_vector<type> d_block_aggregate(1);

block_scan<algorithm, items_per_thread, block_dim_x, block_dim_y, block_dim_z>(
d_in,
d_out,
min_init_value_aggregate_op_t<type, mode>{
target_thread_id, initial_value, thrust::raw_pointer_cast(d_block_aggregate.data())});

c2h::host_vector<type> h_out = d_in;
type h_block_aggregate = host_scan(
mode,
h_out,
[](type l, type r) {
return std::min(l, r);
},
initial_value);

REQUIRE(h_out == d_out);
REQUIRE(h_block_aggregate == d_block_aggregate[0]);
}

CUB_TEST("Block scan supports prefix op and custom scan op", "[scan][block]", algorithm, modes, block_dim_yz)
{
constexpr int items_per_thread = 3;
Expand Down

0 comments on commit f3142fa

Please sign in to comment.