Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions cub/cub/device/device_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,16 @@ private:
}

template <typename Env>
CUB_RUNTIME_FUNCTION static auto get_stream(Env env) -> cudaStream_t
CUB_RUNTIME_FUNCTION static auto get_stream([[maybe_unused]] Env env) -> cudaStream_t
{
return ::cuda::std::execution::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}).get();
}

CUB_RUNTIME_FUNCTION static auto get_stream(cudaStream_t stream) -> cudaStream_t
{
return stream;
if constexpr (::cuda::std::is_invocable_v<::cuda::get_stream_t, Env>)
{
return ::cuda::get_stream(env).get();
}
else
{
return cudaStream_t{};
}
}

public:
Expand Down Expand Up @@ -654,7 +656,7 @@ public:
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
get_stream(env));
::cuda::std::move(env));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by fix

}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
Expand Down
100 changes: 75 additions & 25 deletions cub/test/catch2_test_device_transform_env.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ namespace stdexec = cuda::std::execution;

using namespace thrust::placeholders;

struct custom_stream
{
cudaStream_t stream;

operator cudaStream_t() const noexcept
{
return stream;
}
};

auto make_stream_env(cudaStream_t stream)
{
// MSVC has trouble nesting two aggregate initializations with CTAD
Expand All @@ -31,14 +41,20 @@ C2H_TEST("DeviceTransform::Transform custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
Comment on lines +44 to +53
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that the stream is extracted in the debugger, but I wonder if I could write the unit test in a way to detect if the default stream was taken anywhere. Does anybody know if I can query the stream whether something was really enqueued there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One solution would be to start a graph capture on a stream and see if anything was captured, but that might have some limitations, not sure if its applicable here

}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -55,14 +71,20 @@ C2H_TEST("DeviceTransform::Transform (single argument) custom stream", "[device]
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Transform(a, result.begin(), num_items, _1 + 13, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -79,14 +101,20 @@ C2H_TEST("DeviceTransform::Generate custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Generate(result.begin(), num_items, generator, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Generate(result.begin(), num_items, generator, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Generate(result.begin(), num_items, generator, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -102,14 +130,20 @@ C2H_TEST("DeviceTransform::Fill custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::Fill(result.begin(), num_items, 0xBAD, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand All @@ -136,16 +170,21 @@ C2H_TEST("DeviceTransform::TransformIf custom stream", "[device][transform]")
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, 1337);
auto run = [&](auto streamish) {
cub::DeviceTransform::TransformIf(
cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::TransformIf(
cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::TransformIf(
cuda::std::make_tuple(a, b), result.begin(), num_items, (_1 + _2) > 1000, _1 + _2, env);
run(make_stream_env(stream));
}

auto reference_it = cuda::transform_iterator{cuda::counting_iterator{42}, reference_func{}};
Expand All @@ -164,14 +203,20 @@ C2H_TEST("DeviceTransform::TransformIf (single argument) custom stream", "[devic
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, 1337);
auto run = [&](auto streamish) {
cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::TransformIf(a, result.begin(), num_items, (_1 + 13) > 1000, _1 + 13, env);
run(make_stream_env(stream));
}

auto reference_it = cuda::transform_iterator{cuda::counting_iterator{42}, reference_func{}};
Expand All @@ -191,16 +236,21 @@ C2H_TEST("DeviceTransform::TransformStableArgumentAddresses custom stream", "[de
REQUIRE(cudaStreamCreate(&stream) == cudaSuccess);

c2h::device_vector<type> result(num_items, thrust::no_init);
auto run = [&](auto streamish) {
cub::DeviceTransform::TransformStableArgumentAddresses(
cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish);
};
SECTION("raw stream")
{
cub::DeviceTransform::TransformStableArgumentAddresses(
cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream);
run(stream);
}
SECTION("custom stream")
{
run(custom_stream{stream});
}
SECTION("environment")
{
auto env = make_stream_env(stream);
cub::DeviceTransform::TransformStableArgumentAddresses(
cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, env);
run(make_stream_env(stream));
}

REQUIRE(cudaStreamDestroy(stream) == cudaSuccess);
Expand Down
5 changes: 5 additions & 0 deletions libcudacxx/include/cuda/__stream/get_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ _CCCL_CONCEPT __has_query_get_stream = _CCCL_REQUIRES_EXPR((_Env), const _Env& _
//! @brief `get_stream` is a customization point object that queries a type `T` for an associated stream
struct get_stream_t
{
[[nodiscard]] _CCCL_API constexpr ::cuda::stream_ref operator()(::cudaStream_t __stream) const noexcept
{
return ::cuda::stream_ref{__stream};
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_TEMPLATE(class _Tp)
_CCCL_REQUIRES(__convertible_to_stream_ref<_Tp>)
Expand Down
28 changes: 28 additions & 0 deletions libcudacxx/test/libcudacxx/cuda/stream_ref/get_stream.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@ __host__ __device__ void test()
assert(stream == ref);
}

{ // Can call get_stream on a type convertible to cudaStream_t
struct convertible_to_cuda_stream_t
{
::cudaStream_t stream_;
__host__ __device__ operator ::cudaStream_t() const noexcept
{
return stream_;
}
};
convertible_to_cuda_stream_t str{stream};
auto ref = ::cuda::get_stream(str);
assert(stream == ref);
}

{ // Can call get_stream on a type convertible to stream_ref
struct convertible_to_stream_ref
{
::cudaStream_t stream_;
__host__ __device__ operator ::cuda::stream_ref() const noexcept
{
return ::cuda::stream_ref{stream_};
}
};
convertible_to_stream_ref str{stream};
auto ref = ::cuda::get_stream(str);
assert(stream == ref);
}

{ // Can call get_stream on a type with a get_stream method
struct with_const_get_stream
{
Expand Down