Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e538f4b
Initial plan
Copilot Jan 20, 2026
f2007a2
Add GQA support to Attention(23) CUDA operator
Copilot Jan 20, 2026
53c333f
Add debug tracking fields and num_splits parameter
Copilot Jan 20, 2026
042ff32
Fix code review issues: use v_head_size and parameters.softcap
Copilot Jan 20, 2026
0e7a632
Set softcap to 0.0f explicitly with comment
Copilot Jan 20, 2026
2e10874
Enable CUDA tests for GQA attention tests
Copilot Jan 20, 2026
b86acbd
Remove GQA test filters from disabled tests list
Copilot Jan 20, 2026
213a82d
Add float template instantiation for GQA QkvToContext
Copilot Jan 20, 2026
f79c509
Revert float support for GQA and add type validation
Copilot Jan 20, 2026
e52efb2
change gqa tests to fp16
titaiwangms Jan 22, 2026
98c5dcf
examine gqa parameters and move down MHA parameters
titaiwangms Jan 23, 2026
0f800f5
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 23, 2026
4978e96
support gqa bool masking
titaiwangms Jan 23, 2026
4c644e2
add flash/memory draft
titaiwangms Jan 27, 2026
f04b38e
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 27, 2026
16d5453
finish gqa default
titaiwangms Jan 28, 2026
54d77ae
Apply suggestion from @titaiwangms
titaiwangms Jan 28, 2026
87a5648
introduce python attention tests for gqa
titaiwangms Jan 28, 2026
5981041
lint
titaiwangms Jan 28, 2026
6d7e50a
support attn_mask
titaiwangms Jan 30, 2026
d1cb063
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 30, 2026
e2a4032
clean up and use ORT_MAKE_STATUS
titaiwangms Jan 30, 2026
dcb937a
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 30, 2026
2509464
fix cpu bugs on fp16
titaiwangms Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
417 changes: 396 additions & 21 deletions onnxruntime/core/providers/cuda/llm/attention.cc

Large diffs are not rendered by default.

149 changes: 149 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/llm/attention_mask_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"

namespace onnxruntime {
namespace cuda {

// Validation error codes (stored in validation_result buffer)
constexpr int kValidationOK = 0;
constexpr int kValidationErrorNotStartWithTrue = 1;
constexpr int kValidationErrorNotContiguous = 2;

// CUDA kernel to convert boolean attention mask to sequence lengths.
// Also validates that the mask follows right-padding convention.
//
// The kernel processes one batch per thread.
// For each batch, it finds the first False in the mask row, which indicates
// where padding starts. The sequence length is the index of first False.
//
// Validation:
// - The mask must start with True (first element must be True)
// - After the first False, all remaining elements must be False (contiguous padding)
//
// Handle broadcasting:
// - 2D mask (batch_size, total_seq_len): stride = total_seq_len, batch_idx = threadIdx
// - 3D mask (num_heads, q_seq_len, total_seq_len): broadcasts to [1, num_heads, q_seq, total_seq]
// No per-batch variation; uses first head, first q position for all batches
// - 4D mask (B, H, q_seq_len, total_seq_len): we look at first head, first q position
__global__ void ConvertMaskToSeqlensKernel(
const bool* __restrict__ attn_mask,
int* __restrict__ seqlens_k,
int* __restrict__ validation_result,
const int batch_size,
const int total_seq_len,
const int mask_dims,
const int64_t mask_dim0,
const int64_t mask_dim1,
const int64_t mask_dim2) {
int batch_idx = threadIdx.x + blockIdx.x * blockDim.x;
if (batch_idx >= batch_size) {
return;
}

// Calculate the starting offset for this batch's mask row
// We need to figure out which row of the mask to use based on broadcasting rules
const bool* mask_row = nullptr;

if (mask_dims == 2) {
// Shape: (batch_size or 1, total_seq_len)
// If mask_dim0 == 1, broadcast across all batches
int effective_batch = (mask_dim0 == 1) ? 0 : batch_idx;
mask_row = attn_mask + effective_batch * total_seq_len;
} else if (mask_dims == 3) {
// Shape: (num_heads, q_seq_len, total_seq_len)
// This broadcasts to [1, num_heads, q_seq, total_seq] - same mask for all batches
// We look at first head (h_idx = 0) and first q position (q_idx = 0)
int h_idx = 0; // First head
int q_idx = 0; // First query position
// Stride: q_seq_len * total_seq_len per head
int64_t head_stride = mask_dim1 * total_seq_len; // mask_dim1 = q_seq_len
int64_t q_stride = total_seq_len;
// Same mask row for all batches since 3D has no batch dimension
mask_row = attn_mask + h_idx * head_stride + q_idx * q_stride;
} else {
// 4D: Shape (B, H, q_seq_len, total_seq_len)
// B could be batch_size or 1 (broadcast)
// H could be num_heads or 1 (broadcast)
// We look at first head (h_idx = 0) and first q position (q_idx = 0)
int effective_batch = (mask_dim0 == 1) ? 0 : batch_idx;
int h_idx = 0; // First head
int q_idx = 0; // First query position
// Strides
int64_t batch_stride = mask_dim1 * mask_dim2 * total_seq_len;
int64_t head_stride = mask_dim2 * total_seq_len;
int64_t q_stride = total_seq_len;
mask_row = attn_mask + effective_batch * batch_stride + h_idx * head_stride + q_idx * q_stride;
}

// Initialize validation result for this batch
validation_result[batch_idx] = kValidationOK;

// Check that mask starts with True
if (!mask_row[0]) {
validation_result[batch_idx] = kValidationErrorNotStartWithTrue;
seqlens_k[batch_idx] = -1; // Invalid
return;
}

// Find the first False (where padding starts)
// All elements before this should be True, all after should be False
int seq_len = total_seq_len; // Default: all True (no padding)
bool found_first_false = false;

for (int i = 1; i < total_seq_len; ++i) {
bool current = mask_row[i];

if (!found_first_false && !current) {
// Found first False - this is where padding starts
seq_len = i;
found_first_false = true;
} else if (found_first_false && current) {
// Found True after False - this is invalid (not contiguous)
validation_result[batch_idx] = kValidationErrorNotContiguous;
seqlens_k[batch_idx] = -1; // Invalid
return;
}
}

// seqlens_k is total_sequence_length - 1 for GQA convention
seqlens_k[batch_idx] = seq_len - 1;
}

Status LaunchConvertMaskToSeqlensK(
const bool* attn_mask_bool,
int* seqlens_k,
int* validation_result,
int batch_size,
int total_seq_len,
int mask_dims,
int64_t mask_dim0,
int64_t mask_dim1,
int64_t mask_dim2,
cudaStream_t stream,
int max_threads_per_block) {
if (batch_size == 0 || total_seq_len == 0) {
return Status::OK();
}

int threads = std::min(batch_size, max_threads_per_block);

Check warning on line 131 in onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu:131: Add #include <algorithm> for min [build/include_what_you_use] [4]
int blocks = (batch_size + threads - 1) / threads;

ConvertMaskToSeqlensKernel<<<blocks, threads, 0, stream>>>(
attn_mask_bool,
seqlens_k,
validation_result,
batch_size,
total_seq_len,
mask_dims,
mask_dim0,
mask_dim1,
mask_dim2);

return CUDA_CALL(cudaGetLastError());
}

} // namespace cuda
} // namespace onnxruntime
56 changes: 56 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/common/status.h"

