Skip to content

Commit

Permalink
Fix num splits bug (#21899)
Browse files Browse the repository at this point in the history
### Description
Found a bug with num splits where the heuristic isn't being performed
properly due to incorrect passing of sequence length to heuristic
function.



### Motivation and Context
We were experiencing significant performance issues with long sequence
length with flash attention due to this misconfiguration.
  • Loading branch information
aciddelgado authored Aug 29, 2024
1 parent fd88474 commit 0223e86
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct GroupQueryAttentionParameters {
int sequence_length; // sequence length of input query, key, value
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
int hidden_size;
int num_heads;
int head_size;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
if (use_flash_attention) {
using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = static_cast<int>(num_splits);
softmax_lse_accum_bytes = slse_accum_bytes;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
// split kv buffer
using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads,
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = static_cast<int>(num_splits);
softmax_lse_accum_bytes = slse_accum_bytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->sequence_length = sequence_length; // sequence length of Q
output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
output_parameters->total_sequence_length = total_sequence_length; // total sequence length
output_parameters->hidden_size = q_hidden_size;
output_parameters->num_heads = num_heads;
output_parameters->head_size = head_size;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (use_flash_attention) {
using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = static_cast<int>(num_splits);
softmax_lse_accum_bytes = slse_accum_bytes;
Expand Down

0 comments on commit 0223e86

Please sign in to comment.