Skip to content

Commit

Permalink
Proclaim Thrust/CUB/libcu++ functor address stability
Browse files Browse the repository at this point in the history
Fixes: #2718
  • Loading branch information
bernhardmgruber committed Nov 6, 2024
1 parent c97f2e3 commit 561d6d8
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 7 deletions.
6 changes: 6 additions & 0 deletions cub/cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include <cub/util_cpp_dialect.cuh>
#include <cub/util_type.cuh>

#include <cuda/__functional/address_stability.h>
#include <cuda/std/functional> // cuda::std::plus
#include <cuda/std/type_traits> // cuda::std::common_type
#include <cuda/std/utility> // cuda::std::forward
Expand Down Expand Up @@ -590,3 +591,8 @@ using cub_operator_to_dpx_t = CubOperatorToDpx<ReduceOp, T>;
} // namespace internal

CUB_NAMESPACE_END

template <typename F>
struct ::cuda::proclaims_copyable_arguments<CUB_NS_QUALIFIER::InequalityWrapper<F>>
: ::cuda::proclaims_copyable_arguments<F>
{};
9 changes: 9 additions & 0 deletions libcudacxx/include/cuda/std/__functional/not_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# pragma system_header
#endif // no system header

#include <cuda/__functional/address_stability.h>
#include <cuda/std/__functional/invoke.h>
#include <cuda/std/__functional/perfect_forward.h>
#include <cuda/std/__type_traits/decay.h>
Expand Down Expand Up @@ -72,4 +73,12 @@ _LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX20 auto not_fn(_Fn&& __f)

_LIBCUDACXX_END_NAMESPACE_STD

#if _CCCL_STD_VER > 2014
_LIBCUDACXX_BEGIN_NAMESPACE_CUDA
template <typename _Fn>
struct proclaims_copyable_arguments<_CUDA_VSTD::__not_fn_t<_Fn>> : ::cuda::proclaims_copyable_arguments<_Fn>
{};
_LIBCUDACXX_END_NAMESPACE_CUDA
#endif // _CCCL_STD_VER > 2014

#endif // _LIBCUDACXX___FUNCTIONAL_NOT_FN_H
45 changes: 45 additions & 0 deletions libcudacxx/include/cuda/std/__functional/operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
# pragma system_header
#endif // no system header

#include <cuda/__functional/address_stability.h>
#include <cuda/std/__functional/binary_function.h>
#include <cuda/std/__functional/unary_function.h>
#include <cuda/std/__type_traits/conjunction.h>
#include <cuda/std/__type_traits/is_class.h>
#include <cuda/std/__type_traits/is_enum.h>
#include <cuda/std/__type_traits/is_void.h>
#include <cuda/std/__utility/forward.h>

_LIBCUDACXX_BEGIN_NAMESPACE_STD
Expand Down Expand Up @@ -527,4 +532,44 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT logical_or<void>

_LIBCUDACXX_END_NAMESPACE_STD

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template <typename _T>
struct __has_builtin_operators
: _CUDA_VSTD::bool_constant<!_CUDA_VSTD::is_class<_T>::value && !_CUDA_VSTD::is_enum<_T>::value
&& !_CUDA_VSTD::is_void<_T>::value>
{};

#define _LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(functor) \
/*we know what plus<T> etc. does if T is not a type that could have a weird operatorX() */ \
template <typename _T> \
struct proclaims_copyable_arguments<functor<_T>> : __has_builtin_operators<_T> \
{}; \
/*we do not know what plus<void> etc. does, which depends on the types it is invoked on */ \
template <> \
struct proclaims_copyable_arguments<functor<void>> : _CUDA_VSTD::false_type \
{};

_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::plus);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::minus);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::multiplies);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::divides);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::modulus);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::negate);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::bit_and);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::bit_not);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::bit_or);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::bit_xor);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::equal_to);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::not_equal_to);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::less);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::less_equal);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::greater_equal);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::greater);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::logical_and);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::logical_not);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(_CUDA_VSTD::logical_or);

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif // _LIBCUDACXX___FUNCTIONAL_OPERATIONS_H
2 changes: 2 additions & 0 deletions libcudacxx/include/cuda/std/__functional/ranges_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
_LIBCUDACXX_BEGIN_NAMESPACE_RANGES
_LIBCUDACXX_BEGIN_NAMESPACE_RANGES_ABI

// TODO(bgruber): do we need to specialize proclaims_copyable_arguments here as well?

struct equal_to
{
_LIBCUDACXX_TEMPLATE(class _Tp, class _Up)
Expand Down
75 changes: 68 additions & 7 deletions thrust/testing/address_stability.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,84 @@

#include <unittest/unittest.h>

struct addable
{
_CCCL_HOST_DEVICE friend auto operator+(const addable&, const addable&) -> addable
{
return addable{};
}
};

void TestAddressStabilityLibcuxx()
{
using ::cuda::proclaim_copyable_arguments;
using ::cuda::proclaims_copyable_arguments;

// libcu++ function objects with known types
static_assert(proclaims_copyable_arguments<::cuda::std::plus<int>>::value, "");
static_assert(!proclaims_copyable_arguments<::cuda::std::plus<>>::value, "");

// libcu++ function objects with unknown types
static_assert(!proclaims_copyable_arguments<::cuda::std::plus<addable>>::value, "");
static_assert(!proclaims_copyable_arguments<::cuda::std::plus<>>::value, "");

// libcu++ function objects with unknown types and opt-in
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(cuda::std::plus<addable>{}))>::value,
"");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(cuda::std::plus<>{}))>::value, "");
}
DECLARE_UNITTEST(TestAddressStabilityLibcuxx);

void TestAddressStabilityThrust()
{
using ::cuda::proclaim_copyable_arguments;
using ::cuda::proclaims_copyable_arguments;

// thrust function objects with known types
static_assert(proclaims_copyable_arguments<thrust::plus<int>>::value, "");
static_assert(!proclaims_copyable_arguments<thrust::plus<>>::value, "");

// thrust function objects with unknown types
static_assert(!proclaims_copyable_arguments<thrust::plus<addable>>::value, "");
static_assert(!proclaims_copyable_arguments<thrust::plus<>>::value, "");

// thrust function objects with unknown types and opt-in
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(thrust::plus<addable>{}))>::value,
"");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(thrust::plus<>{}))>::value, "");
}
DECLARE_UNITTEST(TestAddressStabilityThrust);

template <typename T>
struct my_plus
{
_CCCL_HOST_DEVICE auto operator()(int a, int b) const -> int
_CCCL_HOST_DEVICE auto operator()(T a, T b) const -> T
{
return a + b;
}
};

void TestAddressStability()
void TestAddressStabilityUserDefinedFunctionObject()
{
using ::cuda::proclaim_copyable_arguments;
using ::cuda::proclaims_copyable_arguments;

static_assert(!proclaims_copyable_arguments<thrust::plus<int>>::value, "");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(thrust::plus<int>{}))>::value, "");
// by-value overload
static_assert(!proclaims_copyable_arguments<my_plus<int>>::value, "");

// by-value overload with opt-in
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(my_plus<int>{}))>::value, "");

// by-reference overload
static_assert(!proclaims_copyable_arguments<my_plus<int&>>::value, "");
static_assert(!proclaims_copyable_arguments<my_plus<const int&>>::value, "");
static_assert(!proclaims_copyable_arguments<my_plus<int&&>>::value, "");
static_assert(!proclaims_copyable_arguments<my_plus<const int&&>>::value, "");

static_assert(!proclaims_copyable_arguments<my_plus>::value, "");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(my_plus{}))>::value, "");
// by-reference overload with opt-in
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(my_plus<int&>{}))>::value, "");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(my_plus<const int&>{}))>::value, "");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(my_plus<int&&>{}))>::value, "");
static_assert(proclaims_copyable_arguments<decltype(proclaim_copyable_arguments(my_plus<const int&&>{}))>::value, "");
}
DECLARE_UNITTEST(TestAddressStability);
DECLARE_UNITTEST(TestAddressStabilityUserDefinedFunctionObject);
24 changes: 24 additions & 0 deletions thrust/thrust/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,30 @@ THRUST_INLINE_CONSTANT thrust::detail::functional::placeholder<9>::type _10;

THRUST_NAMESPACE_END

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::plus);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::minus);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::multiplies);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::divides);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::modulus);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::negate);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::bit_and);
//_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::bit_not); // does not exist?
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::bit_or);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::bit_xor);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::equal_to);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::not_equal_to);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::less);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::less_equal);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::greater_equal);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::greater);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::logical_and);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::logical_not);
_LIBCUDACXX_MARK_CAN_COPY_ARGUMENTS(thrust::logical_or);

_LIBCUDACXX_END_NAMESPACE_CUDA

#include <thrust/detail/functional.inl>
#include <thrust/detail/functional/operators.h>
#include <thrust/detail/type_traits/is_commutative.h>
7 changes: 7 additions & 0 deletions thrust/thrust/zip_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
# include <thrust/tuple.h>
# include <thrust/type_traits/integer_sequence.h>

# include <cuda/__functional/address_stability.h>

THRUST_NAMESPACE_BEGIN

/*! \addtogroup function_objects Function Objects
Expand Down Expand Up @@ -201,4 +203,9 @@ _CCCL_HOST_DEVICE zip_function<typename std::decay<Function>::type> make_zip_fun

THRUST_NAMESPACE_END

template <typename F>
struct ::cuda::proclaims_copyable_arguments<THRUST_NS_QUALIFIER::zip_function<F>>
: ::cuda::proclaims_copyable_arguments<F>
{};

#endif

0 comments on commit 561d6d8

Please sign in to comment.