Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1e7d5ae
refactor redundant condition checks
titaiwangms Oct 31, 2025
49d7a42
sync to Xavier's cpu refactors
titaiwangms Oct 31, 2025
246a4d1
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Nov 4, 2025
f244983
fix attention-cpu build
titaiwangms Nov 5, 2025
8274bb1
draft
titaiwangms Nov 7, 2025
78f5d61
lint - draft
titaiwangms Nov 7, 2025
53d4e83
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Nov 19, 2025
43623ad
fix typo
titaiwangms Nov 19, 2025
08f15f6
typo-2
titaiwangms Nov 19, 2025
0e75443
update namespace
titaiwangms Nov 19, 2025
277648d
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Nov 19, 2025
5253dd0
update doc
titaiwangms Nov 19, 2025
4db63bc
removed deprecated functions in onnx
titaiwangms Nov 20, 2025
0a7e5f9
Revert "removed deprecated functions in onnx"
titaiwangms Nov 20, 2025
6b18bb4
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Dec 1, 2025
a1ed3d9
fix qkv space - support 3d default
titaiwangms Dec 1, 2025
b462930
turn 4d to tru on disable cuda
titaiwangms Dec 2, 2025
0494e95
refactor attn_mask
titaiwangms Dec 2, 2025
2dc706a
simplify
titaiwangms Dec 9, 2025
5f0b6cd
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Dec 10, 2025
88e631c
support 4d and fix attn_mask bug
titaiwangms Dec 12, 2025
000d394
disregard softcap and softmax_precision
titaiwangms Dec 13, 2025
739e88f
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Dec 15, 2025
6d6d478
fix offset in is_causal
titaiwangms Dec 15, 2025
792445a
add past_seq_length to softmax bias add for causal
titaiwangms Dec 16, 2025
a26d812
resolve merge conflict
titaiwangms Dec 16, 2025
fbbf0b5
update failing cuda tests
titaiwangms Dec 16, 2025
7010308
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Jan 8, 2026
2c793b6
delete past_sequence_length and use flag output_is_Q_K_V_BNSH
titaiwangms Jan 9, 2026
ab41d04
add kv_sequence_length to softmax
titaiwangms Jan 9, 2026
4a8f502
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Jan 13, 2026
2a9167e
remove kv_sequence_length in softmax and disable cross attn causal tests
titaiwangms Jan 13, 2026
ff7e767
address reviews - comments
titaiwangms Jan 14, 2026
c001885
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ Do not modify directly.*
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct AttentionParameters {
float mask_filter_value;
float scale;
bool use_tf32;
bool is_output_bnsh; // whether the output format is Q_K_V_BNSH
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand Down
27 changes: 15 additions & 12 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -748,15 +748,15 @@ Status UnfusedAttention(
mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
parameters.mask_filter_value));
parameters.mask_filter_value, parameters.kv_sequence_length));
} else if (nullptr != mask_index) { // 1d mask index
assert(mask_index_dims.size() == 1);
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D<T>(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, scratch2, parameters.is_unidirectional, parameters.kv_sequence_length));
} else { // no mask
if (nullptr != data.output_qk) {
int64_t qk_size = (int64_t)batch_size * num_heads * sequence_length * total_sequence_length;
Expand All @@ -767,25 +767,28 @@ Status UnfusedAttention(
ComputeSoftmax<T>(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, scratch2, parameters.is_unidirectional, parameters.kv_sequence_length));
}

DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);

// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = data.q;
// compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed
// For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose.
T* temp_output = parameters.is_output_bnsh ? data.output : data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output);
if (!parameters.is_output_bnsh) {
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output));
}
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return result;
return Status::OK();
}

template <typename T>
Expand Down Expand Up @@ -960,7 +963,7 @@ Status QkvToContext(
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
Expand All @@ -981,12 +984,12 @@ Status QkvToContext(

if (!parameters.past_present_share_buffer) {
ORT_RETURN_IF_ERROR(ConcatPastToPresent<T>(batch_size, num_heads, qk_head_size, v_head_size,
sequence_length, total_sequence_length,
kv_sequence_length, total_sequence_length,
stream, max_threads_per_block, data));

} else { // past_present_share_buffer
ORT_RETURN_IF_ERROR(PastPresentBufferShare<T>(batch_size, num_heads, qk_head_size, v_head_size,
sequence_length, fused_runner,
kv_sequence_length, fused_runner,
parameters, data, stream, max_threads_per_block));
}

Expand Down
Loading
Loading