Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e538f4b
Initial plan
Copilot Jan 20, 2026
f2007a2
Add GQA support to Attention(23) CUDA operator
Copilot Jan 20, 2026
53c333f
Add debug tracking fields and num_splits parameter
Copilot Jan 20, 2026
042ff32
Fix code review issues: use v_head_size and parameters.softcap
Copilot Jan 20, 2026
0e7a632
Set softcap to 0.0f explicitly with comment
Copilot Jan 20, 2026
2e10874
Enable CUDA tests for GQA attention tests
Copilot Jan 20, 2026
b86acbd
Remove GQA test filters from disabled tests list
Copilot Jan 20, 2026
213a82d
Add float template instantiation for GQA QkvToContext
Copilot Jan 20, 2026
f79c509
Revert float support for GQA and add type validation
Copilot Jan 20, 2026
e52efb2
change gqa tests to fp16
titaiwangms Jan 22, 2026
98c5dcf
examine gqa parameters and move down MHA parameters
titaiwangms Jan 23, 2026
0f800f5
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 23, 2026
4978e96
support gqa bool masking
titaiwangms Jan 23, 2026
4c644e2
add flash/memory draft
titaiwangms Jan 27, 2026
f04b38e
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 27, 2026
16d5453
finish gqa default
titaiwangms Jan 28, 2026
54d77ae
Apply suggestion from @titaiwangms
titaiwangms Jan 28, 2026
87a5648
introduce python attention tests for gqa
titaiwangms Jan 28, 2026
5981041
lint
titaiwangms Jan 28, 2026
6d7e50a
support attn_mask
titaiwangms Jan 30, 2026
d1cb063
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 30, 2026
e2a4032
clean up and use ORT_MAKE_STATUS
titaiwangms Jan 30, 2026
dcb937a
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 30, 2026
2509464
fix cpu bugs on fp16
titaiwangms Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 341 additions & 21 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <vector>
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/llm/attention_helper.h"
#include "core/providers/cuda/llm/attention.h"
#include "contrib_ops/cuda/bert/attention_data.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"

using namespace onnxruntime::cuda;

Expand Down Expand Up @@ -96,8 +100,344 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* output_qk = context->Output(3, output_qk_shape);

// To reuse the existing attention-cuda implementation in contrib ops,
// map the parameters to contribop_parameters.
// map the parameters to contribop_parameters (MHA).
onnxruntime::contrib::AttentionParameters contribop_parameters;

// QKV format: Determine based on input dimensions
// 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
// transpose_output is true for 3D inputs, false for 4D inputs
if (!parameters.transpose_output) {
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
contribop_parameters.is_output_bnsh = true;
} else {
// 3D inputs in BSNH format (will be transposed)
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
contribop_parameters.is_output_bnsh = false;
}

typedef typename ToCudaType<T>::MappedType CudaT;

// Check if this is Group Query Attention (GQA)
const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads;

if (is_gqa) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Currently, we do not support 4D inputs of QKV.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added exeptions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The support requires kernel changes in FlashAttention and EfficientAttention. If we want to support 4d, the best way would be another cuda kernel to transpose/reshape the input from 4d to 3d before feeding it to those two attention kernels.

// Use GQA path with Flash Attention or Memory Efficient Attention
// GQA only supports float16 and bfloat16 types
if (std::is_same<T, float>::value) {
ORT_THROW("GQA in Attention op (CUDA) does not support float32. Please use float16 or bfloat16.");
}
// For now, GQA doesn't support qk_matmul_output_mode other than kNone
if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone) {
ORT_THROW("qk_matmul_output_mode is not supported yet in GQA path of Attention op (CUDA).");
}
// GQA doesn't support softmax_precision yet
if (parameters.softmax_precision != 0) {
ORT_THROW("softmax_precision is not supported yet in GQA path of Attention op (CUDA).");
}
// causal attention is required for GQA
if (!parameters.is_causal) {
ORT_THROW("Non-causal attention is not supported yet in GQA path of Attention op (CUDA).");
}
// GQA kernel expects K/V input sequence length == Q sequence length (self-attention only)
// Cross-attention (kv_sequence_length != q_sequence_length) is not supported
if (parameters.kv_sequence_length != parameters.q_sequence_length) {
ORT_THROW(
"Cross-attention (kv_sequence_length != q_sequence_length) is not supported in GQA path of Attention op (CUDA). "
"kv_sequence_length=",
parameters.kv_sequence_length, ", q_sequence_length=", parameters.q_sequence_length);
}

auto& device_prop = GetDeviceProp();

