Skip to content

Commit

Permalink
Expose thrust's contiguous iterator unwrap helpers
Browse files Browse the repository at this point in the history
Lifts the following functions into the public API and renames them:
* contiguous_iterator_raw_pointer_t -> unwrap_contiguous_iterator_t
* contiguous_iterator_raw_pointer_cast -> unwrap_contiguous_iterator
* try_unwrap_contiguous_iterator_return_t -> try_unwrap_contiguous_iterator_t
* try_unwrap_contiguous_iterator

Fixes: NVIDIA#1711

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
bernhardmgruber and miscco committed May 7, 2024
1 parent f53f8dd commit c9b9ca9
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 35 deletions.
4 changes: 2 additions & 2 deletions thrust/testing/is_contiguous_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ struct expect_passthrough
template <typename IteratorT, typename PointerT, typename expected_unwrapped_type /* = expect_[pointer|passthrough] */>
struct check_unwrapped_iterator
{
using unwrapped_t = typename std::remove_reference<decltype(thrust::detail::try_unwrap_contiguous_iterator(
std::declval<IteratorT>()))>::type;
using unwrapped_t = ::cuda::std::__libcpp_remove_reference_t<decltype(thrust::try_unwrap_contiguous_iterator(
cuda::std::declval<IteratorT>()))>;

static constexpr bool value =
std::is_same<expected_unwrapped_type, expect_pointer>::value
Expand Down
8 changes: 4 additions & 4 deletions thrust/thrust/system/cuda/detail/adjacent_difference.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,17 @@ adjacent_difference(execution_policy<Derived>& policy, InputIt first, InputIt la
std::size_t storage_size = 0;
cudaStream_t stream = cuda_cub::stream(policy);

using UnwrapInputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<InputIt>;
using UnwrapOutputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<OutputIt>;
using UnwrapInputIt = thrust::try_unwrap_contiguous_iterator_t<InputIt>;
using UnwrapOutputIt = thrust::try_unwrap_contiguous_iterator_t<OutputIt>;

using InputValueT = thrust::iterator_value_t<UnwrapInputIt>;
using OutputValueT = thrust::iterator_value_t<UnwrapOutputIt>;

constexpr bool can_compare_iterators = std::is_pointer<UnwrapInputIt>::value && std::is_pointer<UnwrapOutputIt>::value
&& std::is_same<InputValueT, OutputValueT>::value;

auto first_unwrap = thrust::detail::try_unwrap_contiguous_iterator(first);
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
auto first_unwrap = thrust::try_unwrap_contiguous_iterator(first);
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);

thrust::detail::integral_constant<bool, can_compare_iterators> comparable;

Expand Down
24 changes: 12 additions & 12 deletions thrust/thrust/system/cuda/detail/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ _CCCL_HOST_DEVICE ValuesOutIt inclusive_scan_by_key_n(
}

// Convert to raw pointers if possible:
using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;

auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);

using Dispatch32 = cub::DispatchScanByKey<
KeysInUnwrapIt,
Expand Down Expand Up @@ -195,13 +195,13 @@ _CCCL_HOST_DEVICE ValuesOutIt exclusive_scan_by_key_n(
}

// Convert to raw pointers if possible:
using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;

auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);

using Dispatch32 = cub::DispatchScanByKey<
KeysInUnwrapIt,
Expand Down
39 changes: 22 additions & 17 deletions thrust/thrust/type_traits/is_contiguous_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,24 @@ struct contiguous_iterator_traits

using raw_pointer = typename thrust::detail::pointer_traits<decltype(&*std::declval<Iterator>())>::raw_pointer;
};
} // namespace detail

template <typename Iterator>
using contiguous_iterator_raw_pointer_t = typename contiguous_iterator_traits<Iterator>::raw_pointer;
//! Converts a contiguous iterator type to its underlying raw pointer type.
template <typename ContiguousIterator>
using unwrap_contiguous_iterator_t = typename detail::contiguous_iterator_traits<ContiguousIterator>::raw_pointer;

// Converts a contiguous iterator to a raw pointer:
template <typename Iterator>
_CCCL_HOST_DEVICE contiguous_iterator_raw_pointer_t<Iterator> contiguous_iterator_raw_pointer_cast(Iterator it)
//! Converts a contiguous iterator to its underlying raw pointer.
template <typename ContiguousIterator>
_CCCL_HOST_DEVICE auto unwrap_contiguous_iterator(ContiguousIterator it)
-> unwrap_contiguous_iterator_t<ContiguousIterator>
{
static_assert(thrust::is_contiguous_iterator<Iterator>::value,
"contiguous_iterator_raw_pointer_cast called with "
"non-contiguous iterator.");
static_assert(thrust::is_contiguous_iterator<ContiguousIterator>::value,
"unwrap_contiguous_iterator called with non-contiguous iterator.");
return thrust::raw_pointer_cast(&*it);
}

namespace detail
{
// Implementation for non-contiguous iterators -- passthrough.
template <typename Iterator, bool IsContiguous = thrust::is_contiguous_iterator<Iterator>::value>
struct try_unwrap_contiguous_iterator_impl
Expand All @@ -251,27 +255,28 @@ struct try_unwrap_contiguous_iterator_impl
template <typename Iterator>
struct try_unwrap_contiguous_iterator_impl<Iterator, true /*is_contiguous*/>
{
using type = contiguous_iterator_raw_pointer_t<Iterator>;
using type = unwrap_contiguous_iterator_t<Iterator>;

static _CCCL_HOST_DEVICE type get(Iterator it)
{
return contiguous_iterator_raw_pointer_cast(it);
return unwrap_contiguous_iterator(it);
}
};
} // namespace detail

//! Takes an iterator type and, if it is contiguous, yields the raw pointer type it represents. Otherwise returns the
//! iterator type unmodified.
template <typename Iterator>
using try_unwrap_contiguous_iterator_return_t = typename try_unwrap_contiguous_iterator_impl<Iterator>::type;
using try_unwrap_contiguous_iterator_t = typename detail::try_unwrap_contiguous_iterator_impl<Iterator>::type;

// Casts to a raw pointer if iterator is marked as contiguous, otherwise returns
// the input iterator.
//! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the
//! iterator unmodified.
template <typename Iterator>
_CCCL_HOST_DEVICE try_unwrap_contiguous_iterator_return_t<Iterator> try_unwrap_contiguous_iterator(Iterator it)
_CCCL_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it) -> try_unwrap_contiguous_iterator_t<Iterator>
{
return try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
return detail::try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
}

} // namespace detail

/*! \endcond
*/

Expand Down

0 comments on commit c9b9ca9

Please sign in to comment.