Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ Status CheckCustomAttentionInputs(const T* position_ids,

if (pos_ids_shape[1] < parameters.sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]);
"position_ids dimension 1 must be at least sequence length, got ", pos_ids_shape[1]);
}
}

Expand Down
33 changes: 31 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,51 @@ struct GroupQueryAttentionData {
const T* value = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
int* seqlens_k = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
const T* head_sink = nullptr;

// Total sequence length for each batch. It has shape [batch_size].
int* total_seq_lens = nullptr;

// Past sequence length for each batch (i.e., the offset to append new tokens). Shape [batch_size].
// For first prompt: past_seq_lens[b] = 0
// For token generation or subsequent prompt: past_seq_lens[b] = total_seq_lens[b] - sequence_length
int* past_seq_lens = nullptr;

// Padded sequence length for each batch. Shape [batch_size].
// Only used for first prompt: padded_seq_lens[b] = sequence_length
int* padded_seq_lens = nullptr;

// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
int* seqlens_k_buff = nullptr;

// Position IDs from Input
const int64_t* position_ids = nullptr;

// Memory Efficient buffers
T* fmha_buffer = nullptr;
T* unpacked_qkv_buffer = nullptr;
T* rotary_buffer = nullptr;
int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs
T* k = nullptr;
T* v = nullptr;

#ifndef NDEBUG
// Buffer size tracking for debug validation
// Allocated sizes are set during buffer allocation in group_query_attention.cc
// Max used sizes are updated during kernel calls in group_query_attention_impl.cu
// Validated before operator returns to ensure usage exactly matches allocation
size_t unpacked_qkv_buffer_size = 0; // Allocated size
size_t rotary_buffer_size = 0; // Allocated size
size_t position_ids_buffer_size = 0; // Allocated size
mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels)
mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels)
mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels)
#endif

// Output Tensors
T* output = nullptr;
T* present_key = nullptr;
Expand All @@ -179,6 +206,8 @@ struct GroupQueryAttentionData {
// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
bool use_flash_attention_fast_decode = false;
bool disable_fused_kv = false;
};

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i
constexpr bool is_new_kv_bnsh_format = true;
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
batch_size, num_heads, qk_head_size, parameters.max_sequence_length,
data.seqlens_k_total, nullptr, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value,
nullptr, data.seqlens_k_total, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value,
is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block));

data.k = data.present_key;
Expand Down
Loading
Loading