Skip to content

Commit

Permalink
Expose thrust's contiguous iterator unwrap helpers (#1717)
Browse files Browse the repository at this point in the history
Lifts the following functions into the public API and renames them:
* contiguous_iterator_raw_pointer_t -> unwrap_contiguous_iterator_t
* contiguous_iterator_raw_pointer_cast -> unwrap_contiguous_iterator
* try_unwrap_contiguous_iterator_return_t -> try_unwrap_contiguous_iterator_t
* try_unwrap_contiguous_iterator

Fixes: #1711

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
bernhardmgruber and miscco committed May 7, 2024
1 parent fb83b4a commit a1d8b31
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 35 deletions.
4 changes: 2 additions & 2 deletions thrust/testing/is_contiguous_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ struct expect_passthrough
template <typename IteratorT, typename PointerT, typename expected_unwrapped_type /* = expect_[pointer|passthrough] */>
struct check_unwrapped_iterator
{
using unwrapped_t = typename std::remove_reference<decltype(thrust::detail::try_unwrap_contiguous_iterator(
std::declval<IteratorT>()))>::type;
using unwrapped_t = ::cuda::std::__libcpp_remove_reference_t<decltype(thrust::try_unwrap_contiguous_iterator(
cuda::std::declval<IteratorT>()))>;

static constexpr bool value =
std::is_same<expected_unwrapped_type, expect_pointer>::value
Expand Down
8 changes: 4 additions & 4 deletions thrust/thrust/system/cuda/detail/adjacent_difference.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,17 @@ adjacent_difference(execution_policy<Derived>& policy, InputIt first, InputIt la
std::size_t storage_size = 0;
cudaStream_t stream = cuda_cub::stream(policy);

using UnwrapInputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<InputIt>;
using UnwrapOutputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<OutputIt>;
using UnwrapInputIt = thrust::try_unwrap_contiguous_iterator_t<InputIt>;
using UnwrapOutputIt = thrust::try_unwrap_contiguous_iterator_t<OutputIt>;

using InputValueT = thrust::iterator_value_t<UnwrapInputIt>;
using OutputValueT = thrust::iterator_value_t<UnwrapOutputIt>;

constexpr bool can_compare_iterators = std::is_pointer<UnwrapInputIt>::value && std::is_pointer<UnwrapOutputIt>::value
&& std::is_same<InputValueT, OutputValueT>::value;

auto first_unwrap = thrust::detail::try_unwrap_contiguous_iterator(first);
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
auto first_unwrap = thrust::try_unwrap_contiguous_iterator(first);
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);

thrust::detail::integral_constant<bool, can_compare_iterators> comparable;

Expand Down
24 changes: 12 additions & 12 deletions thrust/thrust/system/cuda/detail/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ _CCCL_HOST_DEVICE ValuesOutIt inclusive_scan_by_key_n(
}

// Convert to raw pointers if possible:
using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;

auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);

using Dispatch32 = cub::DispatchScanByKey<
KeysInUnwrapIt,
Expand Down Expand Up @@ -195,13 +195,13 @@ _CCCL_HOST_DEVICE ValuesOutIt exclusive_scan_by_key_n(
}

// Convert to raw pointers if possible:
using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;

auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);

using Dispatch32 = cub::DispatchScanByKey<
KeysInUnwrapIt,
Expand Down
39 changes: 22 additions & 17 deletions thrust/thrust/type_traits/is_contiguous_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,24 @@ struct contiguous_iterator_traits

using raw_pointer = typename thrust::detail::pointer_traits<decltype(&*std::declval<Iterator>())>::raw_pointer;
};
} // namespace detail

template <typename Iterator>
using contiguous_iterator_raw_pointer_t = typename contiguous_iterator_traits<Iterator>::raw_pointer;
//! Converts a contiguous iterator type to its underlying raw pointer type.
template <typename ContiguousIterator>
using unwrap_contiguous_iterator_t = typename detail::contiguous_iterator_traits<ContiguousIterator>::raw_pointer;

// Converts a contiguous iterator to a raw pointer:
template <typename Iterator>
_CCCL_HOST_DEVICE contiguous_iterator_raw_pointer_t<Iterator> contiguous_iterator_raw_pointer_cast(Iterator it)
//! Converts a contiguous iterator to its underlying raw pointer.
template <typename ContiguousIterator>
_CCCL_HOST_DEVICE auto unwrap_contiguous_iterator(ContiguousIterator it)
-> unwrap_contiguous_iterator_t<ContiguousIterator>
{
static_assert(thrust::is_contiguous_iterator<Iterator>::value,
"contiguous_iterator_raw_pointer_cast called with "
"non-contiguous iterator.");
static_assert(thrust::is_contiguous_iterator<ContiguousIterator>::value,
"unwrap_contiguous_iterator called with non-contiguous iterator.");
return thrust::raw_pointer_cast(&*it);
}

namespace detail
{
// Implementation for non-contiguous iterators -- passthrough.
template <typename Iterator, bool IsContiguous = thrust::is_contiguous_iterator<Iterator>::value>
struct try_unwrap_contiguous_iterator_impl
Expand All @@ -251,27 +255,28 @@ struct try_unwrap_contiguous_iterator_impl
template <typename Iterator>
struct try_unwrap_contiguous_iterator_impl<Iterator, true /*is_contiguous*/>
{
using type = contiguous_iterator_raw_pointer_t<Iterator>;
using type = unwrap_contiguous_iterator_t<Iterator>;

static _CCCL_HOST_DEVICE type get(Iterator it)
{
return contiguous_iterator_raw_pointer_cast(it);
return unwrap_contiguous_iterator(it);
}
};
} // namespace detail

//! Takes an iterator type and, if it is contiguous, yields the raw pointer type it represents. Otherwise returns the
//! iterator type unmodified.
template <typename Iterator>
using try_unwrap_contiguous_iterator_return_t = typename try_unwrap_contiguous_iterator_impl<Iterator>::type;
using try_unwrap_contiguous_iterator_t = typename detail::try_unwrap_contiguous_iterator_impl<Iterator>::type;

// Casts to a raw pointer if iterator is marked as contiguous, otherwise returns
// the input iterator.
//! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the
//! iterator unmodified.
template <typename Iterator>
_CCCL_HOST_DEVICE try_unwrap_contiguous_iterator_return_t<Iterator> try_unwrap_contiguous_iterator(Iterator it)
_CCCL_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it) -> try_unwrap_contiguous_iterator_t<Iterator>
{
return try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
return detail::try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
}

} // namespace detail

/*! \endcond
*/

Expand Down

0 comments on commit a1d8b31

Please sign in to comment.