diff --git a/thrust/thrust/detail/scan.inl b/thrust/thrust/detail/scan.inl index 93d0c0f5b9..976d5baa81 100644 --- a/thrust/thrust/detail/scan.inl +++ b/thrust/thrust/detail/scan.inl @@ -18,6 +18,9 @@ #include +#include "cuda/std/__functional/invoke.h" +#include "cuda/std/__iterator/iterator_traits.h" + #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) @@ -33,23 +36,6 @@ #include #include -template -struct is_callable_with_input -{ -private: - using value_type = typename std::iterator_traits::value_type; - - template - static auto test(int) - -> decltype(std::declval()(std::declval(), std::declval()), std::true_type()); - - template - static auto test(...) -> std::false_type; - -public: - static constexpr bool value = decltype(test(0))::value; -}; - THRUST_NAMESPACE_BEGIN _CCCL_EXEC_CHECK_DISABLE @@ -66,13 +52,16 @@ _CCCL_HOST_DEVICE OutputIterator inclusive_scan( _CCCL_EXEC_CHECK_DISABLE template -_CCCL_HOST_DEVICE - typename std::enable_if::value, OutputIterator>::type - inclusive_scan(const thrust::detail::execution_policy_base& exec, - InputIterator first, - InputIterator last, - OutputIterator result, - AssociativeOperator binary_op) +_CCCL_HOST_DEVICE typename std::enable_if< + ::cuda::std::__invokable::value_type, + typename ::cuda::std::iterator_traits::value_type>::value, + OutputIterator>::type +inclusive_scan(const thrust::detail::execution_policy_base& exec, + InputIterator first, + InputIterator last, + OutputIterator result, + AssociativeOperator binary_op) { using thrust::system::detail::generic::inclusive_scan; return inclusive_scan(thrust::detail::derived_cast(thrust::detail::strip_const(exec)), first, last, result, binary_op); @@ -80,7 +69,11 @@ _CCCL_HOST_DEVICE _CCCL_EXEC_CHECK_DISABLE template -_CCCL_HOST_DEVICE typename std::enable_if::value, OutputIterator>::type +_CCCL_HOST_DEVICE typename std::enable_if< + !::cuda::std::__invokable::value_type, + typename ::cuda::std::iterator_traits::value_type>::value, + OutputIterator>::type inclusive_scan(const thrust::detail::execution_policy_base& exec, InputIterator first, InputIterator last, diff --git a/thrust/thrust/scan.h b/thrust/thrust/scan.h index 272c15135e..445f101d96 100644 --- a/thrust/thrust/scan.h +++ b/thrust/thrust/scan.h @@ -31,6 +31,9 @@ #endif // no system header #include +#include "cuda/std/__functional/invoke.h" +#include "cuda/std/__iterator/iterator_traits.h" + THRUST_NAMESPACE_BEGIN /*! \addtogroup algorithms @@ -205,34 +208,24 @@ OutputIterator inclusive_scan(InputIterator first, InputIterator last, OutputIte * \see https://en.cppreference.com/w/cpp/algorithm/partial_sum */ -template -struct is_callable_with_input -{ -private: - using value_type = typename std::iterator_traits::value_type; - - template - static auto test(int) - -> decltype(std::declval()(std::declval(), std::declval()), std::true_type()); - - template - static auto test(...) -> std::false_type; - -public: - static constexpr bool value = decltype(test(0))::value; -}; - template -_CCCL_HOST_DEVICE - typename std::enable_if::value, OutputIterator>::type - inclusive_scan(const thrust::detail::execution_policy_base& exec, - InputIterator first, - InputIterator last, - OutputIterator result, - AssociativeOperator binary_op); +_CCCL_HOST_DEVICE typename std::enable_if< + ::cuda::std::__invokable::value_type, + typename ::cuda::std::iterator_traits::value_type>::value, + OutputIterator>::type +inclusive_scan(const thrust::detail::execution_policy_base& exec, + InputIterator first, + InputIterator last, + OutputIterator result, + AssociativeOperator binary_op); template -_CCCL_HOST_DEVICE typename std::enable_if::value, OutputIterator>::type +_CCCL_HOST_DEVICE typename std::enable_if< + !::cuda::std::__invokable::value_type, + typename ::cuda::std::iterator_traits::value_type>::value, + OutputIterator>::type inclusive_scan(const thrust::detail::execution_policy_base& exec, InputIterator first, InputIterator last, diff --git a/thrust/thrust/system/cuda/detail/async/inclusive_scan.h b/thrust/thrust/system/cuda/detail/async/inclusive_scan.h index 46abdd38f9..e7385959a9 100644 --- a/thrust/thrust/system/cuda/detail/async/inclusive_scan.h +++ b/thrust/thrust/system/cuda/detail/async/inclusive_scan.h @@ -144,7 +144,7 @@ unique_eager_event async_inclusive_scan_n( OutputIt, BinaryOp, InputValueT, - thrust::detail::int32_t, + std::int32_t, InitialValueType, cub::DeviceScanPolicy, ForceInclusive>; @@ -153,7 +153,7 @@ unique_eager_event async_inclusive_scan_n( OutputIt, BinaryOp, InputValueT, - thrust::detail::int64_t, + std::int64_t, InitialValueType, cub::DeviceScanPolicy, ForceInclusive>; @@ -180,7 +180,7 @@ unique_eager_event async_inclusive_scan_n( } // Allocate temporary storage. - auto content = uninitialized_allocate_unique_n(device_alloc, tmp_size); + auto content = uninitialized_allocate_unique_n(device_alloc, tmp_size); void* const tmp_ptr = raw_pointer_cast(content.get()); // Set up stream with dependencies. diff --git a/thrust/thrust/system/cuda/detail/scan.h b/thrust/thrust/system/cuda/detail/scan.h index cdbe6abb40..61a0ce0eaf 100644 --- a/thrust/thrust/system/cuda/detail/scan.h +++ b/thrust/thrust/system/cuda/detail/scan.h @@ -122,7 +122,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl( OutputIt, ScanOp, InputValueT, - thrust::detail::int32_t, + std::int32_t, InitValueT, cub::DeviceScanPolicy, ForceInclusive>; @@ -131,7 +131,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl( OutputIt, ScanOp, InputValueT, - thrust::detail::int64_t, + std::int64_t, InitValueT, cub::DeviceScanPolicy, ForceInclusive>; @@ -157,7 +157,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl( // Run scan: { // Allocate temporary storage: - thrust::detail::temporary_array tmp{policy, tmp_size}; + thrust::detail::temporary_array tmp{policy, tmp_size}; THRUST_INDEX_TYPE_DISPATCH2( status, Dispatch32::Dispatch, diff --git a/thrust/thrust/system/tbb/detail/scan.h b/thrust/thrust/system/tbb/detail/scan.h index b0f005cb38..18a3ec3ee2 100644 --- a/thrust/thrust/system/tbb/detail/scan.h +++ b/thrust/thrust/system/tbb/detail/scan.h @@ -43,6 +43,10 @@ template +OutputIterator +inclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator result, T init, BinaryFunction binary_op); + template OutputIterator exclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator result, T init, BinaryFunction binary_op); diff --git a/thrust/thrust/system/tbb/detail/scan.inl b/thrust/thrust/system/tbb/detail/scan.inl index b2db92bf57..25ef1112c5 100644 --- a/thrust/thrust/system/tbb/detail/scan.inl +++ b/thrust/thrust/system/tbb/detail/scan.inl @@ -47,6 +47,100 @@ namespace detail namespace scan_detail { +template +struct inclusive_body_init +{ + InputIterator input; + OutputIterator output; + thrust::detail::wrapped_function binary_op; + ValueType sum; + bool first_call; + + inclusive_body_init(InputIterator input, OutputIterator output, BinaryFunction binary_op, ValueType init) + : input(input) + , output(output) + , binary_op(binary_op) + , sum(init) + , first_call(true) + {} + + inclusive_body_init(inclusive_body_init& b, ::tbb::split) + : input(b.input) + , output(b.output) + , binary_op(b.binary_op) + , sum(b.sum) + , first_call(true) + {} + + template + void operator()(const ::tbb::blocked_range& r, ::tbb::pre_scan_tag) + { + InputIterator iter = input + r.begin(); + + ValueType temp = *iter; + + ++iter; + + for (Size i = r.begin() + 1; i != r.end(); ++i, ++iter) + { + temp = binary_op(temp, *iter); + } + + if (first_call) + { + sum = temp; + } + else + { + sum = binary_op(sum, temp); + } + + first_call = false; + } + + template + void operator()(const ::tbb::blocked_range& r, ::tbb::final_scan_tag) + { + InputIterator iter1 = input + r.begin(); + OutputIterator iter2 = output + r.begin(); + + if (first_call) + { + *iter2 = sum = binary_op(*iter1, sum); + ++iter1; + ++iter2; + for (Size i = r.begin() + 1; i != r.end(); ++i, ++iter1, ++iter2) + { + *iter2 = sum = binary_op(sum, *iter1); + } + } + else + { + for (Size i = r.begin(); i != r.end(); ++i, ++iter1, ++iter2) + { + *iter2 = sum = binary_op(sum, *iter1); + } + } + + first_call = false; + } + + void reverse_join(inclusive_body_init& b) + { + // Only accumulate this functor's partial sum if this functor has been + // called at least once -- otherwise we'll over-count the initial value. + if (!first_call) + { + sum = binary_op(b.sum, sum); + } + } + + void assign(inclusive_body_init& b) + { + sum = b.sum; + } +}; + template struct inclusive_body { @@ -56,11 +150,11 @@ struct inclusive_body ValueType sum; bool first_call; - inclusive_body(InputIterator input, OutputIterator output, BinaryFunction binary_op, ValueType dummy) + inclusive_body(InputIterator input, OutputIterator output, BinaryFunction binary_op, ValueType init) : input(input) , output(output) , binary_op(binary_op) - , sum(dummy) + , sum(init) , first_call(true) {} @@ -250,6 +344,30 @@ inclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator resu return result; } +template +OutputIterator inclusive_scan( + tag, InputIterator first, InputIterator last, OutputIterator result, InitialValueType init, BinaryFunction binary_op) +{ + using namespace thrust::detail; + + // Use the input iterator's value type per https://wg21.link/P0571 + using ValueType = InitialValueType; + + using Size = typename thrust::iterator_difference::type; + Size n = thrust::distance(first, last); + + if (n != 0) + { + typedef typename scan_detail::inclusive_body_init Body; + Body scan_body(first, result, binary_op, init); + ::tbb::parallel_scan(::tbb::blocked_range(0, n), scan_body); + } + + thrust::advance(result, n); + + return result; +} + template OutputIterator exclusive_scan( tag, InputIterator first, InputIterator last, OutputIterator result, InitialValueType init, BinaryFunction binary_op)