From a1d8b31106c4f2afd4b0c98d640b5f725c10972f Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Tue, 7 May 2024 21:52:34 +0200 Subject: [PATCH] Expose thrust's contiguous iterator unwrap helpers (#1717) Lifts the following functions into the public API and renames them: * contiguous_iterator_raw_pointer_t -> unwrap_contiguous_iterator_t * contiguous_iterator_raw_pointer_cast -> unwrap_contiguous_iterator * try_unwrap_contiguous_iterator_return_t -> try_unwrap_contiguous_iterator_t * try_unwrap_contiguous_iterator Fixes: #1711 Co-authored-by: Michael Schellenberger Costa --- thrust/testing/is_contiguous_iterator.cu | 4 +- .../system/cuda/detail/adjacent_difference.h | 8 ++-- .../thrust/system/cuda/detail/scan_by_key.h | 24 ++++++------ .../type_traits/is_contiguous_iterator.h | 39 +++++++++++-------- 4 files changed, 40 insertions(+), 35 deletions(-) diff --git a/thrust/testing/is_contiguous_iterator.cu b/thrust/testing/is_contiguous_iterator.cu index f84bc3dac2..3b9ee02c45 100644 --- a/thrust/testing/is_contiguous_iterator.cu +++ b/thrust/testing/is_contiguous_iterator.cu @@ -89,8 +89,8 @@ struct expect_passthrough template struct check_unwrapped_iterator { - using unwrapped_t = typename std::remove_reference()))>::type; + using unwrapped_t = ::cuda::std::__libcpp_remove_reference_t()))>; static constexpr bool value = std::is_same::value diff --git a/thrust/thrust/system/cuda/detail/adjacent_difference.h b/thrust/thrust/system/cuda/detail/adjacent_difference.h index 2e3696a7db..275ee47ce6 100644 --- a/thrust/thrust/system/cuda/detail/adjacent_difference.h +++ b/thrust/thrust/system/cuda/detail/adjacent_difference.h @@ -151,8 +151,8 @@ adjacent_difference(execution_policy& policy, InputIt first, InputIt la std::size_t storage_size = 0; cudaStream_t stream = cuda_cub::stream(policy); - using UnwrapInputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; - using UnwrapOutputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; + using UnwrapInputIt = thrust::try_unwrap_contiguous_iterator_t; + using UnwrapOutputIt = thrust::try_unwrap_contiguous_iterator_t; using InputValueT = thrust::iterator_value_t; using OutputValueT = thrust::iterator_value_t; @@ -160,8 +160,8 @@ adjacent_difference(execution_policy& policy, InputIt first, InputIt la constexpr bool can_compare_iterators = std::is_pointer::value && std::is_pointer::value && std::is_same::value; - auto first_unwrap = thrust::detail::try_unwrap_contiguous_iterator(first); - auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result); + auto first_unwrap = thrust::try_unwrap_contiguous_iterator(first); + auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result); thrust::detail::integral_constant comparable; diff --git a/thrust/thrust/system/cuda/detail/scan_by_key.h b/thrust/thrust/system/cuda/detail/scan_by_key.h index 581c6dc0ba..7f06f29d65 100644 --- a/thrust/thrust/system/cuda/detail/scan_by_key.h +++ b/thrust/thrust/system/cuda/detail/scan_by_key.h @@ -85,14 +85,14 @@ _CCCL_HOST_DEVICE ValuesOutIt inclusive_scan_by_key_n( } // Convert to raw pointers if possible: - using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; - using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; - using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; + using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t; + using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t; + using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t; using AccumT = typename thrust::iterator_traits::value_type; - auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys); - auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values); - auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result); + auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys); + auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values); + auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result); using Dispatch32 = cub::DispatchScanByKey< KeysInUnwrapIt, @@ -195,13 +195,13 @@ _CCCL_HOST_DEVICE ValuesOutIt exclusive_scan_by_key_n( } // Convert to raw pointers if possible: - using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; - using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; - using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t; + using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t; + using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t; + using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t; - auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys); - auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values); - auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result); + auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys); + auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values); + auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result); using Dispatch32 = cub::DispatchScanByKey< KeysInUnwrapIt, diff --git a/thrust/thrust/type_traits/is_contiguous_iterator.h b/thrust/thrust/type_traits/is_contiguous_iterator.h index ece349773a..764c58b3d2 100644 --- a/thrust/thrust/type_traits/is_contiguous_iterator.h +++ b/thrust/thrust/type_traits/is_contiguous_iterator.h @@ -221,20 +221,24 @@ struct contiguous_iterator_traits using raw_pointer = typename thrust::detail::pointer_traits())>::raw_pointer; }; +} // namespace detail -template -using contiguous_iterator_raw_pointer_t = typename contiguous_iterator_traits::raw_pointer; +//! Converts a contiguous iterator type to its underlying raw pointer type. +template +using unwrap_contiguous_iterator_t = typename detail::contiguous_iterator_traits::raw_pointer; -// Converts a contiguous iterator to a raw pointer: -template -_CCCL_HOST_DEVICE contiguous_iterator_raw_pointer_t contiguous_iterator_raw_pointer_cast(Iterator it) +//! Converts a contiguous iterator to its underlying raw pointer. +template +_CCCL_HOST_DEVICE auto unwrap_contiguous_iterator(ContiguousIterator it) + -> unwrap_contiguous_iterator_t { - static_assert(thrust::is_contiguous_iterator::value, - "contiguous_iterator_raw_pointer_cast called with " - "non-contiguous iterator."); + static_assert(thrust::is_contiguous_iterator::value, + "unwrap_contiguous_iterator called with non-contiguous iterator."); return thrust::raw_pointer_cast(&*it); } +namespace detail +{ // Implementation for non-contiguous iterators -- passthrough. template ::value> struct try_unwrap_contiguous_iterator_impl @@ -251,27 +255,28 @@ struct try_unwrap_contiguous_iterator_impl template struct try_unwrap_contiguous_iterator_impl { - using type = contiguous_iterator_raw_pointer_t; + using type = unwrap_contiguous_iterator_t; static _CCCL_HOST_DEVICE type get(Iterator it) { - return contiguous_iterator_raw_pointer_cast(it); + return unwrap_contiguous_iterator(it); } }; +} // namespace detail +//! Takes an iterator type and, if it is contiguous, yields the raw pointer type it represents. Otherwise returns the +//! iterator type unmodified. template -using try_unwrap_contiguous_iterator_return_t = typename try_unwrap_contiguous_iterator_impl::type; +using try_unwrap_contiguous_iterator_t = typename detail::try_unwrap_contiguous_iterator_impl::type; -// Casts to a raw pointer if iterator is marked as contiguous, otherwise returns -// the input iterator. +//! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the +//! iterator unmodified. template -_CCCL_HOST_DEVICE try_unwrap_contiguous_iterator_return_t try_unwrap_contiguous_iterator(Iterator it) +_CCCL_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it) -> try_unwrap_contiguous_iterator_t { - return try_unwrap_contiguous_iterator_impl::get(it); + return detail::try_unwrap_contiguous_iterator_impl::get(it); } -} // namespace detail - /*! \endcond */