Skip to content

Commit

Permalink
Cleanup CUB util_type.cuh (#1863)
Browse files Browse the repository at this point in the history
* Cleanup CUB util_type.cuh
* Replace all uses of __conditional_t in CUB if _If, which does not need to instantiate the type not selected.
  • Loading branch information
bernhardmgruber committed Jul 2, 2024
1 parent 91b78d8 commit 327420b
Show file tree
Hide file tree
Showing 40 changed files with 289 additions and 321 deletions.
4 changes: 3 additions & 1 deletion cub/benchmarks/bench/radix_sort/keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <cub/device/device_radix_sort.cuh>

#include <cuda/std/type_traits>

#include <nvbench_helper.cuh>

// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1
Expand All @@ -46,7 +48,7 @@ struct policy_hub_t
{
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;

using DominantT = cub::detail::conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;

struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
{
Expand Down
4 changes: 3 additions & 1 deletion cub/benchmarks/bench/radix_sort/pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <cub/device/device_radix_sort.cuh>

#include <cuda/std/type_traits>

#include <nvbench_helper.cuh>

// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1
Expand All @@ -44,7 +46,7 @@ struct policy_hub_t
{
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;

using DominantT = cub::detail::conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;

struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
{
Expand Down
8 changes: 5 additions & 3 deletions cub/cub/agent/agent_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -225,9 +227,9 @@ struct AgentHistogram
// Wrap the native input pointer with CacheModifiedInputIterator
// or directly use the supplied input iterator type
using WrappedSampleIteratorT =
cub::detail::conditional_t<std::is_pointer<SampleIteratorT>::value,
CacheModifiedInputIterator<LOAD_MODIFIER, SampleT, OffsetT>,
SampleIteratorT>;
::cuda::std::_If<std::is_pointer<SampleIteratorT>::value,
CacheModifiedInputIterator<LOAD_MODIFIER, SampleT, OffsetT>,
SampleIteratorT>;

/// Pixel input iterator type (for applying cache modifier)
using WrappedPixelIteratorT = CacheModifiedInputIterator<LOAD_MODIFIER, PixelT, OffsetT>;
Expand Down
6 changes: 4 additions & 2 deletions cub/cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

CUB_NAMESPACE_BEGIN

/** \brief cub::RadixSortStoreAlgorithm enumerates different algorithms to write
Expand Down Expand Up @@ -146,10 +148,10 @@ struct AgentRadixSortOnesweep
|| RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
"for onesweep agent, the ranking algorithm must warp-strided key arrangement");

using BlockRadixRankT = cub::detail::conditional_t<
using BlockRadixRankT = ::cuda::std::_If<
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, false, SCAN_ALGORITHM, WARP_MATCH_ATOMIC_OR, RANK_NUM_PARTS>,
cub::detail::conditional_t<
::cuda::std::_If<
RANK_ALGORITHM == RADIX_RANK_MATCH,
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, false, SCAN_ALGORITHM>,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, false, SCAN_ALGORITHM, WARP_MATCH_ANY, RANK_NUM_PARTS>>>;
Expand Down
8 changes: 5 additions & 3 deletions cub/cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

#include <iterator>

_CCCL_SUPPRESS_DEPRECATED_PUSH
Expand Down Expand Up @@ -145,9 +147,9 @@ struct AgentReduce
// Wrap the native input pointer with CacheModifiedInputIterator
// or directly use the supplied input iterator type
using WrappedInputIteratorT =
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;

/// Constants
static constexpr int BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS;
Expand Down
20 changes: 11 additions & 9 deletions cub/cub/agent/agent_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -225,27 +227,27 @@ struct AgentReduceByKey
// CacheModifiedValuesInputIterator or directly use the supplied input
// iterator type
using WrappedKeysInputIteratorT =
cub::detail::conditional_t<std::is_pointer<KeysInputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
KeysInputIteratorT>;
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
KeysInputIteratorT>;

// Cache-modified Input iterator wrapper type (for applying cache modifier)
// for values Wrap the native input pointer with
// CacheModifiedValuesInputIterator or directly use the supplied input
// iterator type
using WrappedValuesInputIteratorT =
cub::detail::conditional_t<std::is_pointer<ValuesInputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
ValuesInputIteratorT>;
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
ValuesInputIteratorT>;

// Cache-modified Input iterator wrapper type (for applying cache modifier)
// for fixup values Wrap the native input pointer with
// CacheModifiedValuesInputIterator or directly use the supplied input
// iterator type
using WrappedFixupInputIteratorT =
cub::detail::conditional_t<std::is_pointer<AggregatesOutputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
AggregatesOutputIteratorT>;
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
AggregatesOutputIteratorT>;

// Reduce-value-by-segment scan operator
using ReduceBySegmentOpT = ReduceBySegmentOp<ReductionOpT>;
Expand Down
10 changes: 6 additions & 4 deletions cub/cub/agent/agent_rle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -231,9 +233,9 @@ struct AgentRle
// Wrap the native input pointer with CacheModifiedVLengthnputIterator
// Directly use the supplied input iterator type
using WrappedInputIteratorT =
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
InputIteratorT>;
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
InputIteratorT>;

// Parameterized BlockLoad type for data
using BlockLoadT =
Expand All @@ -257,7 +259,7 @@ struct AgentRle
using WarpExchangePairs = WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>;

using WarpExchangePairsStorage =
cub::detail::conditional_t<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
::cuda::std::_If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;

using WarpExchangeOffsets = WarpExchange<OffsetT, ITEMS_PER_THREAD>;
using WarpExchangeLengths = WarpExchange<LengthT, ITEMS_PER_THREAD>;
Expand Down
8 changes: 5 additions & 3 deletions cub/cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
#include <cub/grid/grid_queue.cuh>
#include <cub/iterator/cache_modified_input_iterator.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -157,9 +159,9 @@ struct AgentScan
// Wrap the native input pointer with CacheModifiedInputIterator
// or directly use the supplied input iterator type
using WrappedInputIteratorT =
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;

// Constants
enum
Expand Down
14 changes: 8 additions & 6 deletions cub/cub/agent/agent_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -152,14 +154,14 @@ struct AgentScanByKey
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;

using WrappedKeysInputIteratorT =
cub::detail::conditional_t<std::is_pointer<KeysInputIteratorT>::value,
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
KeysInputIteratorT>;
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
KeysInputIteratorT>;

using WrappedValuesInputIteratorT =
cub::detail::conditional_t<std::is_pointer<ValuesInputIteratorT>::value,
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
ValuesInputIteratorT>;
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
ValuesInputIteratorT>;

using BlockLoadKeysT = BlockLoad<KeyT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::LOAD_ALGORITHM>;

Expand Down
16 changes: 9 additions & 7 deletions cub/cub/agent/agent_segment_fixup.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -171,18 +173,18 @@ struct AgentSegmentFixup
// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
// Wrap the native input pointer with CacheModifiedValuesInputIterator
// or directly use the supplied input iterator type
using WrappedPairsInputIteratorT = cub::detail::conditional_t<
std::is_pointer<PairsInputIteratorT>::value,
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
PairsInputIteratorT>;
using WrappedPairsInputIteratorT =
::cuda::std::_If<std::is_pointer<PairsInputIteratorT>::value,
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
PairsInputIteratorT>;

// Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
// Wrap the native input pointer with CacheModifiedValuesInputIterator
// or directly use the supplied input iterator type
using WrappedFixupInputIteratorT =
cub::detail::conditional_t<std::is_pointer<AggregatesOutputIteratorT>::value,
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
AggregatesOutputIteratorT>;
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
AggregatesOutputIteratorT>;

// Reduce-value-by-segment scan operator
using ReduceBySegmentOpT = ReduceByKeyOp<cub::Sum>;
Expand Down
12 changes: 6 additions & 6 deletions cub/cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,17 @@ struct AgentSelectIf
// Wrap the native input pointer with CacheModifiedValuesInputIterator
// or directly use the supplied input iterator type
using WrappedInputIteratorT =
cub::detail::conditional_t<::cuda::std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;
::cuda::std::_If<::cuda::std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;

// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
// Wrap the native input pointer with CacheModifiedValuesInputIterator
// or directly use the supplied input iterator type
using WrappedFlagsInputIteratorT =
cub::detail::conditional_t<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
FlagsInputIteratorT>;
::cuda::std::_If<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
FlagsInputIteratorT>;

// Parameterized BlockLoad type for input data
using BlockLoadT = BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentSelectIfPolicyT::LOAD_ALGORITHM>;
Expand Down
4 changes: 3 additions & 1 deletion cub/cub/agent/agent_spmv_orig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#include <cub/thread/thread_search.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -264,7 +266,7 @@ struct AgentSpmv
{
// Value type to pair with index type OffsetT
// (NullType if loading values directly during merge)
using MergeValueT = cub::detail::conditional_t<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
using MergeValueT = ::cuda::std::_If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;

OffsetT row_end_offset;
MergeValueT nonzero;
Expand Down
6 changes: 3 additions & 3 deletions cub/cub/agent/agent_three_way_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ struct AgentThreeWayPartition
static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;

using WrappedInputIteratorT =
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;

// Parameterized BlockLoad type for input data
using BlockLoadT = cub::BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, PolicyT::LOAD_ALGORITHM>;
Expand Down
Loading

0 comments on commit 327420b

Please sign in to comment.