Skip to content

Commit

Permalink
Address review notes
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 20, 2023
1 parent 38f2e32 commit d1dcb77
Show file tree
Hide file tree
Showing 17 changed files with 155 additions and 90 deletions.
2 changes: 1 addition & 1 deletion cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ template <
typename KeyT, ///< KeyT type
typename ValueT, ///< ValueT type
typename OffsetT, ///< Signed integer type for global offsets
typename DecomposerT = detail::fundamental_decomposer_t>
typename DecomposerT = detail::identity_decomposer_t>
struct AgentRadixSortDownsweep
{
//---------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ template <
bool IS_DESCENDING,
typename KeyT,
typename OffsetT,
typename DecomposerT = detail::fundamental_decomposer_t>
typename DecomposerT = detail::identity_decomposer_t>
struct AgentRadixSortHistogram
{
// constants
Expand Down
2 changes: 1 addition & 1 deletion cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ template <
typename ValueT,
typename OffsetT,
typename PortionOffsetT,
typename DecomposerT = detail::fundamental_decomposer_t>
typename DecomposerT = detail::identity_decomposer_t>
struct AgentRadixSortOnesweep
{
// constants
Expand Down
2 changes: 1 addition & 1 deletion cub/agent/agent_radix_sort_upsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ template <
typename AgentRadixSortUpsweepPolicy, ///< Parameterized AgentRadixSortUpsweepPolicy tuning policy type
typename KeyT, ///< KeyT type
typename OffsetT,
typename DecomposerT = detail::fundamental_decomposer_t> ///< Signed integer type for global offsets
typename DecomposerT = detail::identity_decomposer_t> ///< Signed integer type for global offsets
struct AgentRadixSortUpsweep
{

Expand Down
2 changes: 1 addition & 1 deletion cub/agent/agent_segmented_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ template <bool IS_DESCENDING,
typename KeyT,
typename ValueT,
typename OffsetT,
typename DecomposerT = detail::fundamental_decomposer_t>
typename DecomposerT = detail::identity_decomposer_t>
struct AgentSegmentedRadixSort
{
OffsetT num_items;
Expand Down
4 changes: 2 additions & 2 deletions cub/agent/agent_sub_warp_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ class AgentSubWarpSort

// Segmented sort doesn't support custom types at the moment.
bit_ordered_type default_key_bits = IS_DESCENDING
? traits::min_raw_binary_key(detail::fundamental_decomposer_t{})
: traits::max_raw_binary_key(detail::fundamental_decomposer_t{});
? traits::min_raw_binary_key(detail::identity_decomposer_t{})
: traits::max_raw_binary_key(detail::identity_decomposer_t{});
return reinterpret_cast<KeyT &>(default_key_bits);
}

Expand Down
4 changes: 2 additions & 2 deletions cub/block/block_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ private:
{}

/// Sort blocked arrangement
template <int DESCENDING, int KEYS_ONLY, class DecomposerT = detail::fundamental_decomposer_t>
template <int DESCENDING, int KEYS_ONLY, class DecomposerT = detail::identity_decomposer_t>
__device__ __forceinline__ void SortBlocked(
KeyT (&keys)[ITEMS_PER_THREAD], ///< Keys to sort
ValueT (&values)[ITEMS_PER_THREAD], ///< Values to sort
Expand Down Expand Up @@ -427,7 +427,7 @@ public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

/// Sort blocked -> striped arrangement
template <int DESCENDING, int KEYS_ONLY, class DecomposerT = detail::fundamental_decomposer_t>
template <int DESCENDING, int KEYS_ONLY, class DecomposerT = detail::identity_decomposer_t>
__device__ __forceinline__ void SortBlockedToStriped(
KeyT (&keys)[ITEMS_PER_THREAD], ///< Keys to sort
ValueT (&values)[ITEMS_PER_THREAD], ///< Values to sort
Expand Down
32 changes: 16 additions & 16 deletions cub/block/radix_rank_sort_operations.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ using all_t = //
logic_helper_t<Bs...>, //
logic_helper_t<true_t<Bs>::value...>>;

struct fundamental_decomposer_t
struct identity_decomposer_t
{
template <class T>
__host__ __device__ T& operator()(T &key) const
Expand Down Expand Up @@ -242,13 +242,13 @@ struct bit_ordered_conversion_policy_t
{
using bit_ordered_type = typename Traits<T>::UnsignedBits;

static __host__ __device__ bit_ordered_type to_bit_ordered(detail::fundamental_decomposer_t,
static __host__ __device__ bit_ordered_type to_bit_ordered(detail::identity_decomposer_t,
bit_ordered_type val)
{
return Traits<T>::TwiddleIn(val);
}

static __host__ __device__ bit_ordered_type from_bit_ordered(detail::fundamental_decomposer_t,
static __host__ __device__ bit_ordered_type from_bit_ordered(detail::identity_decomposer_t,
bit_ordered_type val)
{
return Traits<T>::TwiddleOut(val);
Expand All @@ -260,7 +260,7 @@ struct bit_ordered_inversion_policy_t
{
using bit_ordered_type = typename Traits<T>::UnsignedBits;

static __host__ __device__ bit_ordered_type inverse(detail::fundamental_decomposer_t,
static __host__ __device__ bit_ordered_type inverse(detail::identity_decomposer_t,
bit_ordered_type val)
{
return ~val;
Expand All @@ -277,25 +277,25 @@ struct traits_t
template <class FundamentalExtractorT, class /* DecomposerT */>
using digit_extractor_t = FundamentalExtractorT;

static __host__ __device__ bit_ordered_type min_raw_binary_key(detail::fundamental_decomposer_t)
static __host__ __device__ bit_ordered_type min_raw_binary_key(detail::identity_decomposer_t)
{
return Traits<T>::LOWEST_KEY;
}

static __host__ __device__ bit_ordered_type max_raw_binary_key(detail::fundamental_decomposer_t)
static __host__ __device__ bit_ordered_type max_raw_binary_key(detail::identity_decomposer_t)
{
return Traits<T>::MAX_KEY;
}

static __host__ __device__ int default_end_bit(detail::fundamental_decomposer_t)
static __host__ __device__ int default_end_bit(detail::identity_decomposer_t)
{
return sizeof(T) * 8;
}

template <class FundamentalExtractorT>
static __host__
__device__ digit_extractor_t<FundamentalExtractorT, detail::fundamental_decomposer_t>
digit_extractor(int begin_bit, int num_bits, detail::fundamental_decomposer_t)
__device__ digit_extractor_t<FundamentalExtractorT, detail::identity_decomposer_t>
digit_extractor(int begin_bit, int num_bits, detail::identity_decomposer_t)
{
return FundamentalExtractorT(begin_bit, num_bits);
}
Expand All @@ -312,7 +312,7 @@ struct min_raw_binary_key_f
using traits = traits_t<typename ::cuda::std::remove_cv<T>::type>;
using bit_ordered_type = typename traits::bit_ordered_type;
reinterpret_cast<bit_ordered_type &>(field) =
traits::min_raw_binary_key(detail::fundamental_decomposer_t{});
traits::min_raw_binary_key(detail::identity_decomposer_t{});
}
};

Expand All @@ -333,7 +333,7 @@ struct max_raw_binary_key_f
using traits = traits_t<typename ::cuda::std::remove_cv<T>::type>;
using bit_ordered_type = typename traits::bit_ordered_type;
reinterpret_cast<bit_ordered_type &>(field) =
traits::max_raw_binary_key(detail::fundamental_decomposer_t{});
traits::max_raw_binary_key(detail::identity_decomposer_t{});
}
};

Expand All @@ -356,7 +356,7 @@ struct to_bit_ordered_f
using bit_ordered_conversion = typename traits::bit_ordered_conversion_policy;

auto &ordered_field = reinterpret_cast<bit_ordered_type &>(field);
ordered_field = bit_ordered_conversion::to_bit_ordered(detail::fundamental_decomposer_t{},
ordered_field = bit_ordered_conversion::to_bit_ordered(detail::identity_decomposer_t{},
ordered_field);
}
};
Expand All @@ -380,7 +380,7 @@ struct from_bit_ordered_f
using bit_ordered_conversion = typename traits::bit_ordered_conversion_policy;

auto &ordered_field = reinterpret_cast<bit_ordered_type &>(field);
ordered_field = bit_ordered_conversion::from_bit_ordered(detail::fundamental_decomposer_t{},
ordered_field = bit_ordered_conversion::from_bit_ordered(detail::identity_decomposer_t{},
ordered_field);
}
};
Expand Down Expand Up @@ -597,7 +597,7 @@ private:
using bit_ordered_inversion_policy = typename traits::bit_ordered_inversion_policy;

public:
template <class DecomposerT = detail::fundamental_decomposer_t>
template <class DecomposerT = detail::identity_decomposer_t>
static __host__ __device__ __forceinline__ //
bit_ordered_type
In(bit_ordered_type key, DecomposerT decomposer = {})
Expand All @@ -610,7 +610,7 @@ public:
return key;
}

template <class DecomposerT = detail::fundamental_decomposer_t>
template <class DecomposerT = detail::identity_decomposer_t>
static __host__ __device__ __forceinline__ //
bit_ordered_type
Out(bit_ordered_type key, DecomposerT decomposer = {})
Expand All @@ -623,7 +623,7 @@ public:
return key;
}

template <class DecomposerT = detail::fundamental_decomposer_t>
template <class DecomposerT = detail::identity_decomposer_t>
static __host__ __device__ __forceinline__ //
bit_ordered_type
DefaultKey(DecomposerT decomposer = {})
Expand Down
64 changes: 48 additions & 16 deletions cub/device/device_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -639,7 +641,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -941,7 +945,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = false;
Expand Down Expand Up @@ -1082,7 +1088,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = false;
Expand Down Expand Up @@ -1395,7 +1403,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -1530,7 +1540,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -1827,7 +1839,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = true;
Expand Down Expand Up @@ -1968,7 +1982,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = true;
Expand Down Expand Up @@ -2232,7 +2248,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -2356,7 +2374,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -2656,7 +2676,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = false;
Expand Down Expand Up @@ -2785,7 +2807,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = false;
Expand Down Expand Up @@ -3065,7 +3089,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -3186,7 +3212,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

// We cast away const-ness, but will *not* write to these arrays.
// `DispatchRadixSort::Dispatch` will allocate temporary storage and
Expand Down Expand Up @@ -3453,7 +3481,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = true;
Expand Down Expand Up @@ -3582,7 +3612,9 @@ public:
using offset_t = typename detail::ChooseOffsetT<NumItemsT>::Type;
using decomposer_check_t = detail::radix::decomposer_check_t<KeyT, DecomposerT>;

static_assert(decomposer_check_t::value, "DecomposerT must be a functor");
static_assert(decomposer_check_t::value,
"DecomposerT must be a callable object returning a tuple of references to "
"arithmetic types");

constexpr bool is_overwrite_okay = true;
constexpr bool is_descending = true;
Expand Down
Loading

0 comments on commit d1dcb77

Please sign in to comment.