diff --git a/cub/cub/device/device_transform.cuh b/cub/cub/device/device_transform.cuh index 994670c9d33..ac0953e599f 100644 --- a/cub/cub/device/device_transform.cuh +++ b/cub/cub/device/device_transform.cuh @@ -87,14 +87,16 @@ private: } template - CUB_RUNTIME_FUNCTION static auto get_stream(Env env) -> cudaStream_t + CUB_RUNTIME_FUNCTION static auto get_stream([[maybe_unused]] Env env) -> cudaStream_t { - return ::cuda::std::execution::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}).get(); - } - - CUB_RUNTIME_FUNCTION static auto get_stream(cudaStream_t stream) -> cudaStream_t - { - return stream; + if constexpr (::cuda::std::is_invocable_v<::cuda::get_stream_t, Env>) + { + return ::cuda::get_stream(env).get(); + } + else + { + return cudaStream_t{}; + } } public: @@ -489,7 +491,7 @@ public: num_items, ::cuda::std::move(predicate), ::cuda::std::move(transform_op), - get_stream(env)); + ::cuda::std::move(env)); } #ifndef _CCCL_DOXYGEN_INVOKED // Do not document @@ -633,7 +635,7 @@ public: ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), - get_stream(env)); + ::cuda::std::move(env)); } #ifndef _CCCL_DOXYGEN_INVOKED // Do not document diff --git a/cub/test/catch2_test_device_transform_env.cu b/cub/test/catch2_test_device_transform_env.cu index f5a24225178..20191039465 100644 --- a/cub/test/catch2_test_device_transform_env.cu +++ b/cub/test/catch2_test_device_transform_env.cu @@ -13,6 +13,16 @@ namespace stdexec = cuda::std::execution; using namespace thrust::placeholders; +struct custom_stream +{ + cudaStream_t stream; + + operator cudaStream_t() const noexcept + { + return stream; + } +}; + auto make_stream_env(cudaStream_t stream) { // MSVC has trouble nesting two aggregate initializations with CTAD @@ -31,14 +41,20 @@ C2H_TEST("DeviceTransform::Transform custom stream", "[device][transform]") REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, thrust::no_init); + auto run = [&](auto streamish) { + cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, env); + run(make_stream_env(stream)); } REQUIRE(cudaStreamDestroy(stream) == cudaSuccess); @@ -55,14 +71,20 @@ C2H_TEST("DeviceTransform::Transform (single argument) custom stream", "[device] REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, thrust::no_init); + auto run = [&](auto streamish) { + cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, env); + run(make_stream_env(stream)); } REQUIRE(cudaStreamDestroy(stream) == cudaSuccess); @@ -79,14 +101,20 @@ C2H_TEST("DeviceTransform::Generate custom stream", "[device][transform]") REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, thrust::no_init); + auto run = [&](auto streamish) { + cub::DeviceTransform::Generate(result.begin(), num_items, generator, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::Generate(result.begin(), num_items, generator, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::Generate(result.begin(), num_items, generator, env); + run(make_stream_env(stream)); } REQUIRE(cudaStreamDestroy(stream) == cudaSuccess); @@ -102,14 +130,20 @@ C2H_TEST("DeviceTransform::Fill custom stream", "[device][transform]") REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, thrust::no_init); + auto run = [&](auto streamish) { + cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, env); + run(make_stream_env(stream)); } REQUIRE(cudaStreamDestroy(stream) == cudaSuccess); @@ -136,16 +170,21 @@ C2H_TEST("DeviceTransform::TransformIf custom stream", "[device][transform]") REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, 1337); + auto run = [&](auto streamish) { + cub::DeviceTransform::TransformIf( + cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::TransformIf( - cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::TransformIf( - cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, env); + run(make_stream_env(stream)); } auto reference_it = cuda::transform_iterator{cuda::counting_iterator{42}, reference_func{}}; @@ -164,14 +203,20 @@ C2H_TEST("DeviceTransform::TransformIf (single argument) custom stream", "[devic REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, 1337); + auto run = [&](auto streamish) { + cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, env); + run(make_stream_env(stream)); } auto reference_it = cuda::transform_iterator{cuda::counting_iterator{42}, reference_func{}}; @@ -191,16 +236,21 @@ C2H_TEST("DeviceTransform::TransformStableArgumentAddresses custom stream", "[de REQUIRE(cudaStreamCreate(&stream) == cudaSuccess); c2h::device_vector result(num_items, thrust::no_init); + auto run = [&](auto streamish) { + cub::DeviceTransform::TransformStableArgumentAddresses( + cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish); + }; SECTION("raw stream") { - cub::DeviceTransform::TransformStableArgumentAddresses( - cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream); + run(stream); + } + SECTION("custom stream") + { + run(custom_stream{stream}); } SECTION("environment") { - auto env = make_stream_env(stream); - cub::DeviceTransform::TransformStableArgumentAddresses( - cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, env); + run(make_stream_env(stream)); } REQUIRE(cudaStreamDestroy(stream) == cudaSuccess); diff --git a/libcudacxx/include/cuda/__stream/get_stream.h b/libcudacxx/include/cuda/__stream/get_stream.h index f318050ab3d..45717423972 100644 --- a/libcudacxx/include/cuda/__stream/get_stream.h +++ b/libcudacxx/include/cuda/__stream/get_stream.h @@ -57,6 +57,11 @@ _CCCL_CONCEPT __has_query_get_stream = _CCCL_REQUIRES_EXPR((_Env), const _Env& _ //! @brief `get_stream` is a customization point object that queries a type `T` for an associated stream struct get_stream_t { + [[nodiscard]] _CCCL_API constexpr ::cuda::stream_ref operator()(::cudaStream_t __stream) const noexcept + { + return ::cuda::stream_ref{__stream}; + } + _CCCL_EXEC_CHECK_DISABLE _CCCL_TEMPLATE(class _Tp) _CCCL_REQUIRES(__convertible_to_stream_ref<_Tp>) diff --git a/libcudacxx/test/libcudacxx/cuda/stream_ref/get_stream.pass.cpp b/libcudacxx/test/libcudacxx/cuda/stream_ref/get_stream.pass.cpp index bf5b8bf0f13..9691127dbcf 100644 --- a/libcudacxx/test/libcudacxx/cuda/stream_ref/get_stream.pass.cpp +++ b/libcudacxx/test/libcudacxx/cuda/stream_ref/get_stream.pass.cpp @@ -20,6 +20,34 @@ __host__ __device__ void test() assert(stream == ref); } + { // Can call get_stream on a type convertible to cudaStream_t + struct convertible_to_cuda_stream_t + { + ::cudaStream_t stream_; + __host__ __device__ operator ::cudaStream_t() const noexcept + { + return stream_; + } + }; + convertible_to_cuda_stream_t str{stream}; + auto ref = ::cuda::get_stream(str); + assert(stream == ref); + } + + { // Can call get_stream on a type convertible to stream_ref + struct convertible_to_stream_ref + { + ::cudaStream_t stream_; + __host__ __device__ operator ::cuda::stream_ref() const noexcept + { + return ::cuda::stream_ref{stream_}; + } + }; + convertible_to_stream_ref str{stream}; + auto ref = ::cuda::get_stream(str); + assert(stream == ref); + } + { // Can call get_stream on a type with a get_stream method struct with_const_get_stream {