diff --git a/c/parallel/include/cccl/c/binary_search.h b/c/parallel/include/cccl/c/binary_search.h new file mode 100644 index 00000000000..93a5b7f8070 --- /dev/null +++ b/c/parallel/include/cccl/c/binary_search.h @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifndef CCCL_C_EXPERIMENTAL +# error "C exposure is experimental and subject to change. Define CCCL_C_EXPERIMENTAL to acknowledge this notice." +#endif // !CCCL_C_EXPERIMENTAL + +#include +#include + +#include +#include + +CCCL_C_EXTERN_C_BEGIN + +typedef struct cccl_device_binary_search_build_result_t +{ + int cc; + void* cubin; + size_t cubin_size; + CUlibrary library; + CUkernel kernel; +} cccl_device_binary_search_build_result_t; + +CCCL_C_API CUresult cccl_device_binary_search_build( + cccl_device_binary_search_build_result_t* build, + cccl_binary_search_mode_t mode, + cccl_iterator_t d_data, + cccl_iterator_t d_values, + cccl_iterator_t d_out, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path); + +// Extended version with build configuration +CCCL_C_API CUresult cccl_device_binary_search_build_ex( + cccl_device_binary_search_build_result_t* build, + cccl_binary_search_mode_t mode, + cccl_iterator_t d_data, + cccl_iterator_t d_values, + cccl_iterator_t d_out, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path, + cccl_build_config* config); + +CCCL_C_API CUresult cccl_device_binary_search( + cccl_device_binary_search_build_result_t build, + cccl_iterator_t d_data, + uint64_t num_items, + cccl_iterator_t d_values, + uint64_t num_values, + cccl_iterator_t d_out, + cccl_op_t op, + CUstream stream); + +CCCL_C_API CUresult cccl_device_binary_search_cleanup(cccl_device_binary_search_build_result_t* bld_ptr); + +CCCL_C_EXTERN_C_END diff --git a/c/parallel/include/cccl/c/types.h b/c/parallel/include/cccl/c/types.h index 88ab3e8cda1..5a2c0de23b9 100644 --- a/c/parallel/include/cccl/c/types.h +++ b/c/parallel/include/cccl/c/types.h @@ -165,4 +165,10 @@ typedef enum cccl_determinism_t CCCL_GPU_TO_GPU = 2, } cccl_determinism_t; +typedef enum cccl_binary_search_mode_t +{ + CCCL_BINARY_SEARCH_LOWER_BOUND = 0, + CCCL_BINARY_SEARCH_UPPER_BOUND = 1, +} cccl_binary_search_mode_t; + CCCL_C_EXTERN_C_END diff --git a/c/parallel/src/binary_search.cu b/c/parallel/src/binary_search.cu new file mode 100644 index 00000000000..90e876cab68 --- /dev/null +++ b/c/parallel/src/binary_search.cu @@ -0,0 +1,358 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct op_wrapper; +struct device_reduce_policy; + +using OffsetT = unsigned long long; +static_assert(std::is_same_v, OffsetT>, "OffsetT must be size_t"); + +static cudaError_t Invoke( + indirect_arg_t d_in, + size_t num_items, + indirect_arg_t d_values, + size_t num_values, + indirect_arg_t d_out, + cccl_op_t op, + int /*cc*/, + CUfunction kernel, + CUstream stream) +{ + cudaError error = cudaSuccess; + + if (num_values == 0) + { + return error; + } + + void* args[] = {&d_in, &num_items, &d_values, &num_values, &d_out, &op}; + + const unsigned int thread_count = 256; + const size_t items_per_block = 512; + const size_t block_sz = cuda::ceil_div(num_values, items_per_block); + + if (block_sz > std::numeric_limits::max()) + { + return cudaErrorInvalidValue; + } + const unsigned int block_count = static_cast(block_sz); + + check(cuLaunchKernel(kernel, block_count, 1, 1, thread_count, 1, 1, 0, stream, args, 0)); + + // Check for failure to launch + error = CubDebug(cudaPeekAtLastError()); + + return error; +} + +struct binary_search_data_iterator_tag; +struct binary_search_values_iterator_tag; +struct binary_search_output_iterator_tag; +struct binary_search_op_tag; + +CUresult cccl_device_binary_search_build_ex( + cccl_device_binary_search_build_result_t* build_ptr, + cccl_binary_search_mode_t mode, + cccl_iterator_t d_data, + cccl_iterator_t d_values, + cccl_iterator_t d_out, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path, + cccl_build_config* config) +{ + CUresult error = CUDA_SUCCESS; + + try + { + if (d_data.type == cccl_iterator_kind_t::CCCL_ITERATOR) + { + throw std::runtime_error(std::string("Iterators are unsupported in for_each currently")); + } + + const char* name = "test"; + + const int cc = cc_major * 10 + cc_minor; + + auto [d_data_it_name, d_data_it_src] = + get_specialization(template_id(), d_data); + auto [d_values_it_name, d_values_it_src] = + get_specialization(template_id(), d_values); + auto [d_out_it_name, d_out_it_src] = get_specialization( + template_id(), d_out, d_out.value_type); + auto [op_name, op_src] = + get_specialization(template_id(), op, d_data.value_type); + + const std::string mode_t = [&] { + switch (mode) + { + case CCCL_BINARY_SEARCH_LOWER_BOUND: + return "cub::detail::find::lower_bound"; + case CCCL_BINARY_SEARCH_UPPER_BOUND: + return "cub::detail::find::upper_bound"; + } + throw std::runtime_error(std::format("Invalid binary search mode ({})", static_cast(mode))); + }(); + + const std::string src = std::format( + R"XXX( +#include +#include +#include + +{11} + +struct __align__({10}) storage_t {{ + char data[{9}]; +}}; + +{0} +{2} +{4} +{6} + +using policy_dim_t = cub::detail::for_each::policy_t<256, 2>; +using OffsetT = cuda::std::size_t; + +struct device_for_policy +{{ + struct ActivePolicy + {{ + using for_policy_t = policy_dim_t; + }}; +}}; + +CUB_DETAIL_KERNEL_ATTRIBUTES +__launch_bounds__(device_for_policy::ActivePolicy::for_policy_t::block_threads) + void binary_search_kernel({1} d_data, OffsetT num_data, {3} d_values, OffsetT num_values, {5} d_out, {7} op) +{{ + auto d_out_typed = [&] {{ + constexpr auto out_is_ptr = cuda::std::is_pointer_v; + constexpr auto out_matches_items = cuda::std::is_same_v; + constexpr auto need_cast = out_is_ptr && !out_matches_items; + + if constexpr (need_cast) {{ + static_assert(sizeof(decltype(*d_out)) == sizeof(decltype(d_data)), ""); + static_assert(alignof(decltype(*d_out)) == alignof(decltype(d_data)), ""); + return reinterpret_cast<{1} *>(d_out); + }} + else {{ + return d_out; + }} + }}(); + + auto input_it = cuda::make_zip_iterator(d_values, d_out_typed); + auto comp_wrapper = cub::detail::find::make_comp_wrapper<{8}>(d_data, d_data + num_data, op); + auto agent_op = [&comp_wrapper, &input_it](OffsetT index) {{ + comp_wrapper(input_it[index]); + }}; + + using active_policy_t = device_for_policy::ActivePolicy::for_policy_t; + using agent_t = cub::detail::for_each::agent_block_striped_t; + + constexpr auto block_threads = active_policy_t::block_threads; + constexpr auto items_per_tile = active_policy_t::items_per_thread * block_threads; + + const auto tile_base = static_cast(blockIdx.x) * items_per_tile; + const auto num_remaining = num_values - tile_base; + const auto items_in_tile = static_cast(num_remaining < items_per_tile ? num_remaining : items_per_tile); + + if (items_in_tile == items_per_tile) + {{ + agent_t{{tile_base, agent_op}}.template consume_tile(items_per_tile, block_threads); + }} + else + {{ + agent_t{{tile_base, agent_op}}.template consume_tile(items_in_tile, block_threads); + }} +}} +)XXX", + d_data_it_src, + d_data_it_name, + d_values_it_src, + d_values_it_name, + d_out_it_src, + d_out_it_name, + op_src, + op_name, + mode_t, + d_out.value_type.size, + d_out.value_type.alignment, + jit_template_header_contents); + + const std::string arch = std::format("-arch=sm_{0}{1}", cc_major, cc_minor); + + std::vector args = { + arch.c_str(), + cub_path, + thrust_path, + libcudacxx_path, + ctk_path, + "-std=c++20", + "-rdc=true", + "-dlto", + "-DCUB_DISABLE_CDP"}; + + cccl::detail::extend_args_with_build_config(args, config); + + constexpr size_t num_lto_args = 2; + const char* lopts[num_lto_args] = {"-lto", arch.c_str()}; + + std::string lowered_name; + + // Collect all LTO-IRs to be linked + nvrtc_linkable_list linkable_list; + nvrtc_linkable_list_appender appender{linkable_list}; + + appender.append_operation(op); + + // Add iterator definitions if present + for (const auto& it_type : {d_data, d_values, d_out}) + { + if (cccl_iterator_kind_t::CCCL_ITERATOR == it_type.type) + { + appender.append_operation(it_type.advance); + appender.append_operation(it_type.dereference); + } + } + + nvrtc_link_result result = + begin_linking_nvrtc_program(num_lto_args, lopts) + ->add_program(nvrtc_translation_unit{src, name}) + ->add_expression({"binary_search_kernel"}) + ->compile_program({args.data(), args.size()}) + ->get_name({"binary_search_kernel", lowered_name}) + ->link_program() + ->add_link_list(linkable_list) + ->finalize_program(); + + cuLibraryLoadData(&build_ptr->library, result.data.get(), nullptr, nullptr, 0, nullptr, nullptr, 0); + check(cuLibraryGetKernel(&build_ptr->kernel, build_ptr->library, lowered_name.c_str())); + + build_ptr->cc = cc; + build_ptr->cubin = (void*) result.data.release(); + build_ptr->cubin_size = result.size; + } + catch (...) + { + error = CUDA_ERROR_UNKNOWN; + } + return error; +} + +CUresult cccl_device_binary_search( + cccl_device_binary_search_build_result_t build, + cccl_iterator_t d_data, + uint64_t num_items, + cccl_iterator_t d_values, + uint64_t num_values, + cccl_iterator_t d_out, + cccl_op_t op, + CUstream stream) +{ + bool pushed = false; + CUresult error = CUDA_SUCCESS; + + try + { + pushed = try_push_context(); + auto exec_status = + Invoke(d_data, num_items, d_values, num_values, d_out, op, build.cc, (CUfunction) build.kernel, stream); + error = static_cast(exec_status); + } + catch (...) + { + error = CUDA_ERROR_UNKNOWN; + } + + if (pushed) + { + CUcontext dummy; + cuCtxPopCurrent(&dummy); + } + + return error; +} + +CUresult cccl_device_binary_search_build( + cccl_device_binary_search_build_result_t* build, + cccl_binary_search_mode_t mode, + cccl_iterator_t d_data, + cccl_iterator_t d_values, + cccl_iterator_t d_out, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path) +{ + return cccl_device_binary_search_build_ex( + build, + mode, + d_data, + d_values, + d_out, + op, + cc_major, + cc_minor, + cub_path, + thrust_path, + libcudacxx_path, + ctk_path, + nullptr); +} + +CUresult cccl_device_binary_search_cleanup(cccl_device_binary_search_build_result_t* build_ptr) +{ + try + { + if (build_ptr == nullptr) + { + return CUDA_ERROR_INVALID_VALUE; + } + + std::unique_ptr cubin(reinterpret_cast(build_ptr->cubin)); + check(cuLibraryUnload(build_ptr->library)); + } + catch (...) + { + return CUDA_ERROR_UNKNOWN; + } + + return CUDA_SUCCESS; +} diff --git a/c/parallel/src/jit_templates/templates/operation.h b/c/parallel/src/jit_templates/templates/operation.h index 111bf7411de..9c1657f7a42 100644 --- a/c/parallel/src/jit_templates/templates/operation.h +++ b/c/parallel/src/jit_templates/templates/operation.h @@ -337,3 +337,18 @@ struct binary_user_operation_traits } #endif }; + +struct binary_user_predicate_traits +{ + static const constexpr auto name = "binary_user_predicate_traits::type"; + template + using type = user_operation_traits::type{}, ValueT, ValueT>; + +#ifndef _CCCL_C_PARALLEL_JIT_TEMPLATES_PREPROCESS + template + static cuda::std::optional special(cccl_op_t operation, cccl_type_info arg_t) + { + return user_operation_traits::special(operation, arg_t, arg_t, arg_t); + } +#endif +}; diff --git a/c/parallel/test/test_binary_search.cpp b/c/parallel/test/test_binary_search.cpp new file mode 100644 index 00000000000..4d05c4eb9a9 --- /dev/null +++ b/c/parallel/test/test_binary_search.cpp @@ -0,0 +1,193 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include + +#include + +#include "algorithm_execution.h" +#include "build_result_caching.h" +#include "test_util.h" +#include + +using BuildResultT = cccl_device_binary_search_build_result_t; + +struct binary_search_cleanup +{ + CUresult operator()(BuildResultT* build_data) const noexcept + { + return cccl_device_binary_search_cleanup(build_data); + } +}; + +static std::string mode_as_key(cccl_binary_search_mode_t mode) +{ + switch (mode) + { + case cccl_binary_search_mode_t::CCCL_BINARY_SEARCH_LOWER_BOUND: + return "LOWER"; + case cccl_binary_search_mode_t::CCCL_BINARY_SEARCH_UPPER_BOUND: + return "UPPER"; + } + + throw std::runtime_error("Invalid binary search mode"); +} + +template +std::optional make_binary_search_key(bool inclusive, cccl_binary_search_mode_t mode) +{ + const std::string parts[] = {KeyBuilder::type_as_key(), KeyBuilder::bool_as_key(inclusive), mode_as_key(mode)}; + return KeyBuilder::join(parts); +} + +using binary_search_deleter = BuildResultDeleter; +using binary_search_build_cache_t = build_cache_t>; + +template +auto& get_cache() +{ + return fixture::get_or_create().get_value(); +} + +struct binary_search_build +{ + CUresult operator()( + BuildResultT* build_ptr, + cccl_binary_search_mode_t mode, + cccl_iterator_t data, + uint64_t, + cccl_iterator_t values, + uint64_t, + cccl_iterator_t out, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path) const noexcept + { + return cccl_device_binary_search_build( + build_ptr, mode, data, values, out, op, cc_major, cc_minor, cub_path, thrust_path, libcudacxx_path, ctk_path); + } + + static constexpr bool should_check_sass(int) + { + return false; + } +}; + +struct binary_search_run +{ + template + CUresult operator()(BuildResultT build, void*, std::size_t*, cccl_binary_search_mode_t, Ts... args) const noexcept + { + return cccl_device_binary_search(build, args...); + } +}; + +template +struct binary_search_wrapper +{ + static const constexpr auto mode = Mode; + + template + void operator()( + cccl_iterator_t data, + uint64_t num_items, + cccl_iterator_t values, + uint64_t num_values, + cccl_iterator_t output, + cccl_op_t op, + std::optional& cache, + const std::optional& lookup_key) const + { + AlgorithmExecute( + cache, lookup_key, mode, data, num_items, values, num_values, output, op); + } +}; + +using lower_bound = binary_search_wrapper; +using upper_bound = binary_search_wrapper; + +// ============== +// Test section +// ============== + +using integral_types = c2h::type_list; + +struct std_lower_bound_t +{ + template + RangeIteratorT operator()(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) const + { + return std::lower_bound(first, last, value, comp); + } +} std_lower_bound; + +struct std_upper_bound_t +{ + template + RangeIteratorT operator()(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) const + { + return std::upper_bound(first, last, value, comp); + } +} std_upper_bound; + +template +void test_vectorized(Variant variant, HostVariant host_variant) +{ + const std::size_t num_items = GENERATE(0, 43, take(4, random(1 << 12, 1 << 16))); + operation_t op = make_operation("op", get_merge_sort_op(get_type_info().type)); + + const std::vector target_values = generate(num_items / 100); + std::vector data = generate(num_items); + std::copy(target_values.begin(), target_values.end(), data.begin()); + std::sort(data.begin(), data.end()); + const std::vector output(target_values.size(), nullptr); + + pointer_t target_values_ptr(target_values); + pointer_t data_ptr(data); + pointer_t output_ptr(output); + + auto& build_cache = get_cache(); + const auto& test_key = make_binary_search_key(true, Variant::mode); + + variant(data_ptr, num_items, target_values_ptr, target_values.size(), output_ptr, op, build_cache, test_key); + + std::vector results(output_ptr); + std::vector expected(target_values.size(), nullptr); + + std::vector offsets(target_values.size(), 0); + std::vector expected_offsets(target_values.size(), 0); + + for (auto i = 0u; i < target_values.size(); ++i) + { + offsets[i] = results[i] - data_ptr.ptr; + expected_offsets[i] = + host_variant(data.data(), data.data() + num_items, target_values[i], std::less<>()) - data.data(); + } + + CHECK(expected_offsets == offsets); +} + +struct BinarySearch_IntegralTypes_LowerBound_Fixture_Tag; +C2H_TEST("DeviceFind::LowerBound works", "[find][device][binary-search]", integral_types) +{ + using value_type = c2h::get<0, TestType>; + test_vectorized(lower_bound{}, std_lower_bound); +} + +struct BinarySearch_IntegralTypes_UpperBound_Fixture_Tag; +C2H_TEST("DeviceFind::UpperBound works", "[find][device][binary-search]", integral_types) +{ + using value_type = c2h::get<0, TestType>; + test_vectorized(upper_bound{}, std_upper_bound); +} diff --git a/cub/benchmarks/bench/find_if/base.cu b/cub/benchmarks/bench/find_if/base.cu index 16cd8684c02..fc4deaef478 100644 --- a/cub/benchmarks/bench/find_if/base.cu +++ b/cub/benchmarks/bench/find_if/base.cu @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include +#include #include #include diff --git a/cub/cub/cub.cuh b/cub/cub/cub.cuh index 0563f5c0fe0..66d7ca6434c 100644 --- a/cub/cub/cub.cuh +++ b/cub/cub/cub.cuh @@ -44,6 +44,7 @@ // Device #include #include +#include #include #include #include diff --git a/cub/cub/detail/binary_search_helpers.cuh b/cub/cub/detail/binary_search_helpers.cuh new file mode 100644 index 00000000000..0d60b3ae52d --- /dev/null +++ b/cub/cub/detail/binary_search_helpers.cuh @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +namespace detail::find +{ +template +struct comp_wrapper_t +{ + RangeIteratorT first; + RangeIteratorT last; + CompareOpT op; + + template + _CCCL_DEVICE _CCCL_FORCEINLINE void operator()(::cuda::std::tuple args) const + { + ::cuda::std::get<1>(args) = Mode::Invoke(first, last, ::cuda::std::get<0>(args), op); + } +}; + +template +_CCCL_HOST_DEVICE auto make_comp_wrapper(RangeIteratorT first, RangeIteratorT last, CompareOpT comp) +{ + return comp_wrapper_t{first, last, comp}; +} + +struct lower_bound +{ + template + _CCCL_DEVICE _CCCL_FORCEINLINE static RangeIteratorT + Invoke(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) + { + return ::cuda::std::lower_bound(first, last, value, comp); + } +}; + +struct upper_bound +{ + template + _CCCL_DEVICE _CCCL_FORCEINLINE static RangeIteratorT + Invoke(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) + { + return ::cuda::std::upper_bound(first, last, value, comp); + } +}; +} // namespace detail::find + +CUB_NAMESPACE_END diff --git a/cub/cub/device/device_find.cuh b/cub/cub/device/device_find.cuh new file mode 100644 index 00000000000..74d1c0cc727 --- /dev/null +++ b/cub/cub/device/device_find.cuh @@ -0,0 +1,275 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include +#include +#include + +#include + +CUB_NAMESPACE_BEGIN + +struct DeviceFind +{ + //! @rst + //! Finds the first element in the input sequence that satisfies the given predicate. + //! + //! - The search terminates at the first element where the predicate evaluates to true. + //! - The index of the found element is written to ``d_out``. + //! - If no element satisfies the predicate, ``num_items`` is written to ``d_out``. + //! - The range ``[d_out, d_out + 1)`` shall not overlap ``[d_in, d_in + num_items)`` in any way. + //! - @devicestorage + //! + //! Snippet + //! ========================================================================== + //! + //! The code snippet below illustrates the finding of the first element that satisfies the predicate. + //! + //! .. literalinclude:: ../../../cub/test/catch2_test_device_find_if_api.cu + //! :language: c++ + //! :dedent: + //! :start-after: example-begin find-if-predicate + //! :end-before: example-end find-if-predicate + //! + //! .. literalinclude:: ../../../cub/test/catch2_test_device_find_if_api.cu + //! :language: c++ + //! :dedent: + //! :start-after: example-begin device-find-if + //! :end-before: example-end device-find-if + //! @endrst + //! + //! @tparam InputIteratorT + //! **[inferred]** Random-access input iterator type for reading input items @iterator + //! + //! @tparam OutputIteratorT + //! **[inferred]** Random-access output iterator type for writing the result index @iterator + //! + //! @tparam ScanOpT + //! **[inferred]** Unary predicate functor type having member `bool operator()(const T &a)` + //! + //! @tparam NumItemsT + //! **[inferred]** An integral type representing the number of input elements + //! + //! @param[in] d_temp_storage + //! Device-accessible allocation of temporary storage. When `nullptr`, the + //! required allocation size is written to `temp_storage_bytes` and no work is done. + //! + //! @param[in,out] temp_storage_bytes + //! Reference to size in bytes of `d_temp_storage` allocation + //! + //! @param[in] d_in + //! Random-access iterator to the input sequence of data items + //! + //! @param[out] d_out + //! Random-access iterator to the output location for the index of the found element + //! + //! @param[in] scan_op + //! Unary predicate functor for determining whether an element satisfies the search condition + //! + //! @param[in] num_items + //! Total number of input items (i.e., the length of `d_in`) + //! + //! @param[in] stream + //! @rst + //! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. + //! @endrst + template + CUB_RUNTIME_FUNCTION static cudaError_t FindIf( + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + NumItemsT num_items, + cudaStream_t stream = 0) + { + _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceFind::FindIf"); + + using OffsetT = detail::choose_offset_t; + + return detail::find::dispatch_t::Dispatch( + d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast(num_items), scan_op, stream); + } + + //! @rst + //! Overview + //! +++++++++++++++++++++++++++++++++++++++++++++ + //! + //! For each ``value`` in ``[values_first, values_last)``, performs a binary search in the range ``[first, last)``, + //! using ``comp`` as the comparator to find the iterator to the element of said range which **is not** ordered + //! **before** ``value``. + //! + //! - The range ``[first, last)`` must be sorted consistently with ``comp``. + //! + //! @endrst + //! + //! @tparam RangeIteratorT + //! is a model of [Random Access Iterator], whose value type forms a [Relation] with the value type of + //! ``ValuesIteratorT`` using ``CompareOpT`` as the predicate. + //! + //! @tparam ValuesIteratorT + //! is a model of [Random Access Iterator], whose value type forms a [Relation] with the value type of + //! ``RangeIteratorT`` using ``CompareOpT`` as the predicate. + //! + //! @tparam OutputIteratorT + //! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``. + //! + //! @tparam CompareOpT + //! is a model of [Strict Weak Ordering], which forms a [Relation] with the value types of ``RangeIteratorT`` + //! and ``ValuesIteratorT``. + //! + //! @param[in] d_temp_storage + //! Device-accessible allocation of temporary storage. When `nullptr`, the + //! required allocation size is written to `temp_storage_bytes` and no work + //! is done. + //! + //! @param[in,out] temp_storage_bytes + //! Reference to size in bytes of `d_temp_storage` allocation + //! + //! @param[in] first + //! Iterator to the beginning of the ordered range to be searched. + //! + //! @param[in] last + //! Iterator denoting the one-past-the-end element of the ordered range to be searched. + //! + //! @param[in] values_first + //! Iterator to the beginning of the range of values to be searched for. + //! + //! @param[in] values_last + //! Iterator denoting the one-past-the-end element of the range of values to be searched for. + //! + //! @param[out] output + //! Iterator to the beginning of the output range. + //! + //! @param[in] comp + //! Comparison function object which returns true if its first argument is ordered before the second in the + //! [Strict Weak Ordering] of the range to be searched. + //! + //! @param[in] stream + //! **[optional]** CUDA stream to launch kernels within. + //! Default is stream0. + //! + //! [Random Access Iterator]: https://en.cppreference.com/w/cpp/iterator/random_access_iterator + //! [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + //! [Relation]: https://en.cppreference.com/w/cpp/concepts/relation + template + CUB_RUNTIME_FUNCTION static cudaError_t LowerBound( + void* d_temp_storage, + size_t& temp_storage_bytes, + RangeIteratorT first, + RangeIteratorT last, + ValuesIteratorT values_first, + ValuesIteratorT values_last, + OutputIteratorT output, + CompareOpT comp, + cudaStream_t stream = 0) + { + _CCCL_NVTX_RANGE_SCOPE("cub::DeviceFind::LowerBound"); + return DeviceFor::ForEach( + d_temp_storage, + temp_storage_bytes, + ::cuda::make_zip_iterator(values_first, output), + ::cuda::make_zip_iterator(values_last, output + ::cuda::std::distance(values_first, values_last)), + detail::find::make_comp_wrapper(first, last, comp), + stream); + } + + //! @rst + //! Overview + //! +++++++++++++++++++++++++++++++++++++++++++++ + //! + //! For each ``value`` in ``[values_first, values_last)``, performs a binary search in the range ``[first, last)``, + //! using ``comp`` as the comparator to find the iterator to the element of said range which **is** ordered + //! **after** ``value``. + //! + //! - The range ``[first, last)`` must be sorted consistently with ``comp``. + //! + //! @endrst + //! + //! @tparam RangeIteratorT + //! is a model of [Random Access Iterator], whose value type forms a [Relation] with the value type of + //! ``ValuesIteratorT`` using ``CompareOpT`` as the predicate. + //! + //! @tparam ValuesIteratorT + //! is a model of [Random Access Iterator], whose value type forms a [Relation] with the value type of + //! ``RangeIteratorT`` using ``CompareOpT`` as the predicate. + //! + //! @tparam OutputIteratorT + //! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``. + //! + //! @tparam CompareOpT + //! is a model of [Strict Weak Ordering], which forms a [Relation] with the value types of ``RangeIteratorT`` + //! and ``ValuesIteratorT``. + //! + //! @param[in] d_temp_storage + //! Device-accessible allocation of temporary storage. When `nullptr`, the + //! required allocation size is written to `temp_storage_bytes` and no work + //! is done. + //! + //! @param[in,out] temp_storage_bytes + //! Reference to size in bytes of `d_temp_storage` allocation + //! + //! @param[in] first + //! Iterator to the beginning of the ordered range to be searched. + //! + //! @param[in] last + //! Iterator denoting the one-past-the-end element of the ordered range to be searched. + //! + //! @param[in] values_first + //! Iterator to the beginning of the range of values to be searched for. + //! + //! @param[in] values_last + //! Iterator denoting the one-past-the-end element of the range of values to be searched for. + //! + //! @param[out] output + //! Iterator to the beginning of the output range. + //! + //! @param[in] comp + //! Comparison function object which returns true if its first argument is ordered before the second in the + //! [Strict Weak Ordering] of the range to be searched. + //! + //! @param[in] stream + //! **[optional]** CUDA stream to launch kernels within. + //! Default is stream0. + //! + //! [Random Access Iterator]: https://en.cppreference.com/w/cpp/iterator/random_access_iterator + //! [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + //! [Relation]: https://en.cppreference.com/w/cpp/concepts/relation + template + CUB_RUNTIME_FUNCTION static cudaError_t UpperBound( + void* d_temp_storage, + size_t& temp_storage_bytes, + RangeIteratorT first, + RangeIteratorT last, + ValuesIteratorT values_first, + ValuesIteratorT values_last, + OutputIteratorT output, + CompareOpT comp, + cudaStream_t stream = 0) + { + _CCCL_NVTX_RANGE_SCOPE("cub::DeviceFind::UpperBound"); + return DeviceFor::ForEach( + d_temp_storage, + temp_storage_bytes, + ::cuda::make_zip_iterator(values_first, output), + ::cuda::make_zip_iterator(values_last, output + ::cuda::std::distance(values_first, values_last)), + detail::find::make_comp_wrapper(first, last, comp), + stream); + } +}; + +CUB_NAMESPACE_END diff --git a/cub/cub/device/device_find_if.cuh b/cub/cub/device/device_find_if.cuh deleted file mode 100644 index 7e046c61810..00000000000 --- a/cub/cub/device/device_find_if.cuh +++ /dev/null @@ -1,113 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//! @file -//! cub::DeviceFind provides device-wide, parallel operations for finding elements in sequences of data -//! items residing within device-accessible memory. - -#pragma once - -#include - -#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) -# pragma GCC system_header -#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) -# pragma clang system_header -#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) -# pragma system_header -#endif // no system header - -#include -#include -#include - -#include - -#include - -CUB_NAMESPACE_BEGIN - -struct DeviceFind -{ - //! @rst - //! Finds the first element in the input sequence that satisfies the given predicate. - //! - //! - The search terminates at the first element where the predicate evaluates to true. - //! - The index of the found element is written to ``d_out``. - //! - If no element satisfies the predicate, ``num_items`` is written to ``d_out``. - //! - The range ``[d_out, d_out + 1)`` shall not overlap ``[d_in, d_in + num_items)`` in any way. - //! - @devicestorage - //! - //! Snippet - //! ========================================================================== - //! - //! The code snippet below illustrates the finding of the first element that satisfies the predicate. - //! - //! .. literalinclude:: ../../../cub/test/catch2_test_device_find_if_api.cu - //! :language: c++ - //! :dedent: - //! :start-after: example-begin find-if-predicate - //! :end-before: example-end find-if-predicate - //! - //! .. literalinclude:: ../../../cub/test/catch2_test_device_find_if_api.cu - //! :language: c++ - //! :dedent: - //! :start-after: example-begin device-find-if - //! :end-before: example-end device-find-if - //! @endrst - //! - //! @tparam InputIteratorT - //! **[inferred]** Random-access input iterator type for reading input items @iterator - //! - //! @tparam OutputIteratorT - //! **[inferred]** Random-access output iterator type for writing the result index @iterator - //! - //! @tparam ScanOpT - //! **[inferred]** Unary predicate functor type having member `bool operator()(const T &a)` - //! - //! @tparam NumItemsT - //! **[inferred]** An integral type representing the number of input elements - //! - //! @param[in] d_temp_storage - //! Device-accessible allocation of temporary storage. When `nullptr`, the - //! required allocation size is written to `temp_storage_bytes` and no work is done. - //! - //! @param[in,out] temp_storage_bytes - //! Reference to size in bytes of `d_temp_storage` allocation - //! - //! @param[in] d_in - //! Random-access iterator to the input sequence of data items - //! - //! @param[out] d_out - //! Random-access iterator to the output location for the index of the found element - //! - //! @param[in] scan_op - //! Unary predicate functor for determining whether an element satisfies the search condition - //! - //! @param[in] num_items - //! Total number of input items (i.e., the length of `d_in`) - //! - //! @param[in] stream - //! @rst - //! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. - //! @endrst - template - CUB_RUNTIME_FUNCTION static cudaError_t FindIf( - void* d_temp_storage, - size_t& temp_storage_bytes, - InputIteratorT d_in, - OutputIteratorT d_out, - ScanOpT scan_op, - NumItemsT num_items, - cudaStream_t stream = 0) - { - _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceFind::FindIf"); - - using OffsetT = detail::choose_offset_t; - - return detail::find::dispatch_t::Dispatch( - d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast(num_items), scan_op, stream); - } -}; - -CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_device_binary_search.cu b/cub/test/catch2_test_device_binary_search.cu new file mode 100644 index 00000000000..3f2a3de7416 --- /dev/null +++ b/cub/test/catch2_test_device_binary_search.cu @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include + +#include "catch2_test_launch_helper.h" +#include + +// %PARAM% TEST_LAUNCH lid 0:1:2 + +DECLARE_LAUNCH_WRAPPER(cub::DeviceFind::LowerBound, lower_bound); +DECLARE_LAUNCH_WRAPPER(cub::DeviceFind::UpperBound, upper_bound); + +using types = c2h::type_list; + +struct std_lower_bound_t +{ + template + RangeIteratorT operator()(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) const + { + return std::lower_bound(first, last, value, comp); + } +} std_lower_bound; + +struct std_upper_bound_t +{ + template + RangeIteratorT operator()(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) const + { + return std::upper_bound(first, last, value, comp); + } +} std_upper_bound; + +template > +void test_vectorized(Variant variant, HostVariant host_variant, std::size_t num_items = 7492, CompareOp compare_op = {}) +{ + c2h::device_vector target_values_d(num_items / 100, thrust::default_init); + c2h::gen(C2H_SEED(1), target_values_d); + + c2h::device_vector values_d(num_items + target_values_d.size(), thrust::default_init); + c2h::gen(C2H_SEED(1), values_d); + + thrust::copy(c2h::device_policy, target_values_d.begin(), target_values_d.end(), values_d.begin()); + thrust::sort(c2h::device_policy, values_d.begin(), values_d.end(), compare_op); + + using Result = Value*; + c2h::device_vector result_d(target_values_d.size(), thrust::default_init); + variant(thrust::raw_pointer_cast(values_d.data()), + thrust::raw_pointer_cast(values_d.data() + num_items), + thrust::raw_pointer_cast(target_values_d.data()), + thrust::raw_pointer_cast(target_values_d.data() + target_values_d.size()), + thrust::raw_pointer_cast(result_d.data()), + compare_op); + + c2h::host_vector target_values_h = target_values_d; + c2h::host_vector values_h = values_d; + + c2h::host_vector result_h = result_d; + + c2h::host_vector offsets_ref(result_h.size(), thrust::default_init); + c2h::host_vector offsets_h(result_h.size(), thrust::default_init); + + for (auto i = 0u; i < target_values_h.size(); ++i) + { + offsets_ref[i] = + host_variant(values_h.data(), values_h.data() + num_items, target_values_h[i], compare_op) - values_h.data(); + offsets_h[i] = result_h[i] - thrust::raw_pointer_cast(values_d.data()); + } + + CHECK(offsets_ref == offsets_h); +} + +C2H_TEST("DeviceFind::LowerBound works", "[find][device][binary-search]", types) +{ + using value_type = c2h::get<0, TestType>; + test_vectorized(lower_bound, std_lower_bound); +} + +C2H_TEST("DeviceFind::UpperBound works", "[find][device][binary-search]", types) +{ + using value_type = c2h::get<0, TestType>; + test_vectorized(upper_bound, std_upper_bound); +} + +// this test exceeds 4GiB of memory and the range of 32-bit integers +C2H_TEST("DeviceFind::LowerBound really large input", + "[find][device][binary-search][skip-cs-rangecheck][skip-cs-initcheck][skip-cs-synccheck]") +{ + try + { + using value_type = char; + const auto size = std::int64_t{1} << GENERATE(30, 31, 32, 33); + test_vectorized(lower_bound, std_lower_bound, size); + } + catch (const std::bad_alloc&) + { + // allocation failure is not a test failure, so we can run tests on smaller GPUs + } +} + +// this test exceeds 4GiB of memory and the range of 32-bit integers +C2H_TEST("DeviceFind::UpperBound really large input", + "[find][device][binary-search][skip-cs-rangecheck][skip-cs-initcheck][skip-cs-synccheck]") +{ + try + { + using value_type = char; + const auto size = std::int64_t{1} << GENERATE(30, 31, 32, 33); + test_vectorized(upper_bound, std_upper_bound, size); + } + catch (const std::bad_alloc&) + { + // allocation failure is not a test failure, so we can run tests on smaller GPUs + } +} diff --git a/cub/test/catch2_test_device_find_if.cu b/cub/test/catch2_test_device_find_if.cu index 35ee8c2b977..5dadefd0239 100644 --- a/cub/test/catch2_test_device_find_if.cu +++ b/cub/test/catch2_test_device_find_if.cu @@ -3,7 +3,7 @@ #include "insert_nested_NVTX_range_guard.h" -#include +#include #include diff --git a/cub/test/catch2_test_device_find_if_api.cu b/cub/test/catch2_test_device_find_if_api.cu index 774b504f26b..32e0c0a8041 100644 --- a/cub/test/catch2_test_device_find_if_api.cu +++ b/cub/test/catch2_test_device_find_if_api.cu @@ -3,7 +3,7 @@ #include "insert_nested_NVTX_range_guard.h" -#include +#include #include diff --git a/docs/cub/api_docs/device_wide.rst b/docs/cub/api_docs/device_wide.rst index c63f597cae8..fa126572be5 100644 --- a/docs/cub/api_docs/device_wide.rst +++ b/docs/cub/api_docs/device_wide.rst @@ -101,3 +101,4 @@ CUB device-level segmented-problem (batched) parallel algorithms: * :cpp:struct:`cub::DeviceSegmentedReduce` computes reductions across multiple sequences of data residing within device-accessible memory * :cpp:struct:`cub::DeviceCopy` provides device-wide, parallel operations for batched copying of data residing within device-accessible memory * :cpp:struct:`cub::DeviceMemcpy` provides device-wide, parallel operations for batched copying of data residing within device-accessible memory +* :cpp:struct:`cub::DeviceFind` provides vectorized binary search algorithms diff --git a/thrust/thrust/system/cuda/detail/find.h b/thrust/thrust/system/cuda/detail/find.h index c1d5d6c384a..88dfc57d2d5 100644 --- a/thrust/thrust/system/cuda/detail/find.h +++ b/thrust/thrust/system/cuda/detail/find.h @@ -39,7 +39,7 @@ #if _CCCL_CUDA_COMPILATION() # include -# include +# include # include # include diff --git a/thrust/thrust/system/detail/generic/binary_search.inl b/thrust/thrust/system/detail/generic/binary_search.inl index df371db1eec..b7d1676b797 100644 --- a/thrust/thrust/system/detail/generic/binary_search.inl +++ b/thrust/thrust/system/detail/generic/binary_search.inl @@ -194,7 +194,6 @@ template _CCCL_HOST_DEVICE ForwardIterator lower_bound(thrust::execution_policy& exec, ForwardIterator begin, ForwardIterator end, const T& value) { - namespace p = thrust::placeholders; return thrust::lower_bound(exec, begin, end, value, ::cuda::std::less<>{}); } @@ -215,7 +214,6 @@ template _CCCL_HOST_DEVICE ForwardIterator upper_bound(thrust::execution_policy& exec, ForwardIterator begin, ForwardIterator end, const T& value) { - namespace p = thrust::placeholders; return thrust::upper_bound(exec, begin, end, value, ::cuda::std::less<>{}); } @@ -263,7 +261,6 @@ _CCCL_HOST_DEVICE OutputIterator lower_bound( InputIterator values_end, OutputIterator output) { - namespace p = thrust::placeholders; return thrust::lower_bound(exec, begin, end, values_begin, values_end, output, ::cuda::std::less<>{}); } @@ -293,7 +290,6 @@ _CCCL_HOST_DEVICE OutputIterator upper_bound( InputIterator values_end, OutputIterator output) { - namespace p = thrust::placeholders; return thrust::upper_bound(exec, begin, end, values_begin, values_end, output, ::cuda::std::less<>{}); } @@ -323,7 +319,6 @@ _CCCL_HOST_DEVICE OutputIterator binary_search( InputIterator values_end, OutputIterator output) { - namespace p = thrust::placeholders; return thrust::binary_search(exec, begin, end, values_begin, values_end, output, ::cuda::std::less<>{}); }