Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Jan 20, 2026

This pull request adds a new CUDA implementation for converting boolean attention masks to sequence lengths, specifically for use with GQA (Grouped Query Attention) kernels. The implementation includes both the CUDA kernel and its launch function, along with detailed validation to ensure the mask follows right-padding conventions and is broadcastable across various mask shapes (2D, 3D, 4D). The changes also provide comprehensive documentation and error handling.

New CUDA kernel and API for attention mask processing:

  • Added ConvertMaskToSeqlensKernel CUDA kernel in attention_mask_impl.cu to convert boolean attention masks to sequence lengths, supporting 2D, 3D, and 4D masks with broadcasting and validation for right-padding and contiguous True/False patterns.
  • Implemented LaunchConvertMaskToSeqlensK function in both attention_mask_impl.cu and declared it in attention_mask_impl.h, providing a public API to launch the kernel and handle validation results. [1] [2]

Validation and error handling:

  • The kernel checks that each mask row starts with True and that all True values are contiguous, followed only by False (padding); errors are written to a validation results buffer.
  • The API and documentation specify that the function will return errors if the mask is not right-padded or contains non-contiguous True/False values, ensuring robust input validation.

Documentation and code organization:

  • Added comprehensive comments and documentation to both the kernel implementation and the header file, describing mask shape handling, broadcasting rules, and validation logic. [1] [2]

Copilot AI changed the title [WIP] Implement support for group query attention in Attention(23) Support group query attention in Attention(23) CUDA Jan 20, 2026
Copilot AI requested a review from titaiwangms January 20, 2026 22:05
@titaiwangms
Copy link
Contributor

@copilot Can you enable the following three tests

TEST(AttentionTest, Attention3DGqaAttn) {

TEST(AttentionTest, Attention4DGqaAttnMask) {

TEST(AttentionTest, Attention4DGqaWithPastAndPresent) {

Just turn disable_cuda to false

Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot Can you enable the following three tests

TEST(AttentionTest, Attention3DGqaAttn) {

...

Enabled CUDA tests for all three GQA test cases in commit 2e10874:

  • Attention3DGqaAttn
  • Attention4DGqaAttnMask
  • Attention4DGqaWithPastAndPresent

@titaiwangms
Copy link
Contributor

@copilot
Also remove these disabled tests about GQA from the list:

"^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda
"^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda

Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot
Also remove these disabled tests about GQA from the list:
https://github.com/microsoft/onnxruntime/blob/fd21d0aa071edef70d715123b25e77891a63a5d9/onnxruntime/test/testdata/onnx_backend_test...

Removed GQA test filters from onnx_backend_test_series_filters.jsonc in commit b86acbd. The following lines were removed:

  • "^test_attention_3d_gqa.*_cuda"
  • "^test_attention_4d_gqa.*_cuda"

GQA tests will now run on CUDA.

@titaiwangms
Copy link
Contributor

@copilot

Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082

The error message:

LINK : warning LNK4075: ignoring '/INCREMENTAL' due to '/PROFILE' specification [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
     Creating library E:/_work/_temp/build/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_cuda.lib and object E:/_work/_temp/build/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_cuda.exp
attention.cc.obj : error LNK2019: unresolved external symbol "class onnxruntime::common::Status __cdecl onnxruntime::contrib::cuda::QkvToContext<float>(struct cudaDeviceProp const &,struct cublasContext * &,class onnxruntime::Stream *,struct onnxruntime::contrib::GroupQueryAttentionParameters &,struct onnxruntime::contrib::cuda::GroupQueryAttentionData<float> &)" (??$QkvToContext@M@cuda@contrib@onnxruntime@@YA?AVStatus@common@2@AEBUcudaDeviceProp@@AEAPEAUcublasContext@@PEAVStream@2@AEAUGroupQueryAttentionParameters@12@AEAU?$GroupQueryAttentionData@M@012@@Z) referenced in function "public: virtual class onnxruntime::common::Status __cdecl onnxruntime::cuda::Attention<float>::ComputeInternal(struct onnxruntime::OpKernelContext *)const " (?ComputeInternal@?$Attention@M@cuda@onnxruntime@@UEBA?AVStatus@common@3@PEAUOpKernelContext@3@@Z) [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
E:\_work\_temp\build\RelWithDebInfo\RelWithDebInfo\onnxruntime_providers_cuda.dll : fatal error LNK1120: 1 unresolved externals [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
Traceback (most recent call last):
  File "E:\_work\onnxruntime\onnxruntime\tools\ci_build\build.py", line 2684, in <module>

Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot

Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...

Fixed CI linker error in commit 213a82d. The issue was that the GQA QkvToContext template function only had instantiations for half and BFloat16, but not for float. Added the missing float template instantiation along with GroupQueryAttentionData<float> and LaunchUnpackQKV<float, LAYOUT_BNSH> to group_query_attention_impl.cu.

@titaiwangms
Copy link
Contributor

@copilot
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...

Fixed CI linker error in commit 213a82d. The issue was that the GQA QkvToContext template function only had instantiations for half and BFloat16, but not for float. Added the missing float template instantiation along with GroupQueryAttentionData<float> and LaunchUnpackQKV<float, LAYOUT_BNSH> to group_query_attention_impl.cu.

@copilot
Do not add float type. Just throw an error saying float is not supported with GQA. We are following GQA kernel.

REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...
...

Reverted float support for GQA in commit f79c509. Following GroupQueryAttention pattern:

  • Removed float template instantiations from group_query_attention_impl.cu
  • Added type validation in attention.cc to throw an error for float32 types when using GQA path
  • GQA now only supports float16 and bfloat16 types, matching the GroupQueryAttention operator behavior

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Jan 21, 2026
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@titaiwangms titaiwangms marked this pull request as ready for review January 30, 2026 01:20
// Check if this is Group Query Attention (GQA)
const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads;

if (is_gqa) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Currently, we do not support 4D inputs of QKV.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants