diff --git a/cub/cub/thread/thread_reduce.cuh b/cub/cub/thread/thread_reduce.cuh index 10fd51b814d..19d7aad86df 100644 --- a/cub/cub/thread/thread_reduce.cuh +++ b/cub/cub/thread/thread_reduce.cuh @@ -155,6 +155,27 @@ template +_CCCL_DEVICE _CCCL_FORCEINLINE auto thread_reduce_apply(const ReductionOp& reduction_op, L&& lhs, R&& rhs) +{ + if constexpr (::cuda::std::is_invocable_v) + { + return reduction_op(static_cast(lhs), static_cast(rhs)); + } + else if constexpr (::cuda::std::is_invocable_v) + { + return reduction_op(static_cast(lhs), static_cast(rhs)); + } + else if constexpr (::cuda::std::is_invocable_v) + { + return reduction_op(static_cast(lhs), static_cast(rhs)); + } + else + { + return reduction_op(static_cast(lhs), static_cast(rhs)); + } +} /*********************************************************************************************************************** * Enable SIMD/Tree reduction heuristics (Trait) **********************************************************************************************************************/ @@ -274,11 +295,12 @@ inline constexpr bool enable_ternary_reduction_sm50_v = template [[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduceSequential(const Input& input, ReductionOp reduction_op) { - auto retval = static_cast(input[0]); + auto retval = static_cast(input[0]); + using value_type = ::cuda::std::iter_value_t; _CCCL_PRAGMA_UNROLL_FULL() for (int i = 1; i < static_size_v; ++i) { - retval = reduction_op(retval, input[i]); + retval = thread_reduce_apply(reduction_op, retval, input[i]); } return retval; } @@ -288,13 +310,14 @@ template { constexpr auto length = static_size_v; auto array = cub::detail::to_array(input); + using value_type = ::cuda::std::iter_value_t; _CCCL_PRAGMA_UNROLL_FULL() for (int i = 1; i < length; i *= 2) { _CCCL_PRAGMA_UNROLL_FULL() for (int j = 0; j + i < length; j += i * 2) { - array[j] = reduction_op(array[j], array[j + i]); + array[j] = thread_reduce_apply(reduction_op, array[j], array[j + i]); } } return array[0]; @@ -305,14 +328,16 @@ template { constexpr auto length = static_size_v; auto array = cub::detail::to_array(input); + using value_type = ::cuda::std::iter_value_t; _CCCL_PRAGMA_UNROLL_FULL() for (int i = 1; i < length; i *= 3) { _CCCL_PRAGMA_UNROLL_FULL() for (int j = 0; j + i < length; j += i * 3) { - auto value = reduction_op(array[j], array[j + i]); - array[j] = (j + i * 2 < length) ? reduction_op(value, array[j + i * 2]) : value; + auto value = thread_reduce_apply(reduction_op, array[j], array[j + i]); + array[j] = + (j + i * 2 < length) ? thread_reduce_apply(reduction_op, value, array[j + i * 2]) : value; } } return array[0]; @@ -322,13 +347,14 @@ template [[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduceSequentialPartial(const Input& input, ReductionOp reduction_op, int valid_items) { - auto retval = static_cast(input[0]); + auto retval = static_cast(input[0]); + using value_type = ::cuda::std::iter_value_t; _CCCL_PRAGMA_UNROLL_FULL() for (int i = 1; i < static_size_v; ++i) { if (i < valid_items) { - retval = reduction_op(retval, input[i]); + retval = thread_reduce_apply(reduction_op, retval, input[i]); } } return retval; diff --git a/cub/test/catch2_test_block_scan.cu b/cub/test/catch2_test_block_scan.cu index 35d648cc454..65e17fe9101 100644 --- a/cub/test/catch2_test_block_scan.cu +++ b/cub/test/catch2_test_block_scan.cu @@ -308,6 +308,25 @@ struct min_prefix_op_t } }; +struct wider_returning_op +{ + __host__ __device__ long long operator()(int lhs, int rhs) const + { + return static_cast(lhs + rhs); + } + + __host__ __device__ long long operator()(long long, int) const = delete; +}; + +struct exclusive_wider_scan_op_t +{ + template + __device__ void operator()(BlockScanT& scan, int (&thread_data)[ItemsPerThread]) const + { + scan.ExclusiveScan(thread_data, thread_data, 0, wider_returning_op{}); + } +}; + template T host_scan(scan_mode mode, c2h::host_vector& result, ScanOpT scan_op, T initial_value = T{}) { @@ -349,6 +368,35 @@ T host_scan(scan_mode mode, c2h::host_vector& result, ScanOpT scan_op, T init // %PARAM% ALGO_TYPE alg 0:1:2 // %PARAM% TEST_MODE mode 0:1 +C2H_TEST("Block scan handles operators returning wider type", "[scan][block]") +{ + using type = int; + constexpr int items_per_thread = 3; + constexpr int block_dim_x = 64; + constexpr int block_dim_y = 1; + constexpr int block_dim_z = 1; + constexpr cub::BlockScanAlgorithm algorithm = cub::BlockScanAlgorithm::BLOCK_SCAN_RAKING; + constexpr int tile_size = block_dim_x * block_dim_y * block_dim_z * items_per_thread; + + c2h::host_vector h_in(tile_size); + for (int i = 0; i < tile_size; ++i) + { + h_in[i] = static_cast((i % 7) - 3); + } + + c2h::device_vector d_in = h_in; + c2h::device_vector d_out(tile_size); + + c2h::host_vector h_expected = h_in; + host_scan(scan_mode::exclusive, h_expected, wider_returning_op{}, type{0}); + + block_scan( + d_in, d_out, exclusive_wider_scan_op_t{}); + + c2h::host_vector h_out = d_out; + REQUIRE(h_out == h_expected); +} + using types = c2h::type_list; // FIXME(bgruber): uchar3 fails the test, see #3835 using vec_types = c2h::type_list<