Skip to content

Commit

Permalink
Implement P3029R1: deduction from integral_constant (#1786)
Browse files Browse the repository at this point in the history
* Implement P3029R1: deduction from `integral_constant`

* Address review feedback
  • Loading branch information
miscco authored May 30, 2024
1 parent d3e15ee commit bf1a71a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 20 deletions.
51 changes: 46 additions & 5 deletions libcudacxx/include/cuda/std/detail/libcxx/include/span
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ template<class R>
# pragma system_header
#endif // no system header

#include <cuda/std/__concepts/convertible_to.h>
#include <cuda/std/__concepts/equality_comparable.h>
#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/span.h>
#include <cuda/std/__fwd/string.h>
Expand All @@ -157,6 +159,9 @@ template<class R>
#include <cuda/std/__type_traits/is_array.h>
#include <cuda/std/__type_traits/is_const.h>
#include <cuda/std/__type_traits/is_convertible.h>
#include <cuda/std/__type_traits/is_integral.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__type_traits/remove_const.h>
#include <cuda/std/__type_traits/remove_cv.h>
#include <cuda/std/__type_traits/remove_cvref.h>
#include <cuda/std/__type_traits/remove_pointer.h>
Expand Down Expand Up @@ -257,6 +262,40 @@ _LIBCUDACXX_INLINE_VAR constexpr bool __is_span_compatible_container<
nullptr_t>>> = true;
# endif // _CCCL_STD_VER <= 2014 || _CCCL_COMPILER_MSVC_2017

# if _CCCL_STD_VER >= 2020

template <class _Tp>
concept __integral_constant_like =
is_integral_v<decltype(_Tp::value)> //
&& !is_same_v<bool, remove_const_t<decltype(_Tp::value)>> //
&& convertible_to<_Tp, decltype(_Tp::value)> //
&& equality_comparable_with<_Tp, decltype(_Tp::value)> //
&& bool_constant<_Tp() == _Tp::value>::value
&& bool_constant<static_cast<decltype(_Tp::value)>(_Tp()) == _Tp::value>::value;

# else // ^^^ _CCCL_STD_VER >= 2020 ^^^ / vvv _CCCL_STD_VER <= 2017 vvv

template <class _Tp>
_LIBCUDACXX_CONCEPT_FRAGMENT(
__integral_constant_like_,
requires()( //
requires(_CCCL_TRAIT(is_integral, decltype(_Tp::value))),
requires(!_CCCL_TRAIT(is_same, bool, remove_const_t<decltype(_Tp::value)>)),
requires(convertible_to<_Tp, decltype(_Tp::value)>),
requires(equality_comparable_with<_Tp, decltype(_Tp::value)>),
(integral_constant<bool, _Tp() == _Tp::value>::value),
(integral_constant<bool, static_cast<decltype(_Tp::value)>(_Tp()) == _Tp::value>::value) //
));
template <class _Tp>
_LIBCUDACXX_CONCEPT __integral_constant_like = _LIBCUDACXX_FRAGMENT(__integral_constant_like_, _Tp);
# endif // _CCCL_STD_VER <= 2017

template <class _Tp, bool = __integral_constant_like<_Tp>>
_LIBCUDACXX_INLINE_VAR constexpr size_t __maybe_static_ext = dynamic_extent;

template <class _Tp>
_LIBCUDACXX_INLINE_VAR constexpr size_t __maybe_static_ext<_Tp, true> = {_Tp::value};

template <typename _Tp, size_t _Extent>
class _LIBCUDACXX_TEMPLATE_VIS span
{
Expand Down Expand Up @@ -333,13 +372,13 @@ public:

_LIBCUDACXX_TEMPLATE(class _OtherElementType)
_LIBCUDACXX_REQUIRES(__span_array_convertible<_OtherElementType, element_type>)
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 span(array<_OtherElementType, _Extent>& __arr) noexcept
_LIBCUDACXX_INLINE_VISIBILITY constexpr span(array<_OtherElementType, _Extent>& __arr) noexcept
: __data_{__arr.data()}
{}

_LIBCUDACXX_TEMPLATE(class _OtherElementType)
_LIBCUDACXX_REQUIRES(__span_array_convertible<const _OtherElementType, element_type>)
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 span(const array<_OtherElementType, _Extent>& __arr) noexcept
_LIBCUDACXX_INLINE_VISIBILITY constexpr span(const array<_OtherElementType, _Extent>& __arr) noexcept
: __data_{__arr.data()}
{}

Expand Down Expand Up @@ -569,14 +608,14 @@ public:

_LIBCUDACXX_TEMPLATE(class _OtherElementType, size_t _Sz)
_LIBCUDACXX_REQUIRES(__span_array_convertible<_OtherElementType, element_type>)
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 span(array<_OtherElementType, _Sz>& __arr) noexcept
_LIBCUDACXX_INLINE_VISIBILITY constexpr span(array<_OtherElementType, _Sz>& __arr) noexcept
: __data_{__arr.data()}
, __size_{_Sz}
{}

_LIBCUDACXX_TEMPLATE(class _OtherElementType, size_t _Sz)
_LIBCUDACXX_REQUIRES(__span_array_convertible<const _OtherElementType, element_type>)
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 span(const array<_OtherElementType, _Sz>& __arr) noexcept
_LIBCUDACXX_INLINE_VISIBILITY constexpr span(const array<_OtherElementType, _Sz>& __arr) noexcept
: __data_{__arr.data()}
, __size_{_Sz}
{}
Expand Down Expand Up @@ -769,9 +808,11 @@ _CCCL_HOST_DEVICE span(_Container&) -> span<typename _Container::value_type>;
template <class _Container>
_CCCL_HOST_DEVICE span(const _Container&) -> span<const typename _Container::value_type>;
# else // ^^^ _CCCL_COMPILER_MSVC_2017 ^^^ / vvv !_CCCL_COMPILER_MSVC_2017 vvv

_LIBCUDACXX_TEMPLATE(class _It, class _EndOrSize)
_LIBCUDACXX_REQUIRES(contiguous_iterator<_It>)
_CCCL_HOST_DEVICE span(_It, _EndOrSize) -> span<remove_reference_t<iter_reference_t<_It>>>;
_CCCL_HOST_DEVICE span(_It, _EndOrSize)
-> span<remove_reference_t<iter_reference_t<_It>>, __maybe_static_ext<_EndOrSize>>;

_LIBCUDACXX_TEMPLATE(class _Range)
_LIBCUDACXX_REQUIRES(_CUDA_VRANGES::contiguous_range<_Range>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <cuda/std/cassert>
#include <cuda/std/iterator>
#include <cuda/std/span>
#include <cuda/std/type_traits>

#include "test_macros.h"

Expand All @@ -50,6 +51,16 @@ __host__ __device__ void test_iterator_sentinel()
assert(s.size() == cuda::std::size(arr));
assert(s.data() == cuda::std::data(arr));
}

#if !defined(TEST_COMPILER_MSVC)
// P3029R1: deduction from `integral_constant`
{
cuda::std::span s{cuda::std::begin(arr), cuda::std::integral_constant<size_t, 3>{}};
ASSERT_SAME_TYPE(decltype(s), cuda::std::span<int, 3>);
assert(s.size() == cuda::std::size(arr));
assert(s.data() == cuda::std::data(arr));
}
#endif // !TEST_COMPILER_MSVC
}

__host__ __device__ void test_c_array()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __host__ __device__ void checkCV()
}

