-
Notifications
You must be signed in to change notification settings - Fork 321
Fix extracting CUDA stream in cub::DeviceTransform
#7239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix extracting CUDA stream in cub::DeviceTransform
#7239
Conversation
| num_items, | ||
| ::cuda::std::move(transform_op), | ||
| get_stream(env)); | ||
| ::cuda::std::move(env)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drive-by fix
| auto run = [&](auto streamish) { | ||
| cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, streamish); | ||
| }; | ||
| SECTION("raw stream") | ||
| { | ||
| cub::DeviceTransform::Transform(cuda::std::make_tuple(a, b), result.begin(), num_items, _1 + _2, stream); | ||
| run(stream); | ||
| } | ||
| SECTION("custom stream") | ||
| { | ||
| run(custom_stream{stream}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I verified that the stream is extracted in the debugger, but I wonder if I could write the unit test in a way to detect if the default stream was taken anywhere. Does anybody know if I can query the stream whether something was really enqueued there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One solution would be to start a graph capture on a stream and see if anything was captured, but that might have some limitations, not sure if its applicable here
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
miscco
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love to better understand where the issue with the conversions from cudaStream_t lies so that we can extent our get_stream CPO
What was the exact case failing? from what I can see it the second and third branch in the if constexpr should already work
Maybe we just need another clause for get_stream
cub/cub/device/device_transform.cuh
Outdated
| } | ||
| else | ||
| { | ||
| return ::cuda::std::execution::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}).get(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very close to what we have in get_stream
I believe the issue is that we are checking whether tis convertible to stream_ref and not cudaStream_t
But at the same time cudaStream_t should be convertible to stream_ref
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Discussed this PR with @miscco and we concluded that the environment query for |
|
/ok to test cc8f29e |
|
I have updated our |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
NVIDIA#6204 changed cub::DeviceTransform APIs from taking cudaStream_t to environments. Special handling preserved support for cudaStream_t. However, user-provided stream types with conversion operators to cudaStream_t were now queried as environments, failing to return a stream.
We should not have a special overload for the `get_stream` overload but should ensure that we can pass types that are convertible to `::cudaStream_t` and extract that stream I have expanded the get_stream CPO to accept a `::cudaStream_t __stream` so that we can work with those types
cc8f29e to
8ff0cdd
Compare
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
🥳 CI Workflow Results🟩 Finished in 6h 06m: Pass: 100%/126 | Total: 5d 13h | Max: 5h 42m | Hits: 89%/250239See results here. |
|
Successfully created backport PR for |
* Fix extracting CUDA stream in cub::DeviceTransform #6204 changed cub::DeviceTransform APIs from taking cudaStream_t to environments. Special handling preserved support for cudaStream_t. However, user-provided stream types with conversion operators to cudaStream_t were now queried as environments, failing to return a stream. * Properly use `get_stream` in device transform We should not have a special overload for the `get_stream` overload but should ensure that we can pass types that are convertible to `::cudaStream_t` and extract that stream I have expanded the get_stream CPO to accept a `::cudaStream_t __stream` so that we can work with those types Co-authored-by: Michael Schellenberger Costa <[email protected]> (cherry picked from commit 31f8a13)
#6204 changed
cub::DeviceTransformAPIs from takingcudaStream_tto environments. Special handling preserved support forcudaStream_t. However, user-provided stream types with conversion operators tocudaStream_twere now queried as environments, failing to return a stream.This PR treats any type that is convertible to a
cudaStream_tnot as an environment and extracts the underlying stream.Fixes NVBug 5813928