diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ebed2b1972ba9..41ec3dc348dfc 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -653,6 +653,7 @@ Do not modify directly.* |ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||12|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**

or

*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f237b24b899a0..889f8124c02d8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -34,6 +34,7 @@ struct AttentionParameters { float mask_filter_value; float scale; bool use_tf32; + bool is_output_bnsh = false; // whether the output format is BNSH AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index ee4be2c20362d..22862f1ba7b4c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -772,8 +772,9 @@ Status UnfusedAttention( DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = data.q; + // compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed + // For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose. + T* temp_output = parameters.is_output_bnsh ? data.output : data.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, @@ -781,11 +782,13 @@ Status UnfusedAttention( scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); - // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v - Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, - device_prop.maxThreadsPerBlock, false, temp_output, data.output); + if (!parameters.is_output_bnsh) { + // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, + device_prop.maxThreadsPerBlock, false, temp_output, data.output)); + } DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size); - return result; + return Status::OK(); } template @@ -960,7 +963,7 @@ Status QkvToContext( auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; const int total_sequence_length = parameters.total_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; @@ -981,12 +984,12 @@ Status QkvToContext( if (!parameters.past_present_share_buffer) { ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, + kv_sequence_length, total_sequence_length, stream, max_threads_per_block, data)); } else { // past_present_share_buffer ORT_RETURN_IF_ERROR(PastPresentBufferShare(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, fused_runner, + kv_sequence_length, fused_runner, parameters, data, stream, max_threads_per_block)); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index b2e8130ecd17e..7b37a3a4227b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -251,7 +251,6 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); assert(data.query != nullptr); assert(data.key != nullptr); assert(data.value != nullptr); @@ -262,82 +261,96 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, assert(!parameters.is_unidirectional); assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - assert(data.attention_bias == nullptr); - assert(data.mask_index == nullptr); - assert(parameters.hidden_size == parameters.v_hidden_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); - data.v = nullptr; - data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } else if (data.use_memory_efficient_attention || - data.use_flash_attention || - data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { - if (data.bias != nullptr) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, data.q, data.k, data.v); - } else { - data.q = const_cast(data.query); - data.k = const_cast(data.key); - data.v = const_cast(data.value); - } + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + // 3D inputs in BSNH format (will be transposed) + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.attention_bias == nullptr); + assert(data.mask_index == nullptr); + assert(parameters.hidden_size == parameters.v_hidden_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); + data.v = nullptr; + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } else if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, data.q, data.k, data.v); + } else { + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } - data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else if (data.fused_runner != nullptr) { - assert(qk_head_size == v_head_size); - assert(data.attention_bias == nullptr); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (data.fused_runner != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.attention_bias == nullptr); - // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length); - data.k = nullptr; - data.v = nullptr; + // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length); + data.k = nullptr; + data.v = nullptr; - data.qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k, + true, -1); + + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + } else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + // Currently, 4D inputs are only supported in unfused kernel for Attention-23. assert(data.IsUnfused()); - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, data.q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v, - true, -1); - - data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + // Attention-23 does not support bias with Q_K_V_BNSH format. + assert(data.bias == nullptr); + // No need to transpose since QKV is already in BNSH format. + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } else { + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } return Status::OK(); @@ -360,7 +373,6 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); assert(data.query != nullptr); assert(data.key != nullptr); assert(data.value != nullptr); @@ -373,42 +385,53 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, data.past_key != nullptr && data.past_value != nullptr); assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data)); - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - // When there is no past state and there is present state, we output K and V directly to present state. if (data.past_key == nullptr && data.present_key != nullptr) { data.k = data.present_key; data.v = data.present_value; } + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + // 3D inputs in BSNH format (will be transposed) + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.use_lean_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { + // Use oiginal Query (BSNH) since there is no bias. + data.q = const_cast(data.query); - if (data.use_memory_efficient_attention || - data.use_flash_attention || - data.use_lean_attention || - data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { - // Use oiginal Query (BSNH) since there is no bias. - data.q = const_cast(data.query); - - // Key (BxLxNxH) => K (BxNxLxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.k)); - // Value (BxLxNxH) => V (BxNxLxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.v)); - data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else { // unfused kernel + // Key (BxLxNxH) => K (BxNxLxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + // Value (BxLxNxH) => V (BxNxLxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else { // unfused kernel + assert(data.IsUnfused()); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, data.q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + } else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + // Currently, 4D inputs are only supported in unfused kernel for Attention-23. assert(data.IsUnfused()); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, data.q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.v)); - data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + // No need to transpose since QKV is already in BNSH format. + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } else { + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } return Status::OK(); @@ -670,14 +693,27 @@ Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters, case AttentionQkvFormat::Q_K_V_BSNH: if (data.past_key != nullptr || data.present_key != nullptr) { if (data.bias == nullptr) { - DUMP_STRING("PrepareQkv_MHA_WithPast_NoBias"); + DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_NoBias"); ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block)); } else { - DUMP_STRING("PrepareQkv_MHA_WithPast_Bias"); + DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_Bias"); ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block)); } } else { // no past state - DUMP_STRING("PrepareQkv_MHA_NoPast"); + DUMP_STRING("PrepareQkv(3D)_MHA_NoPast"); + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block)); + } + break; + case AttentionQkvFormat::Q_K_V_BNSH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + DUMP_STRING("PrepareQkv(4D)_MHA_WithPast_NoBias"); + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block)); + } else { + ORT_THROW("Q_K_V_BNSH format with bias is not supported."); + } + } else { // no past state + DUMP_STRING("PrepareQkv(4D)_MHA_NoPast"); ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block)); } break; @@ -708,6 +744,8 @@ bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& } else { // no past state return NoQkvWorkspace_MHA_NoPast(data); } + case AttentionQkvFormat::Q_K_V_BNSH: + return false; // currently no scenario needs no workspace default: ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index 938033644a7d6..be3ae44f7a206 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -165,7 +165,7 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length, // Update end position for causal. int end = valid_end; if (causal) { - const int end_causal = total_sequence_length - sequence_length + s + 1; + const int end_causal = (total_sequence_length - sequence_length) + s + 1; if (end_causal < end) { end = end_causal; } @@ -241,7 +241,7 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length, // Update end position for causal. int end = valid_end; if (causal) { - int end_causal = total_sequence_length - sequence_length + s + 1; + const int end_causal = (total_sequence_length - sequence_length) + s + 1; if (end_causal < end) { end = end_causal; } @@ -333,7 +333,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, : float(input[index]); float thread_data = input_data * rsqrt_head_size; if (causal) { - int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. + int from_index = (total_sequence_length - sequence_length) + s; // offset in total sequence length. if (i > from_index) { thread_data = -CUDART_INF_F; } @@ -439,7 +439,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, thread_data = float(input[index]) * rsqrt_head_size; if (causal) { - int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. + int from_index = (total_sequence_length - sequence_length) + s; // offset in total sequence length. if (threadIdx.x > from_index) { thread_data = -CUDART_INF_F; } diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index afd5d55664160..164bea191bbef 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/llm/attention.h" +#include "core/providers/cpu/llm/attention_helper.h" #include "core/common/common.h" #include "core/common/safeint.h" @@ -115,13 +116,13 @@ Attention::Attention(const OpKernelInfo& info) : AttentionBase(info) { q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() - ? static_cast(mode) - : QKMatMulOutputMode::kNone; - ORT_ENFORCE(qk_matmul_output_mode_ == QKMatMulOutputMode::kNone || - qk_matmul_output_mode_ == QKMatMulOutputMode::kQK || - qk_matmul_output_mode_ == QKMatMulOutputMode::kQKMask || - qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftCap || - qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftMax, + ? static_cast(mode) + : attention_helper::QKMatMulOutputMode::kNone; + ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQK || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax, "qk_matmul_output_mode must be 0, 1, 2, or 3."); // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed. scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN()); @@ -140,11 +141,12 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* past_value = context->Input(5); AttentionParameters parameters; - std::vector y_shape; - std::vector present_key_shape; - std::vector present_value_shape; - std::vector output_qk_shape; + TensorShape y_shape; + TensorShape present_key_shape; + TensorShape present_value_shape; + TensorShape output_qk_shape; + // ComputeOutputShapeForAttention also checks the validity of the inputs. ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( Q, K, @@ -243,51 +245,49 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, bool delete_mask_data = false; bool causal = parameters.is_causal && parameters.q_sequence_length > 1; if (mask_index == nullptr) { - // No mask = null mask. + // No external mask: allocate only if causal behavior needed. if (causal) { - size_t mask_data_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T); - void* allocated_ptr = allocator->Alloc(mask_data_bytes); - memset(allocated_ptr, 0, mask_data_bytes); - mask_data = static_cast(allocated_ptr); - for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { - for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i] = mask_filter_value(); + size_t mask_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T); + void* raw = allocator->Alloc(mask_bytes); + memset(raw, 0, mask_bytes); // start all allowed + mask_data = static_cast(raw); + for (int s = 0; s < parameters.q_sequence_length; ++s) { + for (int t = parameters.past_sequence_length + s + 1; t < parameters.total_sequence_length; ++t) { + mask_data[s * parameters.total_sequence_length + t] = mask_filter_value(); } } delete_mask_data = true; } - } else if (mask_index->IsDataType() || causal) { - // We need a copy. - size_t mask_data_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T); - mask_data = static_cast(allocator->Alloc(mask_data_bytes)); - delete_mask_data = true; - - if (mask_index->IsDataType()) { - // Convert bool mask to 0/1 - make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); - } else if (mask_index != nullptr) { - // We make a copy because causal is True. - make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); - } - if (causal) { - // This loop could be parallelized. - // According to the specifications, this configuration is not supported - // as is_causal=1 or mask is not None (exclusive or). - int n_iter = mask_batch_size * mask_num_heads; - for (int i = 0; i < n_iter; ++i) { - for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { - for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = mask_filter_value(); + } else { + const bool is_bool_mask = mask_index->IsDataType(); + const bool need_copy = is_bool_mask || causal; // copy if we must convert or overlay causal pattern + if (need_copy) { + size_t mask_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T); + mask_data = static_cast(allocator->Alloc(mask_bytes)); + delete_mask_data = true; + if (is_bool_mask) { + make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); + } else { + make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); + } + if (causal) { + // Overlay causal -inf above diagonal for every broadcast slice + int slices = mask_batch_size * mask_num_heads; + for (int slice = 0; slice < slices; ++slice) { + T* base = mask_data + probs_matrix_size * slice; + for (int s = 0; s < parameters.q_sequence_length; ++s) { + for (int t = parameters.past_sequence_length + s + 1; t < parameters.total_sequence_length; ++t) { + base[s * parameters.total_sequence_length + t] = mask_filter_value(); + } } } } + } else { + // Reuse mask memory directly (numeric, non-causal) + mask_data = const_cast(mask_index->Data()); } - } else { - // Nothing to do, no necessary copy. - mask_data = const_cast(mask_index->Data()); } - bool transposed_k = parameters.transpose_output && nullptr == present_key; if (nullptr != present_key && parameters.kv_num_heads != parameters.q_num_heads) { // This is not part of the main loop because it is not needed at every iteration and // we cannot ensure the inner body is executed first before getting used in another iteration. @@ -309,6 +309,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, // If past_key is not null, then we need to concatenate it with K, the concatenation is not transposed. const int loop_len = parameters.batch_size * parameters.q_num_heads; const float alpha = parameters.scale; + bool transposed_k = parameters.transpose_output && nullptr == present_key; ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h index c8eff7c5006d6..867207e385dc9 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.h +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/llm/attention_helper.h" +#include "core/providers/cpu/llm/attention_parameters.h" namespace onnxruntime { @@ -100,4 +100,4 @@ class Attention final : public AttentionBase { int softmax_precision_; }; -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.cc b/onnxruntime/core/providers/cpu/llm/attention_helper.cc deleted file mode 100644 index 9bd954f128454..0000000000000 --- a/onnxruntime/core/providers/cpu/llm/attention_helper.cc +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cpu/llm/attention_helper.h" -#include "core/util/shape_checker.h" - -namespace onnxruntime { -namespace attention_helper { - -void AttentionParameters::checkParameters() const { - ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); - ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0"); - ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0"); - ORT_ENFORCE(head_size > 0, "Head size must be greater than 0"); - ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0"); - ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative"); - ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0"); - ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0"); - ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0"); - ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length, - "Total sequence length must be equal to past sequence length plus KV sequence length"); -} - -Status ComputeOutputShapeForAttention( - const Tensor* Q, - const Tensor* K, - const Tensor* V, - const Tensor* attn_mask, - const Tensor* past_key, - const Tensor* past_value, - bool is_causal, - float softcap, - int softmax_precision, - attention_helper::QKMatMulOutputMode qk_matmul_output_mode, - int kv_num_heads, - int q_num_heads, - float scale, - AttentionParameters& parameters, - std::vector& y_shape, - std::vector& present_key_shape, - std::vector& present_value_shape, - std::vector& output_qk_shape) { - ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, - "Q, K, and V inputs must not be null"); - int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); - int k_dims = onnxruntime::narrow(K->Shape().NumDimensions()); - int v_dims = onnxruntime::narrow(V->Shape().NumDimensions()); - ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor"); - ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank."); - ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank."); - - ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null"); - ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)"); - ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)"); - ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)"); - ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)"); - ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)"); - - parameters.is_causal = is_causal; - parameters.softcap = softcap; - parameters.softmax_precision = softmax_precision; - parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul - parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D) - - ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0"); - ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor"); - - if (q_dims == 4) { - // 4D - parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D) - parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D) - - ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]"); - ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]"); - ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]"); - ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size"); - ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length"); - ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)"); - - // From shapes - parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3) - parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D) - parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D) - parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D) - parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D) - parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask - ? 0 - : onnxruntime::narrow(past_key->Shape()[2]); - - y_shape = {static_cast(parameters.batch_size), - static_cast(parameters.q_num_heads), - static_cast(parameters.q_sequence_length), - static_cast(parameters.v_head_size)}; - } else { - // 3D - parameters.kv_num_heads = kv_num_heads; - parameters.q_num_heads = q_num_heads; - - // From shapes - ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads"); - ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads"); - - parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) - parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); - parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; - parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); - parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; - parameters.past_sequence_length = past_key == nullptr - ? 0 - : onnxruntime::narrow(past_key->Shape()[2]); - - y_shape = {static_cast(parameters.batch_size), - static_cast(parameters.q_sequence_length), - static_cast(parameters.q_num_heads * parameters.v_head_size)}; - } - parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; - - ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified"); - ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, - "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); - ORT_ENFORCE(attn_mask == nullptr || - attn_mask->Shape().NumDimensions() < 3 || - attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 || - attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads, - "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads"); - ORT_ENFORCE(attn_mask == nullptr || - attn_mask->Shape().NumDimensions() < 4 || - attn_mask->Shape()[0] == 1 || - attn_mask->Shape()[0] == parameters.batch_size, - "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size"); - ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size); - ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size); - - parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale; - parameters.checkParameters(); - - present_key_shape = {static_cast(parameters.batch_size), - static_cast(parameters.kv_num_heads), - static_cast(parameters.total_sequence_length), - static_cast(parameters.head_size)}; - present_value_shape = {static_cast(parameters.batch_size), - static_cast(parameters.kv_num_heads), - static_cast(parameters.total_sequence_length), - static_cast(parameters.v_head_size)}; - if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) { - output_qk_shape.clear(); - } else { - output_qk_shape = {static_cast(parameters.batch_size), - static_cast(parameters.q_num_heads), - static_cast(parameters.q_sequence_length), - static_cast(parameters.total_sequence_length)}; - } - return Status::OK(); -} -} // namespace attention_helper -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h index 1cea27760408f..4b1e22df4b2f6 100644 --- a/onnxruntime/core/providers/cpu/llm/attention_helper.h +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h @@ -2,54 +2,13 @@ // Licensed under the MIT License. #pragma once -#include "core/common/common.h" -#include "core/providers/common.h" +#include "core/providers/cpu/llm/attention_parameters.h" +#include "core/util/shape_checker.h" namespace onnxruntime { namespace attention_helper { -// enum equivalent to the onnx defintion of qk_matmul_output_mode -enum QKMatMulOutputMode { - kNone = -1, // No output Q*K - kQK = 0, // Output Q*K - kQKMask = 1, // Output Q*K + Mask - kQKSoftCap = 2, // Output SoftCap(Q*K + Mask) - kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask)) -}; - -// Parameters deduced from node attributes and inputs/outputs. -struct AttentionParameters { - /* - * Attention Parameters - * MHA: q_num_heads == kv_num_heads -> MHA - * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0 - * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1 - */ - bool is_causal; - int kv_num_heads; // K.shape[1] or V.shape[1] (4D) - int q_num_heads; // Q.shape[1] (4D) - float scale; - float softcap; - int softmax_precision; - QKMatMulOutputMode qk_matmul_output_mode; - - // From shapes - int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D) - int q_sequence_length; // Q.shape[2] (4D) - int head_size; // Q.shape[3] or K.shape[3 (4D) - int kv_sequence_length; // K.shape[2] or V.shape[2] (4D) - int v_head_size; // V.shape[4] (4D) - int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D) - int total_sequence_length; // past_sequence_length + kv_sequence_length - bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH - // This covers the case where the inputs are 3D. - - // Checks the consistency of the parameters. - void checkParameters() const; -}; - -// Computes the output shape for attention based on the input tensors and parameters. -Status ComputeOutputShapeForAttention( +inline Status ComputeOutputShapeForAttention( const Tensor* Q, const Tensor* K, const Tensor* V, @@ -59,15 +18,122 @@ Status ComputeOutputShapeForAttention( bool is_causal, float softcap, int softmax_precision, - attention_helper::QKMatMulOutputMode qk_matmul_output_mode, + QKMatMulOutputMode qk_matmul_output_mode, int kv_num_heads, int q_num_heads, float scale, AttentionParameters& parameters, - std::vector& y_shape, - std::vector& present_key_shape, - std::vector& present_value_shape, - std::vector& output_qk_shape); + TensorShape& y_shape, + TensorShape& present_key_shape, + TensorShape& present_value_shape, + TensorShape& output_qk_shape) { + ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, + "Q, K, and V inputs must not be null"); + int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); + int k_dims = onnxruntime::narrow(K->Shape().NumDimensions()); + int v_dims = onnxruntime::narrow(V->Shape().NumDimensions()); + ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor"); + ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank."); + ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank."); + ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null"); + ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)"); + ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)"); + ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)"); + ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)"); + ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)"); + + parameters.is_causal = is_causal; + parameters.softcap = softcap; + parameters.softmax_precision = softmax_precision; + parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul + parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D) + + ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor"); + + if (q_dims == 4) { + // 4D + parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D) + parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D) + + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]"); + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]"); + ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]"); + ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size"); + ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)"); + + // From shapes + parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D) + parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D) + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D) + parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D) + parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.v_head_size)}; + } else { + // 3D + parameters.kv_num_heads = kv_num_heads; + parameters.q_num_heads = q_num_heads; + + // From shapes + ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads"); + ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads"); + + parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); + parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); + parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; + parameters.past_sequence_length = past_key == nullptr + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_sequence_length), + static_cast(parameters.q_num_heads * parameters.v_head_size)}; + } + parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; + + ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention."); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, + "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 3 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.q_num_heads, + "attn_mask must be broadcastable to (batch_size, q_num_heads, q_sequence_length, total_sequence_length) but is not compatible with q_num_heads"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 4 || + attn_mask->Shape()[0] == 1 || + attn_mask->Shape()[0] == parameters.batch_size, + "attn_mask must be broadcastable to (batch_size, q_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size"); + ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size); + ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size); + + parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale; + parameters.checkParameters(); + + present_key_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.head_size)}; + present_value_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.v_head_size)}; + output_qk_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.total_sequence_length)}; + return Status::OK(); +} } // namespace attention_helper } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_parameters.h b/onnxruntime/core/providers/cpu/llm/attention_parameters.h new file mode 100644 index 0000000000000..b8586ca4d63dc --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention_parameters.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +// Declares enum QKMatMulOutputMode and struct AttentionParameters inside namespace onnxruntime::attention_helper. +namespace attention_helper { + +// enum equivalent to the onnx defintion of qk_matmul_output_mode +enum QKMatMulOutputMode { + kNone = -1, // No output Q*K + kQK = 0, // Output Q*K + kQKMask = 1, // Output Q*K + Mask + kQKSoftCap = 2, // Output SoftCap(Q*K + Mask) + kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask)) +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct AttentionParameters { + /* + * Attention Parameters + * MHA: q_num_heads == kv_num_heads -> MHA + * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0 + * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1 + */ + bool is_causal; + int kv_num_heads; // K.shape[1] or V.shape[1] (4D) + int q_num_heads; // Q.shape[1] (4D) + float scale; + float softcap; + int softmax_precision; + QKMatMulOutputMode qk_matmul_output_mode; + + // From shapes + int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D) + int q_sequence_length; // Q.shape[2] (4D) + int head_size; // Q.shape[3] or K.shape[3 (4D) + int kv_sequence_length; // K.shape[2] or V.shape[2] (4D) + int v_head_size; // V.shape[4] (4D) + int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D) + int total_sequence_length; // past_sequence_length + kv_sequence_length + bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH + // This covers the case where the inputs are 3D. + + // Checks the consistency of the parameters. + void checkParameters() const { + ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0"); + ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0"); + ORT_ENFORCE(head_size > 0, "Head size must be greater than 0"); + ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0"); + ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative"); + ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0"); + ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0"); + ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0"); + ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length, + "Total sequence length must be equal to past sequence length plus KV sequence length"); + } +}; + +} // namespace attention_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index eab616388d6ae..eb29e4edbf897 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1585,6 +1585,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); // Opset 23. +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization); @@ -2653,6 +2656,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc new file mode 100644 index 0000000000000..99f297bba6444 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -0,0 +1,265 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cpu/llm/attention_helper.h" +#include "core/providers/cuda/llm/attention.h" +#include "contrib_ops/cuda/bert/attention_data.h" +#include "contrib_ops/cuda/bert/attention_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kOnnxDomain, \ + 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", BuildKernelDefConstraints()), \ + Attention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) { + is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1; + // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs. + // The dimension is not yet known. If not specified, the inputs is assumed to be 4D. + kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); + q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); + int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); + qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() + ? static_cast(mode) + : attention_helper::QKMatMulOutputMode::kNone; + ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQK || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap || + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax, + "qk_matmul_output_mode must be 0, 1, 2, or 3."); + // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed. + scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN()); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0)); + ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified"); +} + +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* Q = context->Input(0); + const Tensor* K = context->Input(1); + const Tensor* V = context->Input(2); + const Tensor* attn_mask = context->Input(3); + const Tensor* past_key = context->Input(4); + const Tensor* past_value = context->Input(5); + + attention_helper::AttentionParameters parameters; + TensorShape y_shape; + TensorShape present_key_shape; + TensorShape present_value_shape; + TensorShape output_qk_shape; + + ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( + Q, + K, + V, + attn_mask, + past_key, + past_value, + is_causal_, + softcap_, + softmax_precision_, + qk_matmul_output_mode_, + kv_num_heads_, + q_num_heads_, + scale_, + parameters, + y_shape, + present_key_shape, + present_value_shape, + output_qk_shape) + .IsOK(), + "Output shapes for Attention could not be computed."); + + Tensor* Y = context->Output(0, y_shape); + Tensor* present_key = context->Output(1, present_key_shape); + Tensor* present_value = context->Output(2, present_value_shape); + Tensor* output_qk = context->Output(3, output_qk_shape); + + // To reuse the existing attention-cuda implementation in contrib ops, + // map the parameters to contribop_parameters. + onnxruntime::contrib::AttentionParameters contribop_parameters; + contribop_parameters.batch_size = parameters.batch_size; + contribop_parameters.sequence_length = parameters.q_sequence_length; + contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; + contribop_parameters.past_sequence_length = parameters.past_sequence_length; + contribop_parameters.total_sequence_length = parameters.total_sequence_length; + // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size) + contribop_parameters.max_sequence_length = parameters.total_sequence_length; + contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V + contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; + contribop_parameters.head_size = parameters.head_size; + contribop_parameters.v_head_size = parameters.v_head_size; + contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; + contribop_parameters.num_heads = parameters.q_num_heads; + contribop_parameters.rotary_dim = 0; + contribop_parameters.num_splits = 1; + contribop_parameters.beam_width = 1; + contribop_parameters.is_unidirectional = parameters.is_causal; + contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer + contribop_parameters.is_packed_qkv = false; + contribop_parameters.do_rotary = false; + + // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask + // So mask_type should always be MASK_NONE since we don't have a separate padding mask input + contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; + + // Determine broadcast flags for attention_bias (if it exists) + // Note: The new Attention op uses attn_mask as attention_bias + // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) + // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions). + if (attn_mask != nullptr) { + // TODO(titaiwang, xadupre): attn_mask bool is not supported yet + if (attn_mask->IsDataType()) { + ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA)."); + } + + size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); + auto attn_mask_dims = attn_mask->Shape().GetDims(); + // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting + // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting + // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1 + + if (attn_mask_dims_size == 2) { + // 2D mask: both dimensions need broadcasting + contribop_parameters.broadcast_attn_bias_dim_0 = true; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else if (attn_mask_dims_size == 3) { + // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else { + // 4D mask: check both dim 0 and dim 1 explicitly + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; + } + } else { + contribop_parameters.broadcast_attn_bias_dim_0 = false; + contribop_parameters.broadcast_attn_bias_dim_1 = false; + } + + contribop_parameters.mask_filter_value = -10000.0f; + contribop_parameters.scale = parameters.scale; + contribop_parameters.use_tf32 = UseTF32(); + + // QKV format: Determine based on input dimensions + // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH + // transpose_output is true for 3D inputs, false for 4D inputs + if (!parameters.transpose_output) { + contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; + contribop_parameters.is_output_bnsh = true; + } else { + // 3D inputs in BSNH format (will be transposed) + contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; + contribop_parameters.is_output_bnsh = false; + } + + // TODO(titaiwang, xadupre): Group query attention is not supported yet + if (parameters.kv_num_heads != parameters.q_num_heads) { + ORT_THROW("Group query attention is not supported yet in Attention op (CUDA)."); + } + + // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now + if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && + qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { + ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA)."); + } + // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet + if (parameters.softcap != 0.0f) { + ORT_THROW("softcap is not supported yet in Attention op (CUDA)."); + } + if (parameters.softmax_precision != 0) { + ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); + } + + // TODO(titaiwang): Continue on these parameters + // Construct AttentionData to pass to QkvToContext + typedef typename ToCudaType::MappedType CudaT; + onnxruntime::contrib::cuda::AttentionData data; + + // Set input pointers + data.query = reinterpret_cast(Q->Data()); + data.key = reinterpret_cast(K->Data()); + data.value = reinterpret_cast(V->Data()); + data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask + data.mask_index_dims = gsl::span(); + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + + // Set output pointers + data.output = reinterpret_cast(Y->MutableData()); + data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); + if (nullptr != output_qk) { + data.output_qk = reinterpret_cast(output_qk->MutableData()); + } + + // Set additional fields + data.bias = nullptr; // New Attention op doesn't have bias + if (nullptr != attn_mask) { + data.attention_bias = reinterpret_cast(attn_mask->Data()); + } + data.qkv_format = contribop_parameters.qkv_format; + + // TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.) + // For now, set flags to false and let QkvToContext use the unfused path + data.use_flash_attention = false; + data.use_memory_efficient_attention = false; + data.fused_runner = nullptr; + data.fused_cross_attention_kernel = nullptr; + + // Allocate workspace for Q, K, V processing and scratch buffer + const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); + size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( + sizeof(T), + contribop_parameters.batch_size, + contribop_parameters.num_heads, + contribop_parameters.head_size, + contribop_parameters.v_head_size, + contribop_parameters.sequence_length, + contribop_parameters.kv_sequence_length, + contribop_parameters.total_sequence_length, + nullptr, // fused_runner + false, // use_flash_attention + false, // use_lean_attention + false, // use_fused_cross_attention + false, // use_memory_efficient_attention + false, // use_cudnn_flash_attention + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + // Call QkvToContext to perform the attention computation + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); + + // QkvToContext takes two template parameters: T for computation type, QK for output_qk type + // For now, both are the same type (CudaT) + return onnxruntime::contrib::cuda::QkvToContext( + device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); +} +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h new file mode 100644 index 0000000000000..17e99bb935e1a --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class Attention final : public CudaKernel { + public: + Attention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + bool is_causal_; + int kv_num_heads_; + int q_num_heads_; + attention_helper::QKMatMulOutputMode qk_matmul_output_mode_; + float scale_; + float softcap_; + int softmax_precision_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 894130d6ee991..5c5c2efdb50cd 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -363,7 +363,7 @@ TEST(AttentionTest, Attention3DDefault) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -390,7 +390,7 @@ TEST(AttentionTest, Attention3DDefaultFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -417,7 +417,7 @@ TEST(AttentionTest, Attention4DDefaultBasic) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -444,7 +444,7 @@ TEST(AttentionTest, Attention4DDefault) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -501,7 +501,7 @@ TEST(AttentionTest, Attention4DDefaultFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -587,7 +587,7 @@ TEST(AttentionTest, Attention4DAttnMask) { q, k, v, m, std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -655,7 +655,7 @@ TEST(AttentionTest, Attention4DAttnPastPresentBasic) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -693,19 +693,20 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } +// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausal) { - int batch_size = 2; // Q.shape[0] - int q_num_heads = 3; // Q.shape[1] - int q_sequence_length = 4; // Q.shape[2] - int head_size = 8; // Q.shape[3] - int kv_sequence_length = 6; // K.shape[2] and V.shape[2] - int kv_num_heads = 3; // K.shape[1] and V.shape[1] - int v_head_size = 8; // V.shape[3] - int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; @@ -754,7 +755,7 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasic) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -782,10 +783,11 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } +// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { int batch_size = 2; // Q.shape[0] int q_num_heads = 1; // Q.shape[1] @@ -857,7 +859,7 @@ TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1022,13 +1024,13 @@ TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, q, k, v, m, std::initializer_list(), past_key, past_value, -1, 0, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); qk_matmul = std::vector{1.786287f, 1.851782f, 1.433406f, 1.126638f, 1.074598f, 1.202869f, 1.806932f, 1.039214f, 1.155254f, 1.351381f, 1.709788f, 1.654608f, 0.904174f, 1.045790f, 1.828289f, 1.849986f, 0.982722f, 0.779313f, 1.067731f, 0.932425f, 1.164846f, 0.896809f, 1.215540f, 1.155709f, 1.283348f, 0.972161f, 1.592545f, 1.841960f, 1.391534f, 0.932551f, 0.884336f, 0.881353f, 0.905360f, 1.564150f, 1.275840f, 0.946826f, 1.789871f, 1.878873f, 1.971947f, 1.398552f, 1.823965f, 1.960587f, 1.438784f, 1.481077f, 0.957099f, 1.756017f, 1.234584f, 0.990787f, 1.096593f, 1.033003f, 1.868677f, 1.788607f, 1.659495f, 0.667182f, 1.157819f, 0.870338f, 0.879745f, 1.636864f, 0.894962f, 1.714711f, 1.549994f, 0.733612f, 1.117046f, 0.686474f, 1.499953f, 1.123992f, 1.438267f, 0.931251f, 1.633272f, 0.944889f, 0.987120f, 1.218472f, 1.497553f, 1.638913f, 1.553980f, 0.982279f, 1.142558f, 1.193196f, 1.654746f, 1.014832f, 1.090946f, 1.017206f, 1.702928f, 1.601417f, 0.808653f, 1.406642f, 1.423106f, 1.871002f, 1.358196f, 0.931623f, 0.588504f, 0.783458f, 0.882957f, 0.489307f, 1.322660f, 0.934557f, 1.271919f, 0.800610f, 1.444240f, 1.450752f, 0.946420f, 0.900686f, 0.822093f, 1.113904f, 0.568116f, 1.171030f, 1.175384f, 0.910323f, 1.157407f, 1.345392f, 1.400021f, 0.751548f, 1.625352f, 1.456414f, 0.950937f, 1.145433f, 0.649070f, 1.298100f, 0.639947f, 0.927273f, 0.736265f, 1.065406f, 1.263197f, 1.012355f, 1.297169f, 0.495477f, 0.699773f, 0.500964f, 0.620178f, 1.275150f, 0.760687f, 1.387608f, 1.336798f, 0.539168f, 1.042187f, 0.417132f, 1.257103f, 1.163759f, 1.314552f, 0.982448f, 1.345221f, 0.663667f, 0.850426f, 1.238248f, 1.593812f, 1.438230f, 1.387601f, 0.823150f, 0.726727f, 0.832655f, 1.532544f, 0.946970f, 1.126112f, 1.112509f, 1.565497f, 1.938642f, 0.832394f, 1.284816f, 1.447452f, 1.599816f, 0.609072f, 0.743433f, 1.101475f, 0.490747f, 1.020954f, 0.668047f, 0.921248f, 0.721382f, 1.095978f, 0.794792f, 1.488673f, 1.681718f, 0.852196f, 1.102478f, 0.810369f, 1.130985f, 0.425544f, 1.051735f, 0.694759f, 0.764302f, 1.275671f, 1.157903f, 1.440112f, 0.837447f, 1.422500f, 1.150930f, 1.017296f, 1.116673f, 0.804505f, 1.315179f, 0.553615f, 0.871008f, 0.659033f, 1.116166f, 1.134977f, 0.944172f, 0.857236f, 0.531893f, 1.224364f, 0.670808f, 0.843351f, 1.607988f, 0.720031f, 1.438111f, 1.628858f, 0.904480f, 1.456536f, 0.828884f, 1.145072f, 1.586629f, 1.350379f, 1.396510f, 1.226688f, 0.524469f, 0.711242f, 1.413283f, 1.519931f, 1.444998f, 1.155023f, 0.928222f, 0.827857f, 1.092185f, 1.860113f, 1.373539f, 0.953664f, 1.435734f, 1.350082f, 1.735783f, 0.610580f, 1.155694f, 1.600251f, 1.602529f, 0.859450f, 1.156073f, 0.846617f, 0.916578f, 1.134056f, 1.053106f, 1.173786f, 1.246788f, 1.509772f, 1.256221f, 1.540197f, 2.009806f, 1.067828f, 1.164871f, 0.709226f, 1.221456f, 0.845411f, 1.504512f, 1.201048f, 1.402731f, 1.564370f, 1.576583f, 1.589067f, 1.257597f, 1.674126f, 1.954917f, 1.497631f, 1.948780f, 0.954539f, 2.070836f, 0.927942f, 1.418681f, 0.804113f, 1.388198f, 1.624642f, 1.581236f, 1.511648f, 1.311894f, 0.855986f, 0.902148f, 0.785342f, 1.820220f, 0.852723f, 1.696361f, 1.655653f, 1.089764f, 1.202390f, 1.120222f, 1.284748f, 1.475221f, 1.311156f, 1.243736f, 1.625873f, 0.823371f, 1.226631f, 1.673096f, 1.553962f, 1.025746f, 1.313852f, 1.030482f, 0.989448f, 0.936074f, 1.784927f, 0.708855f, 0.971949f, 1.223065f, 1.461189f, 1.747723f, 0.799575f, 0.823636f, 1.400882f, 1.160547f, 0.520804f, 0.836825f, 0.972166f, 0.543222f, 1.346498f, 1.034594f, 1.565712f, 1.361961f, 1.751214f, 0.736224f, 1.864534f, 1.977835f, 1.411005f, 1.496084f, 1.233789f, 1.105877f, 0.961602f, 1.009357f, 1.110593f, 1.390279f, 1.693497f, 1.302893f, 1.756735f, 1.433344f, 2.067142f, 1.916540f, 1.490259f, 1.488384f, 1.309675f, 1.758509f, 1.141796f, 1.534330f, 1.156855f, 1.274409f, 1.870354f, 1.045789f, 1.400564f, 0.876651f, 0.981051f, 0.559955f, 0.790979f, 1.662600f, 1.021407f, 1.716358f, 1.630805f, 0.674263f, 1.320767f, 0.649261f, 1.538417f, 1.525061f, 1.419455f, 1.148088f, 1.820221f, 0.329244f, 1.033743f, 1.253892f, 1.790469f, 1.711897f, 1.467268f, 1.089224f, 0.834806f, 1.155425f, 2.043234f, 0.849033f, 1.136683f, 1.774663f, 1.735976f, 1.677263f, 0.902375f, 1.213391f, 1.758179f, 1.759598f, 0.879983f, 1.517559f, 0.812989f, 0.499876f, 0.998129f, 0.513259f, 1.094689f, 0.873050f, 1.131224f, 0.546321f, 1.364307f, 1.622263f, 0.652555f, 0.680481f, 0.729973f, 1.123450f, 0.722337f, 1.158875f, 0.845219f, 1.151906f, 1.343835f, 1.411206f, 1.638837f, 1.000100f, 1.652081f, 1.598655f, 0.980791f, 1.122207f, 0.848703f, 1.972988f, 0.610630f, 0.678227f, 0.839634f, 1.289163f, 1.497003f, 1.060701f, 0.971334f, 1.099509f, 1.158767f, 0.871929f, 0.972856f, 1.687900f, 0.854091f, 1.804623f, 1.804263f, 0.738135f, 1.209199f, 1.190654f, 1.425313f, 1.450061f, 1.529269f, 1.249452f, 1.921674f, 0.832500f, 0.940835f, 1.908224f}; @@ -1110,7 +1112,7 @@ TEST(AttentionTest, Attention3DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1161,7 +1163,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1183,9 +1185,9 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 1, 4, 18} std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; - // {2, 3, 12, 4} + // {2, 3, 7, 4} std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; - // {2, 3, 12, 4} + // {2, 3, 7, 4} std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 3565208833266..85d8993ab7166 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -40,6 +40,26 @@ "^test_attention_4d_attn_mask_3d_causal_expanded*", // webgpu "^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen "^test_l2normalization*", // LpNormalization(22) not implemented + // TODO: support the following tests in Attention-cuda + "^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda + "^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda + "^test_attention_3d_diff_heads_sizes_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_4d_diff_heads_sizes_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_3d_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // softcap not supported in Attention-cuda + "test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_4d_diff_heads_sizes_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_4d_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_4d_with_qk_matmul_softcap_cuda", // softcap not supported in Attention-cuda + "^test_attention_4d_attn_mask_bool_cuda", // bool mask not supported in Attention-cuda + "^test_attention_4d_attn_mask_bool_4d_cuda", // bool mask not supported in Attention-cuda + "^test_attention_3d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda + "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_cuda", // QK matmul + bias not supported in Attention-cuda + "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_cuda", // QK matmul + bias not supported in Attention-cuda + "^test_attention_4d_with_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda + "^test_attention_4d_with_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda + "^test_attention_3d_with_past_and_present_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda + "^test_attention_4d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda "^test_l1normalization*", // LpNormalization(22) not implemented "^test_lpnormalization*", // LpNormalization(22) not implemented "^test_tensorscatter*", // TensorScatter(24) not implemented