Skip to content

Commit

Permalink
Address feedback from code review hour
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jul 22, 2024
1 parent bf5cce6 commit fc71d1d
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions thrust/thrust/system/tbb/detail/reduce_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cub/detail/type_traits.cuh>

#include <thrust/detail/minmax.h>
#include <thrust/detail/range/tail_flags.h>
#include <thrust/detail/seq.h>
Expand All @@ -39,6 +36,7 @@
#include <thrust/system/tbb/detail/reduce_by_key.h>
#include <thrust/system/tbb/detail/reduce_intervals.h>

#include <cuda/std/__functional/invoke.h>
#include <cuda/std/__type_traits/void_t.h>

#include <cassert>
Expand All @@ -63,13 +61,27 @@ inline L divide_ri(const L x, const R y)
return (x + (y - 1)) / y;
}

template <typename BinaryFunction, typename InputIterator>
using partial_sum_type = cub::detail::
accumulator_t<BinaryFunction, thrust::iterator_value_t<InputIterator>, thrust::iterator_value_t<InputIterator>>;
template <typename InputIterator, typename BinaryFunction, typename SFINAE = void>
struct partial_sum_type
{
using value_t = thrust::iterator_value_t<InputIterator>;
// TODO(bgruber): unify with CUB's accumulator_t
using type = ::cuda::std::__decay_t<typename ::cuda::std::__invoke_of<BinaryFunction, value_t, value_t>::type>;
};

// gevtushenko preferred to detect a ::result_type on the function first
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename InputIterator, typename BinaryFunction>
struct partial_sum_type<InputIterator, BinaryFunction, ::cuda::std::void_t<typename BinaryFunction::result_type>>
{
using type = typename BinaryFunction::result_type;
};
_CCCL_SUPPRESS_DEPRECATED_POP

template <typename InputIterator1, typename InputIterator2, typename BinaryPredicate, typename BinaryFunction>
thrust::pair<InputIterator1,
thrust::pair<thrust::iterator_value_t<InputIterator1>, partial_sum_type<InputIterator2, BinaryFunction>>>
thrust::pair<thrust::iterator_value_t<InputIterator1>,
typename partial_sum_type<InputIterator2, BinaryFunction>::type>>
reduce_last_segment_backward(
InputIterator1 keys_first,
InputIterator1 keys_last,
Expand All @@ -84,8 +96,8 @@ reduce_last_segment_backward(
thrust::reverse_iterator<InputIterator1> keys_last_r(keys_first);
thrust::reverse_iterator<InputIterator2> values_first_r(values_first + n);

thrust::iterator_value_t<InputIterator1> result_key = *keys_first_r;
partial_sum_type<InputIterator2, BinaryFunction> result_value = *values_first_r;
thrust::iterator_value_t<InputIterator1> result_key = *keys_first_r;
typename partial_sum_type<InputIterator2, BinaryFunction>::type result_value = *values_first_r;

// consume the entirety of the first key's sequence
for (++keys_first_r, ++values_first_r; (keys_first_r != keys_last_r) && binary_pred(*keys_first_r, result_key);
Expand All @@ -106,7 +118,7 @@ template <typename InputIterator1,
thrust::tuple<OutputIterator1,
OutputIterator2,
thrust::iterator_value_t<InputIterator1>,
partial_sum_type<InputIterator2, BinaryFunction>>
typename partial_sum_type<InputIterator2, BinaryFunction>::type>
reduce_by_key_with_carry(
InputIterator1 keys_first,
InputIterator1 keys_last,
Expand All @@ -118,7 +130,8 @@ reduce_by_key_with_carry(
{
// first, consume the last sequence to produce the carry
// XXX is there an elegant way to pose this such that we don't need to default construct carry?
thrust::pair<thrust::iterator_value_t<InputIterator1>, partial_sum_type<InputIterator2, BinaryFunction>> carry;
thrust::pair<thrust::iterator_value_t<InputIterator1>, typename partial_sum_type<InputIterator2, BinaryFunction>::type>
carry;

thrust::tie(keys_last, carry) =
reduce_last_segment_backward(keys_first, keys_last, values_first, binary_pred, binary_op);
Expand Down Expand Up @@ -208,7 +221,7 @@ struct serial_reduce_by_key_body

// consume the rest of the interval with reduce_by_key
using key_type = thrust::iterator_value_t<Iterator1>;
using value_type = partial_sum_type<Iterator2, BinaryFunction>;
using value_type = typename partial_sum_type<Iterator2, BinaryFunction>::type;

// XXX is there a way to pose this so that we don't require default construction of carry?
thrust::pair<key_type, value_type> carry;
Expand Down Expand Up @@ -352,7 +365,7 @@ thrust::pair<Iterator3, Iterator4> reduce_by_key(

// do a reduce_by_key serially in each thread
// the final interval never has a carry by definition, so don't reserve space for it
using carry_type = partial_sum_type<Iterator2, BinaryFunction>;
using carry_type = typename reduce_by_key_detail::partial_sum_type<Iterator2, BinaryFunction>::type;
thrust::detail::temporary_array<carry_type, DerivedPolicy> carries(0, exec, num_intervals - 1);

// force grainsize == 1 with simple_partioner()
Expand Down

0 comments on commit fc71d1d

Please sign in to comment.