Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e538f4b
Initial plan
Copilot Jan 20, 2026
f2007a2
Add GQA support to Attention(23) CUDA operator
Copilot Jan 20, 2026
53c333f
Add debug tracking fields and num_splits parameter
Copilot Jan 20, 2026
042ff32
Fix code review issues: use v_head_size and parameters.softcap
Copilot Jan 20, 2026
0e7a632
Set softcap to 0.0f explicitly with comment
Copilot Jan 20, 2026
2e10874
Enable CUDA tests for GQA attention tests
Copilot Jan 20, 2026
b86acbd
Remove GQA test filters from disabled tests list
Copilot Jan 20, 2026
213a82d
Add float template instantiation for GQA QkvToContext
Copilot Jan 20, 2026
f79c509
Revert float support for GQA and add type validation
Copilot Jan 20, 2026
e52efb2
change gqa tests to fp16
titaiwangms Jan 22, 2026
98c5dcf
examine gqa parameters and move down MHA parameters
titaiwangms Jan 23, 2026
0f800f5
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 23, 2026
4978e96
support gqa bool masking
titaiwangms Jan 23, 2026
4c644e2
add flash/memory draft
titaiwangms Jan 27, 2026
f04b38e
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 27, 2026
16d5453
finish gqa default
titaiwangms Jan 28, 2026
54d77ae
Apply suggestion from @titaiwangms
titaiwangms Jan 28, 2026
87a5648
introduce python attention tests for gqa
titaiwangms Jan 28, 2026
5981041
lint
titaiwangms Jan 28, 2026
6d7e50a
support attn_mask
titaiwangms Jan 30, 2026
d1cb063
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 30, 2026
e2a4032
clean up and use ORT_MAKE_STATUS
titaiwangms Jan 30, 2026
dcb937a
Merge branch 'main' into copilot/support-group-query-attention
titaiwangms Jan 30, 2026
2509464
fix cpu bugs on fp16
titaiwangms Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions onnxruntime/core/providers/cpu/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
parameters.transpose_output
? parameters.head_size * parameters.q_num_heads
: static_cast<int>(parameters.head_size), // lda
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k,
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
transposed_k
? parameters.head_size * parameters.kv_num_heads
: static_cast<int>(parameters.head_size), // ldb
Expand All @@ -413,7 +413,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
MLFloat16(alpha),
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
parameters.head_size * parameters.q_num_heads, // lda
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k,
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
MLFloat16(beta),
output,
Expand Down Expand Up @@ -617,23 +617,23 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
total_sequence_length, // K
attention_probs + attention_probs_offset,
total_sequence_length, // lda
transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v,
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v,
transposed_v ? static_cast<int>(v_head_size * kv_num_heads) : static_cast<int>(v_head_size), // ldb
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
v_head_size * num_heads, // ldc
MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr);
} else {
math::GemmEx<T, ThreadPool>(CblasNoTrans,
CblasNoTrans,
sequence_length, // M
v_head_size, // N
total_sequence_length, // K
MLFloat16(1.f), // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
MLFloat16(0.f), // beta
sequence_length, // M
v_head_size, // N
total_sequence_length, // K
MLFloat16(1.f), // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
MLFloat16(0.f), // beta
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
v_head_size * num_heads, // ldc
nullptr);
Expand Down
Loading
Loading