Skip to content

Commit

Permalink
Add thrust::inclusive_scan tbb with init value
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Jul 19, 2024
1 parent 8135f65 commit ad28931
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 58 deletions.
43 changes: 18 additions & 25 deletions thrust/thrust/detail/scan.inl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#include <thrust/detail/config.h>

#include "cuda/std/__functional/invoke.h"
#include "cuda/std/__iterator/iterator_traits.h"

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
Expand All @@ -33,23 +36,6 @@
#include <thrust/system/detail/generic/scan_by_key.h>
#include <thrust/system/detail/generic/select_system.h>

template <typename T, typename InputIterator>
struct is_callable_with_input
{
private:
using value_type = typename std::iterator_traits<InputIterator>::value_type;

template <typename U>
static auto test(int)
-> decltype(std::declval<U>()(std::declval<value_type>(), std::declval<value_type>()), std::true_type());

template <typename>
static auto test(...) -> std::false_type;

public:
static constexpr bool value = decltype(test<T>(0))::value;
};

THRUST_NAMESPACE_BEGIN

_CCCL_EXEC_CHECK_DISABLE
Expand All @@ -66,21 +52,28 @@ _CCCL_HOST_DEVICE OutputIterator inclusive_scan(

_CCCL_EXEC_CHECK_DISABLE
template <typename DerivedPolicy, typename InputIterator, typename OutputIterator, typename AssociativeOperator>
_CCCL_HOST_DEVICE
typename std::enable_if<is_callable_with_input<AssociativeOperator, InputIterator>::value, OutputIterator>::type
inclusive_scan(const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
InputIterator first,
InputIterator last,
OutputIterator result,
AssociativeOperator binary_op)
_CCCL_HOST_DEVICE typename std::enable_if<
::cuda::std::__invokable<AssociativeOperator,
typename ::cuda::std::iterator_traits<InputIterator>::value_type,
typename ::cuda::std::iterator_traits<InputIterator>::value_type>::value,
OutputIterator>::type
inclusive_scan(const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
InputIterator first,
InputIterator last,
OutputIterator result,
AssociativeOperator binary_op)
{
using thrust::system::detail::generic::inclusive_scan;
return inclusive_scan(thrust::detail::derived_cast(thrust::detail::strip_const(exec)), first, last, result, binary_op);
} // end inclusive_scan()

_CCCL_EXEC_CHECK_DISABLE
template <typename DerivedPolicy, typename InputIterator, typename OutputIterator, typename T>
_CCCL_HOST_DEVICE typename std::enable_if<!is_callable_with_input<T, InputIterator>::value, OutputIterator>::type
_CCCL_HOST_DEVICE typename std::enable_if<
!::cuda::std::__invokable<T,
typename ::cuda::std::iterator_traits<InputIterator>::value_type,
typename ::cuda::std::iterator_traits<InputIterator>::value_type>::value,
OutputIterator>::type
inclusive_scan(const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
InputIterator first,
InputIterator last,
Expand Down
43 changes: 18 additions & 25 deletions thrust/thrust/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#endif // no system header
#include <thrust/detail/execution_policy.h>

#include "cuda/std/__functional/invoke.h"
#include "cuda/std/__iterator/iterator_traits.h"

THRUST_NAMESPACE_BEGIN

/*! \addtogroup algorithms
Expand Down Expand Up @@ -205,34 +208,24 @@ OutputIterator inclusive_scan(InputIterator first, InputIterator last, OutputIte
* \see https://en.cppreference.com/w/cpp/algorithm/partial_sum
*/

template <typename T, typename InputIterator>
struct is_callable_with_input
{
private:
using value_type = typename std::iterator_traits<InputIterator>::value_type;

template <typename U>
static auto test(int)
-> decltype(std::declval<U>()(std::declval<value_type>(), std::declval<value_type>()), std::true_type());

template <typename>
static auto test(...) -> std::false_type;

public:
static constexpr bool value = decltype(test<T>(0))::value;
};

template <typename DerivedPolicy, typename InputIterator, typename OutputIterator, typename AssociativeOperator>
_CCCL_HOST_DEVICE
typename std::enable_if<is_callable_with_input<AssociativeOperator, InputIterator>::value, OutputIterator>::type
inclusive_scan(const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
InputIterator first,
InputIterator last,
OutputIterator result,
AssociativeOperator binary_op);
_CCCL_HOST_DEVICE typename std::enable_if<
::cuda::std::__invokable<AssociativeOperator,
typename ::cuda::std::iterator_traits<InputIterator>::value_type,
typename ::cuda::std::iterator_traits<InputIterator>::value_type>::value,
OutputIterator>::type
inclusive_scan(const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
InputIterator first,
InputIterator last,
OutputIterator result,
AssociativeOperator binary_op);

template <typename DerivedPolicy, typename InputIterator, typename OutputIterator, typename T>
_CCCL_HOST_DEVICE typename std::enable_if<!is_callable_with_input<T, InputIterator>::value, OutputIterator>::type
_CCCL_HOST_DEVICE typename std::enable_if<
!::cuda::std::__invokable<T,
typename ::cuda::std::iterator_traits<InputIterator>::value_type,
typename ::cuda::std::iterator_traits<InputIterator>::value_type>::value,
OutputIterator>::type
inclusive_scan(const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
InputIterator first,
InputIterator last,
Expand Down
6 changes: 3 additions & 3 deletions thrust/thrust/system/cuda/detail/async/inclusive_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ unique_eager_event async_inclusive_scan_n(
OutputIt,
BinaryOp,
InputValueT,
thrust::detail::int32_t,
std::int32_t,
InitialValueType,
cub::DeviceScanPolicy<AccumT, BinaryOp>,
ForceInclusive>;
Expand All @@ -153,7 +153,7 @@ unique_eager_event async_inclusive_scan_n(
OutputIt,
BinaryOp,
InputValueT,
thrust::detail::int64_t,
std::int64_t,
InitialValueType,
cub::DeviceScanPolicy<AccumT, BinaryOp>,
ForceInclusive>;
Expand All @@ -180,7 +180,7 @@ unique_eager_event async_inclusive_scan_n(
}

// Allocate temporary storage.
auto content = uninitialized_allocate_unique_n<thrust::detail::uint8_t>(device_alloc, tmp_size);
auto content = uninitialized_allocate_unique_n<std::uint8_t>(device_alloc, tmp_size);
void* const tmp_ptr = raw_pointer_cast(content.get());

// Set up stream with dependencies.
Expand Down
6 changes: 3 additions & 3 deletions thrust/thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
OutputIt,
ScanOp,
InputValueT,
thrust::detail::int32_t,
std::int32_t,
InitValueT,
cub::DeviceScanPolicy<AccumT, ScanOp>,
ForceInclusive>;
Expand All @@ -131,7 +131,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
OutputIt,
ScanOp,
InputValueT,
thrust::detail::int64_t,
std::int64_t,
InitValueT,
cub::DeviceScanPolicy<AccumT, ScanOp>,
ForceInclusive>;
Expand All @@ -157,7 +157,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
// Run scan:
{
// Allocate temporary storage:
thrust::detail::temporary_array<thrust::detail::uint8_t, Derived> tmp{policy, tmp_size};
thrust::detail::temporary_array<std::uint8_t, Derived> tmp{policy, tmp_size};
THRUST_INDEX_TYPE_DISPATCH2(
status,
Dispatch32::Dispatch,
Expand Down
4 changes: 4 additions & 0 deletions thrust/thrust/system/tbb/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ template <typename InputIterator, typename OutputIterator, typename BinaryFuncti
OutputIterator
inclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator result, BinaryFunction binary_op);

template <typename InputIterator, typename OutputIterator, typename T, typename BinaryFunction>
OutputIterator
inclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator result, T init, BinaryFunction binary_op);

template <typename InputIterator, typename OutputIterator, typename T, typename BinaryFunction>
OutputIterator
exclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator result, T init, BinaryFunction binary_op);
Expand Down
122 changes: 120 additions & 2 deletions thrust/thrust/system/tbb/detail/scan.inl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,100 @@ namespace detail
namespace scan_detail
{

template <typename InputIterator, typename OutputIterator, typename BinaryFunction, typename ValueType>
struct inclusive_body_init
{
InputIterator input;
OutputIterator output;
thrust::detail::wrapped_function<BinaryFunction, ValueType> binary_op;
ValueType sum;
bool first_call;

inclusive_body_init(InputIterator input, OutputIterator output, BinaryFunction binary_op, ValueType init)
: input(input)
, output(output)
, binary_op(binary_op)
, sum(init)
, first_call(true)
{}

inclusive_body_init(inclusive_body_init& b, ::tbb::split)
: input(b.input)
, output(b.output)
, binary_op(b.binary_op)
, sum(b.sum)
, first_call(true)
{}

template <typename Size>
void operator()(const ::tbb::blocked_range<Size>& r, ::tbb::pre_scan_tag)
{
InputIterator iter = input + r.begin();

ValueType temp = *iter;

++iter;

for (Size i = r.begin() + 1; i != r.end(); ++i, ++iter)
{
temp = binary_op(temp, *iter);
}

if (first_call)
{
sum = temp;
}
else
{
sum = binary_op(sum, temp);
}

first_call = false;
}

template <typename Size>
void operator()(const ::tbb::blocked_range<Size>& r, ::tbb::final_scan_tag)
{
InputIterator iter1 = input + r.begin();
OutputIterator iter2 = output + r.begin();

if (first_call)
{
*iter2 = sum = binary_op(*iter1, sum);
++iter1;
++iter2;
for (Size i = r.begin() + 1; i != r.end(); ++i, ++iter1, ++iter2)
{
*iter2 = sum = binary_op(sum, *iter1);
}
}
else
{
for (Size i = r.begin(); i != r.end(); ++i, ++iter1, ++iter2)
{
*iter2 = sum = binary_op(sum, *iter1);
}
}

first_call = false;
}

void reverse_join(inclusive_body_init& b)
{
// Only accumulate this functor's partial sum if this functor has been
// called at least once -- otherwise we'll over-count the initial value.
if (!first_call)
{
sum = binary_op(b.sum, sum);
}
}

void assign(inclusive_body_init& b)
{
sum = b.sum;
}
};

template <typename InputIterator, typename OutputIterator, typename BinaryFunction, typename ValueType>
struct inclusive_body
{
Expand All @@ -56,11 +150,11 @@ struct inclusive_body
ValueType sum;
bool first_call;

inclusive_body(InputIterator input, OutputIterator output, BinaryFunction binary_op, ValueType dummy)
inclusive_body(InputIterator input, OutputIterator output, BinaryFunction binary_op, ValueType init)
: input(input)
, output(output)
, binary_op(binary_op)
, sum(dummy)
, sum(init)
, first_call(true)
{}

Expand Down Expand Up @@ -250,6 +344,30 @@ inclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator resu
return result;
}

template <typename InputIterator, typename OutputIterator, typename InitialValueType, typename BinaryFunction>
OutputIterator inclusive_scan(
tag, InputIterator first, InputIterator last, OutputIterator result, InitialValueType init, BinaryFunction binary_op)
{
using namespace thrust::detail;

// Use the input iterator's value type per https://wg21.link/P0571
using ValueType = InitialValueType;

using Size = typename thrust::iterator_difference<InputIterator>::type;
Size n = thrust::distance(first, last);

if (n != 0)
{
typedef typename scan_detail::inclusive_body_init<InputIterator, OutputIterator, BinaryFunction, ValueType> Body;
Body scan_body(first, result, binary_op, init);
::tbb::parallel_scan(::tbb::blocked_range<Size>(0, n), scan_body);
}

thrust::advance(result, n);

return result;
}

template <typename InputIterator, typename OutputIterator, typename InitialValueType, typename BinaryFunction>
OutputIterator exclusive_scan(
tag, InputIterator first, InputIterator last, OutputIterator result, InitialValueType init, BinaryFunction binary_op)
Expand Down

0 comments on commit ad28931

Please sign in to comment.