// Bridge parameters to GroupQueryAttentionParameters
onnxruntime::contrib::GroupQueryAttentionParameters gqa_parameters;
gqa_parameters.batch_size = parameters.batch_size;
gqa_parameters.sequence_length = parameters.q_sequence_length;
gqa_parameters.seqlen_past_kv_cache = parameters.past_sequence_length;
gqa_parameters.seqlen_present_kv_cache = parameters.total_sequence_length;
gqa_parameters.total_sequence_length = parameters.total_sequence_length;
gqa_parameters.kv_sequence_length = parameters.kv_sequence_length;
gqa_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
gqa_parameters.num_heads = parameters.q_num_heads;
gqa_parameters.head_size = parameters.head_size;
gqa_parameters.v_head_size = parameters.v_head_size;
gqa_parameters.kv_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
gqa_parameters.kv_num_heads = parameters.kv_num_heads;
gqa_parameters.scale = parameters.scale;
gqa_parameters.softcap = parameters.softcap;
gqa_parameters.qkv_format = contribop_parameters.qkv_format;

// Unset or set to default values for GQA-specific fields
gqa_parameters.rotary_dim = 0; // New Attention op doesn't use rotary embeddings directly
gqa_parameters.is_unidirectional = true; // GQA requires causal attention
gqa_parameters.is_packed_qkv = false; // New Attention op has separate Q, K, V inputs
gqa_parameters.is_subsequent_prompt = false;
gqa_parameters.is_first_prompt = parameters.past_sequence_length == 0;
gqa_parameters.do_rotary = false; // New Attention op doesn't use rotary embeddings
gqa_parameters.rotary_interleaved = false;
gqa_parameters.use_smooth_softmax = false;
gqa_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
gqa_parameters.past_kv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
gqa_parameters.local_window_size = -1; // No local window for standard attention
gqa_parameters.zeros_count = 0;
gqa_parameters.zero_ptr = nullptr;
gqa_parameters.num_splits = 1;

// Construct GroupQueryAttentionData
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT> gqa_data;

// Scratch buffers for flash/memory efficient attention
IAllocatorUniquePtr<void> k_buffer;
IAllocatorUniquePtr<void> v_buffer;
IAllocatorUniquePtr<void> fmha_buffer;
IAllocatorUniquePtr<void> unpacked_qkv_buffer;
IAllocatorUniquePtr<int> seq_lens_buffer;
IAllocatorUniquePtr<int> seqlens_k_buffer;

// Present KV cache buffers - GQA kernel uses these as working buffers
// If outputs are not provided, we allocate scratch buffers
IAllocatorUniquePtr<void> present_key_scratch;
IAllocatorUniquePtr<void> present_value_scratch;

// Set input pointers
gqa_data.query = reinterpret_cast<const CudaT*>(Q->Data<T>());
gqa_data.key = reinterpret_cast<const CudaT*>(K->Data<T>());
gqa_data.value = reinterpret_cast<const CudaT*>(V->Data<T>());
gqa_data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
gqa_data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());

// Set output pointers
gqa_data.output = reinterpret_cast<CudaT*>(Y->MutableData<T>());

// GQA kernel requires present_key/present_value buffers as working storage for KV cache
// Allocate scratch buffers if outputs are not provided
size_t present_kv_size = static_cast<size_t>(parameters.batch_size) *
static_cast<size_t>(parameters.kv_num_heads) *
static_cast<size_t>(parameters.total_sequence_length) *
static_cast<size_t>(parameters.head_size) * sizeof(CudaT);
if (present_key != nullptr) {
gqa_data.present_key = reinterpret_cast<CudaT*>(present_key->MutableData<T>());
} else {
present_key_scratch = GetScratchBuffer<void>(present_kv_size, context->GetComputeStream());
gqa_data.present_key = reinterpret_cast<CudaT*>(present_key_scratch.get());
}
if (present_value != nullptr) {
gqa_data.present_value = reinterpret_cast<CudaT*>(present_value->MutableData<T>());
} else {
present_value_scratch = GetScratchBuffer<void>(present_kv_size, context->GetComputeStream());
gqa_data.present_value = reinterpret_cast<CudaT*>(present_value_scratch.get());
}

// Compute past_present_share_buffer early since it's needed for flash attention path selection
gqa_parameters.past_present_share_buffer = (gqa_data.past_key == gqa_data.present_key);

// Flash Attention buffers
IAllocatorUniquePtr<void> softmax_lse_buffer;
IAllocatorUniquePtr<void> softmax_lse_accum_buffer;
IAllocatorUniquePtr<void> out_accum_buffer;

