Skip to content

Commit

Permalink
Make thrust::sort use radix sort with more comparators (#1884)
Browse files Browse the repository at this point in the history
Newly included: ::cuda::std::less/greater and all transparent comparators.
  • Loading branch information
bernhardmgruber authored Jun 23, 2024
1 parent c7d13db commit bb44c7d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 28 deletions.
19 changes: 19 additions & 0 deletions thrust/testing/cuda/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,22 @@ void TestComparisonSortCudaStreams()
cudaStreamDestroy(s);
}
DECLARE_UNITTEST(TestComparisonSortCudaStreams);

template <typename T>
struct TestRadixSortDispatch
{
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, thrust::less<T>>::value, "");
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, thrust::greater<T>>::value, "");
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, ::cuda::std::less<T>>::value, "");
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, ::cuda::std::greater<T>>::value, "");

static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, thrust::less<>>::value, "");
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, thrust::greater<>>::value, "");
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, ::cuda::std::less<>>::value, "");
static_assert(thrust::cuda_cub::__smart_sort::can_use_primitive_sort<T, ::cuda::std::greater<>>::value, "");

void operator()() const {}
};
// TODO(bgruber): use a single test case with a concatenated key list and a cartesion product with the comparators
SimpleUnitTest<TestRadixSortDispatch, IntegralTypes> TestRadixSortDispatchIntegralInstance;
SimpleUnitTest<TestRadixSortDispatch, FloatingPointTypes> TestRadixSortDispatchFPInstance;
75 changes: 47 additions & 28 deletions thrust/thrust/system/cuda/detail/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ template <class SORT_ITEMS, class Comparator>
struct dispatch;

// sort keys in ascending order
template <class K>
struct dispatch<thrust::detail::false_type, thrust::less<K>>
template <class KeyOrVoid>
struct dispatch<thrust::detail::false_type, thrust::less<KeyOrVoid>>
{
template <class Key, class Item, class Size>
THRUST_RUNTIME_FUNCTION static cudaError_t
Expand All @@ -182,8 +182,8 @@ struct dispatch<thrust::detail::false_type, thrust::less<K>>
}; // struct dispatch -- sort keys in ascending order;

// sort keys in descending order
template <class K>
struct dispatch<thrust::detail::false_type, thrust::greater<K>>
template <class KeyOrVoid>
struct dispatch<thrust::detail::false_type, thrust::greater<KeyOrVoid>>
{
template <class Key, class Item, class Size>
THRUST_RUNTIME_FUNCTION static cudaError_t
Expand All @@ -206,8 +206,8 @@ struct dispatch<thrust::detail::false_type, thrust::greater<K>>
}; // struct dispatch -- sort keys in descending order;

// sort pairs in ascending order
template <class K>
struct dispatch<thrust::detail::true_type, thrust::less<K>>
template <class KeyOrVoid>
struct dispatch<thrust::detail::true_type, thrust::less<KeyOrVoid>>
{
template <class Key, class Item, class Size>
THRUST_RUNTIME_FUNCTION static cudaError_t
Expand All @@ -231,8 +231,8 @@ struct dispatch<thrust::detail::true_type, thrust::less<K>>
}; // struct dispatch -- sort pairs in ascending order;

// sort pairs in descending order
template <class K>
struct dispatch<thrust::detail::true_type, thrust::greater<K>>
template <class KeyOrVoid>
struct dispatch<thrust::detail::true_type, thrust::greater<KeyOrVoid>>
{
template <class Key, class Item, class Size>
THRUST_RUNTIME_FUNCTION static cudaError_t
Expand All @@ -255,6 +255,14 @@ struct dispatch<thrust::detail::true_type, thrust::greater<K>>
}
}; // struct dispatch -- sort pairs in descending order;

template <class SORT_ITEMS, class KeyOrVoid>
struct dispatch<SORT_ITEMS, ::cuda::std::less<KeyOrVoid>> : dispatch<SORT_ITEMS, thrust::less<KeyOrVoid>>
{};

