-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Support group query attention in Attention(23) CUDA #27082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: titaiwangms <[email protected]>
Co-authored-by: titaiwangms <[email protected]>
Co-authored-by: titaiwangms <[email protected]>
Co-authored-by: titaiwangms <[email protected]>
|
@copilot Can you enable the following three tests
Just turn |
Co-authored-by: titaiwangms <[email protected]>
|
@copilot onnxruntime/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc Lines 44 to 45 in fd21d0a
|
Co-authored-by: titaiwangms <[email protected]>
Removed GQA test filters from
GQA tests will now run on CUDA. |
|
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082 The error message: |
Co-authored-by: titaiwangms <[email protected]>
Fixed CI linker error in commit 213a82d. The issue was that the GQA |
@copilot
|
Co-authored-by: titaiwangms <[email protected]>
Reverted float support for GQA in commit f79c509. Following GroupQueryAttention pattern:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
| // Check if this is Group Query Attention (GQA) | ||
| const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads; | ||
|
|
||
| if (is_gqa) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Currently, we do not support 4D inputs of QKV.
This pull request adds a new CUDA implementation for converting boolean attention masks to sequence lengths, specifically for use with GQA (Grouped Query Attention) kernels. The implementation includes both the CUDA kernel and its launch function, along with detailed validation to ensure the mask follows right-padding conventions and is broadcastable across various mask shapes (2D, 3D, 4D). The changes also provide comprehensive documentation and error handling.
New CUDA kernel and API for attention mask processing:
ConvertMaskToSeqlensKernelCUDA kernel inattention_mask_impl.cuto convert boolean attention masks to sequence lengths, supporting 2D, 3D, and 4D masks with broadcasting and validation for right-padding and contiguous True/False patterns.LaunchConvertMaskToSeqlensKfunction in bothattention_mask_impl.cuand declared it inattention_mask_impl.h, providing a public API to launch the kernel and handle validation results. [1] [2]Validation and error handling:
Documentation and code organization: