Skip to content

Commit bc27d15

Browse files
committed
Fix BlockScan accumulator type handling
1 parent ba21911 commit bc27d15

File tree

4 files changed

+77
-17
lines changed

4 files changed

+77
-17
lines changed

cub/cub/block/block_reduce.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ public:
443443
_CCCL_DEVICE _CCCL_FORCEINLINE T Reduce(T (&inputs)[ITEMS_PER_THREAD], ReductionOp reduction_op)
444444
{
445445
// Reduce partials
446-
T partial = cub::ThreadReduce(inputs, reduction_op);
446+
T partial =
447+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(inputs)>, ReductionOp, T, T>(inputs, reduction_op);
447448
return Reduce(partial, reduction_op);
448449
}
449450

@@ -601,7 +602,8 @@ public:
601602
_CCCL_DEVICE _CCCL_FORCEINLINE T Sum(T (&inputs)[ITEMS_PER_THREAD])
602603
{
603604
// Reduce partials
604-
T partial = cub::ThreadReduce(inputs, ::cuda::std::plus<>{});
605+
T partial = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(inputs)>, ::cuda::std::plus<>, T, T>(
606+
inputs, ::cuda::std::plus<>{});
605607
return Sum(partial);
606608
}
607609

cub/cub/block/block_scan.cuh

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ public:
950950
ExclusiveScan(T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op)
951951
{
952952
// Reduce consecutive thread items in registers
953-
T thread_prefix = cub::ThreadReduce(input, scan_op);
953+
T thread_prefix = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
954954

955955
// Exclusive thread block-scan
956956
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op);
@@ -1037,7 +1037,7 @@ public:
10371037
T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op, T& block_aggregate)
10381038
{
10391039
// Reduce consecutive thread items in registers
1040-
T thread_prefix = cub::ThreadReduce(input, scan_op);
1040+
T thread_prefix = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
10411041

10421042
// Exclusive thread block-scan
10431043
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op, block_aggregate);
@@ -1121,7 +1121,7 @@ public:
11211121
BlockPrefixCallbackOp& block_prefix_callback_op)
11221122
{
11231123
// Reduce consecutive thread items in registers
1124-
T thread_prefix = cub::ThreadReduce(input, scan_op);
1124+
T thread_prefix = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
11251125

11261126
// Exclusive thread block-scan
11271127
ExclusiveScan(thread_prefix, thread_prefix, scan_op, block_prefix_callback_op);
@@ -1231,7 +1231,8 @@ public:
12311231
ExclusiveScan(T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], ScanOp scan_op)
12321232
{
12331233
// Reduce consecutive thread items in registers
1234-
T thread_partial = cub::ThreadReduce(input, scan_op);
1234+
T thread_partial =
1235+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
12351236

12361237
// Exclusive thread block-scan
12371238
ExclusiveScan(thread_partial, thread_partial, scan_op);
@@ -1275,7 +1276,8 @@ public:
12751276
ExclusiveScan(T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], ScanOp scan_op, T& block_aggregate)
12761277
{
12771278
// Reduce consecutive thread items in registers
1278-
T thread_partial = cub::ThreadReduce(input, scan_op);
1279+
T thread_partial =
1280+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
12791281

12801282
// Exclusive thread block-scan
12811283
ExclusiveScan(thread_partial, thread_partial, scan_op, block_aggregate);
@@ -1524,7 +1526,8 @@ public:
15241526
{
15251527
// Reduce consecutive thread items in registers
15261528
::cuda::std::plus<> scan_op;
1527-
T thread_prefix = cub::ThreadReduce(input, scan_op);
1529+
T thread_prefix =
1530+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, decltype(scan_op), T, T>(input, scan_op);
15281531

15291532
// Exclusive thread block-scan
15301533
ExclusiveSum(thread_prefix, thread_prefix);
@@ -1601,7 +1604,8 @@ public:
16011604
{
16021605
// Reduce consecutive thread items in registers
16031606
::cuda::std::plus<> scan_op;
1604-
T thread_prefix = cub::ThreadReduce(input, scan_op);
1607+
T thread_prefix =
1608+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, decltype(scan_op), T, T>(input, scan_op);
16051609

16061610
// Exclusive thread block-scan
16071611
ExclusiveSum(thread_prefix, thread_prefix, block_aggregate);
@@ -1682,7 +1686,8 @@ public:
16821686
{
16831687
// Reduce consecutive thread items in registers
16841688
::cuda::std::plus<> scan_op;
1685-
T thread_prefix = cub::ThreadReduce(input, scan_op);
1689+
T thread_prefix =
1690+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, decltype(scan_op), T, T>(input, scan_op);
16861691

16871692
// Exclusive thread block-scan
16881693
ExclusiveSum(thread_prefix, thread_prefix, block_prefix_callback_op);
@@ -1954,7 +1959,8 @@ public:
19541959
else
19551960
{
19561961
// Reduce consecutive thread items in registers
1957-
T thread_prefix = cub::ThreadReduce(input, scan_op);
1962+
T thread_prefix =
1963+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, decltype(scan_op), T, T>(input, scan_op);
19581964

19591965
// Exclusive thread block-scan
19601966
ExclusiveScan(thread_prefix, thread_prefix, scan_op);
@@ -2011,7 +2017,7 @@ public:
20112017
InclusiveScan(T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op)
20122018
{
20132019
// Reduce consecutive thread items in registers
2014-
T thread_prefix = cub::ThreadReduce(input, scan_op);
2020+
T thread_prefix = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
20152021

20162022
// Exclusive thread block-scan
20172023
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op);
@@ -2093,7 +2099,8 @@ public:
20932099
else
20942100
{
20952101
// Reduce consecutive thread items in registers
2096-
T thread_prefix = cub::ThreadReduce(input, scan_op);
2102+
T thread_prefix =
2103+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
20972104

20982105
// Exclusive thread block-scan (with no initial value)
20992106
ExclusiveScan(thread_prefix, thread_prefix, scan_op, block_aggregate);
@@ -2160,7 +2167,7 @@ public:
21602167
T (&input)[ITEMS_PER_THREAD], T (&output)[ITEMS_PER_THREAD], T initial_value, ScanOp scan_op, T& block_aggregate)
21612168
{
21622169
// Reduce consecutive thread items in registers
2163-
T thread_prefix = cub::ThreadReduce(input, scan_op);
2170+
T thread_prefix = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
21642171

21652172
// Exclusive thread block-scan
21662173
ExclusiveScan(thread_prefix, thread_prefix, initial_value, scan_op, block_aggregate);
@@ -2295,7 +2302,8 @@ public:
22952302
else
22962303
{
22972304
// Reduce consecutive thread items in registers
2298-
T thread_prefix = cub::ThreadReduce(input, scan_op);
2305+
T thread_prefix =
2306+
cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(input)>, ScanOp, T, T>(input, scan_op);
22992307

23002308
// Exclusive thread block-scan
23012309
ExclusiveScan(thread_prefix, thread_prefix, scan_op, block_prefix_callback_op);

cub/cub/block/specializations/block_reduce_raking_commutative_only.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ struct BlockReduceRakingCommutativeOnly
175175
// Raking reduction in grid
176176
T* raking_segment = BlockRakingLayout::RakingPtr(temp_storage.default_storage.raking_grid, linear_tid);
177177
auto span = ::cuda::std::span<T, SEGMENT_LENGTH>(raking_segment, SEGMENT_LENGTH);
178-
partial = cub::ThreadReduce(span, ::cuda::std::plus<>{}, partial);
178+
partial = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(span)>, ::cuda::std::plus<>, T, T>(
179+
span, ::cuda::std::plus<>{}, partial);
179180

180181
// Warp reduction
181182
partial = WarpReduce(temp_storage.default_storage.warp_storage).Sum(partial);
@@ -223,7 +224,8 @@ struct BlockReduceRakingCommutativeOnly
223224
// Raking reduction in grid
224225
T* raking_segment = BlockRakingLayout::RakingPtr(temp_storage.default_storage.raking_grid, linear_tid);
225226
auto span = ::cuda::std::span<T, SEGMENT_LENGTH>(raking_segment, SEGMENT_LENGTH);
226-
partial = cub::ThreadReduce(span, reduction_op, partial);
227+
partial = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(span)>, ReductionOp, T, T>(
228+
span, reduction_op, partial);
227229

228230
// Warp reduction
229231
partial = WarpReduce(temp_storage.default_storage.warp_storage).Reduce(partial, reduction_op);

cub/test/catch2_test_block_scan.cu

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,25 @@ struct min_prefix_op_t
332332
}
333333
};
334334

335+
struct wider_returning_op
336+
{
337+
__host__ __device__ long long operator()(int lhs, int rhs) const
338+
{
339+
return static_cast<long long>(lhs + rhs);
340+
}
341+
342+
__host__ __device__ long long operator()(long long, int) const = delete;
343+
};
344+
345+
struct exclusive_wider_scan_op_t
346+
{
347+
template <int ItemsPerThread, class BlockScanT>
348+
__device__ void operator()(BlockScanT& scan, int (&thread_data)[ItemsPerThread]) const
349+
{
350+
scan.ExclusiveScan(thread_data, thread_data, 0, wider_returning_op{});
351+
}
352+
};
353+
335354
template <class T, class ScanOpT>
336355
T host_scan(scan_mode mode, c2h::host_vector<T>& result, ScanOpT scan_op, T initial_value = T{})
337356
{
@@ -373,6 +392,35 @@ T host_scan(scan_mode mode, c2h::host_vector<T>& result, ScanOpT scan_op, T init
373392
// %PARAM% ALGO_TYPE alg 0:1:2
374393
// %PARAM% TEST_MODE mode 0:1
375394

395+
C2H_TEST("Block scan handles operators returning wider type", "[scan][block]")
396+
{
397+
using type = int;
398+
constexpr int items_per_thread = 3;
399+
constexpr int block_dim_x = 64;
400+
constexpr int block_dim_y = 1;
401+
constexpr int block_dim_z = 1;
402+
constexpr cub::BlockScanAlgorithm algorithm = cub::BlockScanAlgorithm::BLOCK_SCAN_RAKING;
403+
constexpr int tile_size = block_dim_x * block_dim_y * block_dim_z * items_per_thread;
404+
405+
c2h::host_vector<type> h_in(tile_size);
406+
for (int i = 0; i < tile_size; ++i)
407+
{
408+
h_in[i] = static_cast<type>((i % 7) - 3);
409+
}
410+
411+
c2h::device_vector<type> d_in = h_in;
412+
c2h::device_vector<type> d_out(tile_size);
413+
414+
c2h::host_vector<type> h_expected = h_in;
415+
host_scan(scan_mode::exclusive, h_expected, wider_returning_op{}, type{0});
416+
417+
block_scan<algorithm, items_per_thread, block_dim_x, block_dim_y, block_dim_z, type>(
418+
d_in, d_out, exclusive_wider_scan_op_t{});
419+
420+
c2h::host_vector<type> h_out = d_out;
421+
REQUIRE(h_out == h_expected);
422+
}
423+
376424
using types = c2h::type_list<std::uint8_t, std::uint16_t, std::int32_t, std::int64_t>;
377425
// FIXME(bgruber): uchar3 fails the test, see #3835
378426
using vec_types = c2h::type_list<

0 commit comments

Comments
 (0)