template <class SORT_ITEMS, class KeyOrVoid>
struct dispatch<SORT_ITEMS, ::cuda::std::greater<KeyOrVoid>> : dispatch<SORT_ITEMS, thrust::greater<KeyOrVoid>>
{};

template <typename SORT_ITEMS, typename Derived, typename Key, typename Item, typename Size, typename CompareOp>
THRUST_RUNTIME_FUNCTION void radix_sort(execution_policy<Derived>& policy, Key* keys, Item* items, Size count, CompareOp)
{
Expand Down Expand Up @@ -312,32 +320,43 @@ THRUST_RUNTIME_FUNCTION void radix_sort(execution_policy<Derived>& policy, Key*
namespace __smart_sort
{

// TODO(bgruber): we can drop thrust::less etc. when they truly alias to the ::cuda::std ones
template <class Key, class CompareOp>
struct can_use_primitive_sort
: ::cuda::std::_And<::cuda::std::is_arithmetic<Key>,
::cuda::std::disjunction<::cuda::std::is_same<CompareOp, thrust::less<Key>>,
::cuda::std::is_same<CompareOp, thrust::greater<Key>>>>
{};

template <class Iterator, class CompareOp>
struct enable_if_primitive_sort
: ::cuda::std::enable_if<can_use_primitive_sort<typename iterator_value<Iterator>::type, CompareOp>::value>
{};

template <class Iterator, class CompareOp>
struct enable_if_comparison_sort
: thrust::detail::disable_if<can_use_primitive_sort<typename iterator_value<Iterator>::type, CompareOp>::value>
{};

template <class SORT_ITEMS, class STABLE, class Policy, class KeysIt, class ItemsIt, class CompareOp>
THRUST_RUNTIME_FUNCTION typename enable_if_comparison_sort<KeysIt, CompareOp>::type
using can_use_primitive_sort = ::cuda::std::integral_constant<
bool,
::cuda::std::is_arithmetic<Key>::value
&& (::cuda::std::is_same<CompareOp, thrust::less<Key>>::value
|| ::cuda::std::is_same<CompareOp, ::cuda::std::less<Key>>::value
|| ::cuda::std::is_same<CompareOp, thrust::less<void>>::value
|| ::cuda::std::is_same<CompareOp, ::cuda::std::less<void>>::value
|| ::cuda::std::is_same<CompareOp, thrust::greater<Key>>::value
|| ::cuda::std::is_same<CompareOp, ::cuda::std::greater<Key>>::value
|| ::cuda::std::is_same<CompareOp, thrust::greater<void>>::value
|| ::cuda::std::is_same<CompareOp, ::cuda::std::greater<void>>::value)>;

template <
class SORT_ITEMS,
class STABLE,
class Policy,
class KeysIt,
class ItemsIt,
class CompareOp,
::cuda::std::__enable_if_t<!can_use_primitive_sort<typename iterator_value<KeysIt>::type, CompareOp>::value, int> = 0>
THRUST_RUNTIME_FUNCTION void
smart_sort(Policy& policy, KeysIt keys_first, KeysIt keys_last, ItemsIt items_first, CompareOp compare_op)
{
__merge_sort::merge_sort<SORT_ITEMS, STABLE>(policy, keys_first, keys_last, items_first, compare_op);
}

template <class SORT_ITEMS, class STABLE, class Policy, class KeysIt, class ItemsIt, class CompareOp>
THRUST_RUNTIME_FUNCTION typename enable_if_primitive_sort<KeysIt, CompareOp>::type smart_sort(
template <
class SORT_ITEMS,
class /*STABLE*/,
class Policy,
class KeysIt,
class ItemsIt,
class CompareOp,
::cuda::std::__enable_if_t<can_use_primitive_sort<typename iterator_value<KeysIt>::type, CompareOp>::value, int> = 0>
THRUST_RUNTIME_FUNCTION void smart_sort(
execution_policy<Policy>& policy, KeysIt keys_first, KeysIt keys_last, ItemsIt items_first, CompareOp compare_op)
{
// ensure sequences have trivial iterators
Expand Down

0 comments on commit bb44c7d

Please sign in to comment.