// Check Flash Attention support
#if USE_FLASH_ATTENTION
bool use_flash_attention = onnxruntime::flash::is_supported<T>(device_prop,
gqa_parameters.head_size,
gqa_parameters.num_heads,
gqa_parameters.kv_num_heads);

gqa_data.use_flash_attention = use_flash_attention;
gqa_data.use_flash_attention_fast_decode = use_flash_attention &&
!gqa_parameters.is_first_prompt &&
gqa_parameters.past_present_share_buffer;

if (use_flash_attention) {
// Allocate Flash specific buffers (Softmax LSE, Accum)
size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(
gqa_parameters.sequence_length, gqa_parameters.batch_size, gqa_parameters.num_heads);

int num_heads_for_split = gqa_data.use_flash_attention_fast_decode
? gqa_parameters.kv_num_heads
: gqa_parameters.num_heads;
auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] =
onnxruntime::flash::get_num_splits_and_buffer_sizes(
gqa_parameters.batch_size, gqa_parameters.sequence_length,
gqa_parameters.total_sequence_length, num_heads_for_split,
gqa_parameters.head_size, device_prop.multiProcessorCount);

gqa_parameters.num_splits = static_cast<int>(num_splits);

if (gqa_data.use_flash_attention_fast_decode && num_splits > 1) {
// The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel.
// However, the LSE and Accum buffers must store results for ALL num_heads.
softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(
num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length);
auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; };
out_accum_bytes = onnxruntime::flash::get_out_accum_size(
num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length,
round_multiple(gqa_parameters.head_size, 32));
}

softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());

gqa_data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
gqa_data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
gqa_data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
} else {
gqa_data.softmax_lse = nullptr;
gqa_data.softmax_lse_accum = nullptr;
gqa_data.out_accum = nullptr;
}
#else
gqa_data.use_flash_attention = false;
gqa_data.use_flash_attention_fast_decode = false;
gqa_data.softmax_lse = nullptr;
gqa_data.softmax_lse_accum = nullptr;
gqa_data.out_accum = nullptr;
#endif

// Check Memory Efficient Attention support (fallback if flash attention not available)
#if USE_MEMORY_EFFICIENT_ATTENTION
if (!gqa_data.use_flash_attention) {
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
onnxruntime::contrib::cuda::has_memory_efficient_attention(
sm, std::is_same<T, MLFloat16>::value, std::is_same<T, BFloat16>::value,
gqa_parameters.head_size, gqa_parameters.head_size);
gqa_data.use_memory_efficient_attention = use_memory_efficient_attention;

// KV buffer for head expansion (when num_heads != kv_num_heads)
size_t kv_buffer_bytes = (use_memory_efficient_attention &&
(gqa_parameters.num_heads != gqa_parameters.kv_num_heads))
? (sizeof(T) * gqa_parameters.batch_size * gqa_parameters.num_heads *
gqa_parameters.seqlen_present_kv_cache * gqa_parameters.head_size)
: 0;
// FMHA workspace
size_t fmha_buffer_bytes =
(use_memory_efficient_attention &&
onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace(
gqa_parameters.head_size, sizeof(T) == sizeof(float)))
? (sizeof(float) * gqa_parameters.batch_size * gqa_parameters.sequence_length *
gqa_parameters.num_heads * gqa_parameters.head_size)
: 0;

k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());

gqa_data.k = reinterpret_cast<CudaT*>(k_buffer.get());
gqa_data.v = reinterpret_cast<CudaT*>(v_buffer.get());
gqa_data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
} else {
gqa_data.use_memory_efficient_attention = false;
gqa_data.k = nullptr;
gqa_data.v = nullptr;
gqa_data.fmha_buffer = nullptr;
}
#else
gqa_data.use_memory_efficient_attention = false;
gqa_data.k = nullptr;
gqa_data.v = nullptr;
gqa_data.fmha_buffer = nullptr;
#endif

// Centralized scratch buffer allocation using GQABufferRequirements
auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute<T>(
gqa_parameters,
gqa_data.use_flash_attention,
gqa_data.use_flash_attention_fast_decode,
gqa_data.use_memory_efficient_attention);

if (buffer_req.qkv_buffer_bytes > 0) {
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.qkv_buffer_bytes, context->GetComputeStream());
gqa_data.qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
} else {
gqa_data.qkv_buffer = nullptr;
}

