Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ Status CheckCustomAttentionInputs(const T* position_ids,

if (pos_ids_shape[1] < parameters.sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]);
"position_ids dimension 1 must be at least sequence length, got ", pos_ids_shape[1]);
}
}

Expand Down
18 changes: 16 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,29 @@ struct GroupQueryAttentionData {
const T* value = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
int* seqlens_k = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
const T* head_sink = nullptr;

// Total sequence length for each batch. It has shape [batch_size].
int* total_seq_lens = nullptr;

// Past sequence length for each batch (i.e., the offset to append new tokens). Shape [batch_size].
// For first prompt: past_seq_lens[b] = 0
// For token generation or subsequent prompt: past_seq_lens[b] = total_seq_lens[b] - sequence_length
int* past_seq_lens = nullptr;

// Padded sequence length for each batch. Shape [batch_size].
// Only used for first prompt: padded_seq_lens[b] = sequence_length
int* padded_seq_lens = nullptr;

// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
int* seqlens_k_buff = nullptr;

// Position IDs from Input
const int64_t* position_ids = nullptr;

// Memory Efficient buffers
T* fmha_buffer = nullptr;
Expand All @@ -179,6 +192,7 @@ struct GroupQueryAttentionData {
// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
bool use_flash_attention_fast_decode = false;
};

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i
constexpr bool is_new_kv_bnsh_format = true;
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
batch_size, num_heads, qk_head_size, parameters.max_sequence_length,
data.seqlens_k_total, nullptr, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value,
nullptr, data.seqlens_k_total, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value,
is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block));

data.k = data.present_key;
Expand Down
Loading
Loading