Skip to content

Commit

Permalink
Use onesweep in radix sort of U8 and U16
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 7, 2023
1 parent a06f5aa commit 0e2cfb2
Showing 1 changed file with 148 additions and 39 deletions.
187 changes: 148 additions & 39 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,52 @@ __global__ void DeviceRadixSortExclusiveSumKernel(OffsetT* d_bins)
}
}

namespace detail
{
namespace radix
{

// default
template <std::size_t KeySize, std::size_t ValueSize, std::size_t OffsetSize>
struct sm90_small_key_tuning
{
static constexpr int threads = 384;
static constexpr int items = 23;
};

// clang-format off

// keys
template <> struct sm90_small_key_tuning<1, 0, 4> { static constexpr int threads = 512; static constexpr int items = 19; };
template <> struct sm90_small_key_tuning<2, 0, 4> { static constexpr int threads = 512; static constexpr int items = 19; };

// pairs 8:xx
template <> struct sm90_small_key_tuning<1, 1, 4> { static constexpr int threads = 512; static constexpr int items = 15; };
template <> struct sm90_small_key_tuning<1, 1, 8> { static constexpr int threads = 448; static constexpr int items = 16; };
template <> struct sm90_small_key_tuning<1, 2, 4> { static constexpr int threads = 512; static constexpr int items = 17; };
template <> struct sm90_small_key_tuning<1, 2, 8> { static constexpr int threads = 512; static constexpr int items = 14; };
template <> struct sm90_small_key_tuning<1, 4, 4> { static constexpr int threads = 512; static constexpr int items = 17; };
template <> struct sm90_small_key_tuning<1, 4, 8> { static constexpr int threads = 512; static constexpr int items = 14; };
template <> struct sm90_small_key_tuning<1, 8, 4> { static constexpr int threads = 384; static constexpr int items = 23; };
template <> struct sm90_small_key_tuning<1, 8, 8> { static constexpr int threads = 384; static constexpr int items = 18; };
template <> struct sm90_small_key_tuning<1, 16, 4> { static constexpr int threads = 512; static constexpr int items = 22; };
template <> struct sm90_small_key_tuning<1, 16, 8> { static constexpr int threads = 512; static constexpr int items = 22; };

// pairs 16:xx
template <> struct sm90_small_key_tuning<2, 1, 4> { static constexpr int threads = 384; static constexpr int items = 14; };
template <> struct sm90_small_key_tuning<2, 1, 8> { static constexpr int threads = 384; static constexpr int items = 16; };
template <> struct sm90_small_key_tuning<2, 2, 4> { static constexpr int threads = 384; static constexpr int items = 15; };
template <> struct sm90_small_key_tuning<2, 2, 8> { static constexpr int threads = 448; static constexpr int items = 16; };
template <> struct sm90_small_key_tuning<2, 4, 4> { static constexpr int threads = 512; static constexpr int items = 17; };
template <> struct sm90_small_key_tuning<2, 4, 8> { static constexpr int threads = 512; static constexpr int items = 12; };
template <> struct sm90_small_key_tuning<2, 8, 4> { static constexpr int threads = 384; static constexpr int items = 23; };
template <> struct sm90_small_key_tuning<2, 8, 8> { static constexpr int threads = 512; static constexpr int items = 23; };
template <> struct sm90_small_key_tuning<2, 16, 4> { static constexpr int threads = 512; static constexpr int items = 21; };
template <> struct sm90_small_key_tuning<2, 16, 8> { static constexpr int threads = 576; static constexpr int items = 22; };
// clang-format on

} // namespace radix
} // namespace detail

/******************************************************************************
* Policy
Expand Down Expand Up @@ -959,53 +1005,116 @@ struct DeviceRadixSortPolicy
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
ONESWEEP = sizeof(KeyT) >= sizeof(uint32_t),
ONESWEEP = true,
ONESWEEP_RADIX_BITS = 8,
OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0,
FLOAT_KEYS = std::is_same<KeyT, float>::value ? 1 : 0,
};

// Histogram policy
typedef AgentRadixSortHistogramPolicy <128, 16, 1, KeyT, ONESWEEP_RADIX_BITS> HistogramPolicy;

// Exclusive sum policy
typedef AgentRadixSortExclusiveSumPolicy <256, ONESWEEP_RADIX_BITS> ExclusiveSumPolicy;

typedef AgentRadixSortOnesweepPolicy <384,
KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS :
(sizeof(ValueT) < 8 ? (OFFSET_64BIT ? 17 : 23) : (OFFSET_64BIT ? 29 : 30)),
DominantT, 1, RADIX_RANK_MATCH_EARLY_COUNTS_ANY, BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT, ONESWEEP_RADIX_BITS> OnesweepPolicyKey32;

typedef AgentRadixSortOnesweepPolicy <384, sizeof(ValueT) < 8 ? 30 : 24, DominantT, 1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY, BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT, ONESWEEP_RADIX_BITS> OnesweepPolicyKey64;

typedef typename std::conditional<sizeof(KeyT) == 4,
OnesweepPolicyKey32, OnesweepPolicyKey64>::type OnesweepPolicy;

// ScanPolicy
typedef AgentScanPolicy <512, 23, OffsetT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy;

// Downsweep policies
typedef AgentRadixSortDownsweepPolicy <512, 23, DominantT, BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_MATCH, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicy;
typedef AgentRadixSortDownsweepPolicy <(sizeof(KeyT) > 1) ? 256 : 128, 47, DominantT, BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_MEMOIZE, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS - 1> AltDownsweepPolicy;

// Upsweep policies
typedef AgentRadixSortUpsweepPolicy <256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicy;
typedef AgentRadixSortUpsweepPolicy <256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1> AltUpsweepPolicy;

// Single-tile policy
typedef AgentRadixSortDownsweepPolicy <256, 19, DominantT, BLOCK_LOAD_DIRECT, LOAD_LDG, RADIX_RANK_MEMOIZE, BLOCK_SCAN_WARP_SCANS, SINGLE_TILE_RADIX_BITS> SingleTilePolicy;

// Segmented policies
typedef AgentRadixSortDownsweepPolicy <192, 39, DominantT, BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_MEMOIZE, BLOCK_SCAN_WARP_SCANS, SEGMENTED_RADIX_BITS> SegmentedPolicy;
typedef AgentRadixSortDownsweepPolicy <384, 11, DominantT, BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_MEMOIZE, BLOCK_SCAN_WARP_SCANS, SEGMENTED_RADIX_BITS - 1> AltSegmentedPolicy;
using HistogramPolicy =
AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>;
using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>;

using OnesweepPolicyKey32 =
AgentRadixSortOnesweepPolicy<384,
KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
: (sizeof(ValueT) < 8 ? (OFFSET_64BIT ? 17 : 23)
: (OFFSET_64BIT ? 29 : 30)),
DominantT,
1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS>;

using OnesweepPolicyKey64 = AgentRadixSortOnesweepPolicy<384,
sizeof(ValueT) < 8 ? 30 : 24,
DominantT,
1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS>;

using OnesweepLargeKeyPolicy = //
cub::detail::conditional_t<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;

using OnesweepSmallKeyPolicySizes = //
detail::radix::sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;
using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<OnesweepSmallKeyPolicySizes::threads,
OnesweepSmallKeyPolicySizes::items,
DominantT,
1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
8>;
using OnesweepPolicy = //
cub::detail::conditional_t<sizeof(KeyT) < 4, //
OnesweepSmallKeyPolicy, //
OnesweepLargeKeyPolicy>;

using ScanPolicy = AgentScanPolicy<512,
23,
OffsetT,
BLOCK_LOAD_WARP_TRANSPOSE,
LOAD_DEFAULT,
BLOCK_STORE_WARP_TRANSPOSE,
BLOCK_SCAN_RAKING_MEMOIZE>;

using DownsweepPolicy = AgentRadixSortDownsweepPolicy<512,
23,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MATCH,
BLOCK_SCAN_WARP_SCANS,
PRIMARY_RADIX_BITS>;

using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<(sizeof(KeyT) > 1) ? 256 : 128,
47,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
PRIMARY_RADIX_BITS - 1>;

using UpsweepPolicy =
AgentRadixSortUpsweepPolicy<256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
using AltUpsweepPolicy =
AgentRadixSortUpsweepPolicy<256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1>;

using SingleTilePolicy = AgentRadixSortDownsweepPolicy<256,
19,
DominantT,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
SINGLE_TILE_RADIX_BITS>;

using SegmentedPolicy = AgentRadixSortDownsweepPolicy<192,
39,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
SEGMENTED_RADIX_BITS>;

using AltSegmentedPolicy = AgentRadixSortDownsweepPolicy<384,
11,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
SEGMENTED_RADIX_BITS - 1>;
};


/// MaxPolicy
typedef Policy900 MaxPolicy;
using MaxPolicy = Policy900;
};


Expand Down

0 comments on commit 0e2cfb2

Please sign in to comment.