// Allocate CPU buffer for seqlens_k (total_sequence_length - 1) for GQA compatibility
// The GQA kernel expects sequence length information for flash/memory efficient attention
// We need a CPU buffer first, then copy to GPU
std::vector<int> seqlens_k_host(parameters.batch_size);

// GQA only supports masking, not additive bias.
// For bool mask, we need to convert it to sequence lengths.
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
const bool* b_mask = attn_mask->Data<bool>();

for (int b = 0; b < parameters.batch_size; ++b) {
const bool* row = b_mask + b * parameters.total_sequence_length;
int seq_len = 0;

// Find the actual sequence length by looking for the last valid (true) position
// Mask convention per Attention spec: true = valid (should participate), false = masked out
for (int i = parameters.total_sequence_length - 1; i >= 0; --i) {
if (row[i]) {
seq_len = i + 1;
break;
}
}
// seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention)
seqlens_k_host[b] = seq_len - 1;
}
} else if (attn_mask != nullptr) {
ORT_THROW("Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA).");
} else {
// No mask provided - use full sequence length for all batches
// seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention)
for (int b = 0; b < parameters.batch_size; ++b) {
seqlens_k_host[b] = parameters.total_sequence_length - 1;
}
}

// Copy seqlens_k to GPU
seqlens_k_buffer = GetScratchBuffer<int>(parameters.batch_size, context->GetComputeStream());
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(seqlens_k_buffer.get(), seqlens_k_host.data(),
sizeof(int) * parameters.batch_size,
cudaMemcpyHostToDevice, cuda_stream));

// Process seqlens_k to compute past_seq_lens, total_seq_lens, and padded_seq_lens
// This is always needed for flash/memory efficient attention
seq_lens_buffer = GetScratchBuffer<int>(3 * parameters.batch_size, context->GetComputeStream());
gqa_data.past_seq_lens = seq_lens_buffer.get();
gqa_data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size;
gqa_data.padded_seq_lens = gqa_data.total_seq_lens + parameters.batch_size;

ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchGetSequenceLengths(
seqlens_k_buffer.get(),
gqa_data.past_seq_lens,
gqa_data.total_seq_lens,
gqa_data.padded_seq_lens,
parameters.batch_size,
parameters.q_sequence_length,
gqa_parameters.is_first_prompt,
cuda_stream,
device_prop.maxThreadsPerBlock));

// Set GQA-specific fields
gqa_data.cos_cache = nullptr; // No rotary embeddings
gqa_data.sin_cache = nullptr;
gqa_data.head_sink = nullptr;
gqa_data.position_ids = nullptr;

#ifndef NDEBUG
// Initialize debug tracking fields
gqa_data.unpacked_qkv_buffer_size = 0;
gqa_data.rotary_buffer_size = 0;
gqa_data.position_ids_buffer_size = 0;
gqa_data.unpacked_qkv_max_used = 0;
gqa_data.rotary_max_used = 0;
gqa_data.position_ids_max_used = 0;
#endif

// Call GQA kernel (with flash or memory efficient attention)
cublasHandle_t cublas = GetCublasHandle(context);

return onnxruntime::contrib::cuda::QkvToContext<CudaT>(
device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data);
}

// MHA path (kv_num_heads == q_num_heads)
contribop_parameters.batch_size = parameters.batch_size;
contribop_parameters.sequence_length = parameters.q_sequence_length;
contribop_parameters.kv_sequence_length = parameters.kv_sequence_length;
Expand Down Expand Up @@ -160,24 +500,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
contribop_parameters.mask_filter_value = -10000.0f;
contribop_parameters.scale = parameters.scale;
contribop_parameters.use_tf32 = UseTF32();

// QKV format: Determine based on input dimensions
// 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
// transpose_output is true for 3D inputs, false for 4D inputs
if (!parameters.transpose_output) {
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
contribop_parameters.is_output_bnsh = true;
} else {
// 3D inputs in BSNH format (will be transposed)
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
contribop_parameters.is_output_bnsh = false;
}

// TODO(titaiwang, xadupre): Group query attention is not supported yet
if (parameters.kv_num_heads != parameters.q_num_heads) {
ORT_THROW("Group query attention is not supported yet in Attention op (CUDA).");
}

// TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now
if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone &&
qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) {
Expand All @@ -191,9 +513,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
}

// TODO(titaiwang): Continue on these parameters
// Construct AttentionData to pass to QkvToContext
typedef typename ToCudaType<T>::MappedType CudaT;
onnxruntime::contrib::cuda::AttentionData<CudaT> data;

// Set input pointers
Expand Down
Loading
Loading