namespace onnxruntime {
namespace cuda {

// Convert a boolean attention mask to sequence lengths for use with GQA kernels.
//
// The mask is expected to have the following properties:
// 1. It represents right-padding only (valid tokens first, padding at the end)
// 2. Each batch's mask should start with True (valid) values
// 3. True values should be contiguous, followed by contiguous False (padding) values
// 4. The mask must be broadcastable to (batch_size, num_heads, q_seq_len, total_seq_len)
//
// For 2D mask (batch_size, total_seq_len): uses the mask directly per batch
// For 3D mask (num_heads, q_seq_len, total_seq_len): broadcasts across batches, uses first head/q
// For 4D mask (B, H, q_seq_len, total_seq_len): uses first head, first q position
//
// Parameters:
// attn_mask_bool: Input boolean mask on GPU (True = valid, False = padding)
// seqlens_k: Output buffer for sequence lengths (seqlen - 1 for GQA convention)
// batch_size: Number of batches
// total_seq_len: Total sequence length (last dimension of mask)
// mask_dims: Number of dimensions in the mask (2, 3, or 4)
// mask_dim0: First dimension of mask (batch_size for 2D, num_heads for 3D, batch_size for 4D)
// mask_dim1: Second dimension (0 for 2D, q_seq_len for 3D, num_heads for 4D)
// mask_dim2: Third dimension (0 for 2D/3D, q_seq_len for 4D)
// stream: CUDA stream
// max_threads_per_block: Maximum threads per block
//
// Returns:
// Status::OK() on success
// Error status if mask is invalid (not right-padding, doesn't start with True, etc.)
//
// Note: This function validates the mask on GPU and will return an error if:
// - The mask doesn't start with True for any batch
// - The True/False values are not contiguous (e.g., True, False, True pattern)
Status LaunchConvertMaskToSeqlensK(
const bool* attn_mask_bool,
int* seqlens_k,
int* validation_result, // GPU buffer for validation, size = batch_size
int batch_size,
int total_seq_len,
int mask_dims,
int64_t mask_dim0,
int64_t mask_dim1,
int64_t mask_dim2,
cudaStream_t stream,
int max_threads_per_block);

} // namespace cuda
} // namespace onnxruntime
Loading
Loading