diff --git a/c/parallel/src/binary_search.cu b/c/parallel/src/binary_search.cu index 37728a40d2a..9b89d54a998 100644 --- a/c/parallel/src/binary_search.cu +++ b/c/parallel/src/binary_search.cu @@ -156,22 +156,7 @@ CUB_DETAIL_KERNEL_ATTRIBUTES __launch_bounds__(device_for_policy::ActivePolicy::for_policy_t::block_threads) void binary_search_kernel({1} d_data, OffsetT num_data, {3} d_values, OffsetT num_values, {5} d_out, {7} op) {{ - auto d_out_typed = [&] {{ - constexpr auto out_is_ptr = cuda::std::is_pointer_v; - constexpr auto out_matches_items = cuda::std::is_same_v; - constexpr auto need_cast = out_is_ptr && !out_matches_items; - - if constexpr (need_cast) {{ - static_assert(sizeof(decltype(*d_out)) == sizeof(decltype(d_data)), ""); - static_assert(alignof(decltype(*d_out)) == alignof(decltype(d_data)), ""); - return reinterpret_cast<{1} *>(d_out); - }} - else {{ - return d_out; - }} - }}(); - - auto input_it = cuda::make_zip_iterator(d_values, d_out_typed); + auto input_it = cuda::make_zip_iterator(d_values, d_out); auto comp_wrapper = cub::detail::find::make_comp_wrapper<{8}>(d_data, d_data + num_data, op); auto agent_op = [&comp_wrapper, &input_it](OffsetT index) {{ comp_wrapper(input_it[index]); diff --git a/c/parallel/test/test_binary_search.cpp b/c/parallel/test/test_binary_search.cpp index 4d05c4eb9a9..399128e3f03 100644 --- a/c/parallel/test/test_binary_search.cpp +++ b/c/parallel/test/test_binary_search.cpp @@ -151,31 +151,29 @@ void test_vectorized(Variant variant, HostVariant host_variant) std::vector data = generate(num_items); std::copy(target_values.begin(), target_values.end(), data.begin()); std::sort(data.begin(), data.end()); - const std::vector output(target_values.size(), nullptr); + const std::vector output(target_values.size(), 0); pointer_t target_values_ptr(target_values); pointer_t data_ptr(data); - pointer_t output_ptr(output); + pointer_t output_ptr(output); auto& build_cache = get_cache(); const auto& test_key = make_binary_search_key(true, Variant::mode); variant(data_ptr, num_items, target_values_ptr, target_values.size(), output_ptr, op, build_cache, test_key); - std::vector results(output_ptr); - std::vector expected(target_values.size(), nullptr); + std::vector results(output_ptr); + std::vector expected(target_values.size(), 0); - std::vector offsets(target_values.size(), 0); - std::vector expected_offsets(target_values.size(), 0); + std::vector expected_results(target_values.size(), 0); for (auto i = 0u; i < target_values.size(); ++i) { - offsets[i] = results[i] - data_ptr.ptr; - expected_offsets[i] = + expected_results[i] = host_variant(data.data(), data.data() + num_items, target_values[i], std::less<>()) - data.data(); } - CHECK(expected_offsets == offsets); + CHECK(expected_results == results); } struct BinarySearch_IntegralTypes_LowerBound_Fixture_Tag; diff --git a/cub/cub/detail/binary_search_helpers.cuh b/cub/cub/detail/binary_search_helpers.cuh index 0d60b3ae52d..57bb04c28a5 100644 --- a/cub/cub/detail/binary_search_helpers.cuh +++ b/cub/cub/detail/binary_search_helpers.cuh @@ -15,6 +15,7 @@ #include #include +#include #include CUB_NAMESPACE_BEGIN @@ -44,20 +45,20 @@ _CCCL_HOST_DEVICE auto make_comp_wrapper(RangeIteratorT first, RangeIteratorT la struct lower_bound { template - _CCCL_DEVICE _CCCL_FORCEINLINE static RangeIteratorT + _CCCL_DEVICE _CCCL_FORCEINLINE static ::cuda::std::ptrdiff_t Invoke(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) { - return ::cuda::std::lower_bound(first, last, value, comp); + return ::cuda::std::lower_bound(first, last, value, comp) - first; } }; struct upper_bound { template - _CCCL_DEVICE _CCCL_FORCEINLINE static RangeIteratorT + _CCCL_DEVICE _CCCL_FORCEINLINE static ::cuda::std::ptrdiff_t Invoke(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp) { - return ::cuda::std::upper_bound(first, last, value, comp); + return ::cuda::std::upper_bound(first, last, value, comp) - first; } }; } // namespace detail::find diff --git a/cub/cub/device/device_find.cuh b/cub/cub/device/device_find.cuh index 74d1c0cc727..aefdd858645 100644 --- a/cub/cub/device/device_find.cuh +++ b/cub/cub/device/device_find.cuh @@ -126,7 +126,8 @@ struct DeviceFind //! ``RangeIteratorT`` using ``CompareOpT`` as the predicate. //! //! @tparam OutputIteratorT - //! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``. + //! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``'s difference + //! type. //! //! @tparam CompareOpT //! is a model of [Strict Weak Ordering], which forms a [Relation] with the value types of ``RangeIteratorT`` @@ -209,7 +210,8 @@ struct DeviceFind //! ``RangeIteratorT`` using ``CompareOpT`` as the predicate. //! //! @tparam OutputIteratorT - //! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``. + //! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``'s difference + //! type. //! //! @tparam CompareOpT //! is a model of [Strict Weak Ordering], which forms a [Relation] with the value types of ``RangeIteratorT`` diff --git a/cub/test/catch2_test_device_binary_search.cu b/cub/test/catch2_test_device_binary_search.cu index 3f2a3de7416..c82f8afb8f1 100644 --- a/cub/test/catch2_test_device_binary_search.cu +++ b/cub/test/catch2_test_device_binary_search.cu @@ -45,28 +45,26 @@ void test_vectorized(Variant variant, HostVariant host_variant, std::size_t num_ thrust::copy(c2h::device_policy, target_values_d.begin(), target_values_d.end(), values_d.begin()); thrust::sort(c2h::device_policy, values_d.begin(), values_d.end(), compare_op); - using Result = Value*; - c2h::device_vector result_d(target_values_d.size(), thrust::default_init); + using Result = std::ptrdiff_t; + c2h::device_vector offsets_d(target_values_d.size(), thrust::default_init); variant(thrust::raw_pointer_cast(values_d.data()), thrust::raw_pointer_cast(values_d.data() + num_items), thrust::raw_pointer_cast(target_values_d.data()), thrust::raw_pointer_cast(target_values_d.data() + target_values_d.size()), - thrust::raw_pointer_cast(result_d.data()), + thrust::raw_pointer_cast(offsets_d.data()), compare_op); c2h::host_vector target_values_h = target_values_d; c2h::host_vector values_h = values_d; - c2h::host_vector result_h = result_d; + c2h::host_vector offsets_h = offsets_d; - c2h::host_vector offsets_ref(result_h.size(), thrust::default_init); - c2h::host_vector offsets_h(result_h.size(), thrust::default_init); + c2h::host_vector offsets_ref(offsets_h.size(), thrust::default_init); for (auto i = 0u; i < target_values_h.size(); ++i) { offsets_ref[i] = host_variant(values_h.data(), values_h.data() + num_items, target_values_h[i], compare_op) - values_h.data(); - offsets_h[i] = result_h[i] - thrust::raw_pointer_cast(values_d.data()); } CHECK(offsets_ref == offsets_h);