Skip to content

Commit

Permalink
Various fixes to cub::DeviceTransform (#2709)
Browse files Browse the repository at this point in the history
* Workaround non-copyable iterators
* Use a named constant for SMEM
* Cast to raw reference 2
* Fix passing non-copy-assignable iterators to transform_kernel via kernel_arg
  • Loading branch information
bernhardmgruber authored Nov 6, 2024
1 parent 2a7889b commit c358bde
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
30 changes: 24 additions & 6 deletions cub/cub/device/dispatch/dispatch_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ _CCCL_DEVICE void transform_kernel_impl(

{
// TODO(bgruber): replace by fold over comma in C++17
int dummy[] = {(prefetch_tile<block_dim>(ins, tile_size), 0)..., 0}; // extra zero to handle empty packs
// extra zero at the end handles empty packs
int dummy[] = {(prefetch_tile<block_dim>(THRUST_NS_QUALIFIER::raw_reference_cast(ins), tile_size), 0)..., 0};
(void) &dummy; // nvcc 11.1 needs extra strong unused warning suppression
}

Expand Down Expand Up @@ -494,17 +495,34 @@ _CCCL_DEVICE void transform_kernel_impl(
template <typename It>
union kernel_arg
{
aligned_base_ptr<value_t<It>> aligned_ptr;
It iterator;
aligned_base_ptr<value_t<It>> aligned_ptr; // first member is trivial
It iterator; // may not be trivially [default|copy]-constructible

_CCCL_HOST_DEVICE kernel_arg() {} // in case It is not default-constructible
static_assert(::cuda::std::is_trivial<decltype(aligned_ptr)>::value, "");

// Sometimes It is not trivially [default|copy]-constructible (e.g.
// thrust::normal_iterator<thrust::device_pointer<T>>), so because of
// https://eel.is/c++draft/class.union#general-note-3, kernel_args's special members are deleted. We work around it by
// explicitly defining them.

_CCCL_HOST_DEVICE kernel_arg() {}

_CCCL_HOST_DEVICE kernel_arg(const kernel_arg& other)
{
// since we use kernel_arg only to pass data to the device, the contained data is semantically trivially copyable,
// even if the type system is telling us otherwise.
::cuda::std::memcpy(reinterpret_cast<char*>(this), reinterpret_cast<const char*>(&other), sizeof(kernel_arg));
}
};

template <typename It>
_CCCL_HOST_DEVICE auto make_iterator_kernel_arg(It it) -> kernel_arg<It>
{
kernel_arg<It> arg;
arg.iterator = it;
// since we switch the active member of the union, we must use placement new or construct_at. This also uses the copy
// constructor of It, which works in more cases than assignment (e.g. thrust::transform_iterator with
// non-copy-assignable functor, e.g. in merge sort tests)
::cuda::std::__construct_at(&arg.iterator, it);
return arg;
}

Expand Down Expand Up @@ -620,7 +638,7 @@ struct policy_hub<RequiresStableAddress, ::cuda::std::tuple<RandomAccessIterator
static constexpr bool exhaust_smem =
bulk_copy_smem_for_tile_size<RandomAccessIteratorsIn...>(
async_policy::block_threads * async_policy::min_items_per_thread)
> 48 * 1024;
> int{max_smem_per_block};
static constexpr bool any_type_is_overalinged =
# if _CCCL_STD_VER >= 2017
((alignof(value_t<RandomAccessIteratorsIn>) > bulk_copy_alignment) || ...);
Expand Down
9 changes: 9 additions & 0 deletions cub/test/catch2_test_device_transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,12 @@ C2H_TEST("DeviceTransform::Transform aligned_base_ptr", "[device][device_transfo
CHECK(make_aligned_base_ptr(&arr[128], 128) == aligned_base_ptr<int>{reinterpret_cast<char*>(&arr[128]), 0});
CHECK(make_aligned_base_ptr(&arr[129], 128) == aligned_base_ptr<int>{reinterpret_cast<char*>(&arr[128]), 4});
}

C2H_TEST("DeviceTransform::Transform aligned_base_ptr", "[device][device_transform]")
{
using It = thrust::reverse_iterator<thrust::detail::normal_iterator<thrust::device_ptr<int>>>;
using kernel_arg = cub::detail::transform::kernel_arg<It>;

STATIC_REQUIRE(::cuda::std::is_constructible<kernel_arg>::value);
STATIC_REQUIRE(::cuda::std::is_copy_constructible<kernel_arg>::value);
}

0 comments on commit c358bde

Please sign in to comment.