template <typename T, typename U>
__host__ __device__ TEST_CONSTEXPR_CXX17 bool testConstructorArray()
__host__ __device__ constexpr bool testConstructorArray()
{
cuda::std::array<U, 2> val = {U(), U()};
ASSERT_NOEXCEPT(cuda::std::span<T>{val});
Expand All @@ -73,7 +73,7 @@ __host__ __device__ TEST_CONSTEXPR_CXX17 bool testConstructorArray()
}

template <typename T, typename U>
__host__ __device__ TEST_CONSTEXPR_CXX17 bool testConstructorConstArray()
__host__ __device__ constexpr bool testConstructorConstArray()
{
const cuda::std::array<U, 2> val = {U(), U()};
ASSERT_NOEXCEPT(cuda::std::span<const T>{val});
Expand All @@ -84,14 +84,14 @@ __host__ __device__ TEST_CONSTEXPR_CXX17 bool testConstructorConstArray()
}

template <typename T>
__host__ __device__ TEST_CONSTEXPR_CXX17 bool testConstructors()
__host__ __device__ constexpr bool testConstructors()
{
STATIC_ASSERT_CXX17((testConstructorArray<T, T>()));
STATIC_ASSERT_CXX17((testConstructorArray<const T, const T>()));
STATIC_ASSERT_CXX17((testConstructorArray<const T, T>()));
STATIC_ASSERT_CXX17((testConstructorConstArray<T, T>()));
STATIC_ASSERT_CXX17((testConstructorConstArray<const T, const T>()));
STATIC_ASSERT_CXX17((testConstructorConstArray<const T, T>()));
STATIC_ASSERT_CXX14((testConstructorArray<T, T>()));
STATIC_ASSERT_CXX14((testConstructorArray<const T, const T>()));
STATIC_ASSERT_CXX14((testConstructorArray<const T, T>()));
STATIC_ASSERT_CXX14((testConstructorConstArray<T, T>()));
STATIC_ASSERT_CXX14((testConstructorConstArray<const T, const T>()));
STATIC_ASSERT_CXX14((testConstructorConstArray<const T, T>()));

return testConstructorArray<T, T>() && testConstructorArray<const T, const T>() && testConstructorArray<const T, T>()
&& testConstructorConstArray<T, T>() && testConstructorConstArray<const T, const T>()
Expand Down
6 changes: 0 additions & 6 deletions libcudacxx/test/support/test_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,6 @@
# define STATIC_ASSERT_CXX14(Pred) assert(Pred)
#endif

#if TEST_STD_VER > 2014
# define STATIC_ASSERT_CXX17(Pred) static_assert(Pred, "")
#else
# define STATIC_ASSERT_CXX17(Pred) assert(Pred)
#endif

/* Macros for testing libc++ specific behavior and extensions */
#if defined(_LIBCUDACXX_VERSION)
# define LIBCPP_ASSERT(...) assert(__VA_ARGS__)
Expand Down

0 comments on commit bf1a71a

Please sign in to comment.