diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 7ac19bf67fb8a..d6daa8daef5e0 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/llm/attention.h" +#include "core/providers/cpu/llm/attention_helper.h" #include "core/common/common.h" #include "core/common/safeint.h" @@ -140,10 +141,10 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* past_value = context->Input(5); AttentionParameters parameters; - std::vector y_shape; - std::vector present_key_shape; - std::vector present_value_shape; - std::vector output_qk_shape; + TensorShape y_shape; + TensorShape present_key_shape; + TensorShape present_value_shape; + TensorShape output_qk_shape; ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( Q, diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h index 4fad6914f933d..867c16724964a 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.h +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/llm/attention_helper.h" +#include "core/providers/cpu/llm/attention_parameters.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.cc b/onnxruntime/core/providers/cpu/llm/attention_helper.cc deleted file mode 100644 index 9bd954f128454..0000000000000 --- a/onnxruntime/core/providers/cpu/llm/attention_helper.cc +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cpu/llm/attention_helper.h" -#include "core/util/shape_checker.h" - -namespace onnxruntime { -namespace attention_helper { - -void AttentionParameters::checkParameters() const { - ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); - ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0"); - ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0"); - ORT_ENFORCE(head_size > 0, "Head size must be greater than 0"); - ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0"); - ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative"); - ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0"); - ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0"); - ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0"); - ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length, - "Total sequence length must be equal to past sequence length plus KV sequence length"); -} - -Status ComputeOutputShapeForAttention( - const Tensor* Q, - const Tensor* K, - const Tensor* V, - const Tensor* attn_mask, - const Tensor* past_key, - const Tensor* past_value, - bool is_causal, - float softcap, - int softmax_precision, - attention_helper::QKMatMulOutputMode qk_matmul_output_mode, - int kv_num_heads, - int q_num_heads, - float scale, - AttentionParameters& parameters, - std::vector& y_shape, - std::vector& present_key_shape, - std::vector& present_value_shape, - std::vector& output_qk_shape) { - ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, - "Q, K, and V inputs must not be null"); - int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); - int k_dims = onnxruntime::narrow(K->Shape().NumDimensions()); - int v_dims = onnxruntime::narrow(V->Shape().NumDimensions()); - ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor"); - ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank."); - ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank."); - - ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null"); - ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)"); - ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)"); - ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)"); - ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)"); - ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)"); - - parameters.is_causal = is_causal; - parameters.softcap = softcap; - parameters.softmax_precision = softmax_precision; - parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul - parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D) - - ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0"); - ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor"); - - if (q_dims == 4) { - // 4D - parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D) - parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D) - - ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]"); - ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]"); - ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]"); - ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size"); - ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length"); - ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)"); - - // From shapes - parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3) - parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D) - parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D) - parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D) - parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D) - parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask - ? 0 - : onnxruntime::narrow(past_key->Shape()[2]); - - y_shape = {static_cast(parameters.batch_size), - static_cast(parameters.q_num_heads), - static_cast(parameters.q_sequence_length), - static_cast(parameters.v_head_size)}; - } else { - // 3D - parameters.kv_num_heads = kv_num_heads; - parameters.q_num_heads = q_num_heads; - - // From shapes - ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads"); - ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads"); - - parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) - parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); - parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; - parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); - parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; - parameters.past_sequence_length = past_key == nullptr - ? 0 - : onnxruntime::narrow(past_key->Shape()[2]); - - y_shape = {static_cast(parameters.batch_size), - static_cast(parameters.q_sequence_length), - static_cast(parameters.q_num_heads * parameters.v_head_size)}; - } - parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; - - ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified"); - ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, - "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); - ORT_ENFORCE(attn_mask == nullptr || - attn_mask->Shape().NumDimensions() < 3 || - attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 || - attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads, - "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads"); - ORT_ENFORCE(attn_mask == nullptr || - attn_mask->Shape().NumDimensions() < 4 || - attn_mask->Shape()[0] == 1 || - attn_mask->Shape()[0] == parameters.batch_size, - "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size"); - ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size); - ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size); - - parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale; - parameters.checkParameters(); - - present_key_shape = {static_cast(parameters.batch_size), - static_cast(parameters.kv_num_heads), - static_cast(parameters.total_sequence_length), - static_cast(parameters.head_size)}; - present_value_shape = {static_cast(parameters.batch_size), - static_cast(parameters.kv_num_heads), - static_cast(parameters.total_sequence_length), - static_cast(parameters.v_head_size)}; - if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) { - output_qk_shape.clear(); - } else { - output_qk_shape = {static_cast(parameters.batch_size), - static_cast(parameters.q_num_heads), - static_cast(parameters.q_sequence_length), - static_cast(parameters.total_sequence_length)}; - } - return Status::OK(); -} -} // namespace attention_helper -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h index 1cea27760408f..c3ddcf50f3015 100644 --- a/onnxruntime/core/providers/cpu/llm/attention_helper.h +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h @@ -2,54 +2,13 @@ // Licensed under the MIT License. #pragma once -#include "core/common/common.h" -#include "core/providers/common.h" +#include "core/providers/cpu/llm/attention_parameters.h" +#include "core/util/shape_checker.h" namespace onnxruntime { namespace attention_helper { -// enum equivalent to the onnx defintion of qk_matmul_output_mode -enum QKMatMulOutputMode { - kNone = -1, // No output Q*K - kQK = 0, // Output Q*K - kQKMask = 1, // Output Q*K + Mask - kQKSoftCap = 2, // Output SoftCap(Q*K + Mask) - kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask)) -}; - -// Parameters deduced from node attributes and inputs/outputs. -struct AttentionParameters { - /* - * Attention Parameters - * MHA: q_num_heads == kv_num_heads -> MHA - * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0 - * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1 - */ - bool is_causal; - int kv_num_heads; // K.shape[1] or V.shape[1] (4D) - int q_num_heads; // Q.shape[1] (4D) - float scale; - float softcap; - int softmax_precision; - QKMatMulOutputMode qk_matmul_output_mode; - - // From shapes - int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D) - int q_sequence_length; // Q.shape[2] (4D) - int head_size; // Q.shape[3] or K.shape[3 (4D) - int kv_sequence_length; // K.shape[2] or V.shape[2] (4D) - int v_head_size; // V.shape[4] (4D) - int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D) - int total_sequence_length; // past_sequence_length + kv_sequence_length - bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH - // This covers the case where the inputs are 3D. - - // Checks the consistency of the parameters. - void checkParameters() const; -}; - -// Computes the output shape for attention based on the input tensors and parameters. -Status ComputeOutputShapeForAttention( +inline Status ComputeOutputShapeForAttention( const Tensor* Q, const Tensor* K, const Tensor* V, @@ -64,10 +23,121 @@ Status ComputeOutputShapeForAttention( int q_num_heads, float scale, AttentionParameters& parameters, - std::vector& y_shape, - std::vector& present_key_shape, - std::vector& present_value_shape, - std::vector& output_qk_shape); + TensorShape& y_shape, + TensorShape& present_key_shape, + TensorShape& present_value_shape, + TensorShape& output_qk_shape) { + ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, + "Q, K, and V inputs must not be null"); + int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); + int k_dims = onnxruntime::narrow(K->Shape().NumDimensions()); + int v_dims = onnxruntime::narrow(V->Shape().NumDimensions()); + ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor"); + ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank."); + ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank."); + ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null"); + ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)"); + ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)"); + ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)"); + ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)"); + ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)"); + + parameters.is_causal = is_causal; + parameters.softcap = softcap; + parameters.softmax_precision = softmax_precision; + parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul + parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D) + + ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor"); + + if (q_dims == 4) { + // 4D + parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D) + parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D) + + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]"); + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]"); + ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]"); + ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size"); + ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)"); + + // From shapes + parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D) + parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D) + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D) + parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D) + parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.v_head_size)}; + } else { + // 3D + parameters.kv_num_heads = kv_num_heads; + parameters.q_num_heads = q_num_heads; + + // From shapes + ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads"); + ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads"); + + parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); + parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); + parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; + parameters.past_sequence_length = past_key == nullptr + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_sequence_length), + static_cast(parameters.q_num_heads * parameters.v_head_size)}; + } + parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; + + ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, + "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 3 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads, + "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 4 || + attn_mask->Shape()[0] == 1 || + attn_mask->Shape()[0] == parameters.batch_size, + "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size"); + ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size); + ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size); + + parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale; + parameters.checkParameters(); + + present_key_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.head_size)}; + present_value_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.v_head_size)}; + if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) { + output_qk_shape = {}; + } else { + output_qk_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.total_sequence_length)}; + } + return Status::OK(); +} } // namespace attention_helper } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_parameters.h b/onnxruntime/core/providers/cpu/llm/attention_parameters.h new file mode 100644 index 0000000000000..7c227e9c1a52c --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention_parameters.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace attention_helper { + +// enum equivalent to the onnx defintion of qk_matmul_output_mode +enum QKMatMulOutputMode { + kNone = -1, // No output Q*K + kQK = 0, // Output Q*K + kQKMask = 1, // Output Q*K + Mask + kQKSoftCap = 2, // Output SoftCap(Q*K + Mask) + kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask)) +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct AttentionParameters { + /* + * Attention Parameters + * MHA: q_num_heads == kv_num_heads -> MHA + * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0 + * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1 + */ + bool is_causal; + int kv_num_heads; // K.shape[1] or V.shape[1] (4D) + int q_num_heads; // Q.shape[1] (4D) + float scale; + float softcap; + int softmax_precision; + QKMatMulOutputMode qk_matmul_output_mode; + + // From shapes + int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D) + int q_sequence_length; // Q.shape[2] (4D) + int head_size; // Q.shape[3] or K.shape[3 (4D) + int kv_sequence_length; // K.shape[2] or V.shape[2] (4D) + int v_head_size; // V.shape[4] (4D) + int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D) + int total_sequence_length; // past_sequence_length + kv_sequence_length + bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH + // This covers the case where the inputs are 3D. + + // Checks the consistency of the parameters. + void checkParameters() const { + ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0"); + ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0"); + ORT_ENFORCE(head_size > 0, "Head size must be greater than 0"); + ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0"); + ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative"); + ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0"); + ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0"); + ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0"); + ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length, + "Total sequence length must be equal to past sequence length plus KV sequence length"); + } +}; + +} // namespace attention_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3816cc1f8f6b9..def3bbeb71c78 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1533,6 +1533,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); // Opset 23. +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization); @@ -2601,6 +2604,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc new file mode 100644 index 0000000000000..22f57236abb63 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cpu/llm/attention_helper.h" +#include "core/providers/cuda/llm/attention.h" +#include "core/providers/cuda/llm/attention_naive.h" +#include "contrib_ops/cuda/bert/attention_data.h" +#include "contrib_ops/cuda/bert/attention_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::attention_helper; + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kOnnxDomain, \ + 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", BuildKernelDefConstraints()), \ + Attention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) { + is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1; + // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs. + // The dimension is not yet known. If not specified, the inputs is assumed to be 4D. + kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); + q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); + int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); + qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() + ? static_cast(mode) + : QKMatMulOutputMode::kNone; + ORT_ENFORCE(qk_matmul_output_mode_ == QKMatMulOutputMode::kNone || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQK || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKMask || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftCap || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftMax, + "qk_matmul_output_mode must be 0, 1, 2, or 3."); + // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed. + scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN()); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0)); + ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified"); +} + +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* Q = context->Input(0); + const Tensor* K = context->Input(1); + const Tensor* V = context->Input(2); + const Tensor* attn_mask = context->Input(3); + const Tensor* past_key = context->Input(4); + const Tensor* past_value = context->Input(5); + + AttentionParameters parameters; + TensorShape y_shape; + TensorShape present_key_shape; + TensorShape present_value_shape; + TensorShape output_qk_shape; + + ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( + Q, + K, + V, + attn_mask, + past_key, + past_value, + is_causal_, + softcap_, + softmax_precision_, + qk_matmul_output_mode_, + kv_num_heads_, + q_num_heads_, + scale_, + parameters, + y_shape, + present_key_shape, + present_value_shape, + output_qk_shape) + .IsOK(), + "Output shapes for Attention could not be computed."); + + Tensor* Y = context->Output(0, y_shape); + Tensor* present_key = context->Output(1, present_key_shape); + Tensor* present_value = context->Output(2, present_value_shape); + Tensor* output_qk = parameters.qk_matmul_output_mode == QKMatMulOutputMode::kNone + ? nullptr + : context->Output(3, output_qk_shape); + +#if 0 + // First tentative to following the CPU implementation with the idea behind to have a fallback + // when contrib ops do not work (for FP8 for example or for QKV output options). + // This is unfinished. + NaiveAttention attention_naive_impl; + cudaStream_t stream = Stream(context); + return attention_naive_impl.ApplyAttention(context, + stream, + Q->Data(), // Q + K->Data(), // K + V->Data(), // V + attn_mask, // const Tensor* mask_index, // mask, nullptr if no mask + past_key, // past K input tensor (if not using past state) + past_value, // past V input tensor (if not using past state) + Y, // first output + present_key, // present K output tensor (if separating present KV) + present_value, // present V output tensor (if separating present KV) + output_qk, // Q*K output tensor (if returning Q*K value) + parameters); // attention parameters +#else + + onnxruntime::contrib::AttentionParameters cparameters; + cparameters.batch_size = parameters.batch_size; + cparameters.sequence_length = parameters.q_sequence_length; + cparameters.kv_sequence_length = parameters.kv_sequence_length; + cparameters.past_sequence_length = parameters.past_sequence_length; + cparameters.total_sequence_length = parameters.total_sequence_length; + cparameters.max_sequence_length = parameters.total_sequence_length; // TODO ? + cparameters.input_hidden_size = parameters.batch_size; + cparameters.hidden_size = parameters.batch_size; + cparameters.head_size = parameters.head_size; + cparameters.v_head_size = parameters.v_head_size; + cparameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; + cparameters.num_heads = parameters.q_num_heads; + cparameters.head_size = parameters.head_size; + cparameters.num_splits = 1; // TODO? + cparameters.scale = parameters.scale; + cparameters.do_rotary = false; + cparameters.is_packed_qkv = false; + cparameters.is_unidirectional = false; + + // TODO: Mask 2D does not seem to be supported by the contrib ops. + cparameters.mask_type = attn_mask == nullptr + ? onnxruntime::contrib::AttentionMaskType::MASK_NONE + : (attn_mask->Shape().NumDimensions() == 4 + ? onnxruntime::contrib::AttentionMaskType::MASK_4D_MEGATRON + : onnxruntime::contrib::AttentionMaskType::MASK_3D_ATTENTION); + cparameters.qkv_format = Q->Shape().NumDimensions() == 3 + ? onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + : onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; // for non-packed qkv, permuted + + /* + // TODO: investigate these parameters. + int num_splits; // number of splits for splitkv + int rotary_dim = 0; // rotary embedding dimension + int beam_width; + bool is_unidirectional; + bool past_present_share_buffer; + bool is_packed_qkv = false; // whether qkv is packed + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; + float mask_filter_value; + bool use_tf32; + */ + + // Calls the contrib ops implementation. + // The main issue is to retrieve the expected QK outputs. Attention(23) has many unsupported options. + typedef typename ToCudaType::MappedType CudaT; + onnxruntime::contrib::cuda::AttentionData data; + + // T* gemm_buffer = nullptr; + // const T* bias = nullptr; + // int* seqlens_k_total = nullptr; + + data.query = reinterpret_cast(Q->Data()); + data.key = reinterpret_cast(K->Data()); + data.value = reinterpret_cast(V->Data()); + + // const int* mask_index = nullptr; + // gsl::span mask_index_dims; + // const T* past = nullptr; + data.past_key = past_key == nullptr ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = past_value == nullptr ? nullptr : reinterpret_cast(past_value->Data()); + + // const int32_t* cache_indirection = nullptr; + // const T* attention_bias = nullptr; + + // bool has_qkv_workspace = false; + // T* workspace = nullptr; + + data.output = reinterpret_cast(Y->MutableData()); + // T* present = nullptr; + data.present_key = present_key != nullptr ? reinterpret_cast(present_key->MutableData()) : nullptr; + data.present_value = present_value != nullptr ? reinterpret_cast(present_value->MutableData()) : nullptr; + data.output_qk = output_qk != nullptr ? reinterpret_cast(output_qk->MutableData()) : nullptr; + + // void* fused_runner = nullptr; + // const void* fused_cross_attention_kernel = nullptr;= + // bool use_flash_attention = false; + // bool use_memory_efficient_attention = false; + // bool use_decoder_masked_multihead_attention = false; + // const int32_t* cumulated_sequence_length_q_cache = nullptr; + // const int32_t* cumulated_sequence_length_kv_cache = nullptr; + + // Intermediate data + // T* q = nullptr; + // T* k = nullptr; + // T* v = nullptr; + // T* scratch = nullptr; + data.qkv_format = cparameters.qkv_format; + + // Flash buffers + // T* softmax_lse = nullptr; + // T* softmax_lse_accum = nullptr; + // T* out_accum = nullptr; + + // Flash Atttention and Lean Attention + // int num_splits; + + // Lean Attention + // bool use_lean_attention = false; + // size_t workspace_bytes = 0; + // bool allow_debug_info = false; + + // For MultiHeadAttention only. + data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Default; + // AllocatorPtr allocator = nullptr; + + const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(cparameters, data); + size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize(sizeof(T), + cparameters.batch_size, + cparameters.num_heads, + cparameters.head_size, + cparameters.v_head_size, + cparameters.sequence_length, + cparameters.kv_sequence_length, + cparameters.total_sequence_length, + false, // fused_runner, + false, // use_flash_attention + false, // use_lean_attention, + false, // use_fused_cross_attention, + false, // use_memory_efficient_attention, + false, // use_cudnn_sdpa, + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + // QK type = T type + typedef typename ToCudaType::MappedType CudaQK; + cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); + auto& device_prop = GetDeviceProp(); + return onnxruntime::contrib::cuda::QkvToContext( + device_prop, cublas, cudnn, context->GetComputeStream(), cparameters, data); + +#endif +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h new file mode 100644 index 0000000000000..17e99bb935e1a --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class Attention final : public CudaKernel { + public: + Attention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + bool is_causal_; + int kv_num_heads_; + int q_num_heads_; + attention_helper::QKMatMulOutputMode qk_matmul_output_mode_; + float scale_; + float softcap_; + int softmax_precision_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention_naive.cc b/onnxruntime/core/providers/cuda/llm/attention_naive.cc new file mode 100644 index 0000000000000..474d70b0a2592 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention_naive.cc @@ -0,0 +1,420 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for attention. +*/ + +#include "core/providers/cuda/llm/attention_naive.h" +#include "core/providers/cuda/llm/attention_naive_impl.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" + +using namespace onnxruntime::cuda; +using onnxruntime::attention_helper::AttentionParameters; +using onnxruntime::attention_helper::QKMatMulOutputMode; + +namespace onnxruntime { +namespace cuda { + +template +void ComputeAttentionProbs(cudaStream_t stream, + T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const Tensor* mask_index, // mask + const AttentionParameters& parameters, // attention parameters + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + T* output_qk // Q*K output + ) { + // The case past_key != nullptr and present_key == nullptr is not supported. + // We use the fact present_key is requested to avoid any extra allocation. + // However, if present_key is not requested, we should avoid allocated more memory than needed but that mean + // allocating one buffer per thread. That's why the implementation is not done. + // The user should define a model with a present_key even if not used if past_key is not null. + ORT_ENFORCE((past_key == nullptr) == (present_key == nullptr), + "The implementation only supports past_key and present_key both null or both not null."); + const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H + const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H + const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H + const size_t present_chunk_length = past_chunk_length + k_input_chunk_length; // T x H + + const ptrdiff_t probs_matrix_size = SafeInt(parameters.q_sequence_length) * + parameters.total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); + + // Prepare mask + // Merge causal mask with padding mask, and convert values from 0/1 to -inf/0. + int mask_batch_size = static_cast(mask_index == nullptr || mask_index->Shape().NumDimensions() < 4 + ? 1 + : mask_index->Shape().GetDims()[0]); + int mask_num_heads = static_cast(mask_index == nullptr || mask_index->Shape().NumDimensions() < 3 + ? 1 + : (mask_index->Shape().NumDimensions() < 4 + ? mask_index->Shape().GetDims()[0] + : mask_index->Shape().GetDims()[1])); + + T* mask_data = nullptr; + bool delete_mask_data = false; + bool causal = parameters.is_causal && parameters.q_sequence_length > 1; + if (mask_index == nullptr) { + // No mask = null mask. + if (causal) { + ORT_THROW("causal not implemented yet."); + } + } else if (mask_index->IsDataType() || causal) { + ORT_THROW("boolean mask not implemented yet."); + } else { + // Nothing to do, no necessary copy. + mask_data = const_cast(mask_index->Data()); + } + + bool transposed_k = parameters.transpose_output && nullptr == present_key; + if (nullptr != present_key && parameters.kv_num_heads != parameters.q_num_heads) { + ORT_THROW("past cache and kv_num_heads != q_num_heads not implemented yet."); + } + + // If present_key is not null, it is already initialized to zero. + // Main loop + // With 3D inputs, both Q and K are transposed with permutations (0, 2, 1, 3). + // To avoid expressing the transposition, we use GemmEx with different values for lda, ldb. + // If past_key is not null, then we need to concatenate it with K, the concatenation is not transposed. + const int loop_len = parameters.batch_size * parameters.q_num_heads; + const float alpha = parameters.scale; + int dtype = utils::GetONNXTensorElementDataType(); + + for (std::ptrdiff_t i = 0; i != loop_len; ++i) { + const ptrdiff_t output_offset = SafeInt(i) * probs_matrix_size; + std::ptrdiff_t batch_i = i / parameters.q_num_heads; + std::ptrdiff_t head_i = i % parameters.q_num_heads; + const ptrdiff_t mask_data_offset = probs_matrix_size * + (head_i % mask_num_heads + (batch_i % mask_batch_size) * mask_num_heads); + + T* output = attention_probs + output_offset; + T* out_qk = output_qk == nullptr ? nullptr : output_qk + output_offset; + float beta; + + if (mask_data != nullptr && + (out_qk == nullptr || parameters.qk_matmul_output_mode != attention_helper::QKMatMulOutputMode::kQK)) { + // Broadcast mask data: SxT -> SxT + beta = 1; + ORT_THROW("mask_data != nullptr and out_qk == nullptr or parameters.qk_matmul_output_mode != attention_helper::QKMatMulOutputMode::kQK not implemented yet."); + // memcpy(output, mask_data + mask_data_offset, probs_matrix_bytes); + } else { + beta = 0; + } + + // handling GQA + std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads; + const T* k = K + k_input_chunk_length * ki; + + if (nullptr != present_key) { + if (parameters.kv_num_heads != parameters.q_num_heads) { + // Already done in a loop before this one. + k = present_key + ki * present_chunk_length; + } else { + ORT_THROW("past_key and present_key not implemented yet."); + /* + k = ConcatStateChunk(past_key, K, present_key, + past_chunk_length, k_input_chunk_length, present_chunk_length, + parameters.kv_num_heads, parameters.head_size, batch_i, head_i, + parameters.transpose_output); + */ + } + } + + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + GemmMatMul( + stream, + beta != 0, // has_bias, + false, // has_scales, --> subtypes + dtype, // dtype_A, + dtype, // dtype_B, + dtype, // dtype_C, + dtype, // dtype_Y, + false, // trans_A, + true, // trans_B, + parameters.transpose_output ? Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size + : Q + q_input_chunk_length * i, // p_input_a + transposed_k + ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size + : k, // p_input_b + output, // p_output_y + beta == 0 ? nullptr : output, // p_input_c + nullptr, // p_scale_a + nullptr, // p_scale_b + nullptr, // p_scale_y + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + parameters.head_size * parameters.q_num_heads, // lda + transposed_k + ? parameters.head_size * parameters.kv_num_heads + : parameters.head_size, // ldb + parameters.total_sequence_length, // ldc + true, // row_major_compute + -1, // sm_count + CUBLASLT_EPILOGUE_DEFAULT, + alpha, + beta); + + if (out_qk != nullptr && + (parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKMask || + parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK)) { + // ORT_THROW("out_qk != nullptr and parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKMask or kQK not implemented yet."); + // memcpy(out_qk, output, SafeInt(probs_matrix_size) * sizeof(T)); + if (mask_data != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK) { + // We need to add the bias we could not add because out_qk was requested without the mask. + // This can be optimized with vectorized add using MlasAddFloat32x4. + ORT_THROW("mask_data != nullptr and parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK not implemented yet."); + // MlasEltwiseAdd(output, mask_data + mask_data_offset, output, probs_matrix_size); + } else { + ORT_THROW("out_qk != nullptr and parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKMask or kQK not implemented yet."); + } + } + if (parameters.softcap > 0.0f) { + ORT_THROW("parameters.softcap > 0.0f not implemented yet."); + // ComputeAttentionSoftcapInplace(output, static_cast(probs_matrix_size), parameters.softcap); + } + if (out_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftCap) { + ORT_THROW("out_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftCap not implemented yet."); + // memcpy(out_qk, output, SafeInt(probs_matrix_size) * sizeof(T)); + } + typedef typename ToCudaType::MappedType CudaT; + + // ComputeAttentionSoftmaxInplace(output, parameters.q_sequence_length, parameters.total_sequence_length, nullptr, allocator); + onnxruntime::contrib::attention_softmax_cuda::ComputeSoftmax( + stream, + parameters.total_sequence_length, + parameters.q_sequence_length, + 1, // parameters.batch_size, + 1, // num_heads, + reinterpret_cast(nullptr), // data.attention_bias, + false, // broadcast_attn_bias_dim_0, + false, // broadcast_attn_bias_dim_1, + reinterpret_cast(output), // input + reinterpret_cast(output), // output + false); // causal + + if (output_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftMax) { + ORT_THROW("output_qk != nullptr and parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftMax not implemented yet."); + // memcpy(output_qk + output_offset, output, + // SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T)); + } + } + + if (delete_mask_data) { + // allocator->Free(mask_data); + ORT_THROW("delete_mask_data not implemented yet."); + } +} + +template +void ComputeVxAttentionScore(cudaStream_t stream, + T* output, // buffer for the result with size BxSxNxH_v + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxNxLxH_v + int batch_size, // batch size + int sequence_length, // sequence length + int kv_sequence_length, // sequence length of K or V + int past_sequence_length, // sequence length in past state + int total_sequence_length, // total sequence length = past_sequence_length + kv_sequence_length + int v_head_size, // head size of V (H_v) + int num_heads, // number of attention heads + int kv_num_heads, // number of KV heads + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool transpose_output) { // whether to transpose the output (0, 2, 1, 3) + ORT_ENFORCE((past_value == nullptr) == (present_value == nullptr), + "The implementation only supports past_value and present_value both null or both not null."); + const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v + const ptrdiff_t v_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v + const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v + + // The cost of Gemm + TensorOpCost unit_cost; + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast(SafeInt(sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); + + const size_t bytes_to_copy_trans = SafeInt(v_head_size) * sizeof(T); + double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans); + unit_cost.bytes_loaded += bytes_to_copy_trans_all; + unit_cost.bytes_stored += bytes_to_copy_trans_all; + + bool transposed_v = transpose_output && nullptr == present_value; + if (nullptr != present_value && kv_num_heads != num_heads) { + ORT_THROW("past cache and kv_num_heads != q_num_heads not implemented yet."); + } + int dtype = utils::GetONNXTensorElementDataType(); + + for (std::ptrdiff_t i = 0; i != batch_size * num_heads; ++i) { + // handling GQA + std::ptrdiff_t batch_i = i / num_heads; + std::ptrdiff_t head_i = i % num_heads; + std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads; + const T* v = V + v_input_chunk_length * vi; + + if (nullptr != present_value) { + if (kv_num_heads != num_heads) { + // Already done in a loop before this one. + v = present_value + vi * present_chunk_length; + } else { + // transposed_v is false here. + ORT_THROW("past_value and present_value not implemented yet."); + } + } + + if (transpose_output) { + // transpose_output is false + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + + // V is transposed but not QK. We use a different value for ldb. + GemmMatMul( + stream, + false, // has_bias, + false, // has_scales, --> subtypes + dtype, // dtype_A, + dtype, // dtype_B, + dtype, // dtype_C, + dtype, // dtype_Y, + false, // trans_A, + false, // trans_B, + attention_probs + attention_probs_offset, // QK = p_input_a + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i + : v, // V =p_input_b + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), // p_output_y + nullptr, // p_input_c + nullptr, // p_scale_a + nullptr, // p_scale_b + nullptr, // p_scale_y + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + total_sequence_length, // lda + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + v_head_size * num_heads, // ldc + true, // row_major_compute + -1, // sm_count + CUBLASLT_EPILOGUE_DEFAULT, + 1.f, + 0.f); + } else { + // transpose_output is false + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + ptrdiff_t dest_offset = SafeInt(sequence_length) * v_head_size * i; + T* dest = output + dest_offset; + + GemmMatMul( + stream, + false, // has_bias, + false, // has_scales, --> subtypes + dtype, // dtype_A, + dtype, // dtype_B, + dtype, // dtype_C, + dtype, // dtype_Y, + false, // trans_A, + false, // trans_B, + attention_probs + attention_probs_offset, // QK = p_input_a + v, // V =p_input_b + dest, // p_output_y + nullptr, // p_input_c + nullptr, // p_scale_a + nullptr, // p_scale_b + nullptr, // p_scale_y + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + total_sequence_length, // lda + v_head_size, // ldb + v_head_size, // ldc + true, // row_major_compute + -1, // sm_count + CUBLASLT_EPILOGUE_DEFAULT, + 1.0f, + 0.f); + } + } +} + +template +Status NaiveAttention::ApplyAttention(OpKernelContext* context, + cudaStream_t stream, + const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // Q*K output tensor (if returning Q*K value) + const AttentionParameters& parameters // attention parameters +) const { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr; + const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + T* output_qk_data = output_qk != nullptr ? output_qk->MutableData() : nullptr; + + // Compute the attention score. + size_t bytes = SafeInt(parameters.batch_size) * parameters.q_num_heads * + parameters.q_sequence_length * parameters.total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + // cudaStream_t stream = kernel->Stream(context); + ComputeAttentionProbs(stream, + static_cast(attention_probs), + Q, + K, + mask_index, + parameters, + past_key_data, + present_key_data, + output_qk_data); + + ComputeVxAttentionScore(stream, output->MutableData(), + static_cast(attention_probs), + V, + parameters.batch_size, + parameters.q_sequence_length, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + parameters.v_head_size, + parameters.q_num_heads, + parameters.kv_num_heads, + past_value_data, + present_value_data, + parameters.transpose_output); + return Status::OK(); +} + +#define IMPLEMENT(T) \ + template class NaiveAttention; \ + template void ComputeAttentionProbs(cudaStream_t stream, T * attention_probs, const T* Q, const T* K, const Tensor* mask_index, \ + const AttentionParameters& parameters, const T* past_key, T* present_key, \ + T* output_qk); \ + template void ComputeVxAttentionScore(cudaStream_t stream, T * output, const T* attention_probs, const T* V, int batch_size, \ + int sequence_length, int kv_sequence_length, int past_sequence_length, int total_sequence_length, \ + int v_head_size, int num_heads, int kv_num_heads, const T* past_value, \ + T* present_value, bool transpose_output); + +IMPLEMENT(float); +IMPLEMENT(MLFloat16); +// IMPLEMENT(BFloat16); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention_naive.h b/onnxruntime/core/providers/cuda/llm/attention_naive.h new file mode 100644 index 0000000000000..b56cfaaa2f1ee --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention_naive.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/llm/attention_helper.h" + +namespace onnxruntime { +namespace cuda { + +/** + * Follows the same algorithm as AttentionBase (CPU). + */ +template +class NaiveAttention { + public: + Status ApplyAttention(OpKernelContext* context, // OpKernelContext + cudaStream_t stream, // CUDA stream + const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // Q*K output tensor (if returning Q*K value) + const attention_helper::AttentionParameters& parameters // attention parameters + ) const; + + // protected: + /* + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH_v + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxNxLxH_v + int batch_size, // batch size + int sequence_length, // sequence length + int kv_sequence_length, // sequence length of K or V + int past_sequence_length, // sequence length in past state + int total_sequence_length, // total sequence length = past_sequence_length + kv_sequence_length + int v_head_size, // head size of V (H_v) + int num_heads, // number of attention heads + int kv_num_heads, // number of KV heads + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool transpose_output // whether to transpose the output from BxNxSxH to BxSxNxH + ) const; + */ + /* + void ComputeAttentionProbs(cudaStream_t stream, + T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const Tensor* mask_index, // mask_index + const attention_helper::AttentionParameters& parameters, // attention parameters + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + T* output_qk, // Q*K output + AllocatorPtr allocator) const; + */ + + /* +T* ConcatStateChunk(const T* past, + const T* chunk, + T* present, + size_t past_chunk_length, + size_t input_chunk_length, + size_t present_chunk_length, + size_t num_heads, + size_t head_size, + std::ptrdiff_t batch_i, + std::ptrdiff_t head_i, + bool transposed) const; + */ +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention_naive_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_naive_impl.cu new file mode 100644 index 0000000000000..af8962d0ac4f5 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention_naive_impl.cu @@ -0,0 +1,258 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +#include "core/providers/cuda/llm/attention_naive_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include + +using namespace onnxruntime::cuda; +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace cuda { + +Status GemmMatMul( + cudaStream_t stream, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute, int64_t sm_count, int iepilogue, + float alpha, float beta) { + // TODO: Synchronization should be moved outside of this function. + // TODO: The function should be split in two parts: create descriptors and run cublasLtMatmul. + cublasLtEpilogue_t epilogue = static_cast(iepilogue); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + cublasLtHandle_t cublasLt; + CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt)); + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, + Ddesc = nullptr; + + // Create matrix descriptors. Not setting any extra attributes. + cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A); + cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B); + cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y); + cudaDataType_t scale_cuda_type = + onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C); + + cublasComputeType_t compute_type; + switch (d_cuda_type) { + case CUDA_R_16F: + switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) +#if CUDA_VERSION < 11080 +#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES. +#endif + case CUDA_R_8F_E4M3: + case CUDA_R_8F_E5M2: + compute_type = CUBLAS_COMPUTE_32F; + break; +#endif + default: + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + break; + } + break; + case CUDA_R_16BF: + compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; + break; + case CUDA_R_32F: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + ORT_THROW("Unable to determine computeType in operator GemmFloat8."); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb)); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd)); + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + CUBLAS_RETURN_IF_ERROR( + cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type)); + cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); + +#if CUDA_VERSION >= 11060 + // CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf + if (sm_count != 0) { + int math_sm_count = static_cast(sm_count); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, + sizeof(math_sm_count))); + } +#endif + + if (has_scales) { + // gemm float 8 +#if CUDA_VERSION >= 11080 + // CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf + const int8_t ifast_accumulation_mode = 1; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &ifast_accumulation_mode, sizeof(ifast_accumulation_mode))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a, + sizeof(p_scale_a))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b, + sizeof(p_scale_b))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, + sizeof(p_scale_b))); +#endif + + // float 8 +#if !defined(DISABLE_FLOAT8_TYPES) + if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { + // For FP8 output, cuBLAS requires C_type to be same as bias_type + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type, + sizeof(bias_cuda_type))); + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } +#else + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); +#endif + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue)); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t + // The workspace should be allocated once from OpKernelContext assuming + // only one cuda function is running at a time (which is not necessarily true + // with H100). + size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb + cublasLtMatmulPreference_t preference = nullptr; + cublasLtMatmulPreferenceCreate(&preference); + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, sizeof(workspaceSize)); + + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResults = 0; + cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic( + cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults); + ORT_ENFORCE( + returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to find any suitable algorithm due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, + ", alpha=", alpha, ", beta=", beta, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue, ", smCount=", sm_count, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", M=", M, ", N=", N, ", K=", K, + ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, + ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". Check NVIDIA documentation to see what combination is valid: ", + "https://docs.nvidia.com/cuda/cublas/" + "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" + "cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types."); + + void* workspace = nullptr; + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize)); + } + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul + const void* bias = has_bias ? p_input_c : p_output_y; + cuda_status = cublasLtMatmul( + cublasLt, operationDesc, static_cast(&alpha), /* alpha */ + p_input_a, /* A */ + Adesc, p_input_b, /* B */ + Bdesc, static_cast(&beta), /* beta */ + bias, /* C */ + Cdesc, p_output_y, /* Y */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream); /* stream */ + ORT_ENFORCE( + cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to run cublasLtMatmul due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, ", alpha=", alpha, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue, ", smCount=", sm_count, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + " M=", M, " N=", N, ", K=", K, ", lda=", lda, ", ldb=", + ldb, ", ldd=", ldd, ", workspaceSize=", workspaceSize, + ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". CUDA>=11.8 is required to use float 8 types."); + + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaFree(workspace)); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt)); + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention_naive_impl.h b/onnxruntime/core/providers/cuda/llm/attention_naive_impl.h new file mode 100644 index 0000000000000..b46e2208f76e3 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/attention_naive_impl.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +Status GemmMatMul( + cudaStream_t stream, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute, int64_t sm_count, int /* cublasLtEpilogue_t */ epilogue, + float alpha, float beta); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index ec6fa48681908..f71ca2b6ba131 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -363,7 +363,7 @@ TEST(AttentionTest, Attention3DDefault) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -390,7 +390,7 @@ TEST(AttentionTest, Attention3DDefaultFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -417,7 +417,7 @@ TEST(AttentionTest, Attention4DDefaultBasic) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -444,7 +444,7 @@ TEST(AttentionTest, Attention4DDefault) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -471,7 +471,7 @@ TEST(AttentionTest, Attention4DDefaultFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -499,7 +499,7 @@ TEST(AttentionTest, Attention4DSoftCap) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -527,7 +527,7 @@ TEST(AttentionTest, Attention4DSoftCapFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -557,7 +557,7 @@ TEST(AttentionTest, Attention4DAttnMask) { q, k, v, m, std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -587,7 +587,7 @@ TEST(AttentionTest, Attention4DAttnMaskBool) { q, k, v, std::vector(), m, std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -625,7 +625,7 @@ TEST(AttentionTest, Attention4DAttnPastPresentBasic) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -663,7 +663,7 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -696,7 +696,7 @@ TEST(AttentionTest, Attention4DAttnIsCausal) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -724,7 +724,7 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasic) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -752,7 +752,7 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -780,7 +780,7 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -827,7 +827,7 @@ TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -894,7 +894,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) { q, k, v, m, std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -941,7 +941,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -992,13 +992,13 @@ TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, q, k, v, m, std::initializer_list(), past_key, past_value, -1, 0, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); qk_matmul = std::vector{1.786287f, 1.851782f, 1.433406f, 1.126638f, 1.074598f, 1.202869f, 1.806932f, 1.039214f, 1.155254f, 1.351381f, 1.709788f, 1.654608f, 0.904174f, 1.045790f, 1.828289f, 1.849986f, 0.982722f, 0.779313f, 1.067731f, 0.932425f, 1.164846f, 0.896809f, 1.215540f, 1.155709f, 1.283348f, 0.972161f, 1.592545f, 1.841960f, 1.391534f, 0.932551f, 0.884336f, 0.881353f, 0.905360f, 1.564150f, 1.275840f, 0.946826f, 1.789871f, 1.878873f, 1.971947f, 1.398552f, 1.823965f, 1.960587f, 1.438784f, 1.481077f, 0.957099f, 1.756017f, 1.234584f, 0.990787f, 1.096593f, 1.033003f, 1.868677f, 1.788607f, 1.659495f, 0.667182f, 1.157819f, 0.870338f, 0.879745f, 1.636864f, 0.894962f, 1.714711f, 1.549994f, 0.733612f, 1.117046f, 0.686474f, 1.499953f, 1.123992f, 1.438267f, 0.931251f, 1.633272f, 0.944889f, 0.987120f, 1.218472f, 1.497553f, 1.638913f, 1.553980f, 0.982279f, 1.142558f, 1.193196f, 1.654746f, 1.014832f, 1.090946f, 1.017206f, 1.702928f, 1.601417f, 0.808653f, 1.406642f, 1.423106f, 1.871002f, 1.358196f, 0.931623f, 0.588504f, 0.783458f, 0.882957f, 0.489307f, 1.322660f, 0.934557f, 1.271919f, 0.800610f, 1.444240f, 1.450752f, 0.946420f, 0.900686f, 0.822093f, 1.113904f, 0.568116f, 1.171030f, 1.175384f, 0.910323f, 1.157407f, 1.345392f, 1.400021f, 0.751548f, 1.625352f, 1.456414f, 0.950937f, 1.145433f, 0.649070f, 1.298100f, 0.639947f, 0.927273f, 0.736265f, 1.065406f, 1.263197f, 1.012355f, 1.297169f, 0.495477f, 0.699773f, 0.500964f, 0.620178f, 1.275150f, 0.760687f, 1.387608f, 1.336798f, 0.539168f, 1.042187f, 0.417132f, 1.257103f, 1.163759f, 1.314552f, 0.982448f, 1.345221f, 0.663667f, 0.850426f, 1.238248f, 1.593812f, 1.438230f, 1.387601f, 0.823150f, 0.726727f, 0.832655f, 1.532544f, 0.946970f, 1.126112f, 1.112509f, 1.565497f, 1.938642f, 0.832394f, 1.284816f, 1.447452f, 1.599816f, 0.609072f, 0.743433f, 1.101475f, 0.490747f, 1.020954f, 0.668047f, 0.921248f, 0.721382f, 1.095978f, 0.794792f, 1.488673f, 1.681718f, 0.852196f, 1.102478f, 0.810369f, 1.130985f, 0.425544f, 1.051735f, 0.694759f, 0.764302f, 1.275671f, 1.157903f, 1.440112f, 0.837447f, 1.422500f, 1.150930f, 1.017296f, 1.116673f, 0.804505f, 1.315179f, 0.553615f, 0.871008f, 0.659033f, 1.116166f, 1.134977f, 0.944172f, 0.857236f, 0.531893f, 1.224364f, 0.670808f, 0.843351f, 1.607988f, 0.720031f, 1.438111f, 1.628858f, 0.904480f, 1.456536f, 0.828884f, 1.145072f, 1.586629f, 1.350379f, 1.396510f, 1.226688f, 0.524469f, 0.711242f, 1.413283f, 1.519931f, 1.444998f, 1.155023f, 0.928222f, 0.827857f, 1.092185f, 1.860113f, 1.373539f, 0.953664f, 1.435734f, 1.350082f, 1.735783f, 0.610580f, 1.155694f, 1.600251f, 1.602529f, 0.859450f, 1.156073f, 0.846617f, 0.916578f, 1.134056f, 1.053106f, 1.173786f, 1.246788f, 1.509772f, 1.256221f, 1.540197f, 2.009806f, 1.067828f, 1.164871f, 0.709226f, 1.221456f, 0.845411f, 1.504512f, 1.201048f, 1.402731f, 1.564370f, 1.576583f, 1.589067f, 1.257597f, 1.674126f, 1.954917f, 1.497631f, 1.948780f, 0.954539f, 2.070836f, 0.927942f, 1.418681f, 0.804113f, 1.388198f, 1.624642f, 1.581236f, 1.511648f, 1.311894f, 0.855986f, 0.902148f, 0.785342f, 1.820220f, 0.852723f, 1.696361f, 1.655653f, 1.089764f, 1.202390f, 1.120222f, 1.284748f, 1.475221f, 1.311156f, 1.243736f, 1.625873f, 0.823371f, 1.226631f, 1.673096f, 1.553962f, 1.025746f, 1.313852f, 1.030482f, 0.989448f, 0.936074f, 1.784927f, 0.708855f, 0.971949f, 1.223065f, 1.461189f, 1.747723f, 0.799575f, 0.823636f, 1.400882f, 1.160547f, 0.520804f, 0.836825f, 0.972166f, 0.543222f, 1.346498f, 1.034594f, 1.565712f, 1.361961f, 1.751214f, 0.736224f, 1.864534f, 1.977835f, 1.411005f, 1.496084f, 1.233789f, 1.105877f, 0.961602f, 1.009357f, 1.110593f, 1.390279f, 1.693497f, 1.302893f, 1.756735f, 1.433344f, 2.067142f, 1.916540f, 1.490259f, 1.488384f, 1.309675f, 1.758509f, 1.141796f, 1.534330f, 1.156855f, 1.274409f, 1.870354f, 1.045789f, 1.400564f, 0.876651f, 0.981051f, 0.559955f, 0.790979f, 1.662600f, 1.021407f, 1.716358f, 1.630805f, 0.674263f, 1.320767f, 0.649261f, 1.538417f, 1.525061f, 1.419455f, 1.148088f, 1.820221f, 0.329244f, 1.033743f, 1.253892f, 1.790469f, 1.711897f, 1.467268f, 1.089224f, 0.834806f, 1.155425f, 2.043234f, 0.849033f, 1.136683f, 1.774663f, 1.735976f, 1.677263f, 0.902375f, 1.213391f, 1.758179f, 1.759598f, 0.879983f, 1.517559f, 0.812989f, 0.499876f, 0.998129f, 0.513259f, 1.094689f, 0.873050f, 1.131224f, 0.546321f, 1.364307f, 1.622263f, 0.652555f, 0.680481f, 0.729973f, 1.123450f, 0.722337f, 1.158875f, 0.845219f, 1.151906f, 1.343835f, 1.411206f, 1.638837f, 1.000100f, 1.652081f, 1.598655f, 0.980791f, 1.122207f, 0.848703f, 1.972988f, 0.610630f, 0.678227f, 0.839634f, 1.289163f, 1.497003f, 1.060701f, 0.971334f, 1.099509f, 1.158767f, 0.871929f, 0.972856f, 1.687900f, 0.854091f, 1.804623f, 1.804263f, 0.738135f, 1.209199f, 1.190654f, 1.425313f, 1.450061f, 1.529269f, 1.249452f, 1.921674f, 0.832500f, 0.940835f, 1.908224f}; @@ -1006,13 +1006,13 @@ TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, q, k, v, m, std::initializer_list(), past_key, past_value, -1, 2, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); qk_matmul = std::vector{0.079204f, 0.084565f, 0.055653f, 0.040951f, 0.038874f, 0.044195f, 0.080856f, 0.037523f, 0.042140f, 0.051271f, 0.073371f, 0.069432f, 0.032783f, 0.037770f, 0.082601f, 0.084413f, 0.035462f, 0.028935f, 0.048528f, 0.042386f, 0.053477f, 0.040903f, 0.056258f, 0.052990f, 0.060205f, 0.044104f, 0.082018f, 0.105252f, 0.067083f, 0.042392f, 0.040396f, 0.040276f, 0.041254f, 0.079722f, 0.059754f, 0.043001f, 0.069900f, 0.076406f, 0.083859f, 0.047264f, 0.072324f, 0.082912f, 0.049204f, 0.051330f, 0.030395f, 0.067573f, 0.040116f, 0.031437f, 0.034945f, 0.032792f, 0.075631f, 0.069811f, 0.061356f, 0.022746f, 0.052157f, 0.039125f, 0.039495f, 0.084209f, 0.040101f, 0.091026f, 0.077202f, 0.034126f, 0.050073f, 0.032554f, 0.073434f, 0.050422f, 0.069041f, 0.041583f, 0.083907f, 0.042154f, 0.043972f, 0.055418f, 0.062936f, 0.072492f, 0.066589f, 0.037594f, 0.044129f, 0.046421f, 0.073649f, 0.038838f, 0.041909f, 0.038930f, 0.077284f, 0.069824f, 0.031602f, 0.057467f, 0.058421f, 0.091429f, 0.054749f, 0.035737f, 0.036234f, 0.044034f, 0.048640f, 0.032812f, 0.075502f, 0.051216f, 0.071766f, 0.044795f, 0.085263f, 0.085820f, 0.051827f, 0.049510f, 0.045768f, 0.061277f, 0.035503f, 0.064879f, 0.065162f, 0.049990f, 0.057976f, 0.069967f, 0.073895f, 0.038636f, 0.092571f, 0.078182f, 0.047161f, 0.057286f, 0.034872f, 0.066735f, 0.034556f, 0.046058f, 0.038050f, 0.052880f, 0.064446f, 0.050148f, 0.066673f, 0.029907f, 0.040424f, 0.033136f, 0.037332f, 0.071867f, 0.042963f, 0.080421f, 0.076436f, 0.034427f, 0.056931f, 0.030472f, 0.070581f, 0.064291f, 0.074755f, 0.053630f, 0.077083f, 0.038991f, 0.046997f, 0.069263f, 0.077018f, 0.065921f, 0.062667f, 0.035637f, 0.032361f, 0.035977f, 0.072441f, 0.040334f, 0.048247f, 0.047595f, 0.074868f, 0.108730f, 0.035968f, 0.056545f, 0.066532f, 0.077482f, 0.028769f, 0.032906f, 0.062422f, 0.033892f, 0.057593f, 0.040467f, 0.052127f, 0.042684f, 0.062080f, 0.045935f, 0.091938f, 0.111515f, 0.048649f, 0.062485f, 0.046656f, 0.064291f, 0.031753f, 0.059393f, 0.041563f, 0.044556f, 0.069887f, 0.062123f, 0.082378f, 0.045090f, 0.080940f, 0.061691f, 0.053974f, 0.059613f, 0.043629f, 0.072703f, 0.033948f, 0.046629f, 0.037722f, 0.059583f, 0.060715f, 0.050168f, 0.045991f, 0.033218f, 0.056448f, 0.032452f, 0.038564f, 0.082843f, 0.034089f, 0.069900f, 0.084590f, 0.040994f, 0.071200f, 0.038010f, 0.052145f, 0.081092f, 0.064029f, 0.067052f, 0.056579f, 0.028034f, 0.033791f, 0.068186f, 0.068271f, 0.063343f, 0.047398f, 0.037780f, 0.034172f, 0.044511f, 0.095935f, 0.058974f, 0.038754f, 0.062758f, 0.057607f, 0.084719f, 0.027499f, 0.047430f, 0.073981f, 0.074150f, 0.035269f, 0.047448f, 0.036752f, 0.039415f, 0.048991f, 0.045181f, 0.050976f, 0.054837f, 0.071332f, 0.055356f, 0.073536f, 0.117610f, 0.045851f, 0.050524f, 0.032034f, 0.053465f, 0.036708f, 0.070958f, 0.052385f, 0.064091f, 0.057214f, 0.057917f, 0.058645f, 0.042099f, 0.063851f, 0.084550f, 0.053520f, 0.084033f, 0.031093f, 0.094942f, 0.030276f, 0.049457f, 0.026750f, 0.047972f, 0.060768f, 0.058187f, 0.054276f, 0.044448f, 0.035207f, 0.036870f, 0.032806f, 0.092340f, 0.035092f, 0.081583f, 0.078329f, 0.044479f, 0.049782f, 0.045855f, 0.054055f, 0.065397f, 0.055502f, 0.051883f, 0.076030f, 0.034077f, 0.051003f, 0.079707f, 0.080020f, 0.047184f, 0.062939f, 0.047408f, 0.045502f, 0.043137f, 0.100811f, 0.034370f, 0.044713f, 0.057477f, 0.072930f, 0.097129f, 0.037633f, 0.038550f, 0.068662f, 0.053994f, 0.028478f, 0.039062f, 0.038495f, 0.025068f, 0.055973f, 0.040975f, 0.069692f, 0.056845f, 0.083897f, 0.030405f, 0.093963f, 0.105236f, 0.059703f, 0.065004f, 0.050007f, 0.044003f, 0.038091f, 0.039954f, 0.044211f, 0.058478f, 0.065917f, 0.044603f, 0.070220f, 0.050818f, 0.095779f, 0.082388f, 0.053794f, 0.053693f, 0.044906f, 0.070345f, 0.037966f, 0.056218f, 0.038542f, 0.043350f, 0.078669f, 0.034491f, 0.049179f, 0.029124f, 0.042079f, 0.027618f, 0.034795f, 0.083187f, 0.043812f, 0.087782f, 0.080584f, 0.030962f, 0.059102f, 0.030197f, 0.073473f, 0.072498f, 0.065232f, 0.049729f, 0.097389f, 0.021927f, 0.044356f, 0.055279f, 0.076017f, 0.070273f, 0.055023f, 0.037702f, 0.029233f, 0.040282f, 0.097878f, 0.029652f, 0.039534f, 0.074825f, 0.071985f, 0.067881f, 0.031276f, 0.042686f, 0.073602f, 0.073706f, 0.030584f, 0.057861f, 0.047710f, 0.034884f, 0.057413f, 0.035354f, 0.063233f, 0.050663f, 0.065586f, 0.036542f, 0.082802f, 0.107169f, 0.040638f, 0.041789f, 0.043909f, 0.065079f, 0.043575f, 0.067425f, 0.049272f, 0.066957f, 0.059910f, 0.064085f, 0.080467f, 0.042483f, 0.081539f, 0.077297f, 0.041671f, 0.048000f, 0.036514f, 0.112392f, 0.028779f, 0.030791f, 0.036185f, 0.056722f, 0.069826f, 0.045137f, 0.041278f, 0.046923f, 0.044357f, 0.033296f, 0.036832f, 0.075295f, 0.032707f, 0.084617f, 0.084586f, 0.029126f, 0.046652f, 0.045794f, 0.057906f, 0.059357f, 0.064250f, 0.048568f, 0.095124f, 0.032009f, 0.035671f, 0.093853f}; @@ -1020,7 +1020,7 @@ TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, 3, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); y = std::vector{0.466021f, 0.458662f, 0.433769f, 0.544055f, 0.483743f, 0.601701f, 0.452252f, 0.558874f, 0.462717f, 0.462769f, 0.429452f, 0.544879f, 0.480609f, 0.607708f, 0.462766f, 0.570020f, 0.465546f, 0.464215f, 0.442318f, 0.544785f, 0.481242f, 0.599103f, 0.465833f, 0.567976f, 0.466527f, 0.450295f, 0.420681f, 0.541622f, 0.478068f, 0.592818f, 0.453533f, 0.586057f, 0.586788f, 0.542723f, 0.521934f, 0.605385f, 0.523076f, 0.515204f, 0.538008f, 0.539990f, 0.580554f, 0.544345f, 0.524057f, 0.593493f, 0.520281f, 0.513084f, 0.549197f, 0.556567f, 0.590750f, 0.536522f, 0.528383f, 0.608365f, 0.523467f, 0.511267f, 0.533588f, 0.556113f, 0.589547f, 0.537869f, 0.512585f, 0.601047f, 0.507374f, 0.511124f, 0.547465f, 0.512627f, 0.537318f, 0.460441f, 0.540844f, 0.491120f, 0.495359f, 0.476360f, 0.487767f, 0.575867f, 0.522542f, 0.469555f, 0.552479f, 0.488850f, 0.498227f, 0.480921f, 0.484224f, 0.563258f, 0.536463f, 0.455656f, 0.529199f, 0.484251f, 0.487531f, 0.482517f, 0.496116f, 0.576080f, 0.527226f, 0.455449f, 0.525402f, 0.516090f, 0.487896f, 0.477256f, 0.499739f, 0.574474f, 0.520127f, 0.578615f, 0.430572f, 0.471035f, 0.475543f, 0.515079f, 0.488231f, 0.438589f, 0.525065f, 0.569547f, 0.430350f, 0.477609f, 0.478081f, 0.515330f, 0.479993f, 0.427992f, 0.520505f, 0.584227f, 0.430333f, 0.470616f, 0.468772f, 0.517313f, 0.478180f, 0.435562f, 0.527655f, 0.580609f, 0.440415f, 0.475648f, 0.474939f, 0.501466f, 0.474016f, 0.433277f, 0.489508f, 0.425301f, 0.542249f, 0.446878f, 0.532601f, 0.462732f, 0.460696f, 0.462333f, 0.480973f, 0.421038f, 0.522864f, 0.446350f, 0.525882f, 0.466933f, 0.459678f, 0.470179f, 0.485580f, 0.431242f, 0.545418f, 0.440407f, 0.527849f, 0.471587f, 0.464982f, 0.464551f, 0.502461f, 0.437563f, 0.528884f, 0.426691f, 0.531206f, 0.480744f, 0.460218f, 0.480733f, 0.543597f, 0.506559f, 0.419551f, 0.372524f, 0.622818f, 0.678228f, 0.309035f, 0.543150f, 0.561392f, 0.501923f, 0.420097f, 0.368626f, 0.607674f, 0.661294f, 0.315077f, 0.540017f, 0.552392f, 0.506226f, 0.409681f, 0.376208f, 0.608944f, 0.674258f, 0.301188f, 0.537046f, 0.536986f, 0.515894f, 0.402735f, 0.364314f, 0.612694f, 0.684161f, 0.315733f, 0.553979f}; @@ -1029,7 +1029,7 @@ TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, 2, std::numeric_limits::quiet_NaN(), 1.f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1080,7 +1080,7 @@ TEST(AttentionTest, Attention3DWithPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1131,7 +1131,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmul) { q, k, v, m, std::initializer_list(), past_key, past_value, -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1182,7 +1182,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { q, k, v, m, std::initializer_list(), past_key, past_value, 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1234,7 +1234,7 @@ TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) { q, k, v, m, std::initializer_list(), past_key, past_value, 1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); }