Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions cub/cub/thread/thread_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,27 @@ template <typename Input,
/// Internal namespace (to prevent ADL mishaps between static functions when mixing different CUB installations)
namespace detail
{

template <typename PreferredT, typename ValueT, typename ReductionOp, typename L, typename R>
_CCCL_DEVICE _CCCL_FORCEINLINE auto thread_reduce_apply(const ReductionOp& reduction_op, L&& lhs, R&& rhs)
{
if constexpr (::cuda::std::is_invocable_v<const ReductionOp&, PreferredT, PreferredT>)
{
return reduction_op(static_cast<PreferredT>(lhs), static_cast<PreferredT>(rhs));
}
else if constexpr (::cuda::std::is_invocable_v<const ReductionOp&, PreferredT, ValueT>)
{
return reduction_op(static_cast<PreferredT>(lhs), static_cast<ValueT>(rhs));
}
else if constexpr (::cuda::std::is_invocable_v<const ReductionOp&, ValueT, PreferredT>)
{
return reduction_op(static_cast<ValueT>(lhs), static_cast<PreferredT>(rhs));
}
else
{
return reduction_op(static_cast<ValueT>(lhs), static_cast<ValueT>(rhs));
}
}
/***********************************************************************************************************************
* Enable SIMD/Tree reduction heuristics (Trait)
**********************************************************************************************************************/
Expand Down Expand Up @@ -274,11 +295,12 @@ inline constexpr bool enable_ternary_reduction_sm50_v =
template <typename AccumT, typename Input, typename ReductionOp>
[[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduceSequential(const Input& input, ReductionOp reduction_op)
{
auto retval = static_cast<AccumT>(input[0]);
auto retval = static_cast<AccumT>(input[0]);
using value_type = ::cuda::std::iter_value_t<Input>;
_CCCL_PRAGMA_UNROLL_FULL()
for (int i = 1; i < static_size_v<Input>; ++i)
{
retval = reduction_op(retval, input[i]);
retval = thread_reduce_apply<AccumT, value_type>(reduction_op, retval, input[i]);
}
return retval;
}
Expand All @@ -288,13 +310,14 @@ template <typename AccumT, typename Input, typename ReductionOp>
{
constexpr auto length = static_size_v<Input>;
auto array = cub::detail::to_array<AccumT>(input);
using value_type = ::cuda::std::iter_value_t<Input>;
_CCCL_PRAGMA_UNROLL_FULL()
for (int i = 1; i < length; i *= 2)
{
_CCCL_PRAGMA_UNROLL_FULL()
for (int j = 0; j + i < length; j += i * 2)
{
array[j] = reduction_op(array[j], array[j + i]);
array[j] = thread_reduce_apply<AccumT, value_type>(reduction_op, array[j], array[j + i]);
}
}
return array[0];
Expand All @@ -305,14 +328,16 @@ template <typename AccumT, typename Input, typename ReductionOp>
{
constexpr auto length = static_size_v<Input>;
auto array = cub::detail::to_array<AccumT>(input);
using value_type = ::cuda::std::iter_value_t<Input>;
_CCCL_PRAGMA_UNROLL_FULL()
for (int i = 1; i < length; i *= 3)
{
_CCCL_PRAGMA_UNROLL_FULL()
for (int j = 0; j + i < length; j += i * 3)
{
auto value = reduction_op(array[j], array[j + i]);
array[j] = (j + i * 2 < length) ? reduction_op(value, array[j + i * 2]) : value;
auto value = thread_reduce_apply<AccumT, value_type>(reduction_op, array[j], array[j + i]);
array[j] =
(j + i * 2 < length) ? thread_reduce_apply<AccumT, value_type>(reduction_op, value, array[j + i * 2]) : value;
}
}
return array[0];
Expand All @@ -322,13 +347,14 @@ template <typename AccumT, typename Input, typename ReductionOp>
[[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE AccumT
ThreadReduceSequentialPartial(const Input& input, ReductionOp reduction_op, int valid_items)
{
auto retval = static_cast<AccumT>(input[0]);
auto retval = static_cast<AccumT>(input[0]);
using value_type = ::cuda::std::iter_value_t<Input>;
_CCCL_PRAGMA_UNROLL_FULL()
for (int i = 1; i < static_size_v<Input>; ++i)
{
if (i < valid_items)
{
retval = reduction_op(retval, input[i]);
retval = thread_reduce_apply<AccumT, value_type>(reduction_op, retval, input[i]);
}
}
return retval;
Expand Down
48 changes: 48 additions & 0 deletions cub/test/catch2_test_block_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,25 @@ struct min_prefix_op_t
}
};

struct wider_returning_op
{
__host__ __device__ long long operator()(int lhs, int rhs) const
{
return static_cast<long long>(lhs + rhs);
}

__host__ __device__ long long operator()(long long, int) const = delete;
};

struct exclusive_wider_scan_op_t
{
template <int ItemsPerThread, class BlockScanT>
__device__ void operator()(BlockScanT& scan, int (&thread_data)[ItemsPerThread]) const
{
scan.ExclusiveScan(thread_data, thread_data, 0, wider_returning_op{});
}
};

template <class T, class ScanOpT>
T host_scan(scan_mode mode, c2h::host_vector<T>& result, ScanOpT scan_op, T initial_value = T{})
{
Expand Down Expand Up @@ -349,6 +368,35 @@ T host_scan(scan_mode mode, c2h::host_vector<T>& result, ScanOpT scan_op, T init
// %PARAM% ALGO_TYPE alg 0:1:2
// %PARAM% TEST_MODE mode 0:1

C2H_TEST("Block scan handles operators returning wider type", "[scan][block]")
{
using type = int;
constexpr int items_per_thread = 3;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = 1;
constexpr int block_dim_z = 1;
constexpr cub::BlockScanAlgorithm algorithm = cub::BlockScanAlgorithm::BLOCK_SCAN_RAKING;
constexpr int tile_size = block_dim_x * block_dim_y * block_dim_z * items_per_thread;

c2h::host_vector<type> h_in(tile_size);
for (int i = 0; i < tile_size; ++i)
{
h_in[i] = static_cast<type>((i % 7) - 3);
}

c2h::device_vector<type> d_in = h_in;
c2h::device_vector<type> d_out(tile_size);

c2h::host_vector<type> h_expected = h_in;
host_scan(scan_mode::exclusive, h_expected, wider_returning_op{}, type{0});

block_scan<algorithm, items_per_thread, block_dim_x, block_dim_y, block_dim_z, type>(
d_in, d_out, exclusive_wider_scan_op_t{});

c2h::host_vector<type> h_out = d_out;
REQUIRE(h_out == h_expected);
}

using types = c2h::type_list<std::uint8_t, std::uint16_t, std::int32_t, std::int64_t>;
// FIXME(bgruber): uchar3 fails the test, see #3835
using vec_types = c2h::type_list<
Expand Down