Skip to content

Commit

Permalink
Use value iterator's value type for partial sums in TBB reduce_by_key
Browse files Browse the repository at this point in the history
This allows us to get rid of partial_sum_type, which still uses the C++11-deprecated function object API ::result_type.
  • Loading branch information
bernhardmgruber committed Jul 17, 2024
1 parent be91914 commit 0de25a6
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions thrust/thrust/system/tbb/detail/reduce_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,9 @@ inline L divide_ri(const L x, const R y)
return (x + (y - 1)) / y;
}

template <typename InputIterator, typename BinaryFunction, typename SFINAE = void>
struct partial_sum_type
{
using type = typename thrust::iterator_value<InputIterator>::type;
};

template <typename InputIterator, typename BinaryFunction>
struct partial_sum_type<InputIterator, BinaryFunction, ::cuda::std::void_t<typename BinaryFunction::result_type>>
{
using type = typename BinaryFunction::result_type;
};

template <typename InputIterator1, typename InputIterator2, typename BinaryPredicate, typename BinaryFunction>
thrust::pair<InputIterator1,
thrust::pair<typename thrust::iterator_value<InputIterator1>::type,
typename partial_sum_type<InputIterator2, BinaryFunction>::type>>
thrust::pair<thrust::iterator_value_t<InputIterator1>, thrust::iterator_value_t<InputIterator2>>>
reduce_last_segment_backward(
InputIterator1 keys_first,
InputIterator1 keys_last,
Expand All @@ -90,8 +77,8 @@ reduce_last_segment_backward(
thrust::reverse_iterator<InputIterator1> keys_last_r(keys_first);
thrust::reverse_iterator<InputIterator2> values_first_r(values_first + n);

typename thrust::iterator_value<InputIterator1>::type result_key = *keys_first_r;
typename partial_sum_type<InputIterator2, BinaryFunction>::type result_value = *values_first_r;
thrust::iterator_value_t<InputIterator1> result_key = *keys_first_r;
thrust::iterator_value_t<InputIterator2> result_value = *values_first_r;

// consume the entirety of the first key's sequence
for (++keys_first_r, ++values_first_r; (keys_first_r != keys_last_r) && binary_pred(*keys_first_r, result_key);
Expand All @@ -111,8 +98,8 @@ template <typename InputIterator1,
typename BinaryFunction>
thrust::tuple<OutputIterator1,
OutputIterator2,
typename thrust::iterator_value<InputIterator1>::type,
typename partial_sum_type<InputIterator2, BinaryFunction>::type>
thrust::iterator_value_t<InputIterator1>,
thrust::iterator_value_t<InputIterator2>>
reduce_by_key_with_carry(
InputIterator1 keys_first,
InputIterator1 keys_last,
Expand All @@ -124,9 +111,7 @@ reduce_by_key_with_carry(
{
// first, consume the last sequence to produce the carry
// XXX is there an elegant way to pose this such that we don't need to default construct carry?
thrust::pair<typename thrust::iterator_value<InputIterator1>::type,
typename partial_sum_type<InputIterator2, BinaryFunction>::type>
carry;
thrust::pair<thrust::iterator_value_t<InputIterator1>, thrust::iterator_value_t<InputIterator2>> carry;

thrust::tie(keys_last, carry) =
reduce_last_segment_backward(keys_first, keys_last, values_first, binary_pred, binary_op);
Expand Down Expand Up @@ -215,8 +200,8 @@ struct serial_reduce_by_key_body
Iterator6 my_carry_result = carry_result + interval_idx;

// consume the rest of the interval with reduce_by_key
using key_type = typename thrust::iterator_value<Iterator1>::type;
using value_type = typename partial_sum_type<Iterator2, BinaryFunction>::type;
using key_type = thrust::iterator_value_t<Iterator1>;
using value_type = thrust::iterator_value_t<Iterator2>;

// XXX is there a way to pose this so that we don't require default construction of carry?
thrust::pair<key_type, value_type> carry;
Expand Down Expand Up @@ -360,7 +345,7 @@ thrust::pair<Iterator3, Iterator4> reduce_by_key(

// do a reduce_by_key serially in each thread
// the final interval never has a carry by definition, so don't reserve space for it
using carry_type = typename reduce_by_key_detail::partial_sum_type<Iterator2, BinaryFunction>::type;
using carry_type = thrust::iterator_value_t<Iterator2>;
thrust::detail::temporary_array<carry_type, DerivedPolicy> carries(0, exec, num_intervals - 1);

// force grainsize == 1 with simple_partioner()
Expand Down

0 comments on commit 0de25a6

Please sign in to comment.