From 0223e8647b2672275321f8c14d4df47177afe838 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Thu, 29 Aug 2024 15:00:53 -0700 Subject: [PATCH] Fix num splits bug (#21899) ### Description Found a bug with num splits where the heuristic isn't being performed properly due to incorrect passing of sequence length to heuristic function. ### Motivation and Context We were experiencing significant performance issues with long sequence length with flash attention due to this misconfiguration. --- onnxruntime/contrib_ops/cpu/bert/attention_common.h | 1 + onnxruntime/contrib_ops/cuda/bert/attention.cc | 2 +- onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | 2 +- .../contrib_ops/cuda/bert/group_query_attention_helper.h | 1 + onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 9e6671c26cf59..516ef57d8cd18 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -102,6 +102,7 @@ struct GroupQueryAttentionParameters { int sequence_length; // sequence length of input query, key, value int seqlen_past_kv_cache; // sequence length of past kv tensor int seqlen_present_kv_cache; // sequence length of present kv tensor + int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys int hidden_size; int num_heads; int head_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index e5686b255425c..efbc0b5031657 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -123,7 +123,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention) { using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 1f378a184ab9b..58d1d7f0e4af4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -134,7 +134,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { // split kv buffer using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 91418b17e6dbc..39efdfd66bcc6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -251,6 +251,7 @@ Status CheckInputs(const Tensor* query, output_parameters->sequence_length = sequence_length; // sequence length of Q output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->total_sequence_length = total_sequence_length; // total sequence length output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = head_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 52bfe61608f62..9c558900d1fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -171,7 +171,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention) { using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes;