diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 41ec3dc348dfc..ebed2b1972ba9 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -653,7 +653,6 @@ Do not modify directly.*
|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**
or
*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
index 1d966c3a6d169..f237b24b899a0 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
@@ -34,7 +34,6 @@ struct AttentionParameters {
float mask_filter_value;
float scale;
bool use_tf32;
- bool is_output_bnsh; // whether the output format is BNSH
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 22862f1ba7b4c..ee4be2c20362d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -772,9 +772,8 @@ Status UnfusedAttention(
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
- // compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed
- // For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose.
- T* temp_output = parameters.is_output_bnsh ? data.output : data.q;
+ // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
+ T* temp_output = data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
@@ -782,13 +781,11 @@ Status UnfusedAttention(
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));
- if (!parameters.is_output_bnsh) {
- // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
- device_prop.maxThreadsPerBlock, false, temp_output, data.output));
- }
+ // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
+ Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
+ device_prop.maxThreadsPerBlock, false, temp_output, data.output);
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
- return Status::OK();
+ return result;
}
template
@@ -963,7 +960,7 @@ Status QkvToContext(
auto stream = static_cast(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
- const int kv_sequence_length = parameters.kv_sequence_length;
+ const int sequence_length = parameters.sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
@@ -984,12 +981,12 @@ Status QkvToContext(
if (!parameters.past_present_share_buffer) {
ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size,
- kv_sequence_length, total_sequence_length,
+ sequence_length, total_sequence_length,
stream, max_threads_per_block, data));
} else { // past_present_share_buffer
ORT_RETURN_IF_ERROR(PastPresentBufferShare(batch_size, num_heads, qk_head_size, v_head_size,
- kv_sequence_length, fused_runner,
+ sequence_length, fused_runner,
parameters, data, stream, max_threads_per_block));
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 7b37a3a4227b6..b2e8130ecd17e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -251,6 +251,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
AttentionData& data,
cudaStream_t stream,
int max_threads_per_block) {
+ assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
@@ -261,96 +262,82 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
assert(!parameters.is_unidirectional);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));
- if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
- // 3D inputs in BSNH format (will be transposed)
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
-
- if (data.fused_cross_attention_kernel != nullptr) {
- assert(qk_head_size == v_head_size);
- assert(data.attention_bias == nullptr);
- assert(data.mask_index == nullptr);
- assert(parameters.hidden_size == parameters.v_hidden_size);
-
- // For fused cross attention, besides adding bias, K and V needed to be packed:
- // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
- data.v = nullptr;
- data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
- } else if (data.use_memory_efficient_attention ||
- data.use_flash_attention ||
- data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
- if (data.bias != nullptr) {
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
- } else {
- data.q = const_cast(data.query);
- data.k = const_cast(data.key);
- data.v = const_cast(data.value);
- }
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
- data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- } else if (data.fused_runner != nullptr) {
- assert(qk_head_size == v_head_size);
- assert(data.attention_bias == nullptr);
+ if (data.fused_cross_attention_kernel != nullptr) {
+ assert(qk_head_size == v_head_size);
+ assert(data.attention_bias == nullptr);
+ assert(data.mask_index == nullptr);
+ assert(parameters.hidden_size == parameters.v_hidden_size);
+
+ // For fused cross attention, besides adding bias, K and V needed to be packed:
+ // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
+ data.v = nullptr;
+ data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+ } else if (data.use_memory_efficient_attention ||
+ data.use_flash_attention ||
+ data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
+ if (data.bias != nullptr) {
+ LaunchAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, kv_sequence_length,
+ num_heads, qk_head_size, v_head_size,
+ data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
+ } else {
+ data.q = const_cast(data.query);
+ data.k = const_cast(data.key);
+ data.v = const_cast(data.value);
+ }
- // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
- data.k = nullptr;
- data.v = nullptr;
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ } else if (data.fused_runner != nullptr) {
+ assert(qk_head_size == v_head_size);
+ assert(data.attention_bias == nullptr);
- data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
- } else { // unfused kernel
- assert(data.IsUnfused());
- // Query (BxSxNxH) => Q (BxNxSxH)
- constexpr int format = 0;
- LaunchAddBiasTranspose(
- stream, 1, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, data.q,
- true, -1);
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(
- stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
- true, -1);
-
- // Value (BxLxNxH_v) => K (BxNxLxH_v)
- LaunchAddBiasTranspose(
- stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
- true, -1);
-
- data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
- }
- } else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
- // Currently, 4D inputs are only supported in unfused kernel for Attention-23.
+ // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
+ data.k = nullptr;
+ data.v = nullptr;
+
+ data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ } else { // unfused kernel
assert(data.IsUnfused());
- // Attention-23 does not support bias with Q_K_V_BNSH format.
- assert(data.bias == nullptr);
- // No need to transpose since QKV is already in BNSH format.
- data.q = const_cast(data.query);
- data.k = const_cast(data.key);
- data.v = const_cast(data.value);
- } else {
- ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
+ // Query (BxSxNxH) => Q (BxNxSxH)
+ constexpr int format = 0;
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, data.q,
+ true, -1);
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
+ true, -1);
+
+ // Value (BxLxNxH_v) => K (BxNxLxH_v)
+ LaunchAddBiasTranspose(
+ stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
+ true, -1);
+
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
return Status::OK();
@@ -373,6 +360,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
AttentionData& data,
cudaStream_t stream,
int max_threads_per_block) {
+ assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
@@ -385,53 +373,42 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
data.past_key != nullptr && data.past_value != nullptr);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data));
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
// When there is no past state and there is present state, we output K and V directly to present state.
if (data.past_key == nullptr && data.present_key != nullptr) {
data.k = data.present_key;
data.v = data.present_value;
}
- if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
- // 3D inputs in BSNH format (will be transposed)
- const int batch_size = parameters.batch_size;
- const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
- const int num_heads = parameters.num_heads;
- const int qk_head_size = parameters.head_size;
- const int v_head_size = parameters.v_head_size;
-
- if (data.use_memory_efficient_attention ||
- data.use_flash_attention ||
- data.use_lean_attention ||
- data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
- // Use oiginal Query (BSNH) since there is no bias.
- data.q = const_cast(data.query);
- // Key (BxLxNxH) => K (BxNxLxH)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.k));
- // Value (BxLxNxH) => V (BxNxLxH)
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.v));
- data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
- } else { // unfused kernel
- assert(data.IsUnfused());
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, data.q));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.k));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.v));
- data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
- }
- } else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
- // Currently, 4D inputs are only supported in unfused kernel for Attention-23.
- assert(data.IsUnfused());
- // No need to transpose since QKV is already in BNSH format.
+ if (data.use_memory_efficient_attention ||
+ data.use_flash_attention ||
+ data.use_lean_attention ||
+ data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
+ // Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast(data.query);
- data.k = const_cast(data.key);
- data.v = const_cast(data.value);
- } else {
- ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.k));
+ // Value (BxLxNxH) => V (BxNxLxH)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.value, data.v));
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
+ } else { // unfused kernel
+ assert(data.IsUnfused());
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.query, data.q));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.k));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.value, data.v));
+ data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
return Status::OK();
@@ -693,27 +670,14 @@ Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters,
case AttentionQkvFormat::Q_K_V_BSNH:
if (data.past_key != nullptr || data.present_key != nullptr) {
if (data.bias == nullptr) {
- DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_NoBias");
+ DUMP_STRING("PrepareQkv_MHA_WithPast_NoBias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
} else {
- DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_Bias");
+ DUMP_STRING("PrepareQkv_MHA_WithPast_Bias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block));
}
} else { // no past state
- DUMP_STRING("PrepareQkv(3D)_MHA_NoPast");
- ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
- }
- break;
- case AttentionQkvFormat::Q_K_V_BNSH:
- if (data.past_key != nullptr || data.present_key != nullptr) {
- if (data.bias == nullptr) {
- DUMP_STRING("PrepareQkv(4D)_MHA_WithPast_NoBias");
- ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
- } else {
- ORT_THROW("Q_K_V_BNSH format with bias is not supported.");
- }
- } else { // no past state
- DUMP_STRING("PrepareQkv(4D)_MHA_NoPast");
+ DUMP_STRING("PrepareQkv_MHA_NoPast");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
}
break;
@@ -744,8 +708,6 @@ bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData&
} else { // no past state
return NoQkvWorkspace_MHA_NoPast(data);
}
- case AttentionQkvFormat::Q_K_V_BNSH:
- return false; // currently no scenario needs no workspace
default:
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
index be3ae44f7a206..938033644a7d6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
@@ -165,7 +165,7 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length,
// Update end position for causal.
int end = valid_end;
if (causal) {
- const int end_causal = (total_sequence_length - sequence_length) + s + 1;
+ const int end_causal = total_sequence_length - sequence_length + s + 1;
if (end_causal < end) {
end = end_causal;
}
@@ -241,7 +241,7 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length,
// Update end position for causal.
int end = valid_end;
if (causal) {
- const int end_causal = (total_sequence_length - sequence_length) + s + 1;
+ int end_causal = total_sequence_length - sequence_length + s + 1;
if (end_causal < end) {
end = end_causal;
}
@@ -333,7 +333,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length,
: float(input[index]);
float thread_data = input_data * rsqrt_head_size;
if (causal) {
- int from_index = (total_sequence_length - sequence_length) + s; // offset in total sequence length.
+ int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length.
if (i > from_index) {
thread_data = -CUDART_INF_F;
}
@@ -439,7 +439,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length,
thread_data = float(input[index]) * rsqrt_head_size;
if (causal) {
- int from_index = (total_sequence_length - sequence_length) + s; // offset in total sequence length.
+ int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length.
if (threadIdx.x > from_index) {
thread_data = -CUDART_INF_F;
}
diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc
index 164bea191bbef..afd5d55664160 100644
--- a/onnxruntime/core/providers/cpu/llm/attention.cc
+++ b/onnxruntime/core/providers/cpu/llm/attention.cc
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
#include "core/providers/cpu/llm/attention.h"
-#include "core/providers/cpu/llm/attention_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
@@ -116,13 +115,13 @@ Attention::Attention(const OpKernelInfo& info) : AttentionBase(info) {
q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0));
int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0));
qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists()
- ? static_cast(mode)
- : attention_helper::QKMatMulOutputMode::kNone;
- ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQK ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax,
+ ? static_cast(mode)
+ : QKMatMulOutputMode::kNone;
+ ORT_ENFORCE(qk_matmul_output_mode_ == QKMatMulOutputMode::kNone ||
+ qk_matmul_output_mode_ == QKMatMulOutputMode::kQK ||
+ qk_matmul_output_mode_ == QKMatMulOutputMode::kQKMask ||
+ qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftCap ||
+ qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftMax,
"qk_matmul_output_mode must be 0, 1, 2, or 3.");
// The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed.
scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN());
@@ -141,12 +140,11 @@ Status Attention::Compute(OpKernelContext* context) const {
const Tensor* past_value = context->Input(5);
AttentionParameters parameters;
- TensorShape y_shape;
- TensorShape present_key_shape;
- TensorShape present_value_shape;
- TensorShape output_qk_shape;
+ std::vector y_shape;
+ std::vector present_key_shape;
+ std::vector present_value_shape;
+ std::vector output_qk_shape;
- // ComputeOutputShapeForAttention also checks the validity of the inputs.
ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention(
Q,
K,
@@ -245,49 +243,51 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
bool delete_mask_data = false;
bool causal = parameters.is_causal && parameters.q_sequence_length > 1;
if (mask_index == nullptr) {
- // No external mask: allocate only if causal behavior needed.
+ // No mask = null mask.
if (causal) {
- size_t mask_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T);
- void* raw = allocator->Alloc(mask_bytes);
- memset(raw, 0, mask_bytes); // start all allowed
- mask_data = static_cast(raw);
- for (int s = 0; s < parameters.q_sequence_length; ++s) {
- for (int t = parameters.past_sequence_length + s + 1; t < parameters.total_sequence_length; ++t) {
- mask_data[s * parameters.total_sequence_length + t] = mask_filter_value();
+ size_t mask_data_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T);
+ void* allocated_ptr = allocator->Alloc(mask_data_bytes);
+ memset(allocated_ptr, 0, mask_data_bytes);
+ mask_data = static_cast(allocated_ptr);
+ for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
+ for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
+ mask_data[s_i * parameters.total_sequence_length + m_i] = mask_filter_value();
}
}
delete_mask_data = true;
}
- } else {
- const bool is_bool_mask = mask_index->IsDataType();
- const bool need_copy = is_bool_mask || causal; // copy if we must convert or overlay causal pattern
- if (need_copy) {
- size_t mask_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T);
- mask_data = static_cast(allocator->Alloc(mask_bytes));
- delete_mask_data = true;
- if (is_bool_mask) {
- make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
- } else {
- make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
- }
- if (causal) {
- // Overlay causal -inf above diagonal for every broadcast slice
- int slices = mask_batch_size * mask_num_heads;
- for (int slice = 0; slice < slices; ++slice) {
- T* base = mask_data + probs_matrix_size * slice;
- for (int s = 0; s < parameters.q_sequence_length; ++s) {
- for (int t = parameters.past_sequence_length + s + 1; t < parameters.total_sequence_length; ++t) {
- base[s * parameters.total_sequence_length + t] = mask_filter_value();
- }
+ } else if (mask_index->IsDataType() || causal) {
+ // We need a copy.
+ size_t mask_data_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T);
+ mask_data = static_cast(allocator->Alloc(mask_data_bytes));
+ delete_mask_data = true;
+
+ if (mask_index->IsDataType()) {
+ // Convert bool mask to 0/1
+ make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
+ } else if (mask_index != nullptr) {
+ // We make a copy because causal is True.
+ make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size()));
+ }
+ if (causal) {
+ // This loop could be parallelized.
+ // According to the specifications, this configuration is not supported
+ // as is_causal=1 or mask is not None (exclusive or).
+ int n_iter = mask_batch_size * mask_num_heads;
+ for (int i = 0; i < n_iter; ++i) {
+ for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
+ for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
+ mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = mask_filter_value();
}
}
}
- } else {
- // Reuse mask memory directly (numeric, non-causal)
- mask_data = const_cast(mask_index->Data());
}
+ } else {
+ // Nothing to do, no necessary copy.
+ mask_data = const_cast(mask_index->Data());
}
+ bool transposed_k = parameters.transpose_output && nullptr == present_key;
if (nullptr != present_key && parameters.kv_num_heads != parameters.q_num_heads) {
// This is not part of the main loop because it is not needed at every iteration and
// we cannot ensure the inner body is executed first before getting used in another iteration.
@@ -309,7 +309,6 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs,
// If past_key is not null, then we need to concatenate it with K, the concatenation is not transposed.
const int loop_len = parameters.batch_size * parameters.q_num_heads;
const float alpha = parameters.scale;
- bool transposed_k = parameters.transpose_output && nullptr == present_key;
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h
index 867207e385dc9..c8eff7c5006d6 100644
--- a/onnxruntime/core/providers/cpu/llm/attention.h
+++ b/onnxruntime/core/providers/cpu/llm/attention.h
@@ -5,7 +5,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/platform/threadpool.h"
-#include "core/providers/cpu/llm/attention_parameters.h"
+#include "core/providers/cpu/llm/attention_helper.h"
namespace onnxruntime {
@@ -100,4 +100,4 @@ class Attention final : public AttentionBase {
int softmax_precision_;
};
-} // namespace onnxruntime
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.cc b/onnxruntime/core/providers/cpu/llm/attention_helper.cc
new file mode 100644
index 0000000000000..9bd954f128454
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/llm/attention_helper.cc
@@ -0,0 +1,156 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cpu/llm/attention_helper.h"
+#include "core/util/shape_checker.h"
+
+namespace onnxruntime {
+namespace attention_helper {
+
+void AttentionParameters::checkParameters() const {
+ ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0");
+ ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0");
+ ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0");
+ ORT_ENFORCE(head_size > 0, "Head size must be greater than 0");
+ ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0");
+ ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative");
+ ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0");
+ ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0");
+ ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0");
+ ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length,
+ "Total sequence length must be equal to past sequence length plus KV sequence length");
+}
+
+Status ComputeOutputShapeForAttention(
+ const Tensor* Q,
+ const Tensor* K,
+ const Tensor* V,
+ const Tensor* attn_mask,
+ const Tensor* past_key,
+ const Tensor* past_value,
+ bool is_causal,
+ float softcap,
+ int softmax_precision,
+ attention_helper::QKMatMulOutputMode qk_matmul_output_mode,
+ int kv_num_heads,
+ int q_num_heads,
+ float scale,
+ AttentionParameters& parameters,
+ std::vector& y_shape,
+ std::vector& present_key_shape,
+ std::vector& present_value_shape,
+ std::vector& output_qk_shape) {
+ ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
+ "Q, K, and V inputs must not be null");
+ int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions());
+ int k_dims = onnxruntime::narrow(K->Shape().NumDimensions());
+ int v_dims = onnxruntime::narrow(V->Shape().NumDimensions());
+ ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor");
+ ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank.");
+ ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank.");
+
+ ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null");
+ ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)");
+ ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)");
+ ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)");
+ ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)");
+ ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)");
+
+ parameters.is_causal = is_causal;
+ parameters.softcap = softcap;
+ parameters.softmax_precision = softmax_precision;
+ parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul
+ parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D)
+
+ ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0");
+ ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor");
+
+ if (q_dims == 4) {
+ // 4D
+ parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D)
+ parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D)
+
+ ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]");
+ ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]");
+ ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]");
+ ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size");
+ ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length");
+ ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)");
+
+ // From shapes
+ parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3)
+ parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D)
+ parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D)
+ parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D)
+ parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D)
+ parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask
+ ? 0
+ : onnxruntime::narrow(past_key->Shape()[2]);
+
+ y_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.q_num_heads),
+ static_cast(parameters.q_sequence_length),
+ static_cast(parameters.v_head_size)};
+ } else {
+ // 3D
+ parameters.kv_num_heads = kv_num_heads;
+ parameters.q_num_heads = q_num_heads;
+
+ // From shapes
+ ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads");
+ ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads");
+
+ parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3)
+ parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]);
+ parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads;
+ parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]);
+ parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads;
+ parameters.past_sequence_length = past_key == nullptr
+ ? 0
+ : onnxruntime::narrow(past_key->Shape()[2]);
+
+ y_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.q_sequence_length),
+ static_cast(parameters.q_num_heads * parameters.v_head_size)};
+ }
+ parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length;
+
+ ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified");
+ ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length,
+ "inconsistent total_sequence_length (between attn_mask and past_key and past_value)");
+ ORT_ENFORCE(attn_mask == nullptr ||
+ attn_mask->Shape().NumDimensions() < 3 ||
+ attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 ||
+ attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads,
+ "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads");
+ ORT_ENFORCE(attn_mask == nullptr ||
+ attn_mask->Shape().NumDimensions() < 4 ||
+ attn_mask->Shape()[0] == 1 ||
+ attn_mask->Shape()[0] == parameters.batch_size,
+ "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size");
+ ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size);
+ ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size);
+
+ parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale;
+ parameters.checkParameters();
+
+ present_key_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.kv_num_heads),
+ static_cast(parameters.total_sequence_length),
+ static_cast(parameters.head_size)};
+ present_value_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.kv_num_heads),
+ static_cast(parameters.total_sequence_length),
+ static_cast(parameters.v_head_size)};
+ if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) {
+ output_qk_shape.clear();
+ } else {
+ output_qk_shape = {static_cast(parameters.batch_size),
+ static_cast(parameters.q_num_heads),
+ static_cast(parameters.q_sequence_length),
+ static_cast(parameters.total_sequence_length)};
+ }
+ return Status::OK();
+}
+} // namespace attention_helper
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h
index 4b1e22df4b2f6..1cea27760408f 100644
--- a/onnxruntime/core/providers/cpu/llm/attention_helper.h
+++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h
@@ -2,13 +2,54 @@
// Licensed under the MIT License.
#pragma once
-#include "core/providers/cpu/llm/attention_parameters.h"
-#include "core/util/shape_checker.h"
+#include "core/common/common.h"
+#include "core/providers/common.h"
namespace onnxruntime {
namespace attention_helper {
-inline Status ComputeOutputShapeForAttention(
+// enum equivalent to the onnx defintion of qk_matmul_output_mode
+enum QKMatMulOutputMode {
+ kNone = -1, // No output Q*K
+ kQK = 0, // Output Q*K
+ kQKMask = 1, // Output Q*K + Mask
+ kQKSoftCap = 2, // Output SoftCap(Q*K + Mask)
+ kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask))
+};
+
+// Parameters deduced from node attributes and inputs/outputs.
+struct AttentionParameters {
+ /*
+ * Attention Parameters
+ * MHA: q_num_heads == kv_num_heads -> MHA
+ * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0
+ * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1
+ */
+ bool is_causal;
+ int kv_num_heads; // K.shape[1] or V.shape[1] (4D)
+ int q_num_heads; // Q.shape[1] (4D)
+ float scale;
+ float softcap;
+ int softmax_precision;
+ QKMatMulOutputMode qk_matmul_output_mode;
+
+ // From shapes
+ int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D)
+ int q_sequence_length; // Q.shape[2] (4D)
+ int head_size; // Q.shape[3] or K.shape[3 (4D)
+ int kv_sequence_length; // K.shape[2] or V.shape[2] (4D)
+ int v_head_size; // V.shape[4] (4D)
+ int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D)
+ int total_sequence_length; // past_sequence_length + kv_sequence_length
+ bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH
+ // This covers the case where the inputs are 3D.
+
+ // Checks the consistency of the parameters.
+ void checkParameters() const;
+};
+
+// Computes the output shape for attention based on the input tensors and parameters.
+Status ComputeOutputShapeForAttention(
const Tensor* Q,
const Tensor* K,
const Tensor* V,
@@ -18,122 +59,15 @@ inline Status ComputeOutputShapeForAttention(
bool is_causal,
float softcap,
int softmax_precision,
- QKMatMulOutputMode qk_matmul_output_mode,
+ attention_helper::QKMatMulOutputMode qk_matmul_output_mode,
int kv_num_heads,
int q_num_heads,
float scale,
AttentionParameters& parameters,
- TensorShape& y_shape,
- TensorShape& present_key_shape,
- TensorShape& present_value_shape,
- TensorShape& output_qk_shape) {
- ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
- "Q, K, and V inputs must not be null");
- int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions());
- int k_dims = onnxruntime::narrow(K->Shape().NumDimensions());
- int v_dims = onnxruntime::narrow(V->Shape().NumDimensions());
- ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor");
- ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank.");
- ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank.");
+ std::vector& y_shape,
+ std::vector& present_key_shape,
+ std::vector& present_value_shape,
+ std::vector& output_qk_shape);
- ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null");
- ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)");
- ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)");
- ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)");
- ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)");
- ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)");
-
- parameters.is_causal = is_causal;
- parameters.softcap = softcap;
- parameters.softmax_precision = softmax_precision;
- parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul
- parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D)
-
- ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0");
- ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor");
-
- if (q_dims == 4) {
- // 4D
- parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D)
- parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D)
-
- ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]");
- ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]");
- ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]");
- ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size");
- ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length");
- ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)");
-
- // From shapes
- parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3)
- parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D)
- parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D)
- parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D)
- parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D)
- parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask
- ? 0
- : onnxruntime::narrow(past_key->Shape()[2]);
-
- y_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.q_num_heads),
- static_cast(parameters.q_sequence_length),
- static_cast(parameters.v_head_size)};
- } else {
- // 3D
- parameters.kv_num_heads = kv_num_heads;
- parameters.q_num_heads = q_num_heads;
-
- // From shapes
- ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads");
- ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads");
-
- parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3)
- parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]);
- parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads;
- parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]);
- parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads;
- parameters.past_sequence_length = past_key == nullptr
- ? 0
- : onnxruntime::narrow(past_key->Shape()[2]);
-
- y_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.q_sequence_length),
- static_cast(parameters.q_num_heads * parameters.v_head_size)};
- }
- parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length;
-
- ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention.");
- ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length,
- "inconsistent total_sequence_length (between attn_mask and past_key and past_value)");
- ORT_ENFORCE(attn_mask == nullptr ||
- attn_mask->Shape().NumDimensions() < 3 ||
- attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 ||
- attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.q_num_heads,
- "attn_mask must be broadcastable to (batch_size, q_num_heads, q_sequence_length, total_sequence_length) but is not compatible with q_num_heads");
- ORT_ENFORCE(attn_mask == nullptr ||
- attn_mask->Shape().NumDimensions() < 4 ||
- attn_mask->Shape()[0] == 1 ||
- attn_mask->Shape()[0] == parameters.batch_size,
- "attn_mask must be broadcastable to (batch_size, q_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size");
- ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size);
- ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size);
-
- parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale;
- parameters.checkParameters();
-
- present_key_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.kv_num_heads),
- static_cast(parameters.total_sequence_length),
- static_cast(parameters.head_size)};
- present_value_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.kv_num_heads),
- static_cast(parameters.total_sequence_length),
- static_cast(parameters.v_head_size)};
- output_qk_shape = {static_cast(parameters.batch_size),
- static_cast(parameters.q_num_heads),
- static_cast(parameters.q_sequence_length),
- static_cast(parameters.total_sequence_length)};
- return Status::OK();
-}
} // namespace attention_helper
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/llm/attention_parameters.h b/onnxruntime/core/providers/cpu/llm/attention_parameters.h
deleted file mode 100644
index b8586ca4d63dc..0000000000000
--- a/onnxruntime/core/providers/cpu/llm/attention_parameters.h
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-#include "core/common/common.h"
-#include "core/providers/common.h"
-
-namespace onnxruntime {
-// Declares enum QKMatMulOutputMode and struct AttentionParameters inside namespace onnxruntime::attention_helper.
-namespace attention_helper {
-
-// enum equivalent to the onnx defintion of qk_matmul_output_mode
-enum QKMatMulOutputMode {
- kNone = -1, // No output Q*K
- kQK = 0, // Output Q*K
- kQKMask = 1, // Output Q*K + Mask
- kQKSoftCap = 2, // Output SoftCap(Q*K + Mask)
- kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask))
-};
-
-// Parameters deduced from node attributes and inputs/outputs.
-struct AttentionParameters {
- /*
- * Attention Parameters
- * MHA: q_num_heads == kv_num_heads -> MHA
- * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0
- * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1
- */
- bool is_causal;
- int kv_num_heads; // K.shape[1] or V.shape[1] (4D)
- int q_num_heads; // Q.shape[1] (4D)
- float scale;
- float softcap;
- int softmax_precision;
- QKMatMulOutputMode qk_matmul_output_mode;
-
- // From shapes
- int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D)
- int q_sequence_length; // Q.shape[2] (4D)
- int head_size; // Q.shape[3] or K.shape[3 (4D)
- int kv_sequence_length; // K.shape[2] or V.shape[2] (4D)
- int v_head_size; // V.shape[4] (4D)
- int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D)
- int total_sequence_length; // past_sequence_length + kv_sequence_length
- bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH
- // This covers the case where the inputs are 3D.
-
- // Checks the consistency of the parameters.
- void checkParameters() const {
- ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0");
- ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0");
- ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0");
- ORT_ENFORCE(head_size > 0, "Head size must be greater than 0");
- ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0");
- ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative");
- ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0");
- ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0");
- ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0");
- ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length,
- "Total sequence length must be equal to past sequence length plus KV sequence length");
- }
-};
-
-} // namespace attention_helper
-} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index eb29e4edbf897..eab616388d6ae 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1585,9 +1585,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
// Opset 23.
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization);
@@ -2656,9 +2653,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 23
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc
deleted file mode 100644
index 99f297bba6444..0000000000000
--- a/onnxruntime/core/providers/cuda/llm/attention.cc
+++ /dev/null
@@ -1,265 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/providers/cuda/cuda_common.h"
-#include "core/providers/cpu/llm/attention_helper.h"
-#include "core/providers/cuda/llm/attention.h"
-#include "contrib_ops/cuda/bert/attention_data.h"
-#include "contrib_ops/cuda/bert/attention_impl.h"
-
-using namespace onnxruntime::cuda;
-
-namespace onnxruntime {
-namespace cuda {
-
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
- Attention, \
- kOnnxDomain, \
- 23, \
- T, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("U", BuildKernelDefConstraints()), \
- Attention);
-
-REGISTER_KERNEL_TYPED(float)
-REGISTER_KERNEL_TYPED(MLFloat16)
-REGISTER_KERNEL_TYPED(BFloat16)
-
-template
-Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) {
- is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1;
- // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs.
- // The dimension is not yet known. If not specified, the inputs is assumed to be 4D.
- kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0));
- q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0));
- int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0));
- qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists()
- ? static_cast(mode)
- : attention_helper::QKMatMulOutputMode::kNone;
- ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQK ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap ||
- qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax,
- "qk_matmul_output_mode must be 0, 1, 2, or 3.");
- // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed.
- scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN());
- softcap_ = info.GetAttrOrDefault("softcap", 0.0f);
- softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0));
- ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified");
-}
-
-template
-Status Attention::ComputeInternal(OpKernelContext* context) const {
- const Tensor* Q = context->Input(0);
- const Tensor* K = context->Input(1);
- const Tensor* V = context->Input(2);
- const Tensor* attn_mask = context->Input(3);
- const Tensor* past_key = context->Input(4);
- const Tensor* past_value = context->Input(5);
-
- attention_helper::AttentionParameters parameters;
- TensorShape y_shape;
- TensorShape present_key_shape;
- TensorShape present_value_shape;
- TensorShape output_qk_shape;
-
- ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention(
- Q,
- K,
- V,
- attn_mask,
- past_key,
- past_value,
- is_causal_,
- softcap_,
- softmax_precision_,
- qk_matmul_output_mode_,
- kv_num_heads_,
- q_num_heads_,
- scale_,
- parameters,
- y_shape,
- present_key_shape,
- present_value_shape,
- output_qk_shape)
- .IsOK(),
- "Output shapes for Attention could not be computed.");
-
- Tensor* Y = context->Output(0, y_shape);
- Tensor* present_key = context->Output(1, present_key_shape);
- Tensor* present_value = context->Output(2, present_value_shape);
- Tensor* output_qk = context->Output(3, output_qk_shape);
-
- // To reuse the existing attention-cuda implementation in contrib ops,
- // map the parameters to contribop_parameters.
- onnxruntime::contrib::AttentionParameters contribop_parameters;
- contribop_parameters.batch_size = parameters.batch_size;
- contribop_parameters.sequence_length = parameters.q_sequence_length;
- contribop_parameters.kv_sequence_length = parameters.kv_sequence_length;
- contribop_parameters.past_sequence_length = parameters.past_sequence_length;
- contribop_parameters.total_sequence_length = parameters.total_sequence_length;
- // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size)
- contribop_parameters.max_sequence_length = parameters.total_sequence_length;
- contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V
- contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
- contribop_parameters.head_size = parameters.head_size;
- contribop_parameters.v_head_size = parameters.v_head_size;
- contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
- contribop_parameters.num_heads = parameters.q_num_heads;
- contribop_parameters.rotary_dim = 0;
- contribop_parameters.num_splits = 1;
- contribop_parameters.beam_width = 1;
- contribop_parameters.is_unidirectional = parameters.is_causal;
- contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer
- contribop_parameters.is_packed_qkv = false;
- contribop_parameters.do_rotary = false;
-
- // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask
- // So mask_type should always be MASK_NONE since we don't have a separate padding mask input
- contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
-
- // Determine broadcast flags for attention_bias (if it exists)
- // Note: The new Attention op uses attn_mask as attention_bias
- // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length)
- // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions).
- if (attn_mask != nullptr) {
- // TODO(titaiwang, xadupre): attn_mask bool is not supported yet
- if (attn_mask->IsDataType()) {
- ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA).");
- }
-
- size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
- auto attn_mask_dims = attn_mask->Shape().GetDims();
- // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting
- // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting
- // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1
-
- if (attn_mask_dims_size == 2) {
- // 2D mask: both dimensions need broadcasting
- contribop_parameters.broadcast_attn_bias_dim_0 = true;
- contribop_parameters.broadcast_attn_bias_dim_1 = true;
- } else if (attn_mask_dims_size == 3) {
- // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts
- contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
- contribop_parameters.broadcast_attn_bias_dim_1 = true;
- } else {
- // 4D mask: check both dim 0 and dim 1 explicitly
- contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
- contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1;
- }
- } else {
- contribop_parameters.broadcast_attn_bias_dim_0 = false;
- contribop_parameters.broadcast_attn_bias_dim_1 = false;
- }
-
- contribop_parameters.mask_filter_value = -10000.0f;
- contribop_parameters.scale = parameters.scale;
- contribop_parameters.use_tf32 = UseTF32();
-
- // QKV format: Determine based on input dimensions
- // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
- // transpose_output is true for 3D inputs, false for 4D inputs
- if (!parameters.transpose_output) {
- contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
- contribop_parameters.is_output_bnsh = true;
- } else {
- // 3D inputs in BSNH format (will be transposed)
- contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
- contribop_parameters.is_output_bnsh = false;
- }
-
- // TODO(titaiwang, xadupre): Group query attention is not supported yet
- if (parameters.kv_num_heads != parameters.q_num_heads) {
- ORT_THROW("Group query attention is not supported yet in Attention op (CUDA).");
- }
-
- // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now
- if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone &&
- qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) {
- ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA).");
- }
- // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet
- if (parameters.softcap != 0.0f) {
- ORT_THROW("softcap is not supported yet in Attention op (CUDA).");
- }
- if (parameters.softmax_precision != 0) {
- ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
- }
-
- // TODO(titaiwang): Continue on these parameters
- // Construct AttentionData to pass to QkvToContext
- typedef typename ToCudaType::MappedType CudaT;
- onnxruntime::contrib::cuda::AttentionData data;
-
- // Set input pointers
- data.query = reinterpret_cast(Q->Data());
- data.key = reinterpret_cast(K->Data());
- data.value = reinterpret_cast(V->Data());
- data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask
- data.mask_index_dims = gsl::span();
- data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data());
- data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data());
-
- // Set output pointers
- data.output = reinterpret_cast(Y->MutableData());
- data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData());
- data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData());
- if (nullptr != output_qk) {
- data.output_qk = reinterpret_cast(output_qk->MutableData());
- }
-
- // Set additional fields
- data.bias = nullptr; // New Attention op doesn't have bias
- if (nullptr != attn_mask) {
- data.attention_bias = reinterpret_cast(attn_mask->Data());
- }
- data.qkv_format = contribop_parameters.qkv_format;
-
- // TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.)
- // For now, set flags to false and let QkvToContext use the unfused path
- data.use_flash_attention = false;
- data.use_memory_efficient_attention = false;
- data.fused_runner = nullptr;
- data.fused_cross_attention_kernel = nullptr;
-
- // Allocate workspace for Q, K, V processing and scratch buffer
- const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data);
- size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize(
- sizeof(T),
- contribop_parameters.batch_size,
- contribop_parameters.num_heads,
- contribop_parameters.head_size,
- contribop_parameters.v_head_size,
- contribop_parameters.sequence_length,
- contribop_parameters.kv_sequence_length,
- contribop_parameters.total_sequence_length,
- nullptr, // fused_runner
- false, // use_flash_attention
- false, // use_lean_attention
- false, // use_fused_cross_attention
- false, // use_memory_efficient_attention
- false, // use_cudnn_flash_attention
- no_qkv_workspace);
- auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream());
-
- data.has_qkv_workspace = !no_qkv_workspace;
- data.workspace = reinterpret_cast(work_space.get());
- data.workspace_bytes = workspace_bytes;
-
- // Call QkvToContext to perform the attention computation
- auto& device_prop = GetDeviceProp();
- cublasHandle_t cublas = GetCublasHandle(context);
- cudnnHandle_t cudnn = GetCudnnHandle(context);
-
- // QkvToContext takes two template parameters: T for computation type, QK for output_qk type
- // For now, both are the same type (CudaT)
- return onnxruntime::contrib::cuda::QkvToContext(
- device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data);
-}
-} // namespace cuda
-} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h
deleted file mode 100644
index 17e99bb935e1a..0000000000000
--- a/onnxruntime/core/providers/cuda/llm/attention.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-#include "core/common/common.h"
-#include "core/providers/cuda/cuda_kernel.h"
-
-namespace onnxruntime {
-namespace cuda {
-
-using namespace onnxruntime::cuda;
-
-template
-class Attention final : public CudaKernel {
- public:
- Attention(const OpKernelInfo& info);
- Status ComputeInternal(OpKernelContext* context) const override;
-
- protected:
- bool is_causal_;
- int kv_num_heads_;
- int q_num_heads_;
- attention_helper::QKMatMulOutputMode qk_matmul_output_mode_;
- float scale_;
- float softcap_;
- int softmax_precision_;
-};
-
-} // namespace cuda
-} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
index 5c5c2efdb50cd..894130d6ee991 100644
--- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
+++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc
@@ -363,7 +363,7 @@ TEST(AttentionTest, Attention3DDefault) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector(), std::vector(), std::vector(),
- false, false, true // disable_cpu, disable_cuda, disable_dml
+ false, true, true // disable_cpu, disable_cuda, disable_dml
);
}
@@ -390,7 +390,7 @@ TEST(AttentionTest, Attention3DDefaultFloat16) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector(), std::vector(), std::vector(),
- false, false, true // disable_cpu, disable_cuda, disable_dml
+ false, true, true // disable_cpu, disable_cuda, disable_dml
);
}
@@ -417,7 +417,7 @@ TEST(AttentionTest, Attention4DDefaultBasic) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector(), std::vector(), std::vector(),
- false, false, true // disable_cpu, disable_cuda, disable_dml
+ false, true, true // disable_cpu, disable_cuda, disable_dml
);
}
@@ -444,7 +444,7 @@ TEST(AttentionTest, Attention4DDefault) {
q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(),
-1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector