diff --git a/thrust/testing/transform_iterator.cu b/thrust/testing/transform_iterator.cu index 7bb87d4625..346f0fa8b5 100644 --- a/thrust/testing/transform_iterator.cu +++ b/thrust/testing/transform_iterator.cu @@ -6,6 +6,7 @@ #include #include +#include #include @@ -108,3 +109,110 @@ void TestTransformIteratorNonCopyable() } DECLARE_UNITTEST(TestTransformIteratorNonCopyable); + +struct flip_value +{ + _CCCL_HOST_DEVICE bool operator()(bool b) const + { + return !b; + } +}; + +struct pass_ref +{ + _CCCL_HOST_DEVICE const bool& operator()(const bool& b) const + { + return b; + } +}; + +// TODO(bgruber): replace by libc++ with C++14 +struct forward +{ + template + constexpr _Tp&& operator()(_Tp&& __t) const noexcept + { + return _CUDA_VSTD::forward<_Tp>(__t); + } +}; + +void TestTransformIteratorReferenceAndValueType() +{ + using ::cuda::std::is_same; + using ::cuda::std::negate; + { + thrust::host_vector v; + auto it = v.begin(); + static_assert(is_same::value, ""); + auto it_tr_val = thrust::make_transform_iterator(it, flip_value{}); + static_assert(is_same::value, ""); + auto it_tr_ref = thrust::make_transform_iterator(it, pass_ref{}); + static_assert(is_same::value, ""); + auto it_tr_fwd = thrust::make_transform_iterator(it, forward{}); + } + + { + thrust::device_vector v; + auto it = v.begin(); + static_assert(is_same>::value, ""); + auto it_tr_val = thrust::make_transform_iterator(it, flip_value{}); + static_assert(is_same::value, ""); + auto it_tr_ref = thrust::make_transform_iterator(it, pass_ref{}); + static_assert(is_same::value, ""); + auto it_tr_fwd = thrust::make_transform_iterator(it, forward{}); + static_assert(is_same::value, ""); // device ref. is decayed + } + + { + std::vector v; + auto it = v.begin(); + static_assert(is_same::reference>::value, ""); + auto it_tr_val = thrust::make_transform_iterator(it, flip_value{}); + static_assert(is_same::value, ""); + auto it_tr_ref = thrust::make_transform_iterator(it, pass_ref{}); + static_assert(is_same::value, ""); + auto it_tr_fwd = thrust::make_transform_iterator(it, forward{}); + static_assert(is_same::reference&&>::value, ""); // No handling for std + } +} +DECLARE_UNITTEST(TestTransformIteratorReferenceAndValueType); + +struct foo +{ + int x, y; +}; + +struct access_x +{ + _CCCL_HOST_DEVICE int& operator()(foo& f) const noexcept + { + return f.x; + } +}; + +template