-
Notifications
You must be signed in to change notification settings - Fork 291
PTX shfl_sync
#3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
PTX shfl_sync
#3241
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
fee650c
btx shfl_sync implementation
fbusato 24b8e28
add documentation and tests
fbusato db3edb9
Update libcudacxx/test/libcudacxx/cuda/ptx/manual/shfl_test.h
fbusato eb6df1d
move documentation file
fbusato c0be178
use template parameter for input data
fbusato f0caaa9
fix return type
fbusato 969df92
update docs
fbusato d957522
modify return value type
fbusato f683984
Merge branch 'main' into ptx-shuffle
fbusato 9e577a0
Merge branch 'main' into ptx-shuffle
miscco c17b109
copyright update
fbusato cbb9d84
refactor to better match PTX generator
fbusato 3c1a37a
copyright update
fbusato 9804923
fix comparison of integers of different signs
fbusato 9258523
fix documentation
fbusato 5c38e1e
change function names to match ptx
fbusato f237b19
Merge branch 'main' into ptx-shuffle
fbusato 69dc225
NIT
fbusato 9acb0d7
move manual/shfl_test.h
fbusato 1b4798d
recover correct instructions.rst header
fbusato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
|
|
||
| shfl.sync | ||
| ^^^^^^^^^ | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| // PTX ISA 6.0 | ||
| // shfl.sync.mode.b32 d[|p], a, b, c, membermask; | ||
| // .mode = { .up, .down, .bfly, .idx }; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_idx(T data, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_idx(T data, | ||
| bool& pred, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_up(T data, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_up(T data, | ||
| bool& pred, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_down(T data, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_down(T data, | ||
| bool& pred, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_bfly(T data, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| template<typename T> | ||
| [[nodiscard]] __device__ static inline | ||
| T shfl_sync_bfly(T data, | ||
| bool& pred, | ||
| uint32_t lane_idx_offset, | ||
| uint32_t clamp_segmask, | ||
| uint32_t lane_mask) noexcept; | ||
|
|
||
| **Constrains and checks** | ||
|
|
||
| - ``T`` must have 32-bit size (compile-time) | ||
| - ``lane_idx_offset`` must be less than the warp size (debug mode) | ||
| - ``clamp_segmask`` must use the bit positions [0:4] and [8:12] (debug mode) | ||
| - ``lane_mask`` must be a subset of the active mask (debug mode) | ||
| - The destination lane must be a member of the ``lane_mask`` (debug mode) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
|
|
||
| .. _libcudacxx-ptx-instructions-shfl_sync: | ||
|
|
||
| shfl.sync | ||
| ========= | ||
|
|
||
| - PTX ISA: | ||
| `shfl.sync <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync>`__ | ||
|
|
||
| .. include:: generated/shfl_sync.rst |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,272 @@ | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Part of libcu++, the C++ Standard Library for your entire system, | ||
| // under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef _CUDA_PTX_SHFL_SYNC_H | ||
| #define _CUDA_PTX_SHFL_SYNC_H | ||
|
|
||
| #include <cuda/std/detail/__config> | ||
|
|
||
| #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
| # pragma GCC system_header | ||
| #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
| # pragma clang system_header | ||
| #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
| # pragma system_header | ||
| #endif // no system header | ||
|
|
||
| #include <cuda/__ptx/instructions/get_sreg.h> | ||
| #include <cuda/__ptx/ptx_dot_variants.h> | ||
| #include <cuda/std/__bit/bit_cast.h> | ||
| #include <cuda/std/cstdint> | ||
|
|
||
| #include <nv/target> // __CUDA_MINIMUM_ARCH__ and friends | ||
|
|
||
| _LIBCUDACXX_BEGIN_NAMESPACE_CUDA_PTX | ||
|
|
||
| #if __cccl_ptx_isa >= 600 | ||
|
|
||
| enum class __dot_shfl_mode | ||
| { | ||
| __up, | ||
| __down, | ||
| __bfly, | ||
| __idx | ||
| }; | ||
|
|
||
| [[maybe_unused]] | ||
| _CCCL_DEVICE static inline _CUDA_VSTD::uint32_t __shfl_sync_dst_lane( | ||
| __dot_shfl_mode __shfl_mode, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) | ||
| { | ||
| auto __lane = ::cuda::ptx::get_sreg_laneid(); | ||
| auto __clamp = __clamp_segmask & 0b11111; | ||
| auto __segmask = __clamp_segmask >> 8; | ||
| auto __max_lane = (__lane & __segmask) | (__clamp & ~__segmask); | ||
| _CUDA_VSTD::uint32_t __j = 0; | ||
| if (__shfl_mode == __dot_shfl_mode::__idx) | ||
| { | ||
| auto __min_lane = __lane & __clamp; | ||
| __j = __min_lane | (__lane_idx_offset & ~__segmask); | ||
| } | ||
| else if (__shfl_mode == __dot_shfl_mode::__up) | ||
| { | ||
| __j = __lane - __lane_idx_offset; | ||
| } | ||
| else if (__shfl_mode == __dot_shfl_mode::__down) | ||
| { | ||
| __j = __lane + __lane_idx_offset; | ||
| } | ||
| else | ||
| { | ||
| __j = __lane ^ __lane_idx_offset; | ||
| } | ||
| auto __dst = __shfl_mode == __dot_shfl_mode::__up | ||
| ? (__j >= __max_lane ? __j : __lane) // | ||
| : (__j <= __max_lane ? __j : __lane); | ||
| return (1u << __dst); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_DEVICE static inline void __shfl_sync_checks( | ||
| __dot_shfl_mode __shfl_mode, | ||
| _Tp, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) | ||
| { | ||
| static_assert(sizeof(_Tp) == 4, "shfl.sync only accepts 4-byte data types"); | ||
| if (__shfl_mode != __dot_shfl_mode::__idx) | ||
| { | ||
| _CCCL_ASSERT(__lane_idx_offset < 32, "the lane index or offset must be less than the warp size"); | ||
| } | ||
| _CCCL_ASSERT((__clamp_segmask | 0b1111100011111) == 0b1111100011111, | ||
| "clamp value + segmentation mask must use the bit positions [0:4] and [8:12]"); | ||
| _CCCL_ASSERT((__lane_mask & __activemask()) == __lane_mask, "lane mask must be a subset of the active mask"); | ||
| _CCCL_ASSERT( | ||
| ::cuda::ptx::__shfl_sync_dst_lane(__shfl_mode, __lane_idx_offset, __clamp_segmask, __lane_mask) & __lane_mask, | ||
| "the destination lane must be a member of the lane mask"); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_idx( | ||
| _Tp __data, | ||
| bool& __pred, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__idx, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::int32_t __pred1; | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile( | ||
| "{ \n\t\t" | ||
| ".reg .pred p; \n\t\t" | ||
| "shfl.sync.idx.b32 %0|p, %2, %3, %4, %5; \n\t\t" | ||
| "selp.s32 %1, 1, 0, p; \n\t" | ||
| "}" | ||
| : "=r"(__ret), "=r"(__pred1) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| __pred = static_cast<bool>(__pred1); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_idx( | ||
| _Tp __data, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__idx, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile("{ \n\t\t" | ||
| "shfl.sync.idx.b32 %0, %1, %2, %3, %4; \n\t\t" | ||
| "}" | ||
| : "=r"(__ret) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_up( | ||
| _Tp __data, | ||
| bool& __pred, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__up, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::int32_t __pred1; | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile( | ||
| "{ \n\t\t" | ||
| ".reg .pred p; \n\t\t" | ||
| "shfl.sync.up.b32 %0|p, %2, %3, %4, %5; \n\t\t" | ||
| "selp.s32 %1, 1, 0, p; \n\t" | ||
| "}" | ||
| : "=r"(__ret), "=r"(__pred1) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| __pred = static_cast<bool>(__pred1); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_up( | ||
| _Tp __data, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__up, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile("{ \n\t\t" | ||
| "shfl.sync.up.b32 %0, %1, %2, %3, %4; \n\t\t" | ||
| "}" | ||
| : "=r"(__ret) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_down( | ||
| _Tp __data, | ||
| bool& __pred, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__down, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::int32_t __pred1; | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile( | ||
| "{ \n\t\t" | ||
| ".reg .pred p; \n\t\t" | ||
| "shfl.sync.down.b32 %0|p, %2, %3, %4, %5; \n\t\t" | ||
| "selp.s32 %1, 1, 0, p; \n\t" | ||
| "}" | ||
| : "=r"(__ret), "=r"(__pred1) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| __pred = static_cast<bool>(__pred1); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_down( | ||
| _Tp __data, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__down, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile("{ \n\t\t" | ||
| "shfl.sync.down.b32 %0, %1, %2, %3, %4; \n\t\t" | ||
| "}" | ||
| : "=r"(__ret) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_bfly( | ||
| _Tp __data, | ||
| bool& __pred, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__bfly, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::int32_t __pred1; | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile( | ||
| "{ \n\t\t" | ||
| ".reg .pred p; \n\t\t" | ||
| "shfl.sync.bfly.b32 %0|p, %2, %3, %4, %5; \n\t\t" | ||
| "selp.s32 %1, 1, 0, p; \n\t" | ||
| "}" | ||
| : "=r"(__ret), "=r"(__pred1) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| __pred = static_cast<bool>(__pred1); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| template <typename _Tp> | ||
| _CCCL_NODISCARD _CCCL_DEVICE static inline _Tp shfl_sync_bfly( | ||
| _Tp __data, | ||
| _CUDA_VSTD::uint32_t __lane_idx_offset, | ||
| _CUDA_VSTD::uint32_t __clamp_segmask, | ||
| _CUDA_VSTD::uint32_t __lane_mask) noexcept | ||
| { | ||
| __shfl_sync_checks(__dot_shfl_mode::__bfly, __data, __lane_idx_offset, __clamp_segmask, __lane_mask); | ||
| auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data); | ||
| _CUDA_VSTD::uint32_t __ret; | ||
| asm volatile( // | ||
| "{ \n\t\t" | ||
| "shfl.sync.bfly.b32 %0, %1, %2, %3, %4; \n\t\t" | ||
| "}" | ||
| : "=r"(__ret) | ||
| : "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask)); | ||
| return _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret); | ||
| } | ||
|
|
||
| #endif // __cccl_ptx_isa >= 600 | ||
|
|
||
| _LIBCUDACXX_END_NAMESPACE_CUDA_PTX | ||
|
|
||
| #endif // _CUDA_PTX_SHFL_SYNC_H |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.