Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 54 additions & 160 deletions c/parallel/src/radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include <cub/detail/choose_offset.cuh>
#include <cub/detail/launcher/cuda_driver.cuh>
#include <cub/detail/ptx-json-parser.cuh>
#include <cub/device/device_radix_sort.cuh>

#include <format>
Expand All @@ -31,92 +30,6 @@ static_assert(std::is_same_v<cub::detail::choose_offset_t<OffsetT>, OffsetT>, "O

namespace radix_sort
{
using namespace cub::detail::radix_sort_runtime_policies;

struct radix_sort_runtime_tuning_policy
{
RuntimeRadixSortHistogramAgentPolicy histogram;
RuntimeRadixSortExclusiveSumAgentPolicy exclusive_sum;
RuntimeRadixSortOnesweepAgentPolicy onesweep;
cub::detail::RuntimeScanAgentPolicy scan;
cub::detail::RuntimeRadixSortDownsweepAgentPolicy downsweep;
cub::detail::RuntimeRadixSortDownsweepAgentPolicy alt_downsweep;
RuntimeRadixSortUpsweepAgentPolicy upsweep;
RuntimeRadixSortUpsweepAgentPolicy alt_upsweep;
cub::detail::RuntimeRadixSortDownsweepAgentPolicy single_tile;
bool is_onesweep;

auto Histogram() const
{
return histogram;
}

auto ExclusiveSum() const
{
return exclusive_sum;
}

auto Onesweep() const
{
return onesweep;
}

auto Scan() const
{
return scan;
}

auto Downsweep() const
{
return downsweep;
}

auto AltDownsweep() const
{
return alt_downsweep;
}

auto Upsweep() const
{
return upsweep;
}

auto AltUpsweep() const
{
return alt_upsweep;
}

auto SingleTile() const
{
return single_tile;
}

bool IsOnesweep() const
{
return is_onesweep;
}

template <typename PolicyT>
CUB_RUNTIME_FUNCTION static constexpr int RadixBits(PolicyT policy)
{
return policy.RadixBits();
}

template <typename PolicyT>
CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT policy)
{
return policy.BlockThreads();
}

using MaxPolicy = radix_sort_runtime_tuning_policy;

template <typename F>
cudaError_t Invoke(int, F& op)
{
return op.template Invoke<radix_sort_runtime_tuning_policy>(*this);
}
};

std::string get_single_tile_kernel_name(
std::string_view chained_policy_t,
cccl_sort_order_t sort_order,
Expand Down Expand Up @@ -290,12 +203,10 @@ CUresult cccl_device_radix_sort_build_ex(
{
const char* name = "test";

const int cc = cc_major * 10 + cc_minor;
const auto key_cpp = cccl_type_enum_to_name(input_keys_it.value_type.type);
const auto value_cpp =
input_values_it.type == cccl_iterator_kind_t::CCCL_POINTER && input_values_it.state == nullptr
? "cub::NullType"
: cccl_type_enum_to_name(input_values_it.value_type.type);
const auto keys_only =
input_values_it.type == cccl_iterator_kind_t::CCCL_POINTER && input_values_it.state == nullptr;
const auto value_cpp = keys_only ? "cub::NullType" : cccl_type_enum_to_name(input_values_it.value_type.type);
const std::string op_src =
(decomposer.name == nullptr || (decomposer.name != nullptr && decomposer.name[0] == '\0'))
? "using op_wrapper = cub::detail::identity_decomposer_t;"
Expand All @@ -305,8 +216,32 @@ CUresult cccl_device_radix_sort_build_ex(
std::string offset_t;
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

const auto policy_hub_expr =
std::format("cub::detail::radix_sort::policy_hub<{}, {}, {}>", key_cpp, value_cpp, offset_t);
// TODO(bgruber): generalize this somewhere
const auto key_type = [&] {
switch (input_keys_it.value_type.type)
{
case CCCL_FLOAT32:
return cub::detail::type_t::float32;
case CCCL_FLOAT64:
return cub::detail::type_t::float64;
default:
return cub::detail::type_t::other;
}
}();

const auto policy_sel = cub::detail::radix_sort::policy_selector{
static_cast<int>(input_keys_it.value_type.size),
// FIXME(bgruber): input_values_it.value_type.size is 4 when it represents cub::NullType, which is very odd
keys_only ? 0 : static_cast<int>(input_values_it.value_type.size),
int{sizeof(OffsetT)},
key_type};

// TODO(bgruber): drop this if tuning policies become formattable
std::stringstream policy_sel_str;
policy_sel_str << policy_sel(cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor}));

auto policy_hub_expr =
std::format("cub::detail::radix_sort::policy_selector_from_types<{}, {}, {}>", key_cpp, value_cpp, offset_t);

const std::string final_src = std::format(
R"XXX(
Expand All @@ -321,21 +256,18 @@ struct __align__({3}) values_storage_t {{
char data[{2}];
}};
{4}
using {5} = {6}::MaxPolicy;

#include <cub/detail/ptx-json/json.cuh>
__device__ consteval auto& policy_generator() {{
return ptx_json::id<ptx_json::string("device_radix_sort_policy")>()
= cub::detail::radix_sort::RadixSortPolicyWrapper<{5}::ActivePolicy>::EncodedPolicy();
}}
using device_radix_sort_policy = {5};
using namespace cub;
using namespace cub::detail::radix_sort;
static_assert(device_radix_sort_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {6}, "Host generated and JIT compiled policy mismatch");
)XXX",
input_keys_it.value_type.size, // 0
input_keys_it.value_type.alignment, // 1
input_values_it.value_type.size, // 2
input_values_it.value_type.alignment, // 3
op_src, // 4
chained_policy_t, // 5
policy_hub_expr); // 6
policy_hub_expr, // 5
policy_sel_str.view()); // 6

#if false // CCCL_DEBUGGING_SWITCH
fflush(stderr);
Expand Down Expand Up @@ -379,8 +311,8 @@ __device__ consteval auto& policy_generator() {{
ctk_path,
"-rdc=true",
"-dlto",
"-default-device",
"-DCUB_DISABLE_CDP",
"-DCUB_ENABLE_POLICY_PTX_JSON",
"-std=c++20"};

cccl::detail::extend_args_with_build_config(args, config);
Expand Down Expand Up @@ -434,43 +366,13 @@ __device__ consteval auto& policy_generator() {{
&build_ptr->exclusive_sum_kernel, build_ptr->library, exclusive_sum_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(&build_ptr->onesweep_kernel, build_ptr->library, onesweep_kernel_lowered_name.c_str()));

nlohmann::json runtime_policy =
cub::detail::ptx_json::parse("device_radix_sort_policy", {result.data.get(), result.size});

using namespace cub::detail::radix_sort_runtime_policies;
using cub::detail::RuntimeScanAgentPolicy;
auto single_tile_policy =
cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "SingleTilePolicy");
auto onesweep_policy = RuntimeRadixSortOnesweepAgentPolicy::from_json(runtime_policy, "OnesweepPolicy");
auto upsweep_policy = RuntimeRadixSortUpsweepAgentPolicy::from_json(runtime_policy, "UpsweepPolicy");
auto alt_upsweep_policy = RuntimeRadixSortUpsweepAgentPolicy::from_json(runtime_policy, "AltUpsweepPolicy");
auto downsweep_policy =
cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "DownsweepPolicy");
auto alt_downsweep_policy =
cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "AltDownsweepPolicy");
auto histogram_policy = RuntimeRadixSortHistogramAgentPolicy::from_json(runtime_policy, "HistogramPolicy");
auto exclusive_sum_policy =
RuntimeRadixSortExclusiveSumAgentPolicy::from_json(runtime_policy, "ExclusiveSumPolicy");
auto scan_policy = RuntimeScanAgentPolicy::from_json(runtime_policy, "ScanPolicy");
auto is_onesweep = runtime_policy["Onesweep"].get<bool>();

