diff --git a/c/parallel/src/radix_sort.cu b/c/parallel/src/radix_sort.cu index 757739e442e..4d00187eb52 100644 --- a/c/parallel/src/radix_sort.cu +++ b/c/parallel/src/radix_sort.cu @@ -10,7 +10,6 @@ #include #include -#include #include #include @@ -31,92 +30,6 @@ static_assert(std::is_same_v, OffsetT>, "O namespace radix_sort { -using namespace cub::detail::radix_sort_runtime_policies; - -struct radix_sort_runtime_tuning_policy -{ - RuntimeRadixSortHistogramAgentPolicy histogram; - RuntimeRadixSortExclusiveSumAgentPolicy exclusive_sum; - RuntimeRadixSortOnesweepAgentPolicy onesweep; - cub::detail::RuntimeScanAgentPolicy scan; - cub::detail::RuntimeRadixSortDownsweepAgentPolicy downsweep; - cub::detail::RuntimeRadixSortDownsweepAgentPolicy alt_downsweep; - RuntimeRadixSortUpsweepAgentPolicy upsweep; - RuntimeRadixSortUpsweepAgentPolicy alt_upsweep; - cub::detail::RuntimeRadixSortDownsweepAgentPolicy single_tile; - bool is_onesweep; - - auto Histogram() const - { - return histogram; - } - - auto ExclusiveSum() const - { - return exclusive_sum; - } - - auto Onesweep() const - { - return onesweep; - } - - auto Scan() const - { - return scan; - } - - auto Downsweep() const - { - return downsweep; - } - - auto AltDownsweep() const - { - return alt_downsweep; - } - - auto Upsweep() const - { - return upsweep; - } - - auto AltUpsweep() const - { - return alt_upsweep; - } - - auto SingleTile() const - { - return single_tile; - } - - bool IsOnesweep() const - { - return is_onesweep; - } - - template - CUB_RUNTIME_FUNCTION static constexpr int RadixBits(PolicyT policy) - { - return policy.RadixBits(); - } - - template - CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT policy) - { - return policy.BlockThreads(); - } - - using MaxPolicy = radix_sort_runtime_tuning_policy; - - template - cudaError_t Invoke(int, F& op) - { - return op.template Invoke(*this); - } -}; - std::string get_single_tile_kernel_name( std::string_view chained_policy_t, cccl_sort_order_t sort_order, @@ -290,12 +203,10 @@ CUresult cccl_device_radix_sort_build_ex( { const char* name = "test"; - const int cc = cc_major * 10 + cc_minor; const auto key_cpp = cccl_type_enum_to_name(input_keys_it.value_type.type); - const auto value_cpp = - input_values_it.type == cccl_iterator_kind_t::CCCL_POINTER && input_values_it.state == nullptr - ? "cub::NullType" - : cccl_type_enum_to_name(input_values_it.value_type.type); + const auto keys_only = + input_values_it.type == cccl_iterator_kind_t::CCCL_POINTER && input_values_it.state == nullptr; + const auto value_cpp = keys_only ? "cub::NullType" : cccl_type_enum_to_name(input_values_it.value_type.type); const std::string op_src = (decomposer.name == nullptr || (decomposer.name != nullptr && decomposer.name[0] == '\0')) ? "using op_wrapper = cub::detail::identity_decomposer_t;" @@ -305,8 +216,32 @@ CUresult cccl_device_radix_sort_build_ex( std::string offset_t; check(cccl_type_name_from_nvrtc(&offset_t)); - const auto policy_hub_expr = - std::format("cub::detail::radix_sort::policy_hub<{}, {}, {}>", key_cpp, value_cpp, offset_t); + // TODO(bgruber): generalize this somewhere + const auto key_type = [&] { + switch (input_keys_it.value_type.type) + { + case CCCL_FLOAT32: + return cub::detail::type_t::float32; + case CCCL_FLOAT64: + return cub::detail::type_t::float64; + default: + return cub::detail::type_t::other; + } + }(); + + const auto policy_sel = cub::detail::radix_sort::policy_selector{ + static_cast(input_keys_it.value_type.size), + // FIXME(bgruber): input_values_it.value_type.size is 4 when it represents cub::NullType, which is very odd + keys_only ? 0 : static_cast(input_values_it.value_type.size), + int{sizeof(OffsetT)}, + key_type}; + + // TODO(bgruber): drop this if tuning policies become formattable + std::stringstream policy_sel_str; + policy_sel_str << policy_sel(cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor})); + + auto policy_hub_expr = + std::format("cub::detail::radix_sort::policy_selector_from_types<{}, {}, {}>", key_cpp, value_cpp, offset_t); const std::string final_src = std::format( R"XXX( @@ -321,21 +256,18 @@ struct __align__({3}) values_storage_t {{ char data[{2}]; }}; {4} -using {5} = {6}::MaxPolicy; - -#include -__device__ consteval auto& policy_generator() {{ - return ptx_json::id() - = cub::detail::radix_sort::RadixSortPolicyWrapper<{5}::ActivePolicy>::EncodedPolicy(); -}} +using device_radix_sort_policy = {5}; +using namespace cub; +using namespace cub::detail::radix_sort; +static_assert(device_radix_sort_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {6}, "Host generated and JIT compiled policy mismatch"); )XXX", input_keys_it.value_type.size, // 0 input_keys_it.value_type.alignment, // 1 input_values_it.value_type.size, // 2 input_values_it.value_type.alignment, // 3 op_src, // 4 - chained_policy_t, // 5 - policy_hub_expr); // 6 + policy_hub_expr, // 5 + policy_sel_str.view()); // 6 #if false // CCCL_DEBUGGING_SWITCH fflush(stderr); @@ -379,8 +311,8 @@ __device__ consteval auto& policy_generator() {{ ctk_path, "-rdc=true", "-dlto", + "-default-device", "-DCUB_DISABLE_CDP", - "-DCUB_ENABLE_POLICY_PTX_JSON", "-std=c++20"}; cccl::detail::extend_args_with_build_config(args, config); @@ -434,43 +366,13 @@ __device__ consteval auto& policy_generator() {{ &build_ptr->exclusive_sum_kernel, build_ptr->library, exclusive_sum_kernel_lowered_name.c_str())); check(cuLibraryGetKernel(&build_ptr->onesweep_kernel, build_ptr->library, onesweep_kernel_lowered_name.c_str())); - nlohmann::json runtime_policy = - cub::detail::ptx_json::parse("device_radix_sort_policy", {result.data.get(), result.size}); - - using namespace cub::detail::radix_sort_runtime_policies; - using cub::detail::RuntimeScanAgentPolicy; - auto single_tile_policy = - cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "SingleTilePolicy"); - auto onesweep_policy = RuntimeRadixSortOnesweepAgentPolicy::from_json(runtime_policy, "OnesweepPolicy"); - auto upsweep_policy = RuntimeRadixSortUpsweepAgentPolicy::from_json(runtime_policy, "UpsweepPolicy"); - auto alt_upsweep_policy = RuntimeRadixSortUpsweepAgentPolicy::from_json(runtime_policy, "AltUpsweepPolicy"); - auto downsweep_policy = - cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "DownsweepPolicy"); - auto alt_downsweep_policy = - cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "AltDownsweepPolicy"); - auto histogram_policy = RuntimeRadixSortHistogramAgentPolicy::from_json(runtime_policy, "HistogramPolicy"); - auto exclusive_sum_policy = - RuntimeRadixSortExclusiveSumAgentPolicy::from_json(runtime_policy, "ExclusiveSumPolicy"); - auto scan_policy = RuntimeScanAgentPolicy::from_json(runtime_policy, "ScanPolicy"); - auto is_onesweep = runtime_policy["Onesweep"].get(); - - build_ptr->cc = cc; + build_ptr->cc = cc_major * 10 + cc_minor; build_ptr->cubin = (void*) result.data.release(); build_ptr->cubin_size = result.size; build_ptr->key_type = input_keys_it.value_type; build_ptr->value_type = input_values_it.value_type; build_ptr->order = sort_order; - build_ptr->runtime_policy = new radix_sort::radix_sort_runtime_tuning_policy{ - histogram_policy, - exclusive_sum_policy, - onesweep_policy, - scan_policy, - downsweep_policy, - alt_downsweep_policy, - upsweep_policy, - alt_upsweep_policy, - single_tile_policy, - is_onesweep}; + build_ptr->runtime_policy = new cub::detail::radix_sort::policy_selector{policy_sel}; } catch (const std::exception& exc) { @@ -529,29 +431,20 @@ CUresult cccl_device_radix_sort_impl( cub::DoubleBuffer d_values_buffer( *static_cast(&val_arg_in), *static_cast(&val_arg_out)); - auto exec_status = cub::DispatchRadixSort< - Order, - indirect_arg_t, - indirect_arg_t, - OffsetT, - indirect_arg_t, - radix_sort::radix_sort_runtime_tuning_policy, - radix_sort::radix_sort_kernel_source, - cub::detail::CudaDriverLauncherFactory>:: - Dispatch( - d_temp_storage, - *temp_storage_bytes, - d_keys_buffer, - d_values_buffer, - num_items, - begin_bit, - end_bit, - is_overwrite_okay, - stream, - decomposer, - {build}, - cub::detail::CudaDriverLauncherFactory{cu_device, build.cc}, - *reinterpret_cast(build.runtime_policy)); + auto exec_status = cub::detail::radix_sort::dispatch( + d_temp_storage, + *temp_storage_bytes, + d_keys_buffer, + d_values_buffer, + num_items, + begin_bit, + end_bit, + is_overwrite_okay, + stream, + decomposer, + *static_cast(build.runtime_policy), + radix_sort::radix_sort_kernel_source{build}, + cub::detail::CudaDriverLauncherFactory{cu_device, build.cc}); *selector = d_keys_buffer.selector; error = static_cast(exec_status); @@ -649,8 +542,9 @@ CUresult cccl_device_radix_sort_cleanup(cccl_device_radix_sort_build_result_t* b return CUDA_ERROR_INVALID_VALUE; } + using namespace cub::detail::radix_sort; std::unique_ptr cubin(reinterpret_cast(build_ptr->cubin)); - std::unique_ptr runtime_policy(reinterpret_cast(build_ptr->runtime_policy)); + std::unique_ptr policy(static_cast(build_ptr->runtime_policy)); check(cuLibraryUnload(build_ptr->library)); } catch (const std::exception& exc) diff --git a/c/parallel/src/segmented_sort.cu b/c/parallel/src/segmented_sort.cu index d21309610bd..027ede3c221 100644 --- a/c/parallel/src/segmented_sort.cu +++ b/c/parallel/src/segmented_sort.cu @@ -582,6 +582,7 @@ CUresult cccl_device_segmented_sort_build_ex( ctk_path, "-rdc=true", "-dlto", + "-default-device", "-DCUB_DISABLE_CDP", "-std=c++20"}; @@ -696,6 +697,7 @@ __device__ consteval auto& three_way_partition_policy_generator() {{ ctk_path, "-rdc=true", "-dlto", + "-default-device", "-DCUB_DISABLE_CDP", "-DCUB_ENABLE_POLICY_PTX_JSON", "-std=c++20"}; diff --git a/cub/benchmarks/bench/radix_sort/keys.cu b/cub/benchmarks/bench/radix_sort/keys.cu index afd58a1acbe..f3d95d1642a 100644 --- a/cub/benchmarks/bench/radix_sort/keys.cu +++ b/cub/benchmarks/bench/radix_sort/keys.cu @@ -1,12 +1,6 @@ // SPDX-FileCopyrightText: Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. // SPDX-License-Identifier: BSD-3 -#include -#include - -#include -#include - #include // %//RANGE//% TUNE_RADIX_BITS bits 8:9:1 @@ -15,119 +9,22 @@ // %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1 // %RANGE% TUNE_THREADS_PER_BLOCK tpb 128:1024:32 -using value_t = cub::NullType; - -constexpr cub::SortOrder sort_order = cub::SortOrder::Ascending; -constexpr bool is_overwrite_ok = false; - -#if !TUNE_BASE -template -struct policy_hub_t -{ - static constexpr bool KEYS_ONLY = std::is_same_v; - - using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>; - - struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t> - { - static constexpr int ONESWEEP_RADIX_BITS = TUNE_RADIX_BITS; - static constexpr bool ONESWEEP = true; - static constexpr bool OFFSET_64BIT = sizeof(OffsetT) == 8; - - // Onesweep policy - using OnesweepPolicy = cub::AgentRadixSortOnesweepPolicy< - TUNE_THREADS_PER_BLOCK, - TUNE_ITEMS_PER_THREAD, - DominantT, - 1, - cub::RADIX_RANK_MATCH_EARLY_COUNTS_ANY, - cub::BLOCK_SCAN_RAKING_MEMOIZE, - cub::RADIX_SORT_STORE_DIRECT, - ONESWEEP_RADIX_BITS>; - - // These kernels are launched once, no point in tuning at the moment - using HistogramPolicy = cub::AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>; - using ExclusiveSumPolicy = cub::AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>; - using ScanPolicy = - cub::AgentScanPolicy<512, - 23, - OffsetT, - cub::BLOCK_LOAD_WARP_TRANSPOSE, - cub::LOAD_DEFAULT, - cub::BLOCK_STORE_WARP_TRANSPOSE, - cub::BLOCK_SCAN_RAKING_MEMOIZE>; - - // No point in tuning - static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5; - - // No point in tuning single-tile policy - using SingleTilePolicy = cub::AgentRadixSortDownsweepPolicy< - 256, - 19, - DominantT, - cub::BLOCK_LOAD_DIRECT, - cub::LOAD_LDG, - cub::RADIX_RANK_MEMOIZE, - cub::BLOCK_SCAN_WARP_SCANS, - SINGLE_TILE_RADIX_BITS>; - }; - - using MaxPolicy = policy_t; -}; - -template -constexpr std::size_t max_onesweep_temp_storage_size() -{ - using portion_offset = int; - using onesweep_policy = typename policy_hub_t::policy_t::OnesweepPolicy; - using agent_radix_sort_onesweep_t = - cub::AgentRadixSortOnesweep; - - using hist_policy = typename policy_hub_t::policy_t::HistogramPolicy; - using hist_agent = cub::AgentRadixSortHistogram; - - return (::cuda::std::max) (sizeof(typename agent_radix_sort_onesweep_t::TempStorage), - sizeof(typename hist_agent::TempStorage)); -} - -template -constexpr std::size_t max_temp_storage_size() -{ - using policy_t = typename policy_hub_t::policy_t; - - static_assert(policy_t::ONESWEEP); - return max_onesweep_temp_storage_size(); -} - -template -constexpr bool fits_in_default_shared_memory() -{ - return max_temp_storage_size() < cub::detail::max_smem_per_block; -} -#else // TUNE_BASE -template -constexpr bool fits_in_default_shared_memory() -{ - return true; -} -#endif // TUNE_BASE +#include "policy_selector.h" template -void radix_sort_keys(std::integral_constant, nvbench::state& state, nvbench::type_list) +void radix_sort_keys(nvbench::state& state, nvbench::type_list) { using offset_t = cub::detail::choose_offset_t; - using key_t = T; - using dispatch_t = - cub::DispatchRadixSort -#endif // TUNE_BASE - >; + constexpr cub::SortOrder sort_order = cub::SortOrder::Ascending; + constexpr bool is_overwrite_ok = false; + using key_t = T; + using value_t = cub::NullType; + + if constexpr (!fits_in_default_shared_memory()) + { + return; + } constexpr int begin_bit = 0; constexpr int end_bit = sizeof(key_t) * 8; @@ -153,7 +50,7 @@ void radix_sort_keys(std::integral_constant, nvbench::state& state, // Allocate temporary storage: std::size_t temp_size{}; - dispatch_t::Dispatch( + cub::detail::radix_sort::dispatch( nullptr, temp_size, d_keys, @@ -162,16 +59,22 @@ void radix_sort_keys(std::integral_constant, nvbench::state& state, begin_bit, end_bit, is_overwrite_ok, - 0 /* stream */); + 0 /* stream */ +#if !TUNE_BASE + , + cub::detail::identity_decomposer_t{}, + policy_selector{} +#endif // !TUNE_BASE + ); - thrust::device_vector temp(temp_size); + thrust::device_vector temp(temp_size, thrust::no_init); auto* temp_storage = thrust::raw_pointer_cast(temp.data()); state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { cub::DoubleBuffer keys = d_keys; cub::DoubleBuffer values = d_values; - dispatch_t::Dispatch( + cub::detail::radix_sort::dispatch( temp_storage, temp_size, keys, @@ -180,25 +83,16 @@ void radix_sort_keys(std::integral_constant, nvbench::state& state, begin_bit, end_bit, is_overwrite_ok, - launch.get_stream()); + launch.get_stream() +#if !TUNE_BASE + , + cub::detail::identity_decomposer_t{}, + policy_selector{} +#endif // !TUNE_BASE + ); }); } -template -void radix_sort_keys(std::integral_constant, nvbench::state&, nvbench::type_list) -{ - (void) sort_order; - (void) is_overwrite_ok; -} - -template -void radix_sort_keys(nvbench::state& state, nvbench::type_list tl) -{ - using offset_t = cub::detail::choose_offset_t; - - radix_sort_keys(std::integral_constant()>{}, state, tl); -} - NVBENCH_BENCH_TYPES(radix_sort_keys, NVBENCH_TYPE_AXES(fundamental_types, offset_types)) .set_name("base") .set_type_axes_names({"T{ct}", "OffsetT{ct}"}) diff --git a/cub/benchmarks/bench/radix_sort/pairs.cu b/cub/benchmarks/bench/radix_sort/pairs.cu index 3c535a307fb..74d3af8bfe3 100644 --- a/cub/benchmarks/bench/radix_sort/pairs.cu +++ b/cub/benchmarks/bench/radix_sort/pairs.cu @@ -2,10 +2,6 @@ // SPDX-License-Identifier: BSD-3 #include -#include - -#include -#include #include @@ -15,119 +11,22 @@ // %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1 // %RANGE% TUNE_THREADS_PER_BLOCK tpb 128:1024:32 -constexpr cub::SortOrder sort_order = cub::SortOrder::Ascending; -constexpr bool is_overwrite_ok = false; +#include "policy_selector.h" -#if !TUNE_BASE template -struct policy_hub_t +void radix_sort_values(nvbench::state& state, nvbench::type_list) { - static constexpr bool KEYS_ONLY = std::is_same_v; + using offset_t = cub::detail::choose_offset_t; - using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>; + constexpr cub::SortOrder sort_order = cub::SortOrder::Ascending; + constexpr bool is_overwrite_ok = false; + using key_t = KeyT; + using value_t = ValueT; - struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t> + if constexpr (!fits_in_default_shared_memory()) { - static constexpr int ONESWEEP_RADIX_BITS = TUNE_RADIX_BITS; - static constexpr bool ONESWEEP = true; - static constexpr bool OFFSET_64BIT = sizeof(OffsetT) == 8; - - // Onesweep policy - using OnesweepPolicy = cub::AgentRadixSortOnesweepPolicy< - TUNE_THREADS_PER_BLOCK, - TUNE_ITEMS_PER_THREAD, - DominantT, - 1, - cub::RADIX_RANK_MATCH_EARLY_COUNTS_ANY, - cub::BLOCK_SCAN_RAKING_MEMOIZE, - cub::RADIX_SORT_STORE_DIRECT, - ONESWEEP_RADIX_BITS>; - - // These kernels are launched once, no point in tuning at the moment - using HistogramPolicy = cub::AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>; - using ExclusiveSumPolicy = cub::AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>; - using ScanPolicy = - cub::AgentScanPolicy<512, - 23, - OffsetT, - cub::BLOCK_LOAD_WARP_TRANSPOSE, - cub::LOAD_DEFAULT, - cub::BLOCK_STORE_WARP_TRANSPOSE, - cub::BLOCK_SCAN_RAKING_MEMOIZE>; - - // No point in tuning - static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5; - - // No point in tuning single-tile policy - using SingleTilePolicy = cub::AgentRadixSortDownsweepPolicy< - 256, - 19, - DominantT, - cub::BLOCK_LOAD_DIRECT, - cub::LOAD_LDG, - cub::RADIX_RANK_MEMOIZE, - cub::BLOCK_SCAN_WARP_SCANS, - SINGLE_TILE_RADIX_BITS>; - }; - - using MaxPolicy = policy_t; -}; - -template -constexpr std::size_t max_onesweep_temp_storage_size() -{ - using portion_offset = int; - using onesweep_policy = typename policy_hub_t::policy_t::OnesweepPolicy; - using agent_radix_sort_onesweep_t = - cub::AgentRadixSortOnesweep; - - using hist_policy = typename policy_hub_t::policy_t::HistogramPolicy; - using hist_agent = cub::AgentRadixSortHistogram; - - return (::cuda::std::max) (sizeof(typename agent_radix_sort_onesweep_t::TempStorage), - sizeof(typename hist_agent::TempStorage)); -} - -template -constexpr std::size_t max_temp_storage_size() -{ - using policy_t = typename policy_hub_t::policy_t; - - static_assert(policy_t::ONESWEEP); - return max_onesweep_temp_storage_size(); -} - -template -constexpr bool fits_in_default_shared_memory() -{ - return max_temp_storage_size() < cub::detail::max_smem_per_block; -} -#else // TUNE_BASE -template -constexpr bool fits_in_default_shared_memory() -{ - return true; -} -#endif // TUNE_BASE - -template -void radix_sort_values( - std::integral_constant, nvbench::state& state, nvbench::type_list) -{ - using offset_t = cub::detail::choose_offset_t; - - using key_t = KeyT; - using value_t = ValueT; - using dispatch_t = - cub::DispatchRadixSort -#endif // TUNE_BASE - >; + return; + } constexpr int begin_bit = 0; constexpr int end_bit = sizeof(key_t) * 8; @@ -158,7 +57,7 @@ void radix_sort_values( // Allocate temporary storage: std::size_t temp_size{}; - dispatch_t::Dispatch( + cub::detail::radix_sort::dispatch( nullptr, temp_size, d_keys, @@ -167,16 +66,22 @@ void radix_sort_values( begin_bit, end_bit, is_overwrite_ok, - 0 /* stream */); + 0 /* stream */ +#if !TUNE_BASE + , + cub::detail::identity_decomposer_t{}, + policy_selector{} +#endif // !TUNE_BASE + ); - thrust::device_vector temp(temp_size); + thrust::device_vector temp(temp_size, thrust::no_init); auto* temp_storage = thrust::raw_pointer_cast(temp.data()); state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { cub::DoubleBuffer keys = d_keys; cub::DoubleBuffer values = d_values; - dispatch_t::Dispatch( + cub::detail::radix_sort::dispatch( temp_storage, temp_size, keys, @@ -185,25 +90,16 @@ void radix_sort_values( begin_bit, end_bit, is_overwrite_ok, - launch.get_stream()); + launch.get_stream() +#if !TUNE_BASE + , + cub::detail::identity_decomposer_t{}, + policy_selector{} +#endif // !TUNE_BASE + ); }); } -template -void radix_sort_values(std::integral_constant, nvbench::state&, nvbench::type_list) -{ - (void) sort_order; - (void) is_overwrite_ok; -} - -template -void radix_sort_values(nvbench::state& state, nvbench::type_list tl) -{ - using offset_t = cub::detail::choose_offset_t; - - radix_sort_values(std::integral_constant()>{}, state, tl); -} - #ifdef TUNE_KeyT using key_types = nvbench::type_list; #else // !defined(TUNE_KeyT) diff --git a/cub/benchmarks/bench/radix_sort/policy_selector.h b/cub/benchmarks/bench/radix_sort/policy_selector.h new file mode 100644 index 00000000000..1fec3c4a773 --- /dev/null +++ b/cub/benchmarks/bench/radix_sort/policy_selector.h @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. +// SPDX-License-Identifier: BSD-3 + +#include + +// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1 +#define TUNE_RADIX_BITS 8 + +// %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1 +// %RANGE% TUNE_THREADS_PER_BLOCK tpb 128:1024:32 + +#if !TUNE_BASE +template +struct policy_selector +{ + using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>; + + _CCCL_API constexpr auto operator()(cuda::arch_id) const -> ::cub::detail::radix_sort::radix_sort_policy + { + const auto onesweep = [] { + const auto scaled = + cub::detail::scale_reg_bound(TUNE_THREADS_PER_BLOCK, TUNE_ITEMS_PER_THREAD, sizeof(DominantT)); + return radix_sort_onesweep_policy{ + scaled.block_threads, + scaled.items_per_thread, + 1, + ONESWEEP_RADIX_BITS, + cub::RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + cub::BLOCK_SCAN_RAKING_MEMOIZE, + cub::RADIX_SORT_STORE_DIRECT}; + }(); + + // These kernels are launched once, no point in tuning at the moment + const auto histogram = radix_sort_histogram_policy{ + 128, 16, cub::detail::radix_sort::__scale_num_parts(1, sizeof(KeyT)), ONESWEEP_RADIX_BITS}; + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, ONESWEEP_RADIX_BITS}; + + const auto scan = [] { + const auto scaled = cub::detail::scale_mem_bound(512, 23, sizeof(OffsetT)); + return scan{scaled.block_threads, + scaled.items_per_thread, + cub::BLOCK_LOAD_WARP_TRANSPOSE, + cub::LOAD_DEFAULT, + cub::BLOCK_STORE_WARP_TRANSPOSE, + cub::BLOCK_SCAN_RAKING_MEMOIZE}; + }(); + + // No point in tuning + const int single_tile_radix_bits = (sizeof(KeyT) > 1) ? 6 : 5; + + // No point in tuning single-tile policy + const auto single_tile = [] { + const auto scaled = cub::detail::scale_reg_bound(256, 19, sizeof(DominantT)); + return cub::detail::radix_sort::radix_sort_downsweep_policy{ + scaled.block_threads, + scaled.items_per_thread, + single_tile_radix_bits, + cub::BLOCK_LOAD_DIRECT, + cub::LOAD_LDG, + cub::RADIX_RANK_MEMOIZE, + cub::BLOCK_SCAN_WARP_SCANS, + }; + }(); + + return radix_sort_policy{ + /* use_onesweep */ true, + /* onesweep_radix_bits */ TUNE_RADIX_BITS, + histogram, + exclusive_sum, + onesweep, + scan, + /* downsweep */ {}, + /* alt_downsweep */ {}, + /* upsweep */ {}, + /* alt_upsweep */ {}, + single_tile, + /* segmented not used */ {}, + /* alt_segmented not used */ {}}; + } +}; + +template +constexpr std::size_t max_onesweep_temp_storage_size() +{ + using portion_offset = int; + using onesweep_policy = typename policy_hub_t::policy_t::OnesweepPolicy; + using agent_radix_sort_onesweep_t = + cub::AgentRadixSortOnesweep; + + using hist_policy = typename policy_hub_t::policy_t::HistogramPolicy; + using hist_agent = cub::AgentRadixSortHistogram; + + return (::cuda::std::max) (sizeof(typename agent_radix_sort_onesweep_t::TempStorage), + sizeof(typename hist_agent::TempStorage)); +} + +template +constexpr std::size_t max_temp_storage_size() +{ + using policy_t = typename policy_hub_t::policy_t; + + static_assert(policy_t::ONESWEEP); + return max_onesweep_temp_storage_size(); +} + +template +constexpr bool fits_in_default_shared_memory() +{ + return max_temp_storage_size() < cub::detail::max_smem_per_block; +} +#else // TUNE_BASE +template +constexpr bool fits_in_default_shared_memory() +{ + return true; +} +#endif // TUNE_BASE diff --git a/cub/cub/agent/agent_radix_sort_histogram.cuh b/cub/cub/agent/agent_radix_sort_histogram.cuh index bd42742ce76..1c50370f991 100644 --- a/cub/cub/agent/agent_radix_sort_histogram.cuh +++ b/cub/cub/agent/agent_radix_sort_histogram.cuh @@ -31,21 +31,38 @@ #include #include #include +#include CUB_NAMESPACE_BEGIN +//! @param ComputeT If void, use NOMINAL_4B_NUM_PARTS directly for NUM_PARTS. Otherwise, perform scaling. template struct AgentRadixSortHistogramPolicy { static constexpr int BLOCK_THREADS = BlockThreads; static constexpr int ITEMS_PER_THREAD = ItemsPerThread; + + // need to discard sizeof(ComputeType) in case it's void + template + _CCCL_API static constexpr int num_parts_helper() + { + if constexpr (::cuda::std::is_void_v) + { + return NOMINAL_4B_NUM_PARTS; + } + else + { + return ::cuda::std::max(1, NOMINAL_4B_NUM_PARTS * 4 / ::cuda::std::max(int{sizeof(ComputeType)}, 4)); + } + } + /** NUM_PARTS is the number of private histograms (parts) each histogram is split * into. Each warp lane is assigned to a specific part based on the lane * ID. However, lanes with the same ID in different warp use the same private * histogram. This arrangement helps reduce the degree of conflicts in atomic * operations. */ - static constexpr int NUM_PARTS = - ::cuda::std::max(1, NOMINAL_4B_NUM_PARTS * 4 / ::cuda::std::max(int{sizeof(ComputeT)}, 4)); + static constexpr int NUM_PARTS = num_parts_helper(); + static constexpr int RADIX_BITS = RadixBits; }; diff --git a/cub/cub/agent/agent_radix_sort_upsweep.cuh b/cub/cub/agent/agent_radix_sort_upsweep.cuh index 68cd0f3f173..76ceaea1437 100644 --- a/cub/cub/agent/agent_radix_sort_upsweep.cuh +++ b/cub/cub/agent/agent_radix_sort_upsweep.cuh @@ -34,6 +34,7 @@ #endif #include +#include #include #include @@ -205,35 +206,6 @@ struct AgentRadixSortUpsweep int num_bits; DecomposerT decomposer; - //--------------------------------------------------------------------- - // Helper structure for templated iteration - //--------------------------------------------------------------------- - - // Iterate - template - struct Iterate - { - // BucketKeys - static _CCCL_DEVICE _CCCL_FORCEINLINE void - BucketKeys(AgentRadixSortUpsweep& cta, bit_ordered_type keys[KEYS_PER_THREAD]) - { - cta.Bucket(keys[COUNT]); - - // Next - Iterate::BucketKeys(cta, keys); - } - }; - - // Terminate - template - struct Iterate - { - // BucketKeys - static _CCCL_DEVICE _CCCL_FORCEINLINE void - BucketKeys(AgentRadixSortUpsweep& /*cta*/, bit_ordered_type /*keys*/[KEYS_PER_THREAD]) - {} - }; - //--------------------------------------------------------------------- // Utility methods //--------------------------------------------------------------------- @@ -258,6 +230,7 @@ struct AgentRadixSortUpsweep // Get row offset uint32_t row_offset = digit >> LOG_PACKING_RATIO; + _CCCL_ASSERT(row_offset < COUNTER_LANES, ""); // Increment counter temp_storage.thread_counters[row_offset][threadIdx.x][sub_counter]++; @@ -334,7 +307,9 @@ struct AgentRadixSortUpsweep __syncthreads(); // Bucket tile of keys - Iterate<0, KEYS_PER_THREAD>::BucketKeys(*this, keys); + cuda::static_for([&](auto ic) { + Bucket(keys[ic]); + }); } /** diff --git a/cub/cub/device/device_radix_sort.cuh b/cub/cub/device/device_radix_sort.cuh index da28decec62..cad0db90db1 100644 --- a/cub/cub/device/device_radix_sort.cuh +++ b/cub/cub/device/device_radix_sort.cuh @@ -120,23 +120,8 @@ CUB_NAMESPACE_BEGIN struct DeviceRadixSort { private: - template - CUB_RUNTIME_FUNCTION static cudaError_t custom_radix_sort( - ::cuda::std::false_type, - void* d_temp_storage, - size_t& temp_storage_bytes, - bool is_overwrite_okay, - DoubleBuffer& d_keys, - DoubleBuffer& d_values, - NumItemsT num_items, - DecomposerT decomposer, - int begin_bit, - int end_bit, - cudaStream_t stream); - template CUB_RUNTIME_FUNCTION static cudaError_t custom_radix_sort( - ::cuda::std::true_type, void* d_temp_storage, size_t& temp_storage_bytes, bool is_overwrite_okay, @@ -148,7 +133,7 @@ private: int end_bit, cudaStream_t stream) { - return DispatchRadixSort::Dispatch( + return detail::radix_sort::dispatch( d_temp_storage, temp_storage_bytes, d_keys, @@ -161,21 +146,8 @@ private: decomposer); } - template - CUB_RUNTIME_FUNCTION static cudaError_t custom_radix_sort( - ::cuda::std::false_type, - void* d_temp_storage, - size_t& temp_storage_bytes, - bool is_overwrite_okay, - DoubleBuffer& d_keys, - DoubleBuffer& d_values, - NumItemsT num_items, - DecomposerT decomposer, - cudaStream_t stream); - template CUB_RUNTIME_FUNCTION static cudaError_t custom_radix_sort( - ::cuda::std::true_type, void* d_temp_storage, size_t& temp_storage_bytes, bool is_overwrite_okay, @@ -188,18 +160,17 @@ private: constexpr int begin_bit = 0; const int end_bit = detail::radix::traits_t::default_end_bit(decomposer); - return DeviceRadixSort::custom_radix_sort( - ::cuda::std::true_type{}, + return detail::radix_sort::dispatch( d_temp_storage, temp_storage_bytes, - is_overwrite_okay, d_keys, d_values, num_items, - decomposer, begin_bit, end_bit, - stream); + is_overwrite_okay, + stream, + decomposer); } // Name reported for NVTX ranges @@ -341,7 +312,7 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DispatchRadixSort::Dispatch( + return detail::radix_sort::dispatch( d_temp_storage, temp_storage_bytes, d_keys, @@ -490,18 +461,20 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @rst @@ -628,16 +601,18 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -767,8 +742,16 @@ public: constexpr bool is_overwrite_okay = true; - return DispatchRadixSort::Dispatch( - d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, is_overwrite_okay, stream); + return detail::radix_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream); } //! @rst @@ -891,16 +874,18 @@ public: constexpr bool is_overwrite_okay = true; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -1036,18 +1021,20 @@ public: constexpr bool is_overwrite_okay = true; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @rst @@ -1180,8 +1167,16 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DispatchRadixSort::Dispatch( - d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, is_overwrite_okay, stream); + return detail::radix_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream); } //! @rst @@ -1323,18 +1318,20 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @rst @@ -1463,16 +1460,18 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -1602,8 +1601,16 @@ public: constexpr bool is_overwrite_okay = true; - return DispatchRadixSort::Dispatch( - d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, is_overwrite_okay, stream); + return detail::radix_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream); } //! @rst @@ -1727,16 +1734,18 @@ public: constexpr bool is_overwrite_okay = true; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -1873,18 +1882,20 @@ public: constexpr bool is_overwrite_okay = true; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @} @@ -2006,7 +2017,7 @@ public: // Null value type DoubleBuffer d_values; - return DispatchRadixSort::Dispatch( + return detail::radix_sort::dispatch( d_temp_storage, temp_storage_bytes, d_keys, @@ -2142,18 +2153,20 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @rst @@ -2270,16 +2283,18 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -2395,8 +2410,16 @@ public: // Null value type DoubleBuffer d_values; - return DispatchRadixSort::Dispatch( - d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, is_overwrite_okay, stream); + return detail::radix_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream); } //! @rst @@ -2507,16 +2530,18 @@ public: constexpr bool is_overwrite_okay = true; DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -2640,18 +2665,20 @@ public: constexpr bool is_overwrite_okay = true; DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @rst Sorts keys into descending order using :math:`\approx 2N` auxiliary storage. @@ -2767,8 +2794,16 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return DispatchRadixSort::Dispatch( - d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, is_overwrite_okay, stream); + return detail::radix_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream); } //! @rst @@ -2896,18 +2931,20 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @rst @@ -3022,16 +3059,18 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -3146,8 +3185,16 @@ public: // Null value type DoubleBuffer d_values; - return DispatchRadixSort::Dispatch( - d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, begin_bit, end_bit, is_overwrite_okay, stream); + return detail::radix_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream); } //! @rst @@ -3259,16 +3306,18 @@ public: constexpr bool is_overwrite_okay = true; DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + stream); + } } //! @rst @@ -3393,18 +3442,20 @@ public: constexpr bool is_overwrite_okay = true; DoubleBuffer d_values; - return DeviceRadixSort::custom_radix_sort( - decomposer_check_t{}, - d_temp_storage, - temp_storage_bytes, - is_overwrite_okay, - d_keys, - d_values, - static_cast(num_items), - decomposer, - begin_bit, - end_bit, - stream); + if constexpr (decomposer_check_t::value) + { + return DeviceRadixSort::custom_radix_sort( + d_temp_storage, + temp_storage_bytes, + is_overwrite_okay, + d_keys, + d_values, + static_cast(num_items), + decomposer, + begin_bit, + end_bit, + stream); + } } //! @} diff --git a/cub/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/cub/device/dispatch/dispatch_radix_sort.cuh index 132f6ef6dc5..8f9bb4fa698 100644 --- a/cub/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_radix_sort.cuh @@ -20,6 +20,7 @@ # pragma system_header #endif // no system header +#include #include #include #include @@ -36,6 +37,10 @@ // TODO(bgruber): included for backward compatibility, remove in CCCL 4.0 #include +#if !_CCCL_COMPILER(NVRTC) && defined(CUB_DEBUG_LOG) +# include +#endif + // suppress warnings triggered by #pragma unroll: // "warning: loop not unrolled: the optimizer was unable to perform the requested transformation; the transformation // might be disabled or specified as part of an unsupported transformation ordering [-Wpass-failed=transform-warning]" @@ -46,34 +51,39 @@ CUB_NAMESPACE_BEGIN namespace detail::radix_sort { -template +template struct DeviceRadixSortKernelSource { + // PolicySelector must be stateless, so we can pass the type to the kernel + static_assert(::cuda::std::is_empty_v); + CUB_DEFINE_KERNEL_GETTER(RadixSortSingleTileKernel, - DeviceRadixSortSingleTileKernel); + DeviceRadixSortSingleTileKernel); CUB_DEFINE_KERNEL_GETTER(RadixSortUpsweepKernel, - DeviceRadixSortUpsweepKernel); + DeviceRadixSortUpsweepKernel); CUB_DEFINE_KERNEL_GETTER(RadixSortAltUpsweepKernel, - DeviceRadixSortUpsweepKernel); + DeviceRadixSortUpsweepKernel); - CUB_DEFINE_KERNEL_GETTER(DeviceRadixSortScanBinsKernel, RadixSortScanBinsKernel); + CUB_DEFINE_KERNEL_GETTER(DeviceRadixSortScanBinsKernel, RadixSortScanBinsKernel); - CUB_DEFINE_KERNEL_GETTER(RadixSortDownsweepKernel, - DeviceRadixSortDownsweepKernel); + CUB_DEFINE_KERNEL_GETTER( + RadixSortDownsweepKernel, + DeviceRadixSortDownsweepKernel); - CUB_DEFINE_KERNEL_GETTER(RadixSortAltDownsweepKernel, - DeviceRadixSortDownsweepKernel); + CUB_DEFINE_KERNEL_GETTER( + RadixSortAltDownsweepKernel, + DeviceRadixSortDownsweepKernel); CUB_DEFINE_KERNEL_GETTER(RadixSortHistogramKernel, - DeviceRadixSortHistogramKernel); + DeviceRadixSortHistogramKernel); - CUB_DEFINE_KERNEL_GETTER(RadixSortExclusiveSumKernel, DeviceRadixSortExclusiveSumKernel); + CUB_DEFINE_KERNEL_GETTER(RadixSortExclusiveSumKernel, DeviceRadixSortExclusiveSumKernel); CUB_DEFINE_KERNEL_GETTER( RadixSortOnesweepKernel, - DeviceRadixSortOnesweepKernel); + DeviceRadixSortOnesweepKernel); CUB_RUNTIME_FUNCTION static constexpr size_t KeySize() { @@ -85,6 +95,113 @@ struct DeviceRadixSortKernelSource return sizeof(ValueT); } }; + +// TODO(bgruber): remove in CCCL 4.0 when we drop the radix sort dispatcher after publishing the tuning API +template +_CCCL_API constexpr auto convert_policy() -> radix_sort_policy +{ + using active_policy = LegacyActivePolicy; + + auto convert_downsweep_policy = [](auto p) { + // MSVC will error if we put a [[no_discard]] on the parameter p above: + // C2187: syntax error: 'attribute specifier' was unexpected here + (void) p; + using p_t = decltype(p); + return radix_sort_downsweep_policy{ + p_t::BLOCK_THREADS, + p_t::ITEMS_PER_THREAD, + p_t::RADIX_BITS, + p_t::LOAD_ALGORITHM, + p_t::LOAD_MODIFIER, + p_t::RANK_ALGORITHM, + p_t::SCAN_ALGORITHM}; + }; + + const auto histogram = [] { + using p = typename active_policy::HistogramPolicy; + return radix_sort_histogram_policy{p::BLOCK_THREADS, p::ITEMS_PER_THREAD, p::NUM_PARTS, p::RADIX_BITS}; + }(); + + const auto exclusive_sum = [] { + using p = typename active_policy::ExclusiveSumPolicy; + return radix_sort_exclusive_sum_policy{p::BLOCK_THREADS, p::RADIX_BITS}; + }(); + + const auto onesweep = [] { + using p = typename active_policy::OnesweepPolicy; + return radix_sort_onesweep_policy{ + p::BLOCK_THREADS, + p::ITEMS_PER_THREAD, + p::RANK_NUM_PARTS, + p::RADIX_BITS, + p::RANK_ALGORITHM, + p::SCAN_ALGORITHM, + p::STORE_ALGORITHM}; + }(); + + const auto scan = [] { + using p = typename active_policy::ScanPolicy; + return scan_policy{ + p::BLOCK_THREADS, + p::ITEMS_PER_THREAD, + p::LOAD_ALGORITHM, + p::LOAD_MODIFIER, + p::STORE_ALGORITHM, + p::SCAN_ALGORITHM, + delay_constructor_policy_from_type}; + }(); + + const auto downsweep = convert_downsweep_policy(typename active_policy::DownsweepPolicy{}); + const auto alt_downsweep = convert_downsweep_policy(typename active_policy::AltDownsweepPolicy{}); + + const auto upsweep = [] { + using p = typename active_policy::UpsweepPolicy; + return radix_sort_upsweep_policy{p::BLOCK_THREADS, p::ITEMS_PER_THREAD, p::RADIX_BITS, p::LOAD_MODIFIER}; + }(); + + const auto alt_upsweep = [] { + using p = typename active_policy::AltUpsweepPolicy; + return radix_sort_upsweep_policy{p::BLOCK_THREADS, p::ITEMS_PER_THREAD, p::RADIX_BITS, p::LOAD_MODIFIER}; + }(); + + const auto single_tile = convert_downsweep_policy(typename active_policy::SingleTilePolicy{}); + const auto segmented = convert_downsweep_policy(typename active_policy::SegmentedPolicy{}); + const auto alt_segmented = convert_downsweep_policy(typename active_policy::AltSegmentedPolicy{}); + + return radix_sort_policy{ + active_policy::ONESWEEP, + active_policy::ONESWEEP_RADIX_BITS, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; +} + +// TODO(bgruber): remove in CCCL 4.0 when we drop the radix sort dispatcher after publishing the tuning API +template +CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE constexpr auto convert_policy(RadixSortPolicyWrapper policy) + -> radix_sort_policy +{ + return convert_policy(); +} + +// TODO(bgruber): remove in CCCL 4.0 when we drop the radix sort dispatcher after publishing the tuning API +template +struct policy_selector_from_hub +{ + // this is only called in device code + _CCCL_DEVICE_API constexpr auto operator()(::cuda::arch_id /*arch*/) const -> radix_sort_policy + { + return convert_policy(); + } +}; } // namespace detail::radix_sort /****************************************************************************** @@ -110,14 +227,20 @@ struct DeviceRadixSortKernelSource * Implementation detail, do not specify directly, requirements on the * content of this type are subject to breaking change. */ +// TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 template , - typename KernelSource = detail::radix_sort:: - DeviceRadixSortKernelSource, + typename KernelSource = detail::radix_sort::DeviceRadixSortKernelSource< + detail::radix_sort::policy_selector_from_hub, + Order, + KeyT, + ValueT, + OffsetT, + DecomposerT>, typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY> struct DispatchRadixSort { @@ -176,7 +299,7 @@ struct DispatchRadixSort // Constructor //------------------------------------------------------------------------------ - /// Constructor + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchRadixSort( void* d_temp_storage, size_t& temp_storage_bytes, @@ -222,9 +345,18 @@ struct DispatchRadixSort * @param[in] single_tile_kernel * Kernel function pointer to parameterization of cub::DeviceRadixSortSingleTileKernel */ + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 template CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t InvokeSingleTile(SingleTileKernelT single_tile_kernel, ActivePolicyT policy = {}) + { + return __invoke_single_tile(single_tile_kernel, detail::radix_sort::convert_policy(policy).single_tile); + } + +private: + template + CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t + __invoke_single_tile(SingleTileKernelT single_tile_kernel, detail::radix_sort::radix_sort_downsweep_policy policy) { // Return if the caller is simply requesting the size of the storage allocation if (d_temp_storage == nullptr) @@ -233,21 +365,21 @@ struct DispatchRadixSort return cudaSuccess; } -// Log single_tile_kernel configuration + // Log single_tile_kernel configuration #ifdef CUB_DEBUG_LOG _CubLog("Invoking single_tile_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy, current bit " "%d, bit_grain %d\n", 1, - policy.SingleTile().BlockThreads(), + policy.block_threads, (long long) stream, - policy.SingleTile().ItemsPerThread(), + policy.items_per_thread, 1, begin_bit, - policy.RadixBits(policy.SingleTile())); + policy.radix_bits); #endif // Invoke upsweep_kernel with same grid size as downsweep_kernel - launcher_factory(1, policy.SingleTile().BlockThreads(), 0, stream) + launcher_factory(1, policy.block_threads, 0, stream) .doit(single_tile_kernel, d_keys.Current(), d_keys.Alternate(), @@ -277,6 +409,7 @@ struct DispatchRadixSort return cudaSuccess; } +public: //------------------------------------------------------------------------------ // Normal problem size invocation //------------------------------------------------------------------------------ @@ -284,6 +417,7 @@ struct DispatchRadixSort /** * Invoke a three-kernel sorting pass at the current bit. */ + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t InvokePass( const KeyT* d_keys_in, @@ -403,6 +537,7 @@ struct DispatchRadixSort return cudaSuccess; } + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 /// Pass configuration structure template struct PassConfig @@ -418,6 +553,7 @@ struct DispatchRadixSort int max_downsweep_grid_size; GridEvenShare even_share; + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 /// Initialize pass configuration template CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t InitPassConfig( @@ -433,23 +569,53 @@ struct DispatchRadixSort DownsweepPolicyT downsweep_policy = {}, KernelLauncherFactory launcher_factory = {}) { - this->upsweep_kernel = upsweep_kernel; - this->scan_kernel = scan_kernel; - this->downsweep_kernel = downsweep_kernel; - radix_bits = policy.RadixBits(downsweep_policy); + // FIXME(bgruber): we should actually convert upsweep_policy, scan_policy, and downsweep_policy, since they could + // be different from those inside policy. But this is already so far out of any supported scenario that I am + // willing to cut this corner. + const auto p = detail::radix_sort::convert_policy(policy); + __init_pass_config( + upsweep_kernel, + scan_kernel, + downsweep_kernel, + sm_count, + num_items, + p.downsweep.radix_bits, + p.upsweep, + p.scan, + p.downsweep, + launcher_factory); + return __init_pass_config(p); + } + + CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t __init_pass_config( + UpsweepKernelT upsweep_kern, + ScanKernelT scan_kern, + DownsweepKernelT downsweep_kern, + int sm_count, + OffsetT num_items, + int pass_radix_bits, + detail::radix_sort::radix_sort_upsweep_policy upsweep_policy, + detail::radix_sort::scan_policy scan_policy, + detail::radix_sort::radix_sort_downsweep_policy downsweep_policy, + KernelLauncherFactory launcher_factory) + { + this->upsweep_kernel = upsweep_kern; + this->scan_kernel = scan_kern; + this->downsweep_kernel = downsweep_kern; + this->radix_bits = pass_radix_bits; radix_digits = 1 << radix_bits; - if (const auto error = CubDebug(upsweep_config.Init(upsweep_kernel, upsweep_policy, launcher_factory))) + if (const auto error = CubDebug(upsweep_config.__init(upsweep_kernel, upsweep_policy, launcher_factory))) { return error; } - if (const auto error = CubDebug(scan_config.Init(scan_kernel, scan_policy, launcher_factory))) + if (const auto error = CubDebug(scan_config.__init(scan_kernel, scan_policy, launcher_factory))) { return error; } - if (const auto error = CubDebug(downsweep_config.Init(downsweep_kernel, downsweep_policy, launcher_factory))) + if (const auto error = CubDebug(downsweep_config.__init(downsweep_kernel, downsweep_policy, launcher_factory))) { return error; } @@ -463,18 +629,25 @@ struct DispatchRadixSort } }; + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t InvokeOnesweep(ActivePolicyT policy = {}) + { + return __invoke_onesweep(detail::radix_sort::convert_policy(policy)); + } + +private: + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t __invoke_onesweep(detail::radix_sort::radix_sort_policy policy) { // PortionOffsetT is used for offsets within a portion, and must be signed. using PortionOffsetT = int; using AtomicOffsetT = PortionOffsetT; // compute temporary storage size - const int RADIX_BITS = policy.RadixBits(policy.Onesweep()); + const int RADIX_BITS = policy.onesweep.radix_bits; const int RADIX_DIGITS = 1 << RADIX_BITS; - const int ONESWEEP_ITEMS_PER_THREAD = policy.Onesweep().ItemsPerThread(); - const int ONESWEEP_BLOCK_THREADS = policy.Onesweep().BlockThreads(); + const int ONESWEEP_ITEMS_PER_THREAD = policy.onesweep.items_per_thread; + const int ONESWEEP_BLOCK_THREADS = policy.onesweep.block_threads; const int ONESWEEP_TILE_ITEMS = ONESWEEP_ITEMS_PER_THREAD * ONESWEEP_BLOCK_THREADS; // portions handle inputs with >=2**30 elements, due to the way lookback works // for testing purposes, one portion is <= 2**28 elements @@ -542,7 +715,7 @@ struct DispatchRadixSort return error; } - const int HISTO_BLOCK_THREADS = policy.Histogram().BlockThreads(); + const int HISTO_BLOCK_THREADS = policy.histogram.block_threads; int histo_blocks_per_sm = 1; auto histogram_kernel = kernel_source.RadixSortHistogramKernel(); @@ -559,9 +732,9 @@ struct DispatchRadixSort histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, reinterpret_cast(stream), - policy.Histogram().ItemsPerThread(), + policy.histogram.items_per_thread, histo_blocks_per_sm, - policy.RadixBits(policy.Histogram())); + policy.histogram.radix_bits); #endif if (const auto error = CubDebug( @@ -577,7 +750,7 @@ struct DispatchRadixSort } // exclusive sums to determine starts - const int SCAN_BLOCK_THREADS = policy.BlockThreads(policy.ExclusiveSum()); + const int SCAN_BLOCK_THREADS = policy.exclusive_sum.block_threads; // log exclusive_sum_kernel configuration #ifdef CUB_DEBUG_LOG @@ -585,7 +758,7 @@ struct DispatchRadixSort num_passes, SCAN_BLOCK_THREADS, reinterpret_cast(stream), - policy.RadixBits(policy.ExclusiveSum())); + policy.exclusive_sum.radix_bits); #endif if (const auto error = CubDebug(launcher_factory(num_passes, SCAN_BLOCK_THREADS, 0, stream) @@ -630,7 +803,7 @@ struct DispatchRadixSort num_blocks, ONESWEEP_BLOCK_THREADS, reinterpret_cast(stream), - policy.Onesweep().ItemsPerThread(), + policy.onesweep.items_per_thread, current_bit, num_bits, static_cast(portion), @@ -680,6 +853,7 @@ struct DispatchRadixSort return cudaSuccess; } +public: /** * @brief Invocation (run multiple digit passes) * @@ -711,6 +885,7 @@ struct DispatchRadixSort * Alternate kernel function pointer to parameterization of * cub::DeviceRadixSortDownsweepKernel */ + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 template CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t InvokePasses( UpsweepKernelT upsweep_kernel, @@ -719,6 +894,25 @@ struct DispatchRadixSort DownsweepKernelT downsweep_kernel, DownsweepKernelT alt_downsweep_kernel, ActivePolicyT policy = {}) + { + return __invoke_passes( + upsweep_kernel, + alt_downsweep_kernel, + scan_kernel, + downsweep_kernel, + alt_downsweep_kernel, + detail::radix_sort::convert_policy(policy)); + } + +private: + template + CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t __invoke_passes( + UpsweepKernelT upsweep_kernel, + UpsweepKernelT alt_upsweep_kernel, + ScanKernelT scan_kernel, + DownsweepKernelT downsweep_kernel, + DownsweepKernelT alt_downsweep_kernel, + const detail::radix_sort::radix_sort_policy& policy) { // Get device ordinal int device_ordinal; @@ -736,33 +930,31 @@ struct DispatchRadixSort // Init regular and alternate-digit kernel configurations PassConfig pass_config, alt_pass_config; - if (const auto error = pass_config.InitPassConfig( + if (const auto error = pass_config.__init_pass_config( upsweep_kernel, scan_kernel, downsweep_kernel, - ptx_version, sm_count, num_items, - policy, - policy.Upsweep(), - policy.Scan(), - policy.Downsweep(), + policy.downsweep.radix_bits, + policy.upsweep, + policy.scan, + policy.downsweep, launcher_factory)) { return error; } - if (const auto error = alt_pass_config.InitPassConfig( + if (const auto error = alt_pass_config.__init_pass_config( alt_upsweep_kernel, scan_kernel, alt_downsweep_kernel, - ptx_version, sm_count, num_items, - policy, - policy.AltUpsweep(), - policy.Scan(), - policy.AltDownsweep(), + policy.alt_downsweep.radix_bits, + policy.alt_upsweep, + policy.scan, + policy.alt_downsweep, launcher_factory)) { return error; @@ -869,10 +1061,8 @@ struct DispatchRadixSort return cudaSuccess; } - //------------------------------------------------------------------------------ - // Chained policy invocation - //------------------------------------------------------------------------------ - +public: + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t InvokeCopy() { // is_overwrite_okay == false here @@ -921,11 +1111,20 @@ struct DispatchRadixSort return cudaSuccess; } + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 /// Invocation template - CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {}) + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT = {}) { - auto wrapped_policy = detail::radix_sort::MakeRadixSortPolicyWrapper(policy); + return __invoke([] { + return detail::radix_sort::convert_policy(); + }); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t __invoke(PolicyGetter policy_getter) + { + CUB_DETAIL_CONSTEXPR_ISH auto policy = policy_getter(); // Return if empty problem, or if no bits to sort and double-buffering is used if (num_items == 0 || (begin_bit == end_bit && is_overwrite_okay)) @@ -952,26 +1151,25 @@ struct DispatchRadixSort } // Force kernel code-generation in all compiler passes - if (num_items <= static_cast( - wrapped_policy.SingleTile().BlockThreads() * wrapped_policy.SingleTile().ItemsPerThread())) + if (num_items <= static_cast(policy.single_tile.block_threads * policy.single_tile.items_per_thread)) { // Small, single tile size - return InvokeSingleTile(kernel_source.RadixSortSingleTileKernel(), wrapped_policy); + return __invoke_single_tile(kernel_source.RadixSortSingleTileKernel(), policy.single_tile); } - if CUB_DETAIL_CONSTEXPR_ISH (wrapped_policy.IsOnesweep()) + if CUB_DETAIL_CONSTEXPR_ISH (policy.use_onesweep) { - return InvokeOnesweep(wrapped_policy); + return __invoke_onesweep(policy); } else { - return InvokePasses( + return __invoke_passes( kernel_source.RadixSortUpsweepKernel(), kernel_source.RadixSortAltUpsweepKernel(), kernel_source.DeviceRadixSortScanBinsKernel(), kernel_source.RadixSortDownsweepKernel(), kernel_source.RadixSortAltDownsweepKernel(), - wrapped_policy); + policy); } } @@ -1012,6 +1210,7 @@ struct DispatchRadixSort * @param[in] stream * CUDA stream to launch kernels within. Default is stream0. */ + // TODO(bgruber): deprecate when we make the tuning API public and remove in CCCL 4.0 template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch( void* d_temp_storage, @@ -1061,6 +1260,69 @@ struct DispatchRadixSort } }; +namespace detail::radix_sort +{ +// not used, since we do not need to call Dispatch +struct fake_policy +{ + using MaxPolicy = void; +}; + +template , + typename KernelSource = DeviceRadixSortKernelSource, + typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY> +CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( + void* d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + OffsetT num_items, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + DecomposerT decomposer = {}, + PolicySelector policy_selector = {}, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}) +{ + ::cuda::arch_id arch_id{}; + if (const auto error = CubDebug(launcher_factory.PtxArchId(arch_id))) + { + return error; + } + +#if !_CCCL_COMPILER(NVRTC) && defined(CUB_DEBUG_LOG) + NV_IF_TARGET(NV_IS_HOST, + (std::stringstream ss; ss << policy_selector(arch_id); + _CubLog("Dispatching DeviceReduce to arch %d with tuning: %s\n", (int) arch_id, ss.str().c_str());)) +#endif // !_CCCL_COMPILER(NVRTC) && defined(CUB_DEBUG_LOG) + + return dispatch_arch(policy_selector, arch_id, [&](auto policy_getter) { + return DispatchRadixSort{ + d_temp_storage, + temp_storage_bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + is_overwrite_okay, + stream, + -1 /* ptx_version, not used actually */, + decomposer, + kernel_source, + launcher_factory} + .__invoke(policy_getter); + }); +} +} // namespace detail::radix_sort + CUB_NAMESPACE_END _CCCL_DIAG_POP diff --git a/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh b/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh index cf09fb0bef4..0a5fcdc7125 100644 --- a/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -40,9 +41,6 @@ namespace detail::radix_sort * @brief Upsweep digit-counting kernel entry point (multi-block). * Computes privatized digit histograms, one per block. * - * @tparam ChainedPolicyT - * Chained tuning policy - * * @tparam ALT_DIGIT_BITS * Whether or not to use the alternate (lower-bits) policy * @@ -74,14 +72,14 @@ namespace detail::radix_sort * @param[in] even_share * Even-share descriptor for mapan equal number of tiles onto each thread block */ -template -__launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS) - : int(ChainedPolicyT::ActivePolicy::UpsweepPolicy::BLOCK_THREADS))) +__launch_bounds__(int(ALT_DIGIT_BITS ? PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).alt_upsweep.block_threads + : PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).upsweep.block_threads)) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortUpsweepKernel( const KeyT* d_keys, OffsetT* d_spine, @@ -91,19 +89,23 @@ __launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltUp GridEvenShare even_share, DecomposerT decomposer = {}) { - using ActiveUpsweepPolicyT = - ::cuda::std::_If; - - using ActiveDownsweepPolicyT = - ::cuda::std::_If; + static constexpr radix_sort_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}); + static constexpr radix_sort_upsweep_policy active_upsweep_policy = + ALT_DIGIT_BITS ? policy.alt_upsweep : policy.upsweep; + static constexpr radix_sort_downsweep_policy active_downsweep_policy = + ALT_DIGIT_BITS ? policy.alt_downsweep : policy.downsweep; static constexpr int TILE_ITEMS = - ::cuda::std::max(ActiveUpsweepPolicyT::BLOCK_THREADS * ActiveUpsweepPolicyT::ITEMS_PER_THREAD, - ActiveDownsweepPolicyT::BLOCK_THREADS * ActiveDownsweepPolicyT::ITEMS_PER_THREAD); + ::cuda::std::max(active_upsweep_policy.block_threads * active_upsweep_policy.items_per_thread, + active_downsweep_policy.block_threads * active_downsweep_policy.items_per_thread); + + using ActiveUpsweepPolicyT = + AgentRadixSortUpsweepPolicy>; // Parameterize AgentRadixSortUpsweep type for the current configuration using AgentRadixSortUpsweepT = @@ -129,9 +131,6 @@ __launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltUp * @brief Spine scan kernel entry point (single-block). * Computes an exclusive prefix sum over the privatized digit histograms * - * @tparam ChainedPolicyT - * Chained tuning policy - * * @tparam OffsetT * Signed integer type for global offsets * @@ -142,19 +141,26 @@ __launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltUp * @param[in] num_counts * Total number of bin-counts */ -template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicy::BLOCK_THREADS), 1) +template +__launch_bounds__(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).scan.block_threads, 1) CUB_DETAIL_KERNEL_ATTRIBUTES void RadixSortScanBinsKernel(OffsetT* d_spine, int num_counts) { + static constexpr scan_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).scan; + using ScanPolicy = AgentScanPolicy< + policy.block_threads, + policy.items_per_thread, + void, + policy.load_algorithm, + policy.load_modifier, + policy.store_algorithm, + policy.scan_algorithm, + NoScaling, + delay_constructor_t>; + // Parameterize the AgentScan type for the current configuration - using AgentScanT = - scan::AgentScan, - OffsetT, - OffsetT, - OffsetT>; + using AgentScanT = scan::AgentScan, OffsetT, OffsetT, OffsetT>; // Shared memory storage __shared__ typename AgentScanT::TempStorage temp_storage; @@ -182,9 +188,6 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicy::BLOCK_THREADS), * @brief Downsweep pass kernel entry point (multi-block). * Scatters keys (and values) into corresponding bins for the current digit place. * - * @tparam ChainedPolicyT - * Chained tuning policy - * * @tparam ALT_DIGIT_BITS * Whether or not to use the alternate (lower-bits) policy * @@ -228,15 +231,15 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicy::BLOCK_THREADS), * @param[in] even_share * Even-share descriptor for mapan equal number of tiles onto each thread block */ -template -__launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS) - : int(ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS))) +__launch_bounds__(int(ALT_DIGIT_BITS ? PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).alt_downsweep.block_threads + : PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).downsweep.block_threads)) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortDownsweepKernel( const KeyT* d_keys_in, KeyT* d_keys_out, @@ -249,19 +252,27 @@ __launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltDo GridEvenShare even_share, DecomposerT decomposer = {}) { - using ActiveUpsweepPolicyT = - ::cuda::std::_If; + static constexpr radix_sort_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}); - using ActiveDownsweepPolicyT = - ::cuda::std::_If; + static constexpr radix_sort_upsweep_policy active_upsweep_policy = + ALT_DIGIT_BITS ? policy.alt_upsweep : policy.upsweep; + static constexpr radix_sort_downsweep_policy active_downsweep_policy = + ALT_DIGIT_BITS ? policy.alt_downsweep : policy.downsweep; static constexpr int TILE_ITEMS = - ::cuda::std::max(ActiveUpsweepPolicyT::BLOCK_THREADS * ActiveUpsweepPolicyT::ITEMS_PER_THREAD, - ActiveDownsweepPolicyT::BLOCK_THREADS * ActiveDownsweepPolicyT::ITEMS_PER_THREAD); + ::cuda::std::max(active_upsweep_policy.block_threads * active_upsweep_policy.items_per_thread, + active_downsweep_policy.block_threads * active_downsweep_policy.items_per_thread); + + using ActiveDownsweepPolicyT = AgentRadixSortDownsweepPolicy< + active_downsweep_policy.block_threads, + active_downsweep_policy.items_per_thread, + void, + active_downsweep_policy.load_algorithm, + active_downsweep_policy.load_modifier, + active_downsweep_policy.rank_algorithm, + active_downsweep_policy.scan_algorithm, + active_downsweep_policy.radix_bits, + NoScaling>; // Parameterize AgentRadixSortDownsweep type for the current configuration using AgentRadixSortDownsweepT = radix_sort:: @@ -283,9 +294,6 @@ __launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltDo * @brief Single pass kernel entry point (single-block). * Fully sorts a tile of input. * - * @tparam ChainedPolicyT - * Chained tuning policy - * * @tparam SortOrder * Whether or not to use the alternate (lower-bits) policy * @@ -319,13 +327,13 @@ __launch_bounds__(int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltDo * @param[in] end_bit * The past-the-end (most-significant) bit index needed for key comparison */ -template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) +__launch_bounds__(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).single_tile.block_threads, 1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortSingleTileKernel( const KeyT* d_keys_in, KeyT* d_keys_out, @@ -337,9 +345,10 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THRE DecomposerT decomposer = {}) { // Constants - static constexpr int BLOCK_THREADS = ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS; - static constexpr int ITEMS_PER_THREAD = ChainedPolicyT::ActivePolicy::SingleTilePolicy::ITEMS_PER_THREAD; - static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v; + static constexpr radix_sort_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}); + static constexpr int BLOCK_THREADS = policy.single_tile.block_threads; + static constexpr int ITEMS_PER_THREAD = policy.single_tile.items_per_thread; + static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v; // BlockRadixSort type using BlockRadixSortT = @@ -347,17 +356,15 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THRE BLOCK_THREADS, ITEMS_PER_THREAD, ValueT, - ChainedPolicyT::ActivePolicy::SingleTilePolicy::RADIX_BITS, - (ChainedPolicyT::ActivePolicy::SingleTilePolicy::RANK_ALGORITHM == RADIX_RANK_MEMOIZE), - ChainedPolicyT::ActivePolicy::SingleTilePolicy::SCAN_ALGORITHM>; + policy.single_tile.radix_bits, + (policy.single_tile.rank_algorithm == RADIX_RANK_MEMOIZE), + policy.single_tile.scan_algorithm>; // BlockLoad type (keys) - using BlockLoadKeys = - BlockLoad; + using BlockLoadKeys = BlockLoad; // BlockLoad type (values) - using BlockLoadValues = - BlockLoad; + using BlockLoadValues = BlockLoad; // Unsigned word for key bits using traits = detail::radix::traits_t; @@ -437,23 +444,31 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THRE /** * Histogram kernel */ -template -CUB_DETAIL_KERNEL_ATTRIBUTES -__launch_bounds__(ChainedPolicyT::ActivePolicy::HistogramPolicy::BLOCK_THREADS) void DeviceRadixSortHistogramKernel( - OffsetT* d_bins_out, const KeyT* d_keys_in, OffsetT num_items, int start_bit, int end_bit, DecomposerT decomposer = {}) +CUB_DETAIL_KERNEL_ATTRIBUTES __launch_bounds__( + PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}) + .histogram.block_threads) void DeviceRadixSortHistogramKernel(OffsetT* d_bins_out, + const KeyT* d_keys_in, + OffsetT num_items, + int start_bit, + int end_bit, + DecomposerT decomposer = {}) { - using HistogramPolicyT = typename ChainedPolicyT::ActivePolicy::HistogramPolicy; + static constexpr radix_sort_histogram_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).histogram; + + using HistogramPolicyT = + AgentRadixSortHistogramPolicy; using AgentT = AgentRadixSortHistogram; __shared__ typename AgentT::TempStorage temp_storage; AgentT agent(temp_storage, d_bins_out, d_keys_in, num_items, start_bit, end_bit, decomposer); agent.Process(); } -template -CUB_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(ChainedPolicyT::ActivePolicy::OnesweepPolicy::BLOCK_THREADS) +CUB_DETAIL_KERNEL_ATTRIBUTES void +__launch_bounds__(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).onesweep.block_threads) DeviceRadixSortOnesweepKernel( AtomicOffsetT* d_lookback, AtomicOffsetT* d_ctrs, @@ -476,7 +492,18 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(ChainedPolicyT::ActivePolicy int num_bits, DecomposerT decomposer = {}) { - using OnesweepPolicyT = typename ChainedPolicyT::ActivePolicy::OnesweepPolicy; + static constexpr radix_sort_onesweep_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).onesweep; + using OnesweepPolicyT = AgentRadixSortOnesweepPolicy< + policy.block_threads, + policy.items_per_thread, + void, + policy.rank_num_parts, + policy.rank_algorith, + policy.scan_algorithm, + policy.store_algorithm, + policy.radix_bits, + NoScaling>; + using AgentT = AgentRadixSortOnesweep +template CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortExclusiveSumKernel(OffsetT* d_bins) { - using ExclusiveSumPolicyT = typename ChainedPolicyT::ActivePolicy::ExclusiveSumPolicy; - constexpr int RADIX_BITS = ExclusiveSumPolicyT::RADIX_BITS; + static constexpr radix_sort_exclusive_sum_policy policy = + PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).exclusive_sum; + constexpr int RADIX_BITS = policy.radix_bits; constexpr int RADIX_DIGITS = 1 << RADIX_BITS; - constexpr int BLOCK_THREADS = ExclusiveSumPolicyT::BLOCK_THREADS; + constexpr int BLOCK_THREADS = policy.block_threads; constexpr int BINS_PER_THREAD = (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS; using BlockScan = cub::BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; diff --git a/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh b/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh index 4b2a94cdca6..168753fb634 100644 --- a/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh @@ -18,12 +18,456 @@ #include #include #include +#include #include -CUB_NAMESPACE_BEGIN +#include +#include + +#if !_CCCL_COMPILER(NVRTC) +# include +#endif +CUB_NAMESPACE_BEGIN namespace detail::radix_sort { +enum class delay_constructor_kind +{ + no_delay, + fixed_delay, + exponential_backoff, + exponential_backoff_jitter, + exponential_backoff_jitter_window, + exponential_backon_jitter_window, + exponential_backon_jitter, + exponential_backon +}; + +#if !_CCCL_COMPILER(NVRTC) +inline ::std::ostream& operator<<(::std::ostream& os, delay_constructor_kind kind) +{ + switch (kind) + { + case delay_constructor_kind::no_delay: + return os << "delay_constructor_kind::no_delay"; + case delay_constructor_kind::fixed_delay: + return os << "delay_constructor_kind::fixed_delay"; + case delay_constructor_kind::exponential_backoff: + return os << "delay_constructor_kind::exponential_backoff"; + case delay_constructor_kind::exponential_backoff_jitter: + return os << "delay_constructor_kind::exponential_backoff_jitter"; + case delay_constructor_kind::exponential_backoff_jitter_window: + return os << "delay_constructor_kind::exponential_backoff_jitter_window"; + case delay_constructor_kind::exponential_backon_jitter_window: + return os << "delay_constructor_kind::exponential_backon_jitter_window"; + case delay_constructor_kind::exponential_backon_jitter: + return os << "delay_constructor_kind::exponential_backon_jitter"; + case delay_constructor_kind::exponential_backon: + return os << "delay_constructor_kind::exponential_backon"; + default: + return os << "(kind) << ">"; + } +} +#endif // !_CCCL_COMPILER(NVRTC) + +struct delay_constructor_policy +{ + delay_constructor_kind kind; + unsigned int delay; + unsigned int l2_write_latency; + + _CCCL_API constexpr friend bool operator==(const delay_constructor_policy& lhs, const delay_constructor_policy& rhs) + { + return lhs.kind == rhs.kind && lhs.delay == rhs.delay && lhs.l2_write_latency == rhs.l2_write_latency; + } + + _CCCL_API constexpr friend bool operator!=(const delay_constructor_policy& lhs, const delay_constructor_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const delay_constructor_policy& p) + { + return os << "delay_constructor_policy { .kind = " << p.kind << ", .delay = " << p.delay + << ", .l2_write_latency = " << p.l2_write_latency << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +template +inline constexpr auto delay_constructor_policy_from_type = 0; + +template +inline constexpr auto delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::no_delay, 0, L2WriteLatency}; + +template +inline constexpr auto delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::fixed_delay, Delay, L2WriteLatency}; + +template +inline constexpr auto delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::exponential_backoff, Delay, L2WriteLatency}; + +template +inline constexpr auto + delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::exponential_backoff_jitter, Delay, L2WriteLatency}; + +template +inline constexpr auto + delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::exponential_backoff_jitter_window, Delay, L2WriteLatency}; + +template +inline constexpr auto + delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::exponential_backon_jitter_window, Delay, L2WriteLatency}; + +template +inline constexpr auto delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::exponential_backon_jitter, Delay, L2WriteLatency}; + +template +inline constexpr auto delay_constructor_policy_from_type> = + delay_constructor_policy{delay_constructor_kind::exponential_backon, Delay, L2WriteLatency}; + +// TODO(bgruber): this is modeled after , unify this +template +struct __delay_constructor_t_helper +{ +private: + using delay_constructors = ::cuda::std::__type_list< + detail::no_delay_constructor_t, + detail::fixed_delay_constructor_t, + detail::exponential_backoff_constructor_t, + detail::exponential_backoff_jitter_constructor_t, + detail::exponential_backoff_jitter_window_constructor_t, + detail::exponential_backon_jitter_window_constructor_t, + detail::exponential_backon_jitter_constructor_t, + detail::exponential_backon_constructor_t>; + +public: + using type = ::cuda::std::__type_at_c(Kind), delay_constructors>; +}; + +template +using delay_constructor_t = typename __delay_constructor_t_helper::type; + +struct radix_sort_histogram_policy +{ + int block_threads; + int items_per_thread; + int num_parts; + int radix_bits; + + _CCCL_API constexpr friend bool + operator==(const radix_sort_histogram_policy& lhs, const radix_sort_histogram_policy& rhs) + { + return lhs.block_threads == rhs.block_threads && lhs.items_per_thread == rhs.items_per_thread + && lhs.num_parts == rhs.num_parts && lhs.radix_bits == rhs.radix_bits; + } + + _CCCL_API constexpr friend bool + operator!=(const radix_sort_histogram_policy& lhs, const radix_sort_histogram_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const radix_sort_histogram_policy& p) + { + return os << "radix_sort_histogram_policy { .block_threads = " << p.block_threads << ", .items_per_thread = " + << p.items_per_thread << ", .num_parts = " << p.num_parts << ", .radix_bits = " << p.radix_bits << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +struct radix_sort_exclusive_sum_policy +{ + int block_threads; + int radix_bits; + + _CCCL_API constexpr friend bool + operator==(const radix_sort_exclusive_sum_policy& lhs, const radix_sort_exclusive_sum_policy& rhs) + { + return lhs.block_threads == rhs.block_threads && lhs.radix_bits == rhs.radix_bits; + } + + _CCCL_API constexpr friend bool + operator!=(const radix_sort_exclusive_sum_policy& lhs, const radix_sort_exclusive_sum_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const radix_sort_exclusive_sum_policy& p) + { + return os << "radix_sort_exclusive_sum_policy { .block_threads = " << p.block_threads + << ", .radix_bits = " << p.radix_bits << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +struct radix_sort_onesweep_policy +{ + int block_threads; + int items_per_thread; + int rank_num_parts; + int radix_bits; + RadixRankAlgorithm rank_algorith; + BlockScanAlgorithm scan_algorithm; + RadixSortStoreAlgorithm store_algorithm; + + _CCCL_API constexpr friend bool + operator==(const radix_sort_onesweep_policy& lhs, const radix_sort_onesweep_policy& rhs) + { + return lhs.block_threads == rhs.block_threads && lhs.items_per_thread == rhs.items_per_thread + && lhs.rank_num_parts == rhs.rank_num_parts && lhs.radix_bits == rhs.radix_bits + && lhs.rank_algorith == rhs.rank_algorith && lhs.scan_algorithm == rhs.scan_algorithm + && lhs.store_algorithm == rhs.store_algorithm; + } + + _CCCL_API constexpr friend bool + operator!=(const radix_sort_onesweep_policy& lhs, const radix_sort_onesweep_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const radix_sort_onesweep_policy& p) + { + return os + << "radix_sort_onesweep_policy { .block_threads = " << p.block_threads + << ", .items_per_thread = " << p.items_per_thread << ", .rank_num_parts = " << p.rank_num_parts + << ", .radix_bits = " << p.radix_bits << ", .rank_algorith = " << p.rank_algorith + << ", .scan_algorithm = " << p.scan_algorithm << ", .store_algorithm = " << p.store_algorithm << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +_CCCL_API constexpr auto make_reg_scaled_radix_sort_onesweep_policy( + int nominal_4b_block_threads, + int nominal_4b_items_per_thread, + int compute_t_size, + int rank_num_parts, + int radix_bits, + RadixRankAlgorithm rank_algorith, + BlockScanAlgorithm scan_algorithm, + RadixSortStoreAlgorithm store_algorithm) -> radix_sort_onesweep_policy +{ + const auto scaled = scale_reg_bound(nominal_4b_block_threads, nominal_4b_items_per_thread, compute_t_size); + return radix_sort_onesweep_policy{ + scaled.block_threads, + scaled.items_per_thread, + rank_num_parts, + radix_bits, + rank_algorith, + scan_algorithm, + store_algorithm}; +} + +// TODO(bgruber): move this into the scan tuning header +struct scan_policy +{ + int block_threads; + int items_per_thread; + BlockLoadAlgorithm load_algorithm; + CacheLoadModifier load_modifier; + BlockStoreAlgorithm store_algorithm; + BlockScanAlgorithm scan_algorithm; + delay_constructor_policy delay_constructor; + + _CCCL_API constexpr friend bool operator==(const scan_policy& lhs, const scan_policy& rhs) + { + return lhs.block_threads == rhs.block_threads && lhs.items_per_thread == rhs.items_per_thread + && lhs.load_algorithm == rhs.load_algorithm && lhs.load_modifier == rhs.load_modifier + && lhs.store_algorithm == rhs.store_algorithm && lhs.scan_algorithm == rhs.scan_algorithm + && lhs.delay_constructor == rhs.delay_constructor; + } + + _CCCL_API constexpr friend bool operator!=(const scan_policy& lhs, const scan_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const scan_policy& p) + { + return os + << "scan_policy { .block_threads = " << p.block_threads << ", .items_per_thread = " << p.items_per_thread + << ", .load_algorithm = " << p.load_algorithm << ", .load_modifier = " << p.load_modifier + << ", .store_algorithm = " << p.store_algorithm << ", .scan_algorithm = " << p.scan_algorithm + << ", .delay_constructor = " << p.delay_constructor << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +_CCCL_API constexpr auto make_mem_scaled_scan_policy( + int nominal_4b_block_threads, + int nominal_4b_items_per_thread, + int compute_t_size, + BlockLoadAlgorithm load_algorithm, + CacheLoadModifier load_modifier, + BlockStoreAlgorithm store_algorithm, + BlockScanAlgorithm scan_algorithm, + delay_constructor_policy delay_constructor = {delay_constructor_kind::fixed_delay, 350, 450}) -> scan_policy +{ + const auto scaled = scale_mem_bound(nominal_4b_block_threads, nominal_4b_items_per_thread, compute_t_size); + return scan_policy{ + scaled.block_threads, + scaled.items_per_thread, + load_algorithm, + load_modifier, + store_algorithm, + scan_algorithm, + delay_constructor}; +} + +struct radix_sort_downsweep_policy +{ + int block_threads; + int items_per_thread; + int radix_bits; + BlockLoadAlgorithm load_algorithm; + CacheLoadModifier load_modifier; + RadixRankAlgorithm rank_algorithm; + BlockScanAlgorithm scan_algorithm; + + _CCCL_API constexpr friend bool + operator==(const radix_sort_downsweep_policy& lhs, const radix_sort_downsweep_policy& rhs) + { + return lhs.block_threads == rhs.block_threads && lhs.items_per_thread == rhs.items_per_thread + && lhs.radix_bits == rhs.radix_bits && lhs.load_algorithm == rhs.load_algorithm + && lhs.load_modifier == rhs.load_modifier && lhs.rank_algorithm == rhs.rank_algorithm + && lhs.scan_algorithm == rhs.scan_algorithm; + } + + _CCCL_API constexpr friend bool + operator!=(const radix_sort_downsweep_policy& lhs, const radix_sort_downsweep_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const radix_sort_downsweep_policy& p) + { + return os + << "radix_sort_downsweep_policy { .block_threads = " << p.block_threads + << ", .items_per_thread = " << p.items_per_thread << ", .radix_bits = " << p.radix_bits + << ", .load_algorithm = " << p.load_algorithm << ", .load_modifier = " << p.load_modifier + << ", .rank_algorithm = " << p.rank_algorithm << ", .scan_algorithm = " << p.scan_algorithm << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +_CCCL_API constexpr auto make_reg_scaled_radix_sort_downsweep_policy( + int nominal_4b_block_threads, + int nominal_4b_items_per_thread, + int compute_t_size, + int radix_bits, + BlockLoadAlgorithm load_algorithm, + CacheLoadModifier load_modifier, + RadixRankAlgorithm rank_algorithm, + BlockScanAlgorithm scan_algorithm) -> radix_sort_downsweep_policy +{ + const auto scaled = scale_reg_bound(nominal_4b_block_threads, nominal_4b_items_per_thread, compute_t_size); + return radix_sort_downsweep_policy{ + scaled.block_threads, + scaled.items_per_thread, + radix_bits, + load_algorithm, + load_modifier, + rank_algorithm, + scan_algorithm}; +} + +struct radix_sort_upsweep_policy +{ + int block_threads; + int items_per_thread; + int radix_bits; + CacheLoadModifier load_modifier; + + _CCCL_API constexpr friend bool operator==(const radix_sort_upsweep_policy& lhs, const radix_sort_upsweep_policy& rhs) + { + return lhs.block_threads == rhs.block_threads && lhs.items_per_thread == rhs.items_per_thread + && lhs.radix_bits == rhs.radix_bits && lhs.load_modifier == rhs.load_modifier; + } + + _CCCL_API constexpr friend bool operator!=(const radix_sort_upsweep_policy& lhs, const radix_sort_upsweep_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const radix_sort_upsweep_policy& p) + { + return os + << "radix_sort_upsweep_policy { .block_threads = " << p.block_threads << ", .items_per_thread = " + << p.items_per_thread << ", .radix_bits = " << p.radix_bits << ", .load_modifier = " << p.load_modifier << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +_CCCL_API constexpr auto make_reg_scaled_radix_sort_upsweep_policy( + int nominal_4b_block_threads, + int nominal_4b_items_per_thread, + int compute_t_size, + int radix_bits, + CacheLoadModifier load_modifier) -> radix_sort_upsweep_policy +{ + const auto scaled = scale_reg_bound(nominal_4b_block_threads, nominal_4b_items_per_thread, compute_t_size); + return radix_sort_upsweep_policy{scaled.block_threads, scaled.items_per_thread, radix_bits, load_modifier}; +} + +struct radix_sort_policy +{ + bool use_onesweep; + int onesweep_radix_bits; + radix_sort_histogram_policy histogram; + radix_sort_exclusive_sum_policy exclusive_sum; + radix_sort_onesweep_policy onesweep; + scan_policy scan; + radix_sort_downsweep_policy downsweep; + radix_sort_downsweep_policy alt_downsweep; + radix_sort_upsweep_policy upsweep; + radix_sort_upsweep_policy alt_upsweep; + radix_sort_downsweep_policy single_tile; + // TODO(bgruber): move those over to segmented radix sort when we port it + radix_sort_downsweep_policy segmented; + radix_sort_downsweep_policy alt_segmented; + + _CCCL_API constexpr friend bool operator==(const radix_sort_policy& lhs, const radix_sort_policy& rhs) + { + return lhs.use_onesweep == rhs.use_onesweep && lhs.onesweep_radix_bits == rhs.onesweep_radix_bits + && lhs.histogram == rhs.histogram && lhs.exclusive_sum == rhs.exclusive_sum && lhs.onesweep == rhs.onesweep + && lhs.scan == rhs.scan && lhs.downsweep == rhs.downsweep && lhs.alt_downsweep == rhs.alt_downsweep + && lhs.upsweep == rhs.upsweep && lhs.alt_upsweep == rhs.alt_upsweep && lhs.single_tile == rhs.single_tile + && lhs.segmented == rhs.segmented && lhs.alt_segmented == rhs.alt_segmented; + } + + _CCCL_API constexpr friend bool operator!=(const radix_sort_policy& lhs, const radix_sort_policy& rhs) + { + return !(lhs == rhs); + } + +#if !_CCCL_COMPILER(NVRTC) + friend ::std::ostream& operator<<(::std::ostream& os, const radix_sort_policy& p) + { + return os + << "radix_sort_policy { .use_onesweep = " << p.use_onesweep + << ", .onesweep_radix_bits = " << p.onesweep_radix_bits << ", .histogram = " << p.histogram + << ", .exclusive_sum = " << p.exclusive_sum << ", .onesweep = " << p.onesweep << ", .scan = " << p.scan + << ", .downsweep = " << p.downsweep << ", .alt_downsweep = " << p.alt_downsweep << ", .upsweep = " << p.upsweep + << ", .alt_upsweep = " << p.alt_upsweep << ", .single_tile = " << p.single_tile + << ", .segmented = " << p.segmented << ", .alt_segmented = " << p.alt_segmented << " }"; + } +#endif // !_CCCL_COMPILER(NVRTC) +}; + +// TODO(bgruber): remove for CCCL 4.0 when we drop the public radix sort dispatcher // sm90 default template struct sm90_small_key_tuning @@ -65,6 +509,7 @@ template <> struct sm90_small_key_tuning<2, 16, 4> { static constexpr int thread template <> struct sm90_small_key_tuning<2, 16, 8> { static constexpr int threads = 576; static constexpr int items = 22; }; // clang-format on +// TODO(bgruber): remove for CCCL 4.0 when we drop the public radix sort dispatcher // sm100 default template struct sm100_small_key_tuning : sm90_small_key_tuning @@ -253,6 +698,271 @@ template struct sm100_small_key_tuning : sm template struct sm100_small_key_tuning : sm90_small_key_tuning<8, 16, 8> {}; // clang-format on +struct small_key_tuning_values +{ + int threads; + int items; +}; + +_CCCL_API constexpr auto get_sm90_tuning(int key_size, int value_size, int offset_size) -> small_key_tuning_values +{ + // keys + if (value_size == 0) + { + // clang-format off + if (key_size == 1 && offset_size == 4) return {512,19}; + if (key_size == 1 && offset_size == 8) return {512,19}; + if (key_size == 2 && offset_size == 4) return {512,19}; + if (key_size == 2 && offset_size == 8) return {512,19}; + // clang-format on + } + + // pairs 8:xx + if (key_size == 1) + { + // clang-format off + if (value_size == 1 && offset_size == 4) return {512, 15}; + if (value_size == 1 && offset_size == 8) return {448, 16}; + if (value_size == 2 && offset_size == 4) return {512, 17}; + if (value_size == 2 && offset_size == 8) return {512, 14}; + if (value_size == 4 && offset_size == 4) return {512, 17}; + if (value_size == 4 && offset_size == 8) return {512, 14}; + if (value_size == 8 && offset_size == 4) return {384, 23}; + if (value_size == 8 && offset_size == 8) return {384, 18}; + if (value_size == 16 && offset_size == 4) return {512, 22}; + if (value_size == 16 && offset_size == 8) return {512, 22}; + // clang-format on + } + + // pairs 16:xx + if (key_size == 2) + { + // clang-format off + if (value_size == 1 && offset_size == 4) return {384, 14}; + if (value_size == 1 && offset_size == 8) return {384, 16}; + if (value_size == 2 && offset_size == 4) return {384, 15}; + if (value_size == 2 && offset_size == 8) return {448, 16}; + if (value_size == 4 && offset_size == 4) return {512, 17}; + if (value_size == 4 && offset_size == 8) return {512, 12}; + if (value_size == 8 && offset_size == 4) return {384, 23}; + if (value_size == 8 && offset_size == 8) return {512, 23}; + if (value_size == 16 && offset_size == 4) return {512, 21}; + if (value_size == 16 && offset_size == 8) return {576, 22}; + // clang-format on + } + + return {384, 23}; +} + +_CCCL_API constexpr auto get_sm100_tuning(int key_size, int value_size, int offset_size, type_t key_type) + -> small_key_tuning_values +{ + // keys + if (value_size == 0) + { + if (offset_size == 4) + { + // clang-format off + + // if (key_size == 1) // same as previous tuning + + // // ipt_20.tpb_512 1.013282 0.967525 1.015764 1.047982 + // todo(@gonidelis): insignificant performance gain, need more runs. + if (key_size == 2) return small_key_tuning_values{512,20}; + + // ipt_20.tpb_512 1.089698 0.979276 1.079822 1.199378 + if (key_size == 4 && key_type == type_t::float32) return small_key_tuning_values{512,20}; + + // ipt_21.tpb_512 1.002873 0.994608 1.004196 1.019301 + // todo(@gonidelis): insignificant performance gain, need more runs. + if (key_size == 4) return small_key_tuning_values{512,21}; + + // ipt_18.tpb_288 1.049258 0.985085 1.042400 1.107771 + if (key_size == 8 && key_type == type_t::float64) return small_key_tuning_values{288,18}; + + // ipt_14.tpb_320 1.256020 1.000000 1.228182 1.486711 + if (key_size == 8) return small_key_tuning_values{320,14}; + + // if (key_size == 16) // same as previous tuning + + // clang-format on + } + else if (offset_size == 8) + { + // clang-format off + + // if (key_size == 1) // same as previous tuning + + // ipt_20.tpb_384 1.038445 1.015608 1.037620 1.068105 + if (key_size == 2) return small_key_tuning_values{384,20}; + + // ipt_20.tpb_512 1.021557 0.981437 1.018920 1.039977 + if (key_size == 4 && key_type == type_t::float32) return small_key_tuning_values{512,20}; + + // if (key_size == 4) // same as previous tuning + + // ipt_21.tpb_256 1.068590 0.986635 1.059704 1.144921 + if (key_size == 8 && key_type == type_t::float64) return small_key_tuning_values{256,21}; + + // ipt_18.tpb_320 1.248354 1.000000 1.220666 1.446929 + if (key_size == 8) return small_key_tuning_values{320,18}; + + // if (key_size == 16) // same as previous tuning + + // clang-format on + } + } + + // pairs 1-byte key + if (key_size == 1) + { + // clang-format off + + // if (value_size == 1 && offset_size == 4) // same as previous tuning + + // ipt_18.tpb_512 1.011463 0.978807 1.010106 1.024056 + // todo(@gonidelis): insignificant performance gain, need more runs. + if (value_size == 2 && offset_size == 4) return small_key_tuning_values{512,18}; + + // ipt_18.tpb_512 1.008207 0.980377 1.007132 1.022155 + // todo(@gonidelis): insignificant performance gain, need more runs. + if (value_size == 4 && offset_size == 4) return small_key_tuning_values{512,18}; + + // todo(@gonidelis): regresses for large problem sizes. + // if (value_size == 8 && offset_size == 4) return small_key_tuning_values{288,16}; + + // ipt_21.tpb_576 1.044274 0.979145 1.038723 1.072068 + // todo(@gonidelis): insignificant performance gain, need more runs. + if (value_size == 16 && offset_size == 4) return small_key_tuning_values{576,21}; + + // ipt_20.tpb_384 1.008881 0.968750 1.006846 1.026910 + // todo(@gonidelis): insignificant performance gain, need more runs. + if (value_size == 1 && offset_size == 8) return small_key_tuning_values{384,20}; + + // ipt_22.tpb_256 1.015597 0.966038 1.011167 1.045921 + if (value_size == 2 && offset_size == 8) return small_key_tuning_values{256,22}; + + // ipt_15.tpb_384 1.029730 0.972699 1.029066 1.067894 + if (value_size == 4 && offset_size == 8) return small_key_tuning_values{384,15}; + + // todo(@gonidelis): regresses for large problem sizes. + // if (value_size == 8 && offset_size == 8) return small_key_tuning_values{256,17}; + + // if (value_size == 16 && offset_size == 8) // same as previous tuning + + // clang-format on + } + + // pairs 2-byte key + if (key_size == 2) + { + // clang-format off + + // ipt_20.tpb_448 1.031929 0.936849 1.023411 1.075172 + if (value_size == 1 && offset_size == 4) return small_key_tuning_values{448,20}; + + // ipt_23.tpb_384 1.104683 0.939335 1.087342 1.234988 + if (value_size == 2 && offset_size == 4) return small_key_tuning_values{384,23}; + + // if (value_size == 4 && offset_size == 4) // same as previous tuning + + // todo(@gonidelis): regresses for large problem sizes. + // if (value_size == 8 && offset_size == 4) return small_key_tuning_values{256, 17}; + + // if (value_size == 16 && offset_size == 4) // same as previous tuning + + // ipt_15.tpb_384 1.093598 1.000000 1.088111 1.183369 + if (value_size == 1 && offset_size == 8) return small_key_tuning_values{384, 15}; + + // ipt_15.tpb_576 1.040476 1.000333 1.037060 1.084850 + if (value_size == 2 && offset_size == 8) return small_key_tuning_values{576, 15}; + + // ipt_18.tpb_512 1.096819 0.953488 1.082026 1.209533 + if (value_size == 4 && offset_size == 8) return small_key_tuning_values{512, 18}; + + // todo(@gonidelis): regresses for large problem sizes. + // if (value_size == 8 && offset_size == 8) return small_key_tuning_values{288, 16}; + + // if (value_size == 16 && offset_size == 8) // same as previous tuning + + // clang-format on + } + + // pairs 4-byte key + if (key_size == 4) + { + // clang-format off + + // ipt_21.tpb_416 1.237956 1.001909 1.210882 1.469981 + if (value_size == 1 && offset_size == 4) return small_key_tuning_values{416,21}; + + // ipt_17.tpb_512 1.022121 1.012346 1.022439 1.038524 + if (value_size == 2 && offset_size == 4) return small_key_tuning_values{512,17}; + + // ipt_20.tpb_448 1.012688 0.999531 1.011865 1.028513 + if (value_size == 4 && offset_size == 4) return small_key_tuning_values{448,20}; + + // ipt_15.tpb_384 1.006872 0.998651 1.008374 1.026118 + if (value_size == 8 && offset_size == 4) return small_key_tuning_values{384,15}; + + // if (value_size == 16 && offset_size == 4) // same as previous tuning + + // ipt_17.tpb_512 1.080000 0.927362 1.066211 1.172959 + if (value_size == 1 && offset_size == 8) return small_key_tuning_values{512,17}; + + // ipt_15.tpb_384 1.068529 1.000000 1.062277 1.135281 + if (value_size == 2 && offset_size == 8) return small_key_tuning_values{384,15}; + + // ipt_21.tpb_448 1.080642 0.927713 1.064758 1.191177 + if (value_size == 4 && offset_size == 8) return small_key_tuning_values{448,21}; + + // ipt_13.tpb_448 1.019046 0.991228 1.016971 1.039712 + if (value_size == 8 && offset_size == 8) return small_key_tuning_values{448,13}; + + // if (value_size == 16 && offset_size == 8) // same as previous tuning + + // clang-format on + } + + // pairs 8-byte key + if (key_size == 8) + { + // clang-format off + + // ipt_17.tpb_256 1.276445 1.025562 1.248511 1.496947 + if (value_size == 1 && offset_size == 4) return small_key_tuning_values{256, 17}; + + // ipt_12.tpb_352 1.128086 1.040000 1.117960 1.207254 + if (value_size == 2 && offset_size == 4) return small_key_tuning_values{352, 12}; + + // ipt_12.tpb_352 1.132699 1.040000 1.122676 1.207716 + if (value_size == 4 && offset_size == 4) return small_key_tuning_values{352, 12}; + + // ipt_18.tpb_256 1.266745 0.995432 1.237754 1.460538 + if (value_size == 8 && offset_size == 4) return small_key_tuning_values{256, 18}; + + // if (value_size == 16 && offset_size == 4) // same as previous tuning + + // ipt_15.tpb_384 1.007343 0.997656 1.006929 1.047208 + if (value_size == 1 && offset_size == 8) return small_key_tuning_values{384, 15}; + + // ipt_14.tpb_256 1.186477 1.012683 1.167150 1.332313 + if (value_size == 2 && offset_size == 8) return small_key_tuning_values{256, 14}; + + // ipt_21.tpb_256 1.220607 1.000239 1.196400 1.390471 + if (value_size == 4 && offset_size == 8) return small_key_tuning_values{256, 21}; + + // if (value_size == 8 && offset_size == 8) // same as previous tuning + + // if (value_size == 16 && offset_size == 8) // same as previous tuning + + // clang-format on + } + + return get_sm90_tuning(key_size, value_size, offset_size); +} + +// TODO(bgruber): remove when segmented radix sort is ported to the new tuning API template struct RadixSortPolicyWrapper : PolicyT { @@ -265,6 +975,7 @@ struct RadixSortPolicyWrapper : PolicyT using namespace radix_sort_runtime_policies; #endif +// TODO(bgruber): remove when segmented radix sort is ported to the new tuning API template struct RadixSortPolicyWrapper< StaticPolicyT, @@ -333,6 +1044,7 @@ struct RadixSortPolicyWrapper< #endif }; +// TODO(bgruber): remove when segmented radix sort is ported to the new tuning API template _CCCL_HOST_DEVICE RadixSortPolicyWrapper MakeRadixSortPolicyWrapper(PolicyT policy) { @@ -351,6 +1063,7 @@ _CCCL_HOST_DEVICE RadixSortPolicyWrapper MakeRadixSortPolicyWrapper(Pol * @tparam OffsetT * Signed integer type for global offsets */ +// TODO(bgruber): remove this in CCCL 4.0 when we remove the public radix sort dispatcher template struct policy_hub { @@ -1040,6 +1753,807 @@ struct policy_hub using MaxPolicy = Policy1000; }; + +[[nodiscard]] _CCCL_API constexpr int __scale_num_parts(int nominal_4b_num_parts, int compute_t_size) +{ + return ::cuda::std::max(1, nominal_4b_num_parts * 4 / ::cuda::std::max(compute_t_size, 4)); +} + +struct policy_selector +{ + int key_size; + int value_size; // when 0, indicates keys-only + int offset_size; + type_t key_type; + + // Whether this is a keys-only (or key-value) sort + [[nodiscard]] _CCCL_API constexpr int __keys_only() const + { + return value_size == 0; + } + + // Dominant-sized key/value type + [[nodiscard]] _CCCL_API constexpr int __dominant_size() const + { + return ::cuda::std::max(value_size, key_size); + } + + [[nodiscard]] _CCCL_API constexpr auto make_onsweep_small_key_policy(const small_key_tuning_values& tuning) const + -> radix_sort_policy + { + const int primary_radix_bits = (key_size > 1) ? 7 : 5; + const int single_tile_radix_bits = (key_size > 1) ? 6 : 5; + const int segmented_radix_bits = (key_size > 1) ? 6 : 5; + const int onesweep_radix_bits = 8; + + const auto histogram = radix_sort_histogram_policy{128, 16, __scale_num_parts(1, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const bool offset_64bit = offset_size == 8; + const bool key_is_float = key_type == type_t::float32; + + const auto onesweep_policy_key32 = make_reg_scaled_radix_sort_onesweep_policy( + 384, + __keys_only() ? 20 - offset_64bit - key_is_float + : (value_size < 8 ? (offset_64bit ? 17 : 23) : (offset_64bit ? 29 : 30)), + __dominant_size(), + 1, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT); + + const auto onesweep_policy_key64 = make_reg_scaled_radix_sort_onesweep_policy( + 384, + value_size < 8 ? 30 : 24, + __dominant_size(), + 1, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT); + + const auto onesweep_large_key_policy = key_size == 4 ? onesweep_policy_key32 : onesweep_policy_key64; + + const auto onesweep_small_key_policy = make_reg_scaled_radix_sort_onesweep_policy( + tuning.threads, + tuning.items, + __dominant_size(), + 1, + 8, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT); + + const auto onesweep = key_size < 4 ? onesweep_small_key_policy : onesweep_large_key_policy; + + // The scan, downsweep and upsweep policies are never run on SM90+, but we have to include them to prevent a + // compilation error: When we compile e.g. for SM70 **and** SM90, the host compiler will reach calls to those + // kernels, and instantiate them on the host, which will reach into the policies below to set the launch bounds. The + // device compiler pass will also compile all kernels for SM70 **and** SM90, even though only the onesweep kernel is + // used on SM90. + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 512, + 23, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MATCH, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + (key_size > 1) ? 256 : 128, + 47, + __dominant_size(), + primary_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto upsweep = + make_reg_scaled_radix_sort_upsweep_policy(256, 23, __dominant_size(), primary_radix_bits, LOAD_DEFAULT); + + const auto alt_upsweep = + make_reg_scaled_radix_sort_upsweep_policy(256, 47, __dominant_size(), primary_radix_bits - 1, LOAD_DEFAULT); + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + single_tile_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = make_reg_scaled_radix_sort_downsweep_policy( + 192, + 39, + __dominant_size(), + segmented_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_segmented = make_reg_scaled_radix_sort_downsweep_policy( + 384, + 11, + __dominant_size(), + segmented_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + return radix_sort_policy{ + /* use_onesweep */ true, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } + + [[nodiscard]] _CCCL_API constexpr auto operator()(::cuda::arch_id arch) const -> radix_sort_policy + { + // TODO(bgruber): we should probably separate the segmented policies and move them somewhere else + + if (arch >= ::cuda::arch_id::sm_100) + { + return make_onsweep_small_key_policy(get_sm100_tuning(key_size, value_size, offset_size, key_type)); + } + + if (arch >= ::cuda::arch_id::sm_90) + { + return make_onsweep_small_key_policy(get_sm90_tuning(key_size, value_size, offset_size)); + } + + if (arch >= ::cuda::arch_id::sm_80) + { + const int primary_radix_bits = (key_size > 1) ? 7 : 5; + const int single_tile_radix_bits = (key_size > 1) ? 6 : 5; + const int segmented_radix_bits = (key_size > 1) ? 6 : 5; + const bool use_onesweep = key_size >= int{sizeof(uint32_t)}; + const int onesweep_radix_bits = 8; + const bool offset_64bit = offset_size == 8; + + const auto histogram = radix_sort_histogram_policy{128, 16, __scale_num_parts(1, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const auto onesweep = make_reg_scaled_radix_sort_onesweep_policy( + 384, + offset_64bit && key_size == 4 && !__keys_only() ? 17 : 21, + __dominant_size(), + 1, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT); + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 512, + 23, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MATCH, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + (key_size > 1) ? 256 : 128, + 47, + __dominant_size(), + primary_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto upsweep = + make_reg_scaled_radix_sort_upsweep_policy(256, 23, __dominant_size(), primary_radix_bits, LOAD_DEFAULT); + + const auto alt_upsweep = + make_reg_scaled_radix_sort_upsweep_policy(256, 47, __dominant_size(), primary_radix_bits - 1, LOAD_DEFAULT); + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + single_tile_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = make_reg_scaled_radix_sort_downsweep_policy( + 192, + 39, + __dominant_size(), + segmented_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_segmented = make_reg_scaled_radix_sort_downsweep_policy( + 384, + 11, + __dominant_size(), + segmented_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + return radix_sort_policy{ + use_onesweep, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } + + if (arch >= ::cuda::arch_id::sm_70) + { + const int primary_radix_bits = (key_size > 1) ? 7 : 5; // 7.62B 32b keys/s (GV100) + const int single_tile_radix_bits = (key_size > 1) ? 6 : 5; + const int segmented_radix_bits = (key_size > 1) ? 6 : 5; // 8.7B 32b segmented keys/s (GV100) + const bool use_onesweep = key_size >= int{sizeof(uint32_t)}; // 15.8B 32b keys/s (V100-SXM2, 64M random keys) + const int onesweep_radix_bits = 8; + const bool offset_64bit = offset_size == 8; + + const auto histogram = radix_sort_histogram_policy{256, 8, __scale_num_parts(8, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const auto onesweep = make_reg_scaled_radix_sort_onesweep_policy( + 256, + key_size == 4 && value_size == 4 ? 46 : 23, + __dominant_size(), + 4, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_WARP_SCANS, + RADIX_SORT_STORE_DIRECT); + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 512, + 23, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MATCH, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + (key_size > 1) ? 256 : 128, + offset_64bit ? 46 : 47, + __dominant_size(), + primary_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto upsweep = + make_reg_scaled_radix_sort_upsweep_policy(256, 23, __dominant_size(), primary_radix_bits, LOAD_DEFAULT); + + const auto alt_upsweep = make_reg_scaled_radix_sort_upsweep_policy( + 256, offset_64bit ? 46 : 47, __dominant_size(), primary_radix_bits - 1, LOAD_DEFAULT); + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + single_tile_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = make_reg_scaled_radix_sort_downsweep_policy( + 192, + 39, + __dominant_size(), + segmented_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_segmented = make_reg_scaled_radix_sort_downsweep_policy( + 384, + 11, + __dominant_size(), + segmented_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + return radix_sort_policy{ + use_onesweep, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } + + if (static_cast(arch) >= 62) // TODO(bgruber): add ::cuda::arch_id::sm_62 + { + const int primary_radix_bits = 5; + const int alt_radix_bits = primary_radix_bits - 1; + const bool use_onesweep = key_size >= int{sizeof(uint32_t)}; + const int onesweep_radix_bits = 8; + + const auto histogram = radix_sort_histogram_policy{256, 8, __scale_num_parts(8, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const auto onesweep = make_reg_scaled_radix_sort_onesweep_policy( + 256, + 30, + __dominant_size(), + 2, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_WARP_SCANS, + RADIX_SORT_STORE_DIRECT); + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 16, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 16, + __dominant_size(), + alt_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto upsweep = radix_sort_upsweep_policy{ + downsweep.block_threads, downsweep.items_per_thread, downsweep.radix_bits, downsweep.load_modifier}; + + const auto alt_upsweep = radix_sort_upsweep_policy{ + alt_downsweep.block_threads, + alt_downsweep.items_per_thread, + alt_downsweep.radix_bits, + alt_downsweep.load_modifier}; + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = downsweep; + const auto alt_segmented = alt_downsweep; + + return radix_sort_policy{ + use_onesweep, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } + + if (arch >= ::cuda::arch_id::sm_61) + { + const int primary_radix_bits = (key_size > 1) ? 7 : 5; // 3.4B 32b keys/s, 1.83B 32b pairs/s (1080) + const int single_tile_radix_bits = (key_size > 1) ? 6 : 5; + const int segmented_radix_bits = (key_size > 1) ? 6 : 5; // 3.3B 32b segmented keys/s (1080) + const bool use_onesweep = key_size >= int{sizeof(uint32_t)}; // 10.0B 32b keys/s (GP100, 64M random keys) + const int onesweep_radix_bits = 8; + + const auto histogram = radix_sort_histogram_policy{256, 8, __scale_num_parts(8, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const auto onesweep = make_reg_scaled_radix_sort_onesweep_policy( + 256, + 30, + __dominant_size(), + 2, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_WARP_SCANS, + RADIX_SORT_STORE_DIRECT); + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 384, + 31, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MATCH, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 35, + __dominant_size(), + primary_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto upsweep = + make_reg_scaled_radix_sort_upsweep_policy(128, 16, __dominant_size(), primary_radix_bits, LOAD_LDG); + + const auto alt_upsweep = + make_reg_scaled_radix_sort_upsweep_policy(128, 16, __dominant_size(), primary_radix_bits - 1, LOAD_LDG); + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + single_tile_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = make_reg_scaled_radix_sort_downsweep_policy( + 192, + 39, + __dominant_size(), + segmented_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_segmented = make_reg_scaled_radix_sort_downsweep_policy( + 384, + 11, + __dominant_size(), + segmented_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + return radix_sort_policy{ + use_onesweep, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } + + if (arch >= ::cuda::arch_id::sm_60) + { + const int primary_radix_bits = (key_size > 1) ? 7 : 5; // 6.9B 32b keys/s (Quadro P100) + const int single_tile_radix_bits = (key_size > 1) ? 6 : 5; + const int segmented_radix_bits = (key_size > 1) ? 6 : 5; // 5.9B 32b segmented keys/s (Quadro P100) + const bool use_onesweep = key_size >= int{sizeof(uint32_t)}; // 10.0B 32b keys/s (GP100, 64M random keys) + const int onesweep_radix_bits = 8; + const bool offset_64bit = (offset_size == 8); + + const auto histogram = radix_sort_histogram_policy{256, 8, __scale_num_parts(8, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const auto onesweep = make_reg_scaled_radix_sort_onesweep_policy( + 256, + offset_64bit ? 29 : 30, + __dominant_size(), + 2, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_WARP_SCANS, + RADIX_SORT_STORE_DIRECT); + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 25, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MATCH, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 192, + offset_64bit ? 32 : 39, + __dominant_size(), + primary_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto upsweep = radix_sort_upsweep_policy{ + downsweep.block_threads, downsweep.items_per_thread, downsweep.radix_bits, downsweep.load_modifier}; + + const auto alt_upsweep = radix_sort_upsweep_policy{ + alt_downsweep.block_threads, + alt_downsweep.items_per_thread, + alt_downsweep.radix_bits, + alt_downsweep.load_modifier}; + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + single_tile_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = make_reg_scaled_radix_sort_downsweep_policy( + 192, + 39, + __dominant_size(), + segmented_radix_bits, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_segmented = make_reg_scaled_radix_sort_downsweep_policy( + 384, + 11, + __dominant_size(), + segmented_radix_bits - 1, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + return radix_sort_policy{ + use_onesweep, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } + + // SM50 + const int primary_radix_bits = (key_size > 1) ? 7 : 5; // 3.5B 32b keys/s, 1.92B 32b pairs/s (TitanX) + const int single_tile_radix_bits = (key_size > 1) ? 6 : 5; + const int segmented_radix_bits = (key_size > 1) ? 6 : 5; // 3.1B 32b segmented keys/s (TitanX) + const bool use_onesweep = false; + const int onesweep_radix_bits = 8; + + const auto histogram = radix_sort_histogram_policy{256, 8, __scale_num_parts(1, key_size), onesweep_radix_bits}; + + const auto exclusive_sum = radix_sort_exclusive_sum_policy{256, onesweep_radix_bits}; + + const auto onesweep = make_reg_scaled_radix_sort_onesweep_policy( + 256, + 21, + __dominant_size(), + 1, + onesweep_radix_bits, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_WARP_SCANS, + RADIX_SORT_STORE_DIRECT); + + const auto scan = make_mem_scaled_scan_policy( + 512, + 23, + offset_size, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 160, + 39, + __dominant_size(), + primary_radix_bits, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_BASIC, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_downsweep = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 16, + __dominant_size(), + primary_radix_bits - 1, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_RAKING_MEMOIZE); + + const auto upsweep = radix_sort_upsweep_policy{ + downsweep.block_threads, downsweep.items_per_thread, downsweep.radix_bits, downsweep.load_modifier}; + + const auto alt_upsweep = radix_sort_upsweep_policy{ + alt_downsweep.block_threads, + alt_downsweep.items_per_thread, + alt_downsweep.radix_bits, + alt_downsweep.load_modifier}; + + const auto single_tile = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 19, + __dominant_size(), + single_tile_radix_bits, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto segmented = make_reg_scaled_radix_sort_downsweep_policy( + 192, + 31, + __dominant_size(), + segmented_radix_bits, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + const auto alt_segmented = make_reg_scaled_radix_sort_downsweep_policy( + 256, + 11, + __dominant_size(), + segmented_radix_bits - 1, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS); + + return radix_sort_policy{ + use_onesweep, + onesweep_radix_bits, + histogram, + exclusive_sum, + onesweep, + scan, + downsweep, + alt_downsweep, + upsweep, + alt_upsweep, + single_tile, + segmented, + alt_segmented}; + } +}; + +template +struct policy_selector_from_types +{ + [[nodiscard]] _CCCL_API constexpr auto operator()(cuda::arch_id arch) const -> radix_sort_policy + { + constexpr auto policies = policy_selector{ + int{sizeof(KeyT)}, + ::cuda::std::is_same_v ? 0 : int{sizeof(ValueT)}, + int{sizeof(OffsetT)}, + classify_type}; + return policies(arch); + } +}; } // namespace detail::radix_sort CUB_NAMESPACE_END diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index cbccb3a0933..85703ae385f 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -730,6 +730,8 @@ struct KernelConfig int tile_size{0}; int sm_occupancy{0}; + // TODO(bgruber): remove this function once all reduce and radix sort (segmented) algorithms have been ported to the + // new tuning API template @@ -741,6 +743,19 @@ struct KernelConfig tile_size = block_threads * items_per_thread; return launcher_factory.MaxSmOccupancy(sm_occupancy, kernel_ptr, block_threads); } + + // Using new tuning API conventions + template + CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t + __init(KernelPtrT kernel_ptr, AgentPolicyT agent_policy = {}, LauncherFactory launcher_factory = {}) + { + block_threads = agent_policy.block_threads; + items_per_thread = agent_policy.items_per_thread; + tile_size = block_threads * items_per_thread; + return launcher_factory.MaxSmOccupancy(sm_occupancy, kernel_ptr, block_threads); + } }; } // namespace detail #endif // !_CCCL_COMPILER(NVRTC) diff --git a/cub/test/catch2_test_device_radix_sort_custom_policy_hub.cu b/cub/test/catch2_test_device_radix_sort_custom_policy_hub.cu index a1ca7a97f57..bed84e26eb4 100644 --- a/cub/test/catch2_test_device_radix_sort_custom_policy_hub.cu +++ b/cub/test/catch2_test_device_radix_sort_custom_policy_hub.cu @@ -14,8 +14,7 @@ using namespace cub; template struct my_policy_hub { - static constexpr bool KEYS_ONLY = true; - using DominantT = KeyT; + using DominantT = KeyT; // from Policy500 of the CUB radix sort tunings struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy>