Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1e7d5ae
refactor redundant condition checks
titaiwangms Oct 31, 2025
49d7a42
sync to Xavier's cpu refactors
titaiwangms Oct 31, 2025
246a4d1
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Nov 4, 2025
f244983
fix attention-cpu build
titaiwangms Nov 5, 2025
8274bb1
draft
titaiwangms Nov 7, 2025
78f5d61
lint - draft
titaiwangms Nov 7, 2025
53d4e83
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Nov 19, 2025
43623ad
fix typo
titaiwangms Nov 19, 2025
08f15f6
typo-2
titaiwangms Nov 19, 2025
0e75443
update namespace
titaiwangms Nov 19, 2025
277648d
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Nov 19, 2025
5253dd0
update doc
titaiwangms Nov 19, 2025
4db63bc
removed deprecated functions in onnx
titaiwangms Nov 20, 2025
0a7e5f9
Revert "removed deprecated functions in onnx"
titaiwangms Nov 20, 2025
6b18bb4
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Dec 1, 2025
a1ed3d9
fix qkv space - support 3d default
titaiwangms Dec 1, 2025
b462930
turn 4d to tru on disable cuda
titaiwangms Dec 2, 2025
0494e95
refactor attn_mask
titaiwangms Dec 2, 2025
2dc706a
simplify
titaiwangms Dec 9, 2025
5f0b6cd
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Dec 10, 2025
88e631c
support 4d and fix attn_mask bug
titaiwangms Dec 12, 2025
000d394
disregard softcap and softmax_precision
titaiwangms Dec 13, 2025
739e88f
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Dec 15, 2025
6d6d478
fix offset in is_causal
titaiwangms Dec 15, 2025
792445a
add past_seq_length to softmax bias add for causal
titaiwangms Dec 16, 2025
a26d812
resolve merge conflict
titaiwangms Dec 16, 2025
fbbf0b5
update failing cuda tests
titaiwangms Dec 16, 2025
7010308
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Jan 8, 2026
2c793b6
delete past_sequence_length and use flag output_is_Q_K_V_BNSH
titaiwangms Jan 9, 2026
ab41d04
add kv_sequence_length to softmax
titaiwangms Jan 9, 2026
4a8f502
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Jan 13, 2026
2a9167e
remove kv_sequence_length in softmax and disable cross attn causal tests
titaiwangms Jan 13, 2026
ff7e767
address reviews - comments
titaiwangms Jan 14, 2026
c001885
Merge branch 'main' into titaiwang/support_attention_cuda
titaiwangms Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ Do not modify directly.*
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct AttentionParameters {
float mask_filter_value;
float scale;
bool use_tf32;
bool is_output_bnsh; // whether the output format is BNSH
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -772,20 +772,23 @@ Status UnfusedAttention(

DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);

// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = data.q;
// compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed
// For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose.
T* temp_output = parameters.is_output_bnsh ? data.output : data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output);
if (!parameters.is_output_bnsh) {
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output));
}
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return result;
return Status::OK();
}

template <typename T>
Expand Down Expand Up @@ -960,7 +963,7 @@ Status QkvToContext(
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
Expand All @@ -981,12 +984,12 @@ Status QkvToContext(

if (!parameters.past_present_share_buffer) {
ORT_RETURN_IF_ERROR(ConcatPastToPresent<T>(batch_size, num_heads, qk_head_size, v_head_size,
sequence_length, total_sequence_length,
kv_sequence_length, total_sequence_length,
stream, max_threads_per_block, data));

} else { // past_present_share_buffer
ORT_RETURN_IF_ERROR(PastPresentBufferShare<T>(batch_size, num_heads, qk_head_size, v_head_size,
sequence_length, fused_runner,
kv_sequence_length, fused_runner,
parameters, data, stream, max_threads_per_block));
}

Expand Down
250 changes: 144 additions & 106 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block) {
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
Expand All @@ -262,82 +261,96 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
assert(!parameters.is_unidirectional);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
assert(data.mask_index == nullptr);
assert(parameters.hidden_size == parameters.v_hidden_size);

// For fused cross attention, besides adding bias, K and V needed to be packed:
// Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
} else if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
} else {
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
}
if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
// 3D inputs in BSNH format (will be transposed)
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
assert(data.mask_index == nullptr);
assert(parameters.hidden_size == parameters.v_hidden_size);

// For fused cross attention, besides adding bias, K and V needed to be packed:
// Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
} else if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
} else {
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
}

data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);

// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
data.k = nullptr;
data.v = nullptr;
// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
data.k = nullptr;
data.v = nullptr;

data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
assert(data.IsUnfused());
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, data.q,
true, -1);

// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
true, -1);

// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
true, -1);

data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
} else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
// Currently, 4D inputs are only supported in unfused kernel for Attention-23.
assert(data.IsUnfused());
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, data.q,
true, -1);

// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
true, -1);

// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
true, -1);

data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
// Attention-23 does not support bias with Q_K_V_BNSH format.
assert(data.bias == nullptr);
// No need to transpose since QKV is already in BNSH format.
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
} else {
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}

return Status::OK();
Expand All @@ -360,7 +373,6 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block) {
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
Expand All @@ -373,42 +385,53 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
data.past_key != nullptr && data.past_value != nullptr);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data));

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

// When there is no past state and there is present state, we output K and V directly to present state.
if (data.past_key == nullptr && data.present_key != nullptr) {
data.k = data.present_key;
data.v = data.present_value;
}
if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
// 3D inputs in BSNH format (will be transposed)
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.use_lean_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);

if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.use_lean_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);

// Key (BxLxNxH) => K (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
// Value (BxLxNxH) => V (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else { // unfused kernel
// Key (BxLxNxH) => K (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
// Value (BxLxNxH) => V (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else { // unfused kernel
assert(data.IsUnfused());
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, data.q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
} else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
// Currently, 4D inputs are only supported in unfused kernel for Attention-23.
assert(data.IsUnfused());
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, data.q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
// No need to transpose since QKV is already in BNSH format.
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
} else {
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}

return Status::OK();
Expand Down Expand Up @@ -670,14 +693,27 @@ Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters,
case AttentionQkvFormat::Q_K_V_BSNH:
if (data.past_key != nullptr || data.present_key != nullptr) {
if (data.bias == nullptr) {
DUMP_STRING("PrepareQkv_MHA_WithPast_NoBias");
DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_NoBias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
} else {
DUMP_STRING("PrepareQkv_MHA_WithPast_Bias");
DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_Bias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block));
}
} else { // no past state
DUMP_STRING("PrepareQkv_MHA_NoPast");
DUMP_STRING("PrepareQkv(3D)_MHA_NoPast");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
}
break;
case AttentionQkvFormat::Q_K_V_BNSH:
if (data.past_key != nullptr || data.present_key != nullptr) {
if (data.bias == nullptr) {
DUMP_STRING("PrepareQkv(4D)_MHA_WithPast_NoBias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
} else {
ORT_THROW("Q_K_V_BNSH format with bias is not supported.");
}
} else { // no past state
DUMP_STRING("PrepareQkv(4D)_MHA_NoPast");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
}
break;
Expand Down Expand Up @@ -708,6 +744,8 @@ bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData<T>&
} else { // no past state
return NoQkvWorkspace_MHA_NoPast(data);
}
case AttentionQkvFormat::Q_K_V_BNSH:
return false; // currently no scenario needs no workspace
default:
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}
Expand Down
Loading
Loading