Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions c/parallel/src/binary_search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,7 @@ CUB_DETAIL_KERNEL_ATTRIBUTES
__launch_bounds__(device_for_policy::ActivePolicy::for_policy_t::block_threads)
void binary_search_kernel({1} d_data, OffsetT num_data, {3} d_values, OffsetT num_values, {5} d_out, {7} op)
{{
auto d_out_typed = [&] {{
constexpr auto out_is_ptr = cuda::std::is_pointer_v<decltype(d_out)>;
constexpr auto out_matches_items = cuda::std::is_same_v<decltype(*d_out), decltype(d_data)>;
constexpr auto need_cast = out_is_ptr && !out_matches_items;

if constexpr (need_cast) {{
static_assert(sizeof(decltype(*d_out)) == sizeof(decltype(d_data)), "");
static_assert(alignof(decltype(*d_out)) == alignof(decltype(d_data)), "");
return reinterpret_cast<{1} *>(d_out);
}}
else {{
return d_out;
}}
}}();

auto input_it = cuda::make_zip_iterator(d_values, d_out_typed);
auto input_it = cuda::make_zip_iterator(d_values, d_out);
auto comp_wrapper = cub::detail::find::make_comp_wrapper<{8}>(d_data, d_data + num_data, op);
auto agent_op = [&comp_wrapper, &input_it](OffsetT index) {{
comp_wrapper(input_it[index]);
Expand Down
16 changes: 7 additions & 9 deletions c/parallel/test/test_binary_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,29 @@ void test_vectorized(Variant variant, HostVariant host_variant)
std::vector<Value> data = generate<Value>(num_items);
std::copy(target_values.begin(), target_values.end(), data.begin());
std::sort(data.begin(), data.end());
const std::vector<Value*> output(target_values.size(), nullptr);
const std::vector<std::ptrdiff_t> output(target_values.size(), 0);

pointer_t<Value> target_values_ptr(target_values);
pointer_t<Value> data_ptr(data);
pointer_t<Value*> output_ptr(output);
pointer_t<std::ptrdiff_t> output_ptr(output);

auto& build_cache = get_cache<Fixture>();
const auto& test_key = make_binary_search_key<Value>(true, Variant::mode);

variant(data_ptr, num_items, target_values_ptr, target_values.size(), output_ptr, op, build_cache, test_key);

std::vector<Value*> results(output_ptr);
std::vector<Value*> expected(target_values.size(), nullptr);
std::vector<std::ptrdiff_t> results(output_ptr);
std::vector<std::ptrdiff_t> expected(target_values.size(), 0);

std::vector<std::ptrdiff_t> offsets(target_values.size(), 0);
std::vector<std::ptrdiff_t> expected_offsets(target_values.size(), 0);
std::vector<std::ptrdiff_t> expected_results(target_values.size(), 0);

for (auto i = 0u; i < target_values.size(); ++i)
{
offsets[i] = results[i] - data_ptr.ptr;
expected_offsets[i] =
expected_results[i] =
host_variant(data.data(), data.data() + num_items, target_values[i], std::less<>()) - data.data();
}

CHECK(expected_offsets == offsets);
CHECK(expected_results == results);
}

struct BinarySearch_IntegralTypes_LowerBound_Fixture_Tag;
Expand Down
9 changes: 5 additions & 4 deletions cub/cub/detail/binary_search_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <cuda/std/__algorithm/lower_bound.h>
#include <cuda/std/__algorithm/upper_bound.h>
#include <cuda/std/cstddef>
#include <cuda/std/tuple>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -44,20 +45,20 @@ _CCCL_HOST_DEVICE auto make_comp_wrapper(RangeIteratorT first, RangeIteratorT la
struct lower_bound
{
template <typename RangeIteratorT, typename T, typename CompareOpT>
_CCCL_DEVICE _CCCL_FORCEINLINE static RangeIteratorT
_CCCL_DEVICE _CCCL_FORCEINLINE static ::cuda::std::ptrdiff_t
Invoke(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp)
{
return ::cuda::std::lower_bound(first, last, value, comp);
return ::cuda::std::lower_bound(first, last, value, comp) - first;
}
};

struct upper_bound
{
template <typename RangeIteratorT, typename T, typename CompareOpT>
_CCCL_DEVICE _CCCL_FORCEINLINE static RangeIteratorT
_CCCL_DEVICE _CCCL_FORCEINLINE static ::cuda::std::ptrdiff_t
Invoke(RangeIteratorT first, RangeIteratorT last, const T& value, CompareOpT comp)
{
return ::cuda::std::upper_bound(first, last, value, comp);
return ::cuda::std::upper_bound(first, last, value, comp) - first;
}
};
} // namespace detail::find
Expand Down
6 changes: 4 additions & 2 deletions cub/cub/device/device_find.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ struct DeviceFind
//! ``RangeIteratorT`` using ``CompareOpT`` as the predicate.
//!
//! @tparam OutputIteratorT
//! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``.
//! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``'s difference
//! type.
//!
//! @tparam CompareOpT
//! is a model of [Strict Weak Ordering], which forms a [Relation] with the value types of ``RangeIteratorT``
Expand Down Expand Up @@ -209,7 +210,8 @@ struct DeviceFind
//! ``RangeIteratorT`` using ``CompareOpT`` as the predicate.
//!
//! @tparam OutputIteratorT
//! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``.
//! is a model of [Random Access Iterator], whose value type is assignable from ``RangeIteratorT``'s difference
//! type.
//!
//! @tparam CompareOpT
//! is a model of [Strict Weak Ordering], which forms a [Relation] with the value types of ``RangeIteratorT``
Expand Down
12 changes: 5 additions & 7 deletions cub/test/catch2_test_device_binary_search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,26 @@ void test_vectorized(Variant variant, HostVariant host_variant, std::size_t num_
thrust::copy(c2h::device_policy, target_values_d.begin(), target_values_d.end(), values_d.begin());
thrust::sort(c2h::device_policy, values_d.begin(), values_d.end(), compare_op);

using Result = Value*;
c2h::device_vector<Result> result_d(target_values_d.size(), thrust::default_init);
using Result = std::ptrdiff_t;
c2h::device_vector<Result> offsets_d(target_values_d.size(), thrust::default_init);
variant(thrust::raw_pointer_cast(values_d.data()),
thrust::raw_pointer_cast(values_d.data() + num_items),
thrust::raw_pointer_cast(target_values_d.data()),
thrust::raw_pointer_cast(target_values_d.data() + target_values_d.size()),
thrust::raw_pointer_cast(result_d.data()),
thrust::raw_pointer_cast(offsets_d.data()),
compare_op);

c2h::host_vector<Value> target_values_h = target_values_d;
c2h::host_vector<Value> values_h = values_d;

c2h::host_vector<Result> result_h = result_d;
c2h::host_vector<Result> offsets_h = offsets_d;

c2h::host_vector<std::ptrdiff_t> offsets_ref(result_h.size(), thrust::default_init);
c2h::host_vector<std::ptrdiff_t> offsets_h(result_h.size(), thrust::default_init);
c2h::host_vector<std::ptrdiff_t> offsets_ref(offsets_h.size(), thrust::default_init);

for (auto i = 0u; i < target_values_h.size(); ++i)
{
offsets_ref[i] =
host_variant(values_h.data(), values_h.data() + num_items, target_values_h[i], compare_op) - values_h.data();
offsets_h[i] = result_h[i] - thrust::raw_pointer_cast(values_d.data());
}

CHECK(offsets_ref == offsets_h);
Expand Down
Loading