build_ptr->cc = cc;
build_ptr->cc = cc_major * 10 + cc_minor;
build_ptr->cubin = (void*) result.data.release();
build_ptr->cubin_size = result.size;
build_ptr->key_type = input_keys_it.value_type;
build_ptr->value_type = input_values_it.value_type;
build_ptr->order = sort_order;
build_ptr->runtime_policy = new radix_sort::radix_sort_runtime_tuning_policy{
histogram_policy,
exclusive_sum_policy,
onesweep_policy,
scan_policy,
downsweep_policy,
alt_downsweep_policy,
upsweep_policy,
alt_upsweep_policy,
single_tile_policy,
is_onesweep};
build_ptr->runtime_policy = new cub::detail::radix_sort::policy_selector{policy_sel};
}
catch (const std::exception& exc)
{
Expand Down Expand Up @@ -529,29 +431,20 @@ CUresult cccl_device_radix_sort_impl(
cub::DoubleBuffer<indirect_arg_t> d_values_buffer(
*static_cast<indirect_arg_t**>(&val_arg_in), *static_cast<indirect_arg_t**>(&val_arg_out));

auto exec_status = cub::DispatchRadixSort<
Order,
indirect_arg_t,
indirect_arg_t,
OffsetT,
indirect_arg_t,
radix_sort::radix_sort_runtime_tuning_policy,
radix_sort::radix_sort_kernel_source,
cub::detail::CudaDriverLauncherFactory>::
Dispatch(
d_temp_storage,
*temp_storage_bytes,
d_keys_buffer,
d_values_buffer,
num_items,
begin_bit,
end_bit,
is_overwrite_okay,
stream,
decomposer,
{build},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
*reinterpret_cast<radix_sort::radix_sort_runtime_tuning_policy*>(build.runtime_policy));
auto exec_status = cub::detail::radix_sort::dispatch<Order>(
d_temp_storage,
*temp_storage_bytes,
d_keys_buffer,
d_values_buffer,
num_items,
begin_bit,
end_bit,
is_overwrite_okay,
stream,
decomposer,
*static_cast<cub::detail::radix_sort::policy_selector*>(build.runtime_policy),
radix_sort::radix_sort_kernel_source{build},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});

*selector = d_keys_buffer.selector;
error = static_cast<CUresult>(exec_status);
Expand Down Expand Up @@ -649,8 +542,9 @@ CUresult cccl_device_radix_sort_cleanup(cccl_device_radix_sort_build_result_t* b
return CUDA_ERROR_INVALID_VALUE;
}

using namespace cub::detail::radix_sort;
std::unique_ptr<char[]> cubin(reinterpret_cast<char*>(build_ptr->cubin));
std::unique_ptr<char[]> runtime_policy(reinterpret_cast<char*>(build_ptr->runtime_policy));
std::unique_ptr<policy_selector> policy(static_cast<policy_selector*>(build_ptr->runtime_policy));
check(cuLibraryUnload(build_ptr->library));
}
catch (const std::exception& exc)
Expand Down
2 changes: 2 additions & 0 deletions c/parallel/src/segmented_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ CUresult cccl_device_segmented_sort_build_ex(
ctk_path,
"-rdc=true",
"-dlto",
"-default-device",
"-DCUB_DISABLE_CDP",
"-std=c++20"};

Expand Down Expand Up @@ -696,6 +697,7 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
ctk_path,
"-rdc=true",
"-dlto",
"-default-device",
"-DCUB_DISABLE_CDP",
"-DCUB_ENABLE_POLICY_PTX_JSON",
"-std=c++20"};
Expand Down
Loading
Loading