diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index ebed2b1972ba9..41ec3dc348dfc 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -653,6 +653,7 @@ Do not modify directly.*
|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**
or
*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
index f237b24b899a0..889f8124c02d8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
@@ -34,6 +34,7 @@ struct AttentionParameters {
float mask_filter_value;
float scale;
bool use_tf32;
+ bool is_output_bnsh = false; // whether the output format is BNSH
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index ee4be2c20362d..22862f1ba7b4c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -772,8 +772,9 @@ Status UnfusedAttention(
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
- // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
- T* temp_output = data.q;
+ // compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed
+ // For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose.
+ T* temp_output = parameters.is_output_bnsh ? data.output : data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
@@ -781,11 +782,13 @@ Status UnfusedAttention(
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));
- // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
- Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
- device_prop.maxThreadsPerBlock, false, temp_output, data.output);
+ if (!parameters.is_output_bnsh) {
+ // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
+ ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
+ device_prop.maxThreadsPerBlock, false, temp_output, data.output));
+ }
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
- return result;
+ return Status::OK();
}
template
@@ -960,7 +963,7 @@ Status QkvToContext(
auto stream = static_cast(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
@@ -981,12 +984,12 @@ Status QkvToContext(
if (!parameters.past_present_share_buffer) {
ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size,
- sequence_length, total_sequence_length,
+ kv_sequence_length, total_sequence_length,
stream, max_threads_per_block, data));
} else { // past_present_share_buffer
ORT_RETURN_IF_ERROR(PastPresentBufferShare(batch_size, num_heads, qk_head_size, v_head_size,
- sequence_length, fused_runner,
+ kv_sequence_length, fused_runner,
parameters, data, stream, max_threads_per_block));
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index b2e8130ecd17e..7b37a3a4227b6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -251,7 +251,6 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
AttentionData& data,
cudaStream_t stream,
int max_threads_per_block) {
- assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
@@ -262,82 +261,96 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
assert(!parameters.is_unidirectional);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
-
- if (data.fused_cross_attention_kernel != nullptr) {
- assert(qk_head_size == v_head_size);
- assert(data.attention_bias == nullptr);
- assert(data.mask_index == nullptr);
- assert(parameters.hidden_size == parameters.v_hidden_size);
-
- // For fused cross attention, besides adding bias, K and V needed to be packed:
- // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
- data.v = nullptr;
- data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
- } else if (data.use_memory_efficient_attention ||
- data.use_flash_attention ||
- data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
- if (data.bias != nullptr) {
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
- } else {
- data.q = const_cast(data.query);
- data.k = const_cast(data.key);
- data.v = const_cast(data.value);
- }
+ if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ // 3D inputs in BSNH format (will be transposed)
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ if (data.fused_cross_attention_kernel != nullptr) {
+ assert(qk_head_size == v_head_size);
+ assert(data.attention_bias == nullptr);
+ assert(data.mask_index == nullptr);
+ assert(parameters.hidden_size == parameters.v_hidden_size);
+
+ // For fused cross attention, besides adding bias, K and V needed to be packed:
+ // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
+ data.v = nullptr;
+ data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+ } else if (data.use_memory_efficient_attention ||
+ data.use_flash_attention ||
+ data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
+ if (data.bias != nullptr) {
+ LaunchAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, kv_sequence_length,
+ num_heads, qk_head_size, v_head_size,
+ data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
+ } else {
+ data.q = const_cast(data.query);
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ }
- data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- } else if (data.fused_runner != nullptr) {
- assert(qk_head_size == v_head_size);
- assert(data.attention_bias == nullptr);
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ } else if (data.fused_runner != nullptr) {
+ assert(qk_head_size == v_head_size);
+ assert(data.attention_bias == nullptr);
- // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
- data.k = nullptr;
- data.v = nullptr;
+ // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
+ data.k = nullptr;
+ data.v = nullptr;
- data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
- } else { // unfused kernel
+ data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ } else { // unfused kernel
+ assert(data.IsUnfused());
+ // Query (BxSxNxH) => Q (BxNxSxH)
+ constexpr int format = 0;
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, data.q,
+ true, -1);
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
+ true, -1);
+
+ // Value (BxLxNxH_v) => K (BxNxLxH_v)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
+ true, -1);
+
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+ } else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
+ // Currently, 4D inputs are only supported in unfused kernel for Attention-23.
assert(data.IsUnfused());
- // Query (BxSxNxH) => Q (BxNxSxH)
- constexpr int format = 0;
- LaunchAddBiasTranspose(
- stream, 1, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, data.q,
- true, -1);
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(
- stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
- true, -1);
-
- // Value (BxLxNxH_v) => K (BxNxLxH_v)
- LaunchAddBiasTranspose(
- stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
- true, -1);
-
- data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ // Attention-23 does not support bias with Q_K_V_BNSH format.
+ assert(data.bias == nullptr);
+ // No need to transpose since QKV is already in BNSH format.
+ data.q = const_cast(data.query);
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ } else {
+ ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}
return Status::OK();
@@ -360,7 +373,6 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
AttentionData& data,
cudaStream_t stream,
int max_threads_per_block) {
- assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
@@ -373,42 +385,53 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
data.past_key != nullptr && data.past_value != nullptr);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data));
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
-
// When there is no past state and there is present state, we output K and V directly to present state.
if (data.past_key == nullptr && data.present_key != nullptr) {
data.k = data.present_key;
data.v = data.present_value;
}
+ if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ // 3D inputs in BSNH format (will be transposed)
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ if (data.use_memory_efficient_attention ||
+ data.use_flash_attention ||
+ data.use_lean_attention ||
+ data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
+ // Use oiginal Query (BSNH) since there is no bias.
+ data.q = const_cast(data.query);
- if (data.use_memory_efficient_attention ||
- data.use_flash_attention ||
- data.use_lean_attention ||
- data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
- // Use oiginal Query (BSNH) since there is no bias.
- data.q = const_cast(data.query);
-
- // Key (BxLxNxH) => K (BxNxLxH)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.k));
- // Value (BxLxNxH) => V (BxNxLxH)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.v));
- data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
- } else { // unfused kernel
+ // Key (BxLxNxH) => K (BxNxLxH)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.k));
+ // Value (BxLxNxH) => V (BxNxLxH)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.value, data.v));
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
+ } else { // unfused kernel
+ assert(data.IsUnfused());
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.query, data.q));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.k));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.value, data.v));
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+ } else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
+ // Currently, 4D inputs are only supported in unfused kernel for Attention-23.
assert(data.IsUnfused());
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, data.q));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.k));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.v));
- data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ // No need to transpose since QKV is already in BNSH format.
+ data.q = const_cast(data.query);
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ } else {
+ ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}
return Status::OK();
@@ -670,14 +693,27 @@ Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters,
case AttentionQkvFormat::Q_K_V_BSNH:
if (data.past_key != nullptr || data.present_key != nullptr) {
if (data.bias == nullptr) {
- DUMP_STRING("PrepareQkv_MHA_WithPast_NoBias");
+ DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_NoBias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
} else {
- DUMP_STRING("PrepareQkv_MHA_WithPast_Bias");
+ DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_Bias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block));
}
} else { // no past state
- DUMP_STRING("PrepareQkv_MHA_NoPast");
+ DUMP_STRING("PrepareQkv(3D)_MHA_NoPast");
+ ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
+ }
+ break;
+ case AttentionQkvFormat::Q_K_V_BNSH:
+ if (data.past_key != nullptr || data.present_key != nullptr) {
+ if (data.bias == nullptr) {
+ DUMP_STRING("PrepareQkv(4D)_MHA_WithPast_NoBias");
+ ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
+ } else {
+ ORT_THROW("Q_K_V_BNSH format with bias is not supported.");
+ }
+ } else { // no past state
+ DUMP_STRING("PrepareQkv(4D)_MHA_NoPast");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
}
break;
@@ -708,6 +744,8 @@ bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData&
} else { // no past state
return NoQkvWorkspace_MHA_NoPast(data);
}
+ case AttentionQkvFormat::Q_K_V_BNSH:
+ return false; // currently no scenario needs no workspace
default:
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
index 938033644a7d6..be3ae44f7a206 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
@@ -165,7 +165,7 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length,
// Update end position for causal.
int end = valid_end;
if (causal) {
- const int end_causal = total_sequence_length - sequence_length + s + 1;
+ const int end_causal = (total_sequence_length - sequence_length) + s + 1;
if (end_causal < end) {
end = end_causal;
}
@@ -241,7 +241,7 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length,
// Update end position for causal.
int end = valid_end;
if (causal) {
- int end_causal = total_sequence_length - sequence_length + s + 1;
+ const int end_causal = (total_sequence_length - sequence_length) + s + 1;
if (end_causal < end) {
end = end_causal;
}
@@ -333,7 +333,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length,
: float(input[index]);
float thread_data = input_data * rsqrt_head_size;
if (causal) {
- int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length.
+ int from_index = (total_sequence_length - sequence_length) + s; // offset in total sequence length.
if (i > from_index) {
thread_data = -CUDART_INF_F;
}
@@ -439,7 +439,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length,
thread_data = float(input[index]) * rsqrt_head_size;
if (causal) {
- int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length.
+ int from_index = (total_sequence_length - sequence_length) + s; // offset in total sequence length.
if (threadIdx.x > from_index) {
thread_data = -CUDART_INF_F;
}
diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc
index afd5d55664160..164bea191bbef 100644
--- a/onnxruntime/core/providers/cpu/llm/attention.cc
+++ b/onnxruntime/core/providers/cpu/llm/attention.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/cpu/llm/attention.h"
+#include "core/providers/cpu/llm/attention_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
@@ -115,13 +116,13 @@ Attention::Attention(const OpKernelInfo& info) : AttentionBase(info) {
q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0));
int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0));
qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists()
- ? static_cast(mode)
- : QKMatMulOutputMode::kNone;
- ORT_ENFORCE(qk_matmul_output_mode_ == QKMatMulOutputMode::kNone ||
- qk_matmul_output_mode_ == QKMatMulOutputMode::kQK ||
- qk_matmul_output_mode_ == QKMatMulOutputMode::kQKMask ||
- qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftCap ||
- qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftMax,
+ ? static_cast(mode)
+ : attention_helper::QKMatMulOutputMode::kNone;
+ ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQK ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax,
"qk_matmul_output_mode must be 0, 1, 2, or 3.");
// The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed.
scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN());
@@ -140,11 +141,12 @@ Status Attention::Compute(OpKernelContext* context) const {
const Tensor* past_value = context->Input(5);
AttentionParameters parameters;
- std::vector y_shape;
- std::vector present_key_shape;
- std::vector present_value_shape;
- std::vector output_qk_shape;
+ TensorShape y_shape;
+ TensorShape present_key_shape;
+ TensorShape present_value_shape;
+ TensorShape output_qk_shape;
+ // ComputeOutputShapeForAttention also checks the validity of the inputs.
ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention(
Q,
K,
@@ -243,51 +245,49 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
bool delete_mask_data = false;
bool causal = parameters.is_causal && parameters.q_sequence_length > 1;
if (mask_index == nullptr) {
- // No mask = null mask.
+ // No external mask: allocate only if causal behavior needed.
if (causal) {
- size_t mask_data_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T);
- void* allocated_ptr = allocator->Alloc(mask_data_bytes);
- memset(allocated_ptr, 0, mask_data_bytes);
- mask_data = static_cast(allocated_ptr);
- for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
- for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
- mask_data[s_i * parameters.total_sequence_length + m_i] = mask_filter_value();
+ size_t mask_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T);
+ void* raw = allocator->Alloc(mask_bytes);
+ memset(raw, 0, mask_bytes); // start all allowed
+ mask_data = static_cast(raw);
+ for (int s = 0; s < parameters.q_sequence_length; ++s) {
+ for (int t = parameters.past_sequence_length + s + 1; t < parameters.total_sequence_length; ++t) {
+ mask_data[s * parameters.total_sequence_length + t] = mask_filter_value();
}
}
delete_mask_data = true;
}
- } else if (mask_index->IsDataType() || causal) {
- // We need a copy.
- size_t mask_data_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T);
- mask_data = static_cast(allocator->Alloc(mask_data_bytes));
- delete_mask_data = true;
-
- if (mask_index->IsDataType()) {
- // Convert bool mask to 0/1
- make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
- } else if (mask_index != nullptr) {
- // We make a copy because causal is True.
- make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
- }
- if (causal) {
- // This loop could be parallelized.
- // According to the specifications, this configuration is not supported
- // as is_causal=1 or mask is not None (exclusive or).
- int n_iter = mask_batch_size * mask_num_heads;
- for (int i = 0; i < n_iter; ++i) {
- for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
- for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
- mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = mask_filter_value();
+ } else {
+ const bool is_bool_mask = mask_index->IsDataType();
+ const bool need_copy = is_bool_mask || causal; // copy if we must convert or overlay causal pattern
+ if (need_copy) {
+ size_t mask_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T);
+ mask_data = static_cast(allocator->Alloc(mask_bytes));
+ delete_mask_data = true;
+ if (is_bool_mask) {
+ make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
+ } else {
+ make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
+ }
+ if (causal) {
+ // Overlay causal -inf above diagonal for every broadcast slice
+ int slices = mask_batch_size * mask_num_heads;
+ for (int slice = 0; slice < slices; ++slice) {
+ T* base = mask_data + probs_matrix_size * slice;
+ for (int s = 0; s < parameters.q_sequence_length; ++s) {
+ for (int t = parameters.past_sequence_length + s + 1; t < parameters.total_sequence_length; ++t) {
+ base[s * parameters.total_sequence_length + t] = mask_filter_value();
+ }
}
}
}
+ } else {
+ // Reuse mask memory directly (numeric, non-causal)
+ mask_data = const_cast(mask_index->Data());
}
- } else {
- // Nothing to do, no necessary copy.
- mask_data = const_cast(mask_index->Data());
}
- bool transposed_k = parameters.transpose_output && nullptr == present_key;
if (nullptr != present_key && parameters.kv_num_heads != parameters.q_num_heads) {
// This is not part of the main loop because it is not needed at every iteration and
// we cannot ensure the inner body is executed first before getting used in another iteration.
@@ -309,6 +309,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
// If past_key is not null, then we need to concatenate it with K, the concatenation is not transposed.
const int loop_len = parameters.batch_size * parameters.q_num_heads;
const float alpha = parameters.scale;
+ bool transposed_k = parameters.transpose_output && nullptr == present_key;
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h
index c8eff7c5006d6..867207e385dc9 100644
--- a/onnxruntime/core/providers/cpu/llm/attention.h
+++ b/onnxruntime/core/providers/cpu/llm/attention.h
@@ -5,7 +5,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/platform/threadpool.h"
-#include "core/providers/cpu/llm/attention_helper.h"
+#include "core/providers/cpu/llm/attention_parameters.h"
namespace onnxruntime {
@@ -100,4 +100,4 @@ class Attention final : public AttentionBase {
int softmax_precision_;
};
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.cc b/onnxruntime/core/providers/cpu/llm/attention_helper.cc
deleted file mode 100644
index 9bd954f128454..0000000000000
--- a/onnxruntime/core/providers/cpu/llm/attention_helper.cc
+++ /dev/null
@@ -1,156 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/providers/cpu/llm/attention_helper.h"
-#include "core/util/shape_checker.h"
-
-namespace onnxruntime {
-namespace attention_helper {
-
-void AttentionParameters::checkParameters() const {
- ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0");
- ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0");
- ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0");
- ORT_ENFORCE(head_size > 0, "Head size must be greater than 0");
- ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0");
- ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative");
- ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0");
- ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0");
- ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0");
- ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length,
- "Total sequence length must be equal to past sequence length plus KV sequence length");
-}
-
-Status ComputeOutputShapeForAttention(
- const Tensor* Q,
- const Tensor* K,
- const Tensor* V,
- const Tensor* attn_mask,
- const Tensor* past_key,
- const Tensor* past_value,
- bool is_causal,
- float softcap,
- int softmax_precision,
- attention_helper::QKMatMulOutputMode qk_matmul_output_mode,
- int kv_num_heads,
- int q_num_heads,
- float scale,
- AttentionParameters& parameters,
- std::vector& y_shape,
- std::vector& present_key_shape,
- std::vector& present_value_shape,
- std::vector& output_qk_shape) {
- ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
- "Q, K, and V inputs must not be null");
- int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions());
- int k_dims = onnxruntime::narrow(K->Shape().NumDimensions());
- int v_dims = onnxruntime::narrow(V->Shape().NumDimensions());
- ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor");
- ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank.");
- ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank.");
-
- ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null");
- ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)");
- ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)");
- ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)");
- ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)");
- ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)");
-
- parameters.is_causal = is_causal;
- parameters.softcap = softcap;
- parameters.softmax_precision = softmax_precision;
- parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul
- parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D)
-
- ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0");
- ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor");
-
- if (q_dims == 4) {
- // 4D
- parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D)
- parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D)
-
- ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]");
- ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]");
- ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]");
- ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size");
- ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length");
- ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)");
-
- // From shapes
- parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3)
- parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D)
- parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D)
- parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D)
- parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D)
- parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask
- ? 0
- : onnxruntime::narrow(past_key->Shape()[2]);
-
- y_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.q_num_heads),
- static_cast(parameters.q_sequence_length),
- static_cast(parameters.v_head_size)};
- } else {
- // 3D
- parameters.kv_num_heads = kv_num_heads;
- parameters.q_num_heads = q_num_heads;
-
- // From shapes
- ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads");
- ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads");
-
- parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3)
- parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]);
- parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads;
- parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]);
- parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads;
- parameters.past_sequence_length = past_key == nullptr
- ? 0
- : onnxruntime::narrow(past_key->Shape()[2]);
-
- y_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.q_sequence_length),
- static_cast(parameters.q_num_heads * parameters.v_head_size)};
- }
- parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length;
-
- ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified");
- ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length,
- "inconsistent total_sequence_length (between attn_mask and past_key and past_value)");
- ORT_ENFORCE(attn_mask == nullptr ||
- attn_mask->Shape().NumDimensions() < 3 ||
- attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 ||
- attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads,
- "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads");
- ORT_ENFORCE(attn_mask == nullptr ||
- attn_mask->Shape().NumDimensions() < 4 ||
- attn_mask->Shape()[0] == 1 ||
- attn_mask->Shape()[0] == parameters.batch_size,
- "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size");
- ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size);
- ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size);
-
- parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale;
- parameters.checkParameters();
-
- present_key_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.kv_num_heads),
- static_cast(parameters.total_sequence_length),
- static_cast(parameters.head_size)};
- present_value_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.kv_num_heads),
- static_cast(parameters.total_sequence_length),
- static_cast(parameters.v_head_size)};
- if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) {
- output_qk_shape.clear();
- } else {
- output_qk_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.q_num_heads),
- static_cast(parameters.q_sequence_length),
- static_cast(parameters.total_sequence_length)};
- }
- return Status::OK();
-}
-} // namespace attention_helper
-} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h
index 1cea27760408f..4b1e22df4b2f6 100644
--- a/onnxruntime/core/providers/cpu/llm/attention_helper.h
+++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h
@@ -2,54 +2,13 @@
// Licensed under the MIT License.
#pragma once
-#include "core/common/common.h"
-#include "core/providers/common.h"
+#include "core/providers/cpu/llm/attention_parameters.h"
+#include "core/util/shape_checker.h"
namespace onnxruntime {
namespace attention_helper {
-// enum equivalent to the onnx defintion of qk_matmul_output_mode
-enum QKMatMulOutputMode {
- kNone = -1, // No output Q*K
- kQK = 0, // Output Q*K
- kQKMask = 1, // Output Q*K + Mask
- kQKSoftCap = 2, // Output SoftCap(Q*K + Mask)
- kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask))
-};
-
-// Parameters deduced from node attributes and inputs/outputs.
-struct AttentionParameters {
- /*
- * Attention Parameters
- * MHA: q_num_heads == kv_num_heads -> MHA
- * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0
- * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1
- */
- bool is_causal;
- int kv_num_heads; // K.shape[1] or V.shape[1] (4D)
- int q_num_heads; // Q.shape[1] (4D)
- float scale;
- float softcap;
- int softmax_precision;
- QKMatMulOutputMode qk_matmul_output_mode;
-
- // From shapes
- int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D)
- int q_sequence_length; // Q.shape[2] (4D)
- int head_size; // Q.shape[3] or K.shape[3 (4D)
- int kv_sequence_length; // K.shape[2] or V.shape[2] (4D)
- int v_head_size; // V.shape[4] (4D)
- int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D)
- int total_sequence_length; // past_sequence_length + kv_sequence_length
- bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH
- // This covers the case where the inputs are 3D.
-
- // Checks the consistency of the parameters.
- void checkParameters() const;
-};
-
-// Computes the output shape for attention based on the input tensors and parameters.
-Status ComputeOutputShapeForAttention(
+inline Status ComputeOutputShapeForAttention(
const Tensor* Q,
const Tensor* K,
const Tensor* V,
@@ -59,15 +18,122 @@ Status ComputeOutputShapeForAttention(
bool is_causal,
float softcap,
int softmax_precision,
- attention_helper::QKMatMulOutputMode qk_matmul_output_mode,
+ QKMatMulOutputMode qk_matmul_output_mode,
int kv_num_heads,
int q_num_heads,
float scale,
AttentionParameters& parameters,
- std::vector& y_shape,
- std::vector& present_key_shape,
- std::vector& present_value_shape,
- std::vector& output_qk_shape);
+ TensorShape& y_shape,
+ TensorShape& present_key_shape,
+ TensorShape& present_value_shape,
+ TensorShape& output_qk_shape) {
+ ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
+ "Q, K, and V inputs must not be null");
+ int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions());
+ int k_dims = onnxruntime::narrow(K->Shape().NumDimensions());
+ int v_dims = onnxruntime::narrow(V->Shape().NumDimensions());
+ ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor");
+ ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank.");
+ ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank.");
+ ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null");
+ ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)");
+ ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)");
+ ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)");
+ ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)");
+ ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)");
+
+ parameters.is_causal = is_causal;
+ parameters.softcap = softcap;
+ parameters.softmax_precision = softmax_precision;
+ parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul
+ parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D)
+
+ ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0");
+ ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor");
+
+ if (q_dims == 4) {
+ // 4D
+ parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D)
+ parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D)
+
+ ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]");
+ ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]");
+ ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]");
+ ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size");
+ ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length");
+ ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)");
+
+ // From shapes
+ parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3)
+ parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D)
+ parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D)
+ parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D)
+ parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D)
+ parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask
+ ? 0
+ : onnxruntime::narrow(past_key->Shape()[2]);
+
+ y_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.q_num_heads),
+ static_cast(parameters.q_sequence_length),
+ static_cast(parameters.v_head_size)};
+ } else {
+ // 3D
+ parameters.kv_num_heads = kv_num_heads;
+ parameters.q_num_heads = q_num_heads;
+
+ // From shapes
+ ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads");
+ ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads");
+
+ parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3)
+ parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]);
+ parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads;
+ parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]);
+ parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads;
+ parameters.past_sequence_length = past_key == nullptr
+ ? 0
+ : onnxruntime::narrow(past_key->Shape()[2]);
+
+ y_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.q_sequence_length),
+ static_cast(parameters.q_num_heads * parameters.v_head_size)};
+ }
+ parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length;
+
+ ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention.");
+ ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length,
+ "inconsistent total_sequence_length (between attn_mask and past_key and past_value)");
+ ORT_ENFORCE(attn_mask == nullptr ||
+ attn_mask->Shape().NumDimensions() < 3 ||
+ attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 ||
+ attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.q_num_heads,
+ "attn_mask must be broadcastable to (batch_size, q_num_heads, q_sequence_length, total_sequence_length) but is not compatible with q_num_heads");
+ ORT_ENFORCE(attn_mask == nullptr ||
+ attn_mask->Shape().NumDimensions() < 4 ||
+ attn_mask->Shape()[0] == 1 ||
+ attn_mask->Shape()[0] == parameters.batch_size,
+ "attn_mask must be broadcastable to (batch_size, q_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size");
+ ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size);
+ ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size);
+
+ parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale;
+ parameters.checkParameters();
+
+ present_key_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.kv_num_heads),
+ static_cast(parameters.total_sequence_length),
+ static_cast(parameters.head_size)};
+ present_value_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.kv_num_heads),
+ static_cast(parameters.total_sequence_length),
+ static_cast(parameters.v_head_size)};
+ output_qk_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.q_num_heads),
+ static_cast(parameters.q_sequence_length),
+ static_cast(parameters.total_sequence_length)};
+ return Status::OK();
+}
} // namespace attention_helper
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/llm/attention_parameters.h b/onnxruntime/core/providers/cpu/llm/attention_parameters.h
new file mode 100644
index 0000000000000..b8586ca4d63dc
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/llm/attention_parameters.h
@@ -0,0 +1,65 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/common.h"
+
+namespace onnxruntime {
+// Declares enum QKMatMulOutputMode and struct AttentionParameters inside namespace onnxruntime::attention_helper.
+namespace attention_helper {
+
+// enum equivalent to the onnx defintion of qk_matmul_output_mode
+enum QKMatMulOutputMode {
+ kNone = -1, // No output Q*K
+ kQK = 0, // Output Q*K
+ kQKMask = 1, // Output Q*K + Mask
+ kQKSoftCap = 2, // Output SoftCap(Q*K + Mask)
+ kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask))
+};
+
+// Parameters deduced from node attributes and inputs/outputs.
+struct AttentionParameters {
+ /*
+ * Attention Parameters
+ * MHA: q_num_heads == kv_num_heads -> MHA
+ * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0
+ * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1
+ */
+ bool is_causal;
+ int kv_num_heads; // K.shape[1] or V.shape[1] (4D)
+ int q_num_heads; // Q.shape[1] (4D)
+ float scale;
+ float softcap;
+ int softmax_precision;
+ QKMatMulOutputMode qk_matmul_output_mode;
+
+ // From shapes
+ int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D)
+ int q_sequence_length; // Q.shape[2] (4D)
+ int head_size; // Q.shape[3] or K.shape[3 (4D)
+ int kv_sequence_length; // K.shape[2] or V.shape[2] (4D)
+ int v_head_size; // V.shape[4] (4D)
+ int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D)
+ int total_sequence_length; // past_sequence_length + kv_sequence_length
+ bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH
+ // This covers the case where the inputs are 3D.
+
+ // Checks the consistency of the parameters.
+ void checkParameters() const {
+ ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0");
+ ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0");
+ ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0");
+ ORT_ENFORCE(head_size > 0, "Head size must be greater than 0");
+ ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0");
+ ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative");
+ ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0");
+ ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0");
+ ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0");
+ ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length,
+ "Total sequence length must be equal to past sequence length plus KV sequence length");
+ }
+};
+
+} // namespace attention_helper
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index eab616388d6ae..eb29e4edbf897 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1585,6 +1585,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
// Opset 23.
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization);
@@ -2653,6 +2656,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 23
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc
new file mode 100644
index 0000000000000..99f297bba6444
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/attention.cc
@@ -0,0 +1,265 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cpu/llm/attention_helper.h"
+#include "core/providers/cuda/llm/attention.h"
+#include "contrib_ops/cuda/bert/attention_data.h"
+#include "contrib_ops/cuda/bert/attention_impl.h"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ Attention, \
+ kOnnxDomain, \
+ 23, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("U", BuildKernelDefConstraints()), \
+ Attention);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
+
+template
+Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) {
+ is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1;
+ // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs.
+ // The dimension is not yet known. If not specified, the inputs is assumed to be 4D.
+ kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0));
+ q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0));
+ int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0));
+ qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists()
+ ? static_cast(mode)
+ : attention_helper::QKMatMulOutputMode::kNone;
+ ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQK ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap ||
+ qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax,
+ "qk_matmul_output_mode must be 0, 1, 2, or 3.");
+ // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed.
+ scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN());
+ softcap_ = info.GetAttrOrDefault("softcap", 0.0f);
+ softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0));
+ ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified");
+}
+
+template
+Status Attention::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* Q = context->Input(0);
+ const Tensor* K = context->Input(1);
+ const Tensor* V = context->Input(2);
+ const Tensor* attn_mask = context->Input(3);
+ const Tensor* past_key = context->Input(4);
+ const Tensor* past_value = context->Input(5);
+
+ attention_helper::AttentionParameters parameters;
+ TensorShape y_shape;
+ TensorShape present_key_shape;
+ TensorShape present_value_shape;
+ TensorShape output_qk_shape;
+
+ ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention(
+ Q,
+ K,
+ V,
+ attn_mask,
+ past_key,
+ past_value,
+ is_causal_,
+ softcap_,
+ softmax_precision_,
+ qk_matmul_output_mode_,
+ kv_num_heads_,
+ q_num_heads_,
+ scale_,
+ parameters,
+ y_shape,
+ present_key_shape,
+ present_value_shape,
+ output_qk_shape)
+ .IsOK(),
+ "Output shapes for Attention could not be computed.");
+
+ Tensor* Y = context->Output(0, y_shape);
+ Tensor* present_key = context->Output(1, present_key_shape);
+ Tensor* present_value = context->Output(2, present_value_shape);
+ Tensor* output_qk = context->Output(3, output_qk_shape);
+
+ // To reuse the existing attention-cuda implementation in contrib ops,
+ // map the parameters to contribop_parameters.
+ onnxruntime::contrib::AttentionParameters contribop_parameters;
+ contribop_parameters.batch_size = parameters.batch_size;
+ contribop_parameters.sequence_length = parameters.q_sequence_length;
+ contribop_parameters.kv_sequence_length = parameters.kv_sequence_length;
+ contribop_parameters.past_sequence_length = parameters.past_sequence_length;
+ contribop_parameters.total_sequence_length = parameters.total_sequence_length;
+ // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size)
+ contribop_parameters.max_sequence_length = parameters.total_sequence_length;
+ contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V
+ contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
+ contribop_parameters.head_size = parameters.head_size;
+ contribop_parameters.v_head_size = parameters.v_head_size;
+ contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
+ contribop_parameters.num_heads = parameters.q_num_heads;
+ contribop_parameters.rotary_dim = 0;
+ contribop_parameters.num_splits = 1;
+ contribop_parameters.beam_width = 1;
+ contribop_parameters.is_unidirectional = parameters.is_causal;
+ contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer
+ contribop_parameters.is_packed_qkv = false;
+ contribop_parameters.do_rotary = false;
+
+ // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask
+ // So mask_type should always be MASK_NONE since we don't have a separate padding mask input
+ contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
+
+ // Determine broadcast flags for attention_bias (if it exists)
+ // Note: The new Attention op uses attn_mask as attention_bias
+ // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length)
+ // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions).
+ if (attn_mask != nullptr) {
+ // TODO(titaiwang, xadupre): attn_mask bool is not supported yet
+ if (attn_mask->IsDataType()) {
+ ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA).");
+ }
+
+ size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
+ auto attn_mask_dims = attn_mask->Shape().GetDims();
+ // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting
+ // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting
+ // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1
+
+ if (attn_mask_dims_size == 2) {
+ // 2D mask: both dimensions need broadcasting
+ contribop_parameters.broadcast_attn_bias_dim_0 = true;
+ contribop_parameters.broadcast_attn_bias_dim_1 = true;
+ } else if (attn_mask_dims_size == 3) {
+ // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts
+ contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
+ contribop_parameters.broadcast_attn_bias_dim_1 = true;
+ } else {
+ // 4D mask: check both dim 0 and dim 1 explicitly
+ contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
+ contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1;
+ }
+ } else {
+ contribop_parameters.broadcast_attn_bias_dim_0 = false;
+ contribop_parameters.broadcast_attn_bias_dim_1 = false;
+ }
+
+ contribop_parameters.mask_filter_value = -10000.0f;
+ contribop_parameters.scale = parameters.scale;
+ contribop_parameters.use_tf32 = UseTF32();
+
+ // QKV format: Determine based on input dimensions
+ // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
+ // transpose_output is true for 3D inputs, false for 4D inputs
+ if (!parameters.transpose_output) {
+ contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
+ contribop_parameters.is_output_bnsh = true;
+ } else {
+ // 3D inputs in BSNH format (will be transposed)
+ contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
+ contribop_parameters.is_output_bnsh = false;
+ }
+
+ // TODO(titaiwang, xadupre): Group query attention is not supported yet
+ if (parameters.kv_num_heads != parameters.q_num_heads) {
+ ORT_THROW("Group query attention is not supported yet in Attention op (CUDA).");
+ }
+
+ // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now
+ if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone &&
+ qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) {
+ ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA).");
+ }
+ // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet
+ if (parameters.softcap != 0.0f) {
+ ORT_THROW("softcap is not supported yet in Attention op (CUDA).");
+ }
+ if (parameters.softmax_precision != 0) {
+ ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
+ }
+
+ // TODO(titaiwang): Continue on these parameters
+ // Construct AttentionData to pass to QkvToContext
+ typedef typename ToCudaType::MappedType CudaT;
+ onnxruntime::contrib::cuda::AttentionData data;
+
+ // Set input pointers
+ data.query = reinterpret_cast(Q->Data());
+ data.key = reinterpret_cast(K->Data());
+ data.value = reinterpret_cast(V->Data());
+ data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask
+ data.mask_index_dims = gsl::span();
+ data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data());
+ data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data());
+
+ // Set output pointers
+ data.output = reinterpret_cast(Y->MutableData());
+ data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData());
+ data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData());
+ if (nullptr != output_qk) {
+ data.output_qk = reinterpret_cast(output_qk->MutableData());
+ }
+
+ // Set additional fields
+ data.bias = nullptr; // New Attention op doesn't have bias
+ if (nullptr != attn_mask) {
+ data.attention_bias = reinterpret_cast(attn_mask->Data());
+ }
+ data.qkv_format = contribop_parameters.qkv_format;
+
+ // TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.)
+ // For now, set flags to false and let QkvToContext use the unfused path
+ data.use_flash_attention = false;
+ data.use_memory_efficient_attention = false;
+ data.fused_runner = nullptr;
+ data.fused_cross_attention_kernel = nullptr;
+
+ // Allocate workspace for Q, K, V processing and scratch buffer
+ const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data);
+ size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize(
+ sizeof(T),
+ contribop_parameters.batch_size,
+ contribop_parameters.num_heads,
+ contribop_parameters.head_size,
+ contribop_parameters.v_head_size,
+ contribop_parameters.sequence_length,
+ contribop_parameters.kv_sequence_length,
+ contribop_parameters.total_sequence_length,
+ nullptr, // fused_runner
+ false, // use_flash_attention
+ false, // use_lean_attention
+ false, // use_fused_cross_attention
+ false, // use_memory_efficient_attention
+ false, // use_cudnn_flash_attention
+ no_qkv_workspace);
+ auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream());
+
+ data.has_qkv_workspace = !no_qkv_workspace;
+ data.workspace = reinterpret_cast(work_space.get());
+ data.workspace_bytes = workspace_bytes;
+
+ // Call QkvToContext to perform the attention computation
+ auto& device_prop = GetDeviceProp();
+ cublasHandle_t cublas = GetCublasHandle(context);
+ cudnnHandle_t cudnn = GetCudnnHandle(context);
+
+ // QkvToContext takes two template parameters: T for computation type, QK for output_qk type
+ // For now, both are the same type (CudaT)
+ return onnxruntime::contrib::cuda::QkvToContext(
+ device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data);
+}
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h
new file mode 100644
index 0000000000000..17e99bb935e1a
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/attention.h
@@ -0,0 +1,30 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class Attention final : public CudaKernel {
+ public:
+ Attention(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ protected:
+ bool is_causal_;
+ int kv_num_heads_;
+ int q_num_heads_;
+ attention_helper::QKMatMulOutputMode qk_matmul_output_mode_;
+ float scale_;
+ float softcap_;
+ int softmax_precision_;
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
index 894130d6ee991..5c5c2efdb50cd 100644
--- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
+++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
@@ -363,7 +363,7 @@ TEST(AttentionTest, Attention3DDefault) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector(), std::vector(), std::vector(),
- false, true, true // disable_cpu, disable_cuda, disable_dml
+ false, false, true // disable_cpu, disable_cuda, disable_dml
);
}
@@ -390,7 +390,7 @@ TEST(AttentionTest, Attention3DDefaultFloat16) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector(), std::vector(), std::vector(),
- false, true, true // disable_cpu, disable_cuda, disable_dml
+ false, false, true // disable_cpu, disable_cuda, disable_dml
);
}
@@ -417,7 +417,7 @@ TEST(AttentionTest, Attention4DDefaultBasic) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector(), std::vector(), std::vector(),
- false, true, true // disable_cpu, disable_cuda, disable_dml
+ false, false, true // disable_cpu, disable_cuda, disable_dml
);
}
@@ -444,7 +444,7 @@ TEST(AttentionTest, Attention4DDefault) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits