Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions cub/cub/device/device_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,16 @@ private:
}

template <typename Env>
CUB_RUNTIME_FUNCTION static auto get_stream(Env env) -> cudaStream_t
CUB_RUNTIME_FUNCTION static auto get_stream([[maybe_unused]] Env env) -> cudaStream_t
{
return ::cuda::std::execution::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}).get();
}

CUB_RUNTIME_FUNCTION static auto get_stream(cudaStream_t stream) -> cudaStream_t
{
return stream;
if constexpr (::cuda::std::is_invocable_v<::cuda::get_stream_t, Env>)
{
return ::cuda::get_stream(env).get();
}
else
{
return cudaStream_t{};
}
}

public:
Expand Down Expand Up @@ -489,7 +491,7 @@ public:
num_items,
::cuda::std::move(predicate),
::cuda::std::move(transform_op),
get_stream(env));
::cuda::std::move(env));
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
Expand Down Expand Up @@ -633,7 +635,7 @@ public:
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
get_stream(env));
::cuda::std::move(env));
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
Expand Down
100 changes: 75 additions & 25 deletions cub/test/catch2_test_device_transform_env.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ namespace stdexec = cuda::std::execution;

using namespace thrust::placeholders;

struct custom_stream
{
cudaStream_t stream;

operator cudaStream_t() const noexcept
{
return stream;
}
};

auto make_stream_env(cudaStream_t stream)
{
// MSVC has trouble nesting two aggregate initializations with CTAD
Expand All @@ -31,14 +41,20 @@ C2H_TEST("DeviceTransform::Transform custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -55,14 +71,20 @@ C2H_TEST("DeviceTransform::Transform (single argument) custom stream", "[device]
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -79,14 +101,20 @@ C2H_TEST("DeviceTransform::Generate custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Generate(result.begin(), num_items, generator, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Generate(result.begin(), num_items, generator, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Generate(result.begin(), num_items, generator, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -102,14 +130,20 @@ C2H_TEST("DeviceTransform::Fill custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -136,16 +170,21 @@ C2H_TEST("DeviceTransform::TransformIf custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, 1337);
auto run = [&](auto streamish) {
cub::DeviceTransform::TransformIf(
cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::TransformIf(
cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::TransformIf(
cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, env);
run(make_stream_env(stream));
}

auto reference_it = cuda::transform_iterator{cuda::counting_iterator{42}, reference_func{}};
Expand All @@ -164,14 +203,20 @@ C2H_TEST("DeviceTransform::TransformIf (single argument) custom stream", "[devic
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, 1337);
auto run = [&](auto streamish) {
cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, env);
run(make_stream_env(stream));
}

auto reference_it = cuda::transform_iterator{cuda::counting_iterator{42}, reference_func{}};
Expand All @@ -191,16 +236,21 @@ C2H_TEST("DeviceTransform::TransformStableArgumentAddresses custom stream", "[de
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::TransformStableArgumentAddresses(
cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::TransformStableArgumentAddresses(
cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::TransformStableArgumentAddresses(
cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand Down
5 changes: 5 additions & 0 deletions libcudacxx/include/cuda/__stream/get_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ _CCCL_CONCEPT __has_query_get_stream = _CCCL_REQUIRES_EXPR((_Env), const _Env& _
//! @brief `get_stream` is a customization point object that queries a type `T` for an associated stream
struct get_stream_t
{
[[nodiscard]] _CCCL_API constexpr ::cuda::stream_ref operator()(::cudaStream_t __stream) const noexcept
{
return ::cuda::stream_ref{__stream};
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_TEMPLATE(class _Tp)
_CCCL_REQUIRES(__convertible_to_stream_ref<_Tp>)
Expand Down
28 changes: 28 additions & 0 deletions libcudacxx/test/libcudacxx/cuda/stream_ref/get_stream.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@ __host__ __device__ void test()
assert(stream == ref);
}

{ // Can call get_stream on a type convertible to cudaStream_t
struct convertible_to_cuda_stream_t
{
::cudaStream_t stream_;
__host__ __device__ operator ::cudaStream_t() const noexcept
{
return stream_;
}
};
convertible_to_cuda_stream_t str{stream};
auto ref = ::cuda::get_stream(str);
assert(stream == ref);
}

{ // Can call get_stream on a type convertible to stream_ref
struct convertible_to_stream_ref
{
::cudaStream_t stream_;
__host__ __device__ operator ::cuda::stream_ref() const noexcept
{
return ::cuda::stream_ref{stream_};
}
};
convertible_to_stream_ref str{stream};
auto ref = ::cuda::get_stream(str);
assert(stream == ref);
}

{ // Can call get_stream on a type with a get_stream method
struct with_const_get_stream
{
Expand Down