diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 01172bb8f3270..f7c54ad456925 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -351,7 +351,7 @@ Status CheckCustomAttentionInputs(const T* position_ids, if (pos_ids_shape[1] < parameters.sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); + "position_ids dimension 1 must be at least sequence length, got ", pos_ids_shape[1]); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index e08d120750a40..2344b425ed263 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -153,24 +153,51 @@ struct GroupQueryAttentionData { const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; - int* seqlens_k = nullptr; const T* cos_cache = nullptr; const T* sin_cache = nullptr; const T* head_sink = nullptr; + // Total sequence length for each batch. It has shape [batch_size]. + int* total_seq_lens = nullptr; + + // Past sequence length for each batch (i.e., the offset to append new tokens). Shape [batch_size]. + // For first prompt: past_seq_lens[b] = 0 + // For token generation or subsequent prompt: past_seq_lens[b] = total_seq_lens[b] - sequence_length + int* past_seq_lens = nullptr; + + // Padded sequence length for each batch. Shape [batch_size]. + // Only used for first prompt: padded_seq_lens[b] = sequence_length + int* padded_seq_lens = nullptr; + // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; - int* seqlens_k_buff = nullptr; + + // Position IDs from Input + const int64_t* position_ids = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; T* unpacked_qkv_buffer = nullptr; T* rotary_buffer = nullptr; + int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs T* k = nullptr; T* v = nullptr; +#ifndef NDEBUG + // Buffer size tracking for debug validation + // Allocated sizes are set during buffer allocation in group_query_attention.cc + // Max used sizes are updated during kernel calls in group_query_attention_impl.cu + // Validated before operator returns to ensure usage exactly matches allocation + size_t unpacked_qkv_buffer_size = 0; // Allocated size + size_t rotary_buffer_size = 0; // Allocated size + size_t position_ids_buffer_size = 0; // Allocated size + mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels) + mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels) + mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels) +#endif + // Output Tensors T* output = nullptr; T* present_key = nullptr; @@ -179,6 +206,8 @@ struct GroupQueryAttentionData { // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; + bool use_flash_attention_fast_decode = false; + bool disable_fused_kv = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 1ed4527018fd1..ee4be2c20362d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -918,7 +918,7 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i constexpr bool is_new_kv_bnsh_format = true; ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( batch_size, num_heads, qk_head_size, parameters.max_sequence_length, - data.seqlens_k_total, nullptr, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value, + nullptr, data.seqlens_k_total, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value, is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block)); data.k = data.present_key; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 84f651ca5470d..7004edae31498 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -3,6 +3,7 @@ #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_kv_cache.h" +#include "contrib_ops/cuda/bert/rotary_common.cuh" #include "core/providers/cuda/cu_inc/common.cuh" using namespace onnxruntime::cuda; @@ -11,6 +12,32 @@ namespace onnxruntime { namespace contrib { namespace cuda { +// ============================================================================ +// ConcatTensorToTensor Kernel +// ============================================================================ +// PURPOSE: +// Concatenates past KV cache with new KV tokens to create present KV cache. +// Used for non-shared buffer mode (separate past and present tensors). +// +// INPUTS: +// tensor_add_sequence_length - Number of new tokens to append (L) +// tensor_in - Past KV cache [K, B, N, P, H] where P is past sequence length +// tensor_add - New KV tokens [K, B, N, L, H] where L is new sequence length +// +// OUTPUTS: +// tensor_out - Present KV cache [K, B, N, T, H] where T = P + L +// +// THREAD MAPPING: +// threadIdx.x = h (head dimension element) +// threadIdx.y = n (head index) +// blockIdx.x = s (sequence position in output) +// blockIdx.y = b (batch index) +// blockIdx.z = chunk_id (K dimension, typically 2 for K and V) +// +// ASSUMPTIONS: +// - H * num_heads <= max_threads_per_block (use ConcatTensorToTensorLarge otherwise) +// - Output format is BNSH +// ============================================================================ template __global__ void ConcatTensorToTensor(const int tensor_add_sequence_length, const T* tensor_in, @@ -33,18 +60,18 @@ __global__ void ConcatTensorToTensor(const int tensor_add_sequence_length, // tensor_out: K x BxNxTxH, where T = P + L const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; - int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); + const int64_t present_SH = int64_t(all_sequence_length) * H; + const int64_t present_NSH = num_heads * present_SH; + int64_t out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); if (s < tensor_in_sequence_length) { - const int past_SH = tensor_in_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); + const int64_t past_SH = int64_t(tensor_in_sequence_length) * H; + const int64_t past_NSH = num_heads * past_SH; + const int64_t in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); tensor_out[out_offset] = tensor_in[in_offset]; } else if (s < all_sequence_length) { - const int SH = tensor_add_sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); + const int64_t SH = int64_t(tensor_add_sequence_length) * H; + const int64_t NSH = num_heads * SH; + const int64_t in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); tensor_out[out_offset] = tensor_add[in_offset]; } } @@ -73,19 +100,19 @@ __global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, // tensor_out: K x BxNxTxH const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; + const int64_t present_SH = int64_t(all_sequence_length) * H; + const int64_t present_NSH = num_heads * present_SH; while (h < H) { - int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); + int64_t out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); if (s < tensor_in_sequence_length) { - const int past_SH = tensor_in_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); + const int64_t past_SH = int64_t(tensor_in_sequence_length) * H; + const int64_t past_NSH = num_heads * past_SH; + const int64_t in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); tensor_out[out_offset] = tensor_in[in_offset]; } else if (s < all_sequence_length) { - const int SH = tensor_add_sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); + const int64_t SH = int64_t(tensor_add_sequence_length) * H; + const int64_t NSH = num_heads * SH; + const int64_t in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); tensor_out[out_offset] = tensor_add[in_offset]; } @@ -210,7 +237,25 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, BFloat16* tensor_out) { assert(num_heads <= max_threads_per_block); const dim3 grid(all_sequence_length, batch_size, matrix_num); - if (0 == (head_size & 1)) { + if (0 == (head_size % 8)) { + const int H = head_size / 8; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>( + sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>( + sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else if (0 == (head_size & 1)) { const int H = head_size / 2; if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); @@ -268,20 +313,20 @@ __global__ void AddBiasTransAppendKvToPresentSmall( const int S = gridDim.x; const int B = gridDim.y; - constexpr int M = 3; // Matrix count in qkv - const int m = blockIdx.z + 1; // k = 1, v = 2 + constexpr int M = static_cast(QKV::COUNT); // Matrix count in qkv + const int m = blockIdx.z + 1; // k = 1, v = 2 - const int NH = N * head_size; - const int NHS = NH * S; + const int64_t NH = N * head_size; + const int64_t NHS = NH * S; qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); if (biases) { biases += (m * NH + n * head_size); } - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; + const int64_t MsH = int64_t(max_sequence_length) * head_size; + const int64_t NMsH = N * MsH; + const int64_t BNMsH = B * NMsH; present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); for (int h = threadIdx.x; h < head_size; h += blockDim.x) { @@ -304,20 +349,20 @@ __global__ void AddBiasTransAppendKvToPresent( const int S = gridDim.y; const int B = (gridDim.z >> 1); - constexpr int M = 3; // Matrix count in qkv - const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 + constexpr int M = static_cast(QKV::COUNT); // Matrix count in qkv + const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 - const int NH = N * head_size; - const int NHS = NH * S; + const int64_t NH = N * head_size; + const int64_t NHS = NH * S; qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); if (biases) { biases += (m * NH + n * head_size); } - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; + const int64_t MsH = int64_t(max_sequence_length) * head_size; + const int64_t NMsH = N * MsH; + const int64_t BNMsH = B * NMsH; present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); for (int h = threadIdx.x; h < head_size; h += blockDim.x) { @@ -396,102 +441,169 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const BFloat16* qkv_buffer, BFloat16* present); -// Kernel to append new and past kv in either BSNH or BNSH format -// Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file -template -__global__ void ConcatNewToPastKV(const int new_seqlen, - const int past_buffer_seqlen, - const T* past_kv, - const T* new_kv, - T* present_kv, - const int* seqlens_k, - const bool past_only, - // const int* seqlens_q, - const bool is_bsnh) { // refers to past; otherwise bnsh +// Fused kernel to append new K and V to past in either BSNH or BNSH format. +// Adapted from ConcatTensorToTensor kernel. +// +// Grid.z encodes K vs V: blockIdx.z == 0 -> K (with optional RoPE), blockIdx.z == 1 -> V (no RoPE) +// +// Input Format Requirements: +// - new_key/new_value: Must be contiguous BSNH format [batch, seq, kv_heads, head_size] +// - past_key/past_value: Either BSNH or BNSH based on is_bsnh flag +// - present_key/present_value: Same format as past (BSNH or BNSH) +// +// RoPE Requirements (when cos_cache != nullptr && rotary_dim > 0): +// - new_key must be contiguous BSNH so RotaryDispatcher can read pair values +// - The pair element for non-interleaved RoPE is read from new_key at offset (in_offset - h + pair_idx) +// - cos_cache/sin_cache: [max_position, rotary_dim/2] contiguous + +template +__global__ void ConcatNewToPastKVFused(const int new_seqlen, + const int past_buffer_seqlen, + const T* past_key, + const T* past_value, + const T* new_key, + const T* new_value, + T* present_key, + T* present_value, + const int* past_seq_lens, + const int* total_seq_lens, + const bool past_only, + const bool is_bsnh, + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved) { const int h = threadIdx.x; const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; + const int kind = blockIdx.z; // 0 for K, 1 for V - const int present_buffer_seqlen = gridDim.x; + const int present_buffer_seqlen = gridDim.x; // gridDim.x is present_sequence_length const int num_heads = blockDim.y; const int H = blockDim.x; - const int present_batch_stride = present_buffer_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + const int64_t present_batch_stride = int64_t(present_buffer_seqlen) * num_heads * H; + const int64_t row_stride = is_bsnh ? num_heads * H : H; + const int64_t present_head_stride = is_bsnh ? H : int64_t(present_buffer_seqlen) * H; + + // Determine pointers based on kind + const T* past_ptr = (kind == 0) ? past_key : past_value; + const T* new_ptr = (kind == 0) ? new_key : new_value; + T* present_ptr = (kind == 0) ? present_key : present_value; - // past_kv: BPNH or BNPH - // new_kv: BLNH - // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = past_seq_lens[b]; - // prompt, token, and interactive decoding cases - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; + int64_t out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; - int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { - const int past_batch_stride = past_buffer_seqlen * num_heads * H; - const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; - const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; - present_kv[out_offset] = past_kv[in_offset]; + const int64_t past_batch_stride = int64_t(past_buffer_seqlen) * num_heads * H; + const int64_t past_head_stride = is_bsnh ? H : int64_t(past_buffer_seqlen) * H; + const int64_t in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_ptr[out_offset] = past_ptr[in_offset]; } else if (!past_only && s < past_seqlen + new_seqlen) { - // Note: new KV always BSNH - const int new_batch_stride = new_seqlen * num_heads * H; - const int new_row_stride = num_heads * H; - const int new_head_stride = H; - const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; - present_kv[out_offset] = new_kv[in_offset]; + const int64_t new_batch_stride = int64_t(new_seqlen) * num_heads * H; + const int64_t new_row_stride = num_heads * H; + const int64_t new_head_stride = H; + const int64_t in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + + T val = new_ptr[in_offset]; + + // Apply RoPE only for K (kind == 0) + if (kind == 0 && cos_cache != nullptr && rotary_dim > 0) { + int pos_id = 0; + if (position_ids) { + int new_s_idx = s - past_seqlen; + if (new_s_idx >= 0 && new_s_idx < new_seqlen) { + pos_id = static_cast(position_ids[b * new_seqlen + new_s_idx]); + } else { + pos_id = s; + } + } else { + pos_id = s; + } + + // Check bounds for pos_id to be safe? + // RoPE cache size usually matches max_seq_len. + + RotaryDispatcher::apply(val, cos_cache, sin_cache, rotary_dim, h, pos_id, interleaved, new_key, in_offset - h); + } + present_ptr[out_offset] = val; } } -// Use when (H*)*num_heads > 1024 -template -__global__ void ConcatNewToPastKVLarge(const int new_seqlen, - const int past_buffer_seqlen, - const int H, - const int num_heads, - const T* past_kv, - const T* new_kv, - T* present_kv, - const int* seqlens_k, - const bool past_only, - const bool is_bsnh) { +template +__global__ void ConcatNewToPastKVFusedLarge(const int new_seqlen, + const int past_buffer_seqlen, + const int H, + const int num_heads, + const T* past_key, + const T* past_value, + const T* new_key, + const T* new_value, + T* present_key, + T* present_value, + const int* past_seq_lens, + const int* total_seq_lens, + const bool past_only, + const bool is_bsnh, + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved) { int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { const int h = i % H; const int n = i / H; const int s = blockIdx.y; - const int b = blockIdx.z; + const int b = blockIdx.z / 2; // Integer div + const int kind = blockIdx.z % 2; // 0 for K, 1 for V + const int present_buffer_seqlen = gridDim.y; + // gridDim.z is batch_size * 2 - const int present_batch_stride = present_buffer_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + const int64_t present_batch_stride = int64_t(present_buffer_seqlen) * num_heads * H; + const int64_t row_stride = is_bsnh ? num_heads * H : H; + const int64_t present_head_stride = is_bsnh ? H : int64_t(present_buffer_seqlen) * H; - // past_kv: BPNH or BNPH - // new_kv: BLNH - // present_kv: BTNH or BNTH, where T = P + L + const T* past_ptr = (kind == 0) ? past_key : past_value; + const T* new_ptr = (kind == 0) ? new_key : new_value; + T* present_ptr = (kind == 0) ? present_key : present_value; - // prompt, token, and interactive decoding cases - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; + const int past_seqlen = past_seq_lens[b]; + + const int64_t out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; - int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { - const int past_batch_stride = past_buffer_seqlen * num_heads * H; - const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; - const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; - present_kv[out_offset] = past_kv[in_offset]; + const int64_t past_batch_stride = int64_t(past_buffer_seqlen) * num_heads * H; + const int64_t past_head_stride = is_bsnh ? H : int64_t(past_buffer_seqlen) * H; + const int64_t in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_ptr[out_offset] = past_ptr[in_offset]; } else if (!past_only && s < past_seqlen + new_seqlen) { - const int new_batch_stride = new_seqlen * num_heads * H; - const int new_row_stride = num_heads * H; - const int new_head_stride = H; - const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; - present_kv[out_offset] = new_kv[in_offset]; + const int64_t new_batch_stride = int64_t(new_seqlen) * num_heads * H; + const int64_t new_row_stride = num_heads * H; + const int64_t new_head_stride = H; + const int64_t in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + + T val = new_ptr[in_offset]; + + if (kind == 0 && cos_cache != nullptr && rotary_dim > 0) { + int pos_id = s; + int new_s_idx = s - past_seqlen; + if (position_ids && new_s_idx >= 0 && new_s_idx < new_seqlen) { + pos_id = static_cast(position_ids[b * new_seqlen + new_s_idx]); + } + + RotaryDispatcher::apply(val, cos_cache, sin_cache, rotary_dim, h, pos_id, interleaved, new_key, in_offset - h); + } + present_ptr[out_offset] = val; } } } -// Concat new to kv buffer in place template Status LaunchConcatNewToPastKV(const int batch_size, const int kv_num_heads, @@ -500,7 +612,8 @@ Status LaunchConcatNewToPastKV(const int batch_size, const int past_sequence_length, const int present_sequence_length, const bool is_bsnh, - const int* seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, const T* past_key, const T* past_value, const T* new_key, @@ -509,51 +622,60 @@ Status LaunchConcatNewToPastKV(const int batch_size, T* present_value, cudaStream_t stream, const int max_threads_per_block, - const bool past_only) { - const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + const bool past_only, + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved) { + constexpr int num_elements_per_thread = std::max(1, 8 / int(sizeof(T))); + const int H = head_size / num_elements_per_thread; + if (H * kv_num_heads <= max_threads_per_block) { - const dim3 grid(present_sequence_length, batch_size, 1); + // Grid Z dim is 2: 0 for K, 1 for V + const dim3 grid(present_sequence_length, batch_size, 2); const dim3 block(H, kv_num_heads, 1); - ConcatNewToPastKV<<>>(kv_sequence_length, - past_sequence_length, - reinterpret_cast(past_key), - reinterpret_cast(new_key), - reinterpret_cast(present_key), - seqlens_k, - past_only, - is_bsnh); - ConcatNewToPastKV<<>>(kv_sequence_length, - past_sequence_length, - reinterpret_cast(past_value), - reinterpret_cast(new_value), - reinterpret_cast(present_value), - seqlens_k, - past_only, - is_bsnh); + + ConcatNewToPastKVFused<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(past_key), + reinterpret_cast(past_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + reinterpret_cast(present_key), + reinterpret_cast(present_value), + past_seq_lens, + total_seq_lens, + past_only, + is_bsnh, + reinterpret_cast(cos_cache), + reinterpret_cast(sin_cache), + rotary_dim, position_ids, interleaved); } else { + // Large kernel version int steps = (H * kv_num_heads + 255) / 256; - const dim3 grid(steps, present_sequence_length, batch_size); + // Grid Z dim is batch_size * 2 + // We encode b and kind in blockIdx.z in the kernel + const dim3 grid(steps, present_sequence_length, batch_size * 2); const dim3 block(256, 1, 1); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - past_sequence_length, - H, - kv_num_heads, - reinterpret_cast(past_key), - reinterpret_cast(new_key), - reinterpret_cast(present_key), - seqlens_k, - past_only, - is_bsnh); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - past_sequence_length, - H, - kv_num_heads, - reinterpret_cast(past_value), - reinterpret_cast(new_value), - reinterpret_cast(present_value), - seqlens_k, - past_only, - is_bsnh); + + ConcatNewToPastKVFusedLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(past_key), + reinterpret_cast(past_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + reinterpret_cast(present_key), + reinterpret_cast(present_value), + past_seq_lens, + total_seq_lens, + past_only, + is_bsnh, + reinterpret_cast(cos_cache), + reinterpret_cast(sin_cache), + rotary_dim, position_ids, interleaved); } return CUDA_CALL(cudaGetLastError()); } @@ -565,7 +687,8 @@ template Status LaunchConcatNewToPastKV(const int batch_size, const int past_sequence_length, const int present_sequence_length, const bool is_bsnh, - const int* seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, const half* past_key, const half* past_value, const half* new_key, @@ -574,7 +697,12 @@ template Status LaunchConcatNewToPastKV(const int batch_size, half* present_value, cudaStream_t stream, const int max_threads_per_block, - const bool past_only); + const bool past_only, + const half* cos_cache, + const half* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved); template Status LaunchConcatNewToPastKV(const int batch_size, const int kv_num_heads, @@ -583,7 +711,8 @@ template Status LaunchConcatNewToPastKV(const int batch_size, const int past_sequence_length, const int present_sequence_length, const bool is_bsnh, - const int* seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, const BFloat16* past_key, const BFloat16* past_value, const BFloat16* new_key, @@ -592,15 +721,74 @@ template Status LaunchConcatNewToPastKV(const int batch_size, BFloat16* present_value, cudaStream_t stream, const int max_threads_per_block, - const bool past_only); - -// Kernel to append new kv to kv buffer in place + const bool past_only, + const BFloat16* cos_cache, + const BFloat16* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved); + +template Status LaunchConcatNewToPastKV(const int batch_size, + const int kv_num_heads, + const int head_size, + const int kv_sequence_length, + const int past_sequence_length, + const int present_sequence_length, + const bool is_bsnh, + const int* past_seq_lens, + const int* total_seq_lens, + const float* past_key, + const float* past_value, + const float* new_key, + const float* new_value, + float* present_key, + float* present_value, + cudaStream_t stream, + const int max_threads_per_block, + const bool past_only, + const float* cos_cache, + const float* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved); + +// ============================================================================ +// ConcatKVInPlace Kernel +// ============================================================================ +// PURPOSE: +// Appends new KV tokens to existing KV cache buffer IN-PLACE. +// Used when past and present KV share the same memory (kv_share_buffer=true). +// +// INPUTS: +// max_seqlen - Maximum sequence length (buffer size) +// new_kv - New KV tokens to append +// past_seq_lens - Per-batch offset where to write (can be null) +// total_seq_lens - Per-batch total valid tokens after appending +// is_past_kv_bnsh_format - True if KV buffer is BNSH, false for BSNH +// is_new_kv_bnsh_format - True if new_kv is BNSH, false for BSNH +// +// OUTPUTS: +// kv_buff - Updated KV cache with new tokens appended at past_seq_len offset +// +// THREAD MAPPING: +// threadIdx.x = h (head dimension element) +// threadIdx.y = n (head index) +// blockIdx.x = s (new token sequence position, 0 to new_seqlen-1) +// blockIdx.y = b (batch index) +// +// BOUNDS CHECK: +// Only writes when (s + past_seq_len < total_seq_lens[b]) to prevent +// out-of-bounds access with variable-length sequences. +// +// ASSUMPTIONS: +// - H * kv_num_heads <= max_threads_per_block (use ConcatKVInPlaceLarge otherwise) +// ============================================================================ template __global__ void ConcatKVInPlace(const int max_seqlen, T* kv_buff, const T* new_kv, - const int* seqlens_k, - const int* total_seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { const int h = threadIdx.x; @@ -612,19 +800,19 @@ __global__ void ConcatKVInPlace(const int max_seqlen, const int kv_num_heads = blockDim.y; const int H = blockDim.x; - const int past_seq_len = (total_seqlens_k != nullptr) - ? (total_seqlens_k[b] - new_seqlen) - : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); + const int past_seq_len = (past_seq_lens != nullptr) ? past_seq_lens[b] : (total_seq_lens[b] - new_seqlen); - int out_offset = is_past_kv_bnsh_format - ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) - : INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h); + int64_t out_offset = is_past_kv_bnsh_format + ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) + : INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h); - int in_offset = is_new_kv_bnsh_format - ? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h) - : INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h); + int64_t in_offset = is_new_kv_bnsh_format + ? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h) + : INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h); - kv_buff[out_offset] = new_kv[in_offset]; + if (s + past_seq_len < total_seq_lens[b]) { + kv_buff[out_offset] = new_kv[in_offset]; + } } template @@ -633,8 +821,8 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int kv_num_heads, T* kv_buff, const T* new_kv, - const int* seqlens_k, - const int* total_seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh int i = threadIdx.x + (blockDim.x * blockIdx.x); @@ -644,30 +832,29 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int s = blockIdx.y; const int b = blockIdx.z; const int new_seqlen = gridDim.y; - const int past_seq_len = (total_seqlens_k != nullptr) - ? (total_seqlens_k[b] - new_seqlen) - : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); + const int past_seq_len = (past_seq_lens != nullptr) ? past_seq_lens[b] : (total_seq_lens[b] - new_seqlen); - int out_offset = is_past_kv_bnsh_format - ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) - : INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h); + int64_t out_offset = is_past_kv_bnsh_format + ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) + : INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h); - int in_offset = is_new_kv_bnsh_format - ? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h) - : INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h); + int64_t in_offset = is_new_kv_bnsh_format + ? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h) + : INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h); - kv_buff[out_offset] = new_kv[in_offset]; + if (s + past_seq_len < total_seq_lens[b]) { + kv_buff[out_offset] = new_kv[in_offset]; + } } } -// Concat new to kv buffer in place template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* seqlens_k, - const int* total_seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, int new_seq_len, const T* new_key, const T* new_value, @@ -687,15 +874,15 @@ Status LaunchConcatKVInPlace(int batch_size, ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_key), reinterpret_cast(new_key), - seqlens_k, - total_seqlens_k, + past_seq_lens, + total_seq_lens, is_past_kv_bnsh_format, is_new_kv_bnsh_format); ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_value), reinterpret_cast(new_value), - seqlens_k, - total_seqlens_k, + past_seq_lens, + total_seq_lens, is_past_kv_bnsh_format, is_new_kv_bnsh_format); } else { @@ -707,8 +894,8 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_key), reinterpret_cast(new_key), - seqlens_k, - total_seqlens_k, + past_seq_lens, + total_seq_lens, is_past_kv_bnsh_format, is_new_kv_bnsh_format); ConcatKVInPlaceLarge<<>>(max_sequence_length, @@ -716,8 +903,8 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_value), reinterpret_cast(new_value), - seqlens_k, - total_seqlens_k, + past_seq_lens, + total_seq_lens, is_past_kv_bnsh_format, is_new_kv_bnsh_format); } @@ -728,8 +915,8 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* seqlens_k, - const int* total_seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, int new_seq_len, const half* new_key, const half* new_value, @@ -744,8 +931,8 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* seqlens_k, - const int* total_seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, int new_seq_len, const BFloat16* new_key, const BFloat16* new_value, @@ -760,8 +947,8 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* seqlens_k, - const int* total_seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, int new_seq_len, const float* new_key, const float* new_value, @@ -772,6 +959,223 @@ template Status LaunchConcatKVInPlace(int batch_size, cudaStream_t stream, const int max_threads_per_block); +// ============================================================================ +// ConcatKVInPlaceFused Kernel +// ============================================================================ +// PURPOSE: +// Fused kernel that appends BOTH K and V in a single kernel launch. +// Eliminates one kernel launch compared to calling ConcatKVInPlace twice. +// +// INPUTS: +// max_seqlen, new_seqlen - Buffer and new sequence dimensions +// new_k, new_v - New K and V tokens to append (must be pre-rotated if RoPE is needed) +// past_seq_lens - Per-batch write offset (can be null) +// total_seq_lens - Per-batch total valid tokens +// is_past_kv_bnsh_format - True if KV buffer is BNSH, false for BSNH +// is_new_kv_bnsh_format - True if new K/V is BNSH, false for BSNH +// +// OUTPUTS: +// k_buff - Updated K cache +// v_buff - Updated V cache +// +// NOTE: +// RoPE should be applied BEFORE calling this kernel. +// For fused RoPE+append, use ConcatNewToPastKVFused instead. +// ============================================================================ +template +__global__ void ConcatKVInPlaceFused(const int max_seqlen, + const int new_seqlen, + T* k_buff, + T* v_buff, + const T* new_k, + const T* new_v, + const int* past_seq_lens, + const int* total_seq_lens, + const bool is_past_kv_bnsh_format, + const bool is_new_kv_bnsh_format) { + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int kv_num_heads = blockDim.y; + const int H = blockDim.x; + + const int past_seq_len = (past_seq_lens != nullptr) ? past_seq_lens[b] : (total_seq_lens[b] - new_seqlen); + + // Early exit to prevent out-of-bounds access and redundant writes + if (s + past_seq_len >= total_seq_lens[b]) { + return; + } + + // Use int64_t for offsets to prevent overflow + int64_t out_offset = is_past_kv_bnsh_format + ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) + : INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h); + + int64_t in_offset = is_new_kv_bnsh_format + ? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h) + : INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h); + + // Simple copy for K and V + k_buff[out_offset] = new_k[in_offset]; + v_buff[out_offset] = new_v[in_offset]; +} + +// Large version for when H * kv_num_heads > max_threads_per_block +template +__global__ void ConcatKVInPlaceFusedLarge(const int max_seqlen, + const int new_seqlen, + const int H, + const int kv_num_heads, + T* k_buff, + T* v_buff, + const T* new_k, + const T* new_v, + const int* past_seq_lens, + const int* total_seq_lens, + const bool is_past_kv_bnsh_format, + const bool is_new_kv_bnsh_format) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * kv_num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int past_seq_len = (past_seq_lens != nullptr) ? past_seq_lens[b] : (total_seq_lens[b] - new_seqlen); + + if (s + past_seq_len >= total_seq_lens[b]) { + return; + } + + int64_t out_offset = is_past_kv_bnsh_format + ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) + : INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h); + + int64_t in_offset = is_new_kv_bnsh_format + ? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h) + : INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h); + + k_buff[out_offset] = new_k[in_offset]; + v_buff[out_offset] = new_v[in_offset]; + } +} + +// Launcher for fused K+V append +template +Status LaunchConcatKVInPlaceFused(int batch_size, + int kv_num_heads, + int head_size, + int max_sequence_length, + const int* past_seq_lens, + const int* total_seq_lens, + int new_seq_len, + const T* new_key, + const T* new_value, + T* present_key, + T* present_value, + bool is_past_kv_bnsh_format, + bool is_new_kv_bnsh_format, + cudaStream_t stream, + const int max_threads_per_block) { + // Determine vectorization factor (float2 is 8 bytes) + constexpr int vector_bytes = sizeof(float2); + constexpr int element_bytes = sizeof(T); + constexpr int elements_per_vector = vector_bytes / element_bytes; + + if (head_size % elements_per_vector != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by ", elements_per_vector, " for vectorized kernel."); + } + + const int H = head_size / elements_per_vector; + + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(new_seq_len, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + + // Single kernel for both K and V + ConcatKVInPlaceFused<<>>( + max_sequence_length, + new_seq_len, + reinterpret_cast(present_key), + reinterpret_cast(present_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + past_seq_lens, + total_seq_lens, + is_past_kv_bnsh_format, + is_new_kv_bnsh_format); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, new_seq_len, batch_size); + const dim3 block(256, 1, 1); + + ConcatKVInPlaceFusedLarge<<>>( + max_sequence_length, + new_seq_len, + H, + kv_num_heads, + reinterpret_cast(present_key), + reinterpret_cast(present_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + past_seq_lens, + total_seq_lens, + is_past_kv_bnsh_format, + is_new_kv_bnsh_format); + } + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchConcatKVInPlaceFused(int batch_size, + int kv_num_heads, + int head_size, + int max_sequence_length, + const int* past_seq_lens, + const int* total_seq_lens, + int new_seq_len, + const half* new_key, + const half* new_value, + half* present_key, + half* present_value, + bool is_past_kv_bnsh_format, + bool is_new_kv_bnsh_format, + cudaStream_t stream, + const int max_threads_per_block); + +template Status LaunchConcatKVInPlaceFused(int batch_size, + int kv_num_heads, + int head_size, + int max_sequence_length, + const int* past_seq_lens, + const int* total_seq_lens, + int new_seq_len, + const BFloat16* new_key, + const BFloat16* new_value, + BFloat16* present_key, + BFloat16* present_value, + bool is_past_kv_bnsh_format, + bool is_new_kv_bnsh_format, + cudaStream_t stream, + const int max_threads_per_block); + +template Status LaunchConcatKVInPlaceFused(int batch_size, + int kv_num_heads, + int head_size, + int max_sequence_length, + const int* past_seq_lens, + const int* total_seq_lens, + int new_seq_len, + const float* new_key, + const float* new_value, + float* present_key, + float* present_value, + bool is_past_kv_bnsh_format, + bool is_new_kv_bnsh_format, + cudaStream_t stream, + const int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h index d7d7bdd87d62f..6e54aa85131f1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h @@ -9,12 +9,31 @@ #include "core/providers/cuda/cuda_common.h" // Macro to help compute index of flatten 4D matrix, note that dim1 is not used so it is excluded. -#define INDEX_4D(dim2, dim3, dim4, i, j, k, l) ((i) * (dim2) * (dim3) * (dim4) + (j) * (dim3) * (dim4) + (k) * (dim4) + (l)) +#define INDEX_4D(dim2, dim3, dim4, i, j, k, l) (int64_t(i) * (dim2) * (dim3) * (dim4) + int64_t(j) * (dim3) * (dim4) + int64_t(k) * (dim4) + int64_t(l)) namespace onnxruntime { namespace contrib { namespace cuda { +// Matrix index constants +enum class QKV : int { + Q = 0, + K = 1, + V = 2, + COUNT = 3 +}; + +// KV Cache Layout Documentation: +// BSNH format: [batch_size, sequence_length, num_heads, head_size] +// - Preferred for most operations due to better memory coalescing for typical access patterns +// - Adjacent threads in a warp (h dimension) access contiguous memory +// - Used when is_bsnh=true +// +// BNSH format: [batch_size, num_heads, sequence_length, head_size] +// - Used when sequence dimension needs to be contiguous +// - May suffer from worse coalescing if head_size is small +// - Used when is_bsnh=false (or explicit bnsh flags) + Status LaunchConcatTensorToTensor(cudaStream_t stream, const int all_sequence_length, const int sequence_length, @@ -64,6 +83,8 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const T* qkv_buffer, T* present); +// Fused KV Append for Separate Buffer Mode: Appends New K & V to Past in one kernel +// Uses blockIdx.z to distinguish between K and V template Status LaunchConcatNewToPastKV(const int batch_size, const int kv_num_heads, @@ -72,7 +93,8 @@ Status LaunchConcatNewToPastKV(const int batch_size, const int past_sequence_length, const int present_sequence_length, const bool is_bsnh, - const int* seqlens_k, + const int* past_seq_lens, + const int* total_seq_lens, const T* past_key, const T* past_value, const T* new_key, @@ -81,15 +103,20 @@ Status LaunchConcatNewToPastKV(const int batch_size, T* present_value, cudaStream_t stream, const int max_threads_per_block, - const bool past_only); + const bool past_only, + const T* cos_cache = nullptr, + const T* sin_cache = nullptr, + const int rotary_dim = 0, + const int64_t* position_ids = nullptr, + const bool interleaved = false); template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, - int max_sequence_length, // max sequence length of present_key or present_value. - const int* seqlens_k, // it is not used when total_seqlens_k is available. - const int* total_seqlens_k, // optional, nullptr means it is not available. + int max_sequence_length, // max sequence length of present_key or present_value. + const int* past_seq_lens, + const int* total_seq_lens, int new_seq_len, const T* new_key, const T* new_value, @@ -100,6 +127,26 @@ Status LaunchConcatKVInPlace(int batch_size, cudaStream_t stream, const int max_threads_per_block); +// Truly fused K+V In-Place Append with RoPE +// Single kernel that appends K (with RoPE rotation) and V (without rotation) to KV cache. +// This eliminates a separate kernel launch for V, saving kernel overhead. +template +Status LaunchConcatKVInPlaceFused(int batch_size, + int kv_num_heads, + int head_size, + int max_sequence_length, + const int* past_seq_lens, + const int* total_seq_lens, + int new_seq_len, + const T* new_key, + const T* new_value, + T* present_key, + T* present_value, + bool is_past_kv_bnsh_format, + bool is_new_kv_bnsh_format, + cudaStream_t stream, + const int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 76704b5b29fcd..bdb0c6e8273cf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -173,7 +173,7 @@ size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + if (params.num_splits <= 1 && !force_split_kernel) { run_mha_fwd_(params, stream); } else { run_mha_fwd_splitkv_dispatch(params, stream); @@ -524,6 +524,29 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; + } else if (is_packed_qkv) { + // Handle Packed QKV where K/V are part of Q + // q_ptr points to the start of the packed buffer (Batch, Seq, (H + 2*Hk)*D) + + params.seqlen_knew = seqlen_q; // For packed, new K len is same as Q len + + // Strides for Packed QKV + // layout: [batch, seq, (h + 2*hk), d] + int64_t row_stride = (num_heads + 2 * num_heads_k) * head_size; + params.q_batch_stride = seqlen_q * row_stride; + params.knew_batch_stride = seqlen_q * row_stride; + params.vnew_batch_stride = seqlen_q * row_stride; + + params.q_row_stride = row_stride; + params.knew_row_stride = row_stride; + params.vnew_row_stride = row_stride; + + params.q_head_stride = head_size; + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; } else { params.seqlen_knew = 0; params.knew_ptr = nullptr; @@ -539,6 +562,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; if (seqlens_k_ != nullptr) { params.cu_seqlens_k = static_cast(seqlens_k_); + params.seqused_k = static_cast(seqlens_k_); } if (rotary_cos != nullptr) { @@ -573,7 +597,12 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); + // or if using packed QKV (to ensure correct handling of strided inputs which might be better supported or isolated in split kernel logic). + // Note: if the fused kernel handles packing/rotary/appending, it should pass is_packed_qkv=false to this API (via use_packed_for_fa=false), + // effectively bypassing this check and allowing standard kernels if otherwise eligible. + bool force_split = (k_new != nullptr) || is_packed_qkv; + + run_mha_fwd(params, stream, force_split); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 1c131af8453a1..c99db85f93421 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/cuda/cuda_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" @@ -8,6 +9,8 @@ #include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cpu/utils/debug_macros.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -35,6 +38,9 @@ namespace cuda { REGISTER_KERNEL_TYPED(MLFloat16) REGISTER_KERNEL_TYPED(BFloat16) +constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; +constexpr const char* kDisableFusedKv = "ORT_DISABLE_FUSED_KV"; + template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : CudaKernel(info) { @@ -64,6 +70,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); CUDA_CALL_THROW(cudaMemset(zeros_.get(), 0, kZerosCount * sizeof(int))); } + + disable_flash_decode_ = ParseEnvironmentVariableWithDefault(kDisableFlashDecode, false); + disable_fused_kv_ = ParseEnvironmentVariableWithDefault(kDisableFusedKv, false); } template @@ -73,17 +82,23 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); - const Tensor* seqlens_k = context->Input(5); + + // The input seqlens_k is total sequence length - 1 for historical reasons. + // Rename it to total_seq_lens_minus_one in cuda kernel to avoid confusion. + const Tensor* total_seq_lens_minus_one = context->Input(5); + + // The max of total sequence lengths. The content of this tensor is a scalar stored in CPU memory. const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); const Tensor* position_ids = context->Input(9); const Tensor* attention_bias = context->Input(10); const Tensor* head_sink = context->Input(11); - if (position_ids != nullptr || attention_bias != nullptr) { + if (attention_bias != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "position_ids and attention_bias are not supported in GroupQueryAttention cuda kernel."); + "attention_bias is not supported in GroupQueryAttention cuda kernel."); } auto& device_prop = GetDeviceProp(); @@ -101,7 +116,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { ¶meters, num_heads_, kv_num_heads_, - seqlens_k, + total_seq_lens_minus_one, total_seqlen, scale_, softcap_, @@ -116,8 +131,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; parameters.zeros_count = kZerosCount; parameters.zero_ptr = zeros_.get(); - - int sequence_length = parameters.sequence_length; parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; @@ -134,87 +147,165 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); - output_shape[1] = static_cast(sequence_length); + output_shape[1] = static_cast(parameters.sequence_length); output_shape[2] = static_cast(parameters.hidden_size); Tensor* output = context->Output(0, output_shape); + // Set up present KV output shapes + std::vector present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; + + TensorShape present_shape(present_dims); + context->Output(1, present_shape); // present_key + context->Output(2, present_shape); // present_value + + IAllocatorUniquePtr k_buffer; + IAllocatorUniquePtr v_buffer; + IAllocatorUniquePtr rotary_buffer; + IAllocatorUniquePtr position_ids_buffer; + IAllocatorUniquePtr fmha_buffer; + IAllocatorUniquePtr unpacked_qkv_buffer; + IAllocatorUniquePtr seq_lens_buffer; + + // Flash Attention buffers + IAllocatorUniquePtr softmax_lse_buffer; + IAllocatorUniquePtr softmax_lse_accum_buffer; + IAllocatorUniquePtr out_accum_buffer; + + data.position_ids = (position_ids != nullptr) ? position_ids->Data() : nullptr; + + // Input pointers for both paths + data.query = reinterpret_cast(query->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); + + // Handle Past/Present pointers + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.present_key = reinterpret_cast(context->Output(1)->MutableData()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + data.present_value = reinterpret_cast(context->Output(2)->MutableData()); + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, parameters.head_size, parameters.num_heads, parameters.kv_num_heads); - // Allocate buffers - size_t softmax_lse_bytes = 0; - size_t softmax_lse_accum_bytes = 0; - size_t out_accum_bytes = 0; + data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_share_buffer; if (use_flash_attention) { - // softmax buffer - softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffer - using namespace std; - auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + data.use_flash_attention = true; + data.use_memory_efficient_attention = false; + + // Allocate Flash specific buffers (Softmax LSE, Accum) + size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); - softmax_lse_accum_bytes = slse_accum_bytes; - out_accum_bytes = o_accum_bytes; + + softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); -#else - constexpr bool use_flash_attention = false; - auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif + if (data.use_flash_attention_fast_decode && parameters.sequence_length == 1) { + // FlashAttentionDecoding Fast Path: + // - Uses Flash Attention's internal KV append logic, so total_seq_lens and padded_seq_lens are not needed. + // - Past_seq_lens is passed as seqlens_k to Flash Attention, which uses it to: + // 1. Determine where to append new K/V in the cache + // 2. Apply correct causal masking (attention only to positions [0, past_seq_len]) + // - The input seqlens_k from ONNX graph is (total_len - 1), which equals past_seq_len when seq_len == 1. + // - This optimization avoids launching GetSequenceLengths kernel for single-token decoding. + data.past_seq_lens = const_cast(total_seq_lens_minus_one->Data()); + } else { + // Compute sequence length buffers (past_seq_lens and total_seq_lens). + // Allocate buffer for both: first half is past_seq_lens, second half is total_seq_lens. + seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, context->GetComputeStream()); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + data.past_seq_lens = seq_lens_buffer.get(); + data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size; + data.padded_seq_lens = data.total_seq_lens + parameters.batch_size; + ORT_RETURN_IF_ERROR(LaunchGetSequenceLengths(total_seq_lens_minus_one->Data(), + data.past_seq_lens, + data.total_seq_lens, + data.padded_seq_lens, + parameters.batch_size, + parameters.sequence_length, + parameters.is_first_prompt, + cuda_stream, + device_prop.maxThreadsPerBlock)); + } + + if (!use_flash_attention) { + // Fall back to memory efficient attention. #if USE_MEMORY_EFFICIENT_ATTENTION - int sm = (device_prop.major * 10) + device_prop.minor; - bool use_memory_efficient_attention = - !use_flash_attention && - !disable_memory_efficient_attention_ && - has_memory_efficient_attention(sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.head_size); - - // allocate buffers - size_t kv_buffer_bytes = 0; - // need a buffer if we must ungroup kv - const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); - if (use_memory_efficient_attention && needs_buff) { - kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + has_memory_efficient_attention(sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.head_size); + + // KV buffer for head expansion (when num_heads != kv_num_heads) + size_t kv_buffer_bytes = (use_memory_efficient_attention && (parameters.num_heads != parameters.kv_num_heads)) + ? (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size) + : 0; + // FMHA workspace + size_t fmha_buffer_bytes = (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) + ? (sizeof(float) * parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size) + : 0; + + k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; +#endif + + data.use_memory_efficient_attention = use_memory_efficient_attention; + data.use_flash_attention = false; + + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + data.disable_fused_kv = disable_fused_kv_; } - size_t rotary_buffer_bytes = 0; - if (use_memory_efficient_attention && do_rotary_) { - rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size; - rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; + + // Centralized scratch buffer allocation using GQABufferRequirements + // This ensures allocation logic stays in sync with kernel usage + auto buffer_req = GQABufferRequirements::Compute( + parameters, + use_flash_attention, + data.use_flash_attention_fast_decode, + data.use_memory_efficient_attention); + + if (buffer_req.unpacked_qkv_bytes > 0) { + unpacked_qkv_buffer = GetScratchBuffer(buffer_req.unpacked_qkv_bytes, context->GetComputeStream()); + data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } - size_t fmha_buffer_bytes = 0; - if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { - fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + if (buffer_req.rotary_buffer_bytes > 0) { + rotary_buffer = GetScratchBuffer(buffer_req.rotary_buffer_bytes, context->GetComputeStream()); + data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); } - size_t unpacked_qkv_bytes = 0; - if (use_memory_efficient_attention && parameters.is_packed_qkv) { - unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); + if (buffer_req.position_ids_bytes > 0) { + position_ids_buffer = GetScratchBuffer(buffer_req.position_ids_bytes, context->GetComputeStream()); + data.position_ids_buffer = reinterpret_cast(position_ids_buffer.get()); } - auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream()); - auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); - auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream()); -#else - constexpr bool use_memory_efficient_attention = false; - auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); - auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); - auto rotary_buffer = GetScratchBuffer(0, context->GetComputeStream()); - auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); - auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#ifndef NDEBUG + // Track allocated sizes for validation + data.unpacked_qkv_buffer_size = buffer_req.unpacked_qkv_bytes; + data.rotary_buffer_size = buffer_req.rotary_buffer_bytes; + data.position_ids_buffer_size = buffer_req.position_ids_bytes; #endif if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; - debug_info.use_efficient_attention = use_memory_efficient_attention; + debug_info.use_efficient_attention = data.use_memory_efficient_attention; debug_info.Print("GroupQueryAttention", this->Node().Name(), @@ -222,67 +313,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { std::is_same::value); } - // seqlens_k buffer - size_t seqlens_k_bytes = 0; - seqlens_k_bytes = sizeof(int) * parameters.batch_size; - auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); - - std::vector present_dims; - if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - present_dims = { - parameters.batch_size, parameters.seqlen_present_kv_cache, parameters.kv_num_heads, parameters.head_size}; - } else { // BNSH - present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; - } - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - - data.query = reinterpret_cast(query->Data()); - data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); - data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); - data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); - data.output = reinterpret_cast(output->MutableData()); - data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); - data.seqlens_k = const_cast(seqlens_k->Data()); - data.use_flash_attention = use_flash_attention; - data.use_memory_efficient_attention = use_memory_efficient_attention; if (data.past_key == data.present_key) { parameters.kv_share_buffer = true; + ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when kv_share_buffer is true"); } else { parameters.kv_share_buffer = false; + ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when kv_share_buffer is false"); } - // Flash Buffers - if (softmax_lse_buffer != nullptr) { - data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); - } - if (softmax_lse_accum_buffer != nullptr) { - data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); - } - if (out_accum_buffer != nullptr) { - data.out_accum = reinterpret_cast(out_accum_buffer.get()); - } - if (seqlens_k_buffer != nullptr) { - data.seqlens_k_buff = reinterpret_cast(seqlens_k_buffer.get()); - } - // Memory Efficient Buffers - if (k_buffer != nullptr) { - data.k = reinterpret_cast(k_buffer.get()); - data.v = reinterpret_cast(v_buffer.get()); - } - if (fmha_buffer != nullptr) { - data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); - } - if (unpacked_qkv_buffer != nullptr) { - data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); - } - if (rotary_buffer != nullptr) { - data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); - } - // Rotary Embedding + + data.output = reinterpret_cast(output->MutableData()); + if (parameters.do_rotary) { data.cos_cache = reinterpret_cast(cos_cache->Data()); data.sin_cache = reinterpret_cast(sin_cache->Data()); @@ -294,8 +334,23 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = GetCublasHandle(context); - return QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data); + ORT_RETURN_IF_ERROR(QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data)); + +#ifndef NDEBUG + // Validate buffer usage matches allocation exactly + ORT_ENFORCE(data.unpacked_qkv_max_used == data.unpacked_qkv_buffer_size, + "unpacked_qkv_buffer: used ", data.unpacked_qkv_max_used, + " bytes but allocated ", data.unpacked_qkv_buffer_size); + ORT_ENFORCE(data.rotary_max_used == data.rotary_buffer_size, + "rotary_buffer: used ", data.rotary_max_used, + " bytes but allocated ", data.rotary_buffer_size); + ORT_ENFORCE(data.position_ids_max_used == data.position_ids_buffer_size, + "position_ids_buffer: used ", data.position_ids_max_used, + " bytes but allocated ", data.position_ids_buffer_size); +#endif + + return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 08457feb099b3..5bf26e8c6edac 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -7,6 +7,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { @@ -33,6 +34,9 @@ class GroupQueryAttention final : public CudaKernel { float softcap_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + bool disable_flash_decode_; + bool disable_fused_kv_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; const AttentionKernelOptions* kernel_options_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index cfdaf2aa74837..6643555dc30ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -19,31 +19,35 @@ limitations under the License. */ // Modifications: -// (1) support GPT-2 past state, unidirectional mask (causal) +// (1) support past state, unidirectional mask (causal) // (2) use flash attention kernel from (https://github.com/Dao-AILab/flash-attention) // (3) support different number of heads for Q and KV // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include +#include #include +#include // For getenv + +#include #include -#include "core/providers/cuda/cu_inc/common.cuh" -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/attention_softmax.h" -#include "contrib_ops/cuda/bert/transformer_common.h" + +#include "contrib_ops/cpu/utils/debug_macros.h" #include "contrib_ops/cuda/bert/add_bias_transpose.h" -#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" #include "contrib_ops/cuda/bert/bert_padding.h" -#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" -#include "contrib_ops/cuda/bert/attention_impl.h" -#include "core/providers/cuda/shared_inc/cuda_call.h" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" -#include +#include "contrib_ops/cuda/bert/rotary_common.cuh" +#include "contrib_ops/cuda/bert/transformer_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; @@ -57,28 +61,26 @@ namespace cuda { ////////// Auxiliary Kernels for KV prep -// Kernel for seqlens_k -__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { - int id = blockDim.x * blockIdx.x + threadIdx.x; - if (id < batch_size) seqlens_k[id] = seqlen; -} - // Concat new to past in present. Supports past BSNH or past BNSH template -Status LaunchConcatNewToPastKV(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, - const void* new_key, - const void* new_value, - cudaStream_t stream, - const int max_threads_per_block, - const bool past_only = false) { +Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, + cudaStream_t stream, + const int max_threads_per_block, + const bool past_only = false, + const T* cos_cache = nullptr, + const T* sin_cache = nullptr, + const int rotary_dim = 0, + const int64_t* position_ids = nullptr, + const bool interleaved = false) { const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - const int* seqlens_k = parameters.is_first_prompt ? nullptr : reinterpret_cast(data.seqlens_k); AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; @@ -90,7 +92,8 @@ Status LaunchConcatNewToPastKV(GroupQueryAttentionParameters& parameters, past_sequence_length, present_sequence_length, is_bsnh, - seqlens_k, + data.past_seq_lens, + data.total_seq_lens, data.past_key, data.past_value, reinterpret_cast(new_key), @@ -99,7 +102,12 @@ Status LaunchConcatNewToPastKV(GroupQueryAttentionParameters& parameters, data.present_value, stream, max_threads_per_block, - past_only); + past_only, + cos_cache, + sin_cache, + rotary_dim, + position_ids, + interleaved); } // Concat new to kv buffer in place @@ -112,8 +120,6 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, cudaStream_t stream, const int max_threads_per_block) { const int max_sequence_length = parameters.seqlen_present_kv_cache; - const int* seqlens_k = (parameters.is_first_prompt && !parameters.is_subsequent_prompt) ? nullptr - : reinterpret_cast(data.seqlens_k); assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -123,8 +129,8 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, parameters.kv_num_heads, parameters.head_size, max_sequence_length, - seqlens_k, - nullptr, // total_seqlens_k would be wrong to use here + data.past_seq_lens, + data.total_seq_lens, parameters.sequence_length, reinterpret_cast(new_key), reinterpret_cast(new_value), @@ -136,7 +142,32 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, max_threads_per_block); } -// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +// ============================================================================ +// Ungroup Kernel +// ============================================================================ +// PURPOSE: +// Expands grouped KV heads to match Q heads for Memory Efficient Attention. +// Each KV head is replicated q_num_heads/kv_num_heads times. +// +// INPUTS: +// kv_in - Grouped KV tensor with kv_num_heads heads +// in_seqlen - Sequence length of input tensor +// kv_num_heads - Number of KV heads (fewer than Q heads) +// is_bsnh - True for BSNH format, False for BNSH format +// +// OUTPUTS: +// kv_out - Ungrouped tensor with q_num_heads heads (BSNH format) +// +// THREAD MAPPING: +// threadIdx.x = h (head dimension element) +// threadIdx.y = out_n (output head index) +// blockIdx.x = s (sequence position) +// blockIdx.y = b (batch index) +// +// ASSUMPTIONS: +// - q_num_heads is divisible by kv_num_heads +// - H * q_num_heads <= max_threads_per_block (use UngroupLarge otherwise) +// ============================================================================ template __global__ void Ungroup(const T* kv_in, T* kv_out, @@ -167,6 +198,19 @@ __global__ void Ungroup(const T* kv_in, kv_out[out_offset] = kv_in[in_offset]; } +// ============================================================================ +// UngroupLarge Kernel +// ============================================================================ +// PURPOSE: +// Same as Ungroup but for cases where H * q_num_heads > max_threads_per_block. +// Uses a 1D thread grid to avoid block dimension limit. +// +// THREAD MAPPING: +// Each thread processes one element indexed by (threadIdx.x + blockDim.x * blockIdx.x) +// This linear index is decomposed into (h, out_n) within the kernel. +// blockIdx.y = s (sequence position) +// blockIdx.z = b (batch index) +// ============================================================================ template __global__ void UngroupLarge(const T* kv_in, T* kv_out, @@ -200,7 +244,8 @@ __global__ void UngroupLarge(const T* kv_in, } // Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. -Status LaunchUngroup(GroupQueryAttentionParameters& parameters, +template +Status LaunchUngroup(const GroupQueryAttentionParameters& parameters, float2* k_buff, float2* v_buff, const float2* k_og, const float2* v_og, const int buff_seqlen, const int og_seqlen, @@ -248,43 +293,33 @@ Status LaunchUngroup(GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } -__global__ void PastToTotalSeqlen(int32_t* seqlens_k, - int32_t* seqlens_k_buff, - const int add_seqlen) { - seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; -} - -// Calculate total sequence length from seqlens_k -Status LaunchGetSeqlensTotal(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int batch_size, cudaStream_t stream, - const int /*threads_per_block*/) { - const dim3 grid(1, 1, 1); - // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads - const dim3 block(batch_size, 1, 1); - PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, 1); - return CUDA_CALL(cudaGetLastError()); -} - -// Currently, interactive decoding only works for batch_size 1 -__global__ void GetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, - const int batch_size, const int sequence_length) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < batch_size) { - seqlens_k_buff[tid] = seqlens_k[tid] + 1 - sequence_length; - } -} - -// Calculate past sequence length for each batch entry for flash attention kernel -Status LaunchGetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, - const int batch_size, const int sequence_length, cudaStream_t stream, - const int max_threads_per_block) { - const int threads = std::min(batch_size, max_threads_per_block); - const int blocks = (threads / max_threads_per_block) + 1; - GetSeqlensInteractive<<>>(seqlens_k, seqlens_k_buff, batch_size, - sequence_length); - return CUDA_CALL(cudaGetLastError()); -} - -// Kernel to unpack qkv from packed qkv +// ============================================================================ +// UnpackQKV Kernel +// ============================================================================ +// PURPOSE: +// Unpacks packed QKV tensor into separate Q, K, V tensors. +// Packed input has interleaved [Q, K, V] per token. +// +// INPUTS: +// packed_qkv - Input tensor of shape [B, S, (Q_heads + 2*KV_heads) * head_size] +// num_heads - Number of Q heads +// kv_num_heads - Number of KV heads +// head_size - Head dimension +// sequence_length - Token sequence length +// batch_size - Batch size +// +// OUTPUTS: +// unpacked_q - Q tensor [B, S, num_heads, head_size] if BSNH, or [B, num_heads, S, head_size] if BNSH +// unpacked_k - K tensor [B, S, kv_num_heads, head_size] if BSNH, or [B, kv_num_heads, S, head_size] if BNSH +// unpacked_v - V tensor (same layout as K) +// +// TEMPLATE PARAM: +// output_bnsh - If true, outputs BNSH format; if false, outputs BSNH format +// +// THREAD MAPPING: +// One thread per element in packed_qkv. Thread determines which of Q/K/V +// the element belongs to based on the offset within the hidden dimension. +// ============================================================================ template __global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, @@ -300,7 +335,7 @@ __global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* int offset = tid % d; if (output_bnsh) { // output BNSH int head_count = kv_num_heads; - T* unpacked; + T* unpacked = nullptr; if (offset < q_hidden) { unpacked = unpacked_q; head_count = num_heads; @@ -311,23 +346,36 @@ __global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked = unpacked_v; offset -= (q_hidden + k_hidden); } - int n = offset / head_size; - int h = offset % head_size; - int unpacked_i = INDEX_4D(head_count, sequence_length, head_size, b, n, s, h); - unpacked[unpacked_i] = packed_qkv[tid]; + if (unpacked != nullptr) { + int n = offset / head_size; + int h = offset % head_size; + + int unpacked_i = INDEX_4D(head_count, sequence_length, head_size, b, n, s, h); + unpacked[unpacked_i] = packed_qkv[tid]; + } else { +#ifndef NDEBUG + assert(false && "Unexpected null 'unpacked' pointer in GroupQueryAttention unpack kernel"); +#endif + } } else { // output BSNH if (offset < q_hidden) { - int unpacked_i = b * sequence_length * num_heads * head_size + s * num_heads * head_size + offset; - unpacked_q[unpacked_i] = packed_qkv[tid]; + if (unpacked_q != nullptr) { + int unpacked_i = b * sequence_length * num_heads * head_size + s * num_heads * head_size + offset; + unpacked_q[unpacked_i] = packed_qkv[tid]; + } } else if (offset < q_hidden + k_hidden) { - int unpacked_i = b * sequence_length * kv_num_heads * head_size + - s * kv_num_heads * head_size + (offset - q_hidden); - unpacked_k[unpacked_i] = packed_qkv[tid]; + if (unpacked_k != nullptr) { + int unpacked_i = b * sequence_length * kv_num_heads * head_size + + s * kv_num_heads * head_size + (offset - q_hidden); + unpacked_k[unpacked_i] = packed_qkv[tid]; + } } else { - int unpacked_i = b * sequence_length * kv_num_heads * head_size + - s * kv_num_heads * head_size + (offset - q_hidden - k_hidden); - unpacked_v[unpacked_i] = packed_qkv[tid]; + if (unpacked_v != nullptr) { + int unpacked_i = b * sequence_length * kv_num_heads * head_size + + s * kv_num_heads * head_size + (offset - q_hidden - k_hidden); + unpacked_v[unpacked_i] = packed_qkv[tid]; + } } } } @@ -345,72 +393,344 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -__global__ void SeqlensToPosIdsInteractive(const int32_t* seqlens_k, int64_t* position_ids, - const int seqlen, const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - int b = tid / seqlen; - int s = tid % seqlen; - if (b < batch_size) { - const int total_seqlen = seqlens_k[b] + 1; - const int past_seqlen = total_seqlen - seqlen; - if (past_seqlen + s < total_seqlen) { - position_ids[tid] = past_seqlen + s; - } else { - position_ids[tid] = 1; - } +// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache +// This eliminates 4 kernel launches: Unpack -> Rotate Q -> Rotate K -> Append K -> Append V +// Becomes: Single kernel that does all operations in one pass +// +// Bounds Safety: +// - cache_s = past_seq_len + s is guaranteed < max_seqlen by the caller (group_query_attention.cc) +// because present_sequence_length = max(past + new_seq_len) across batches, and the present +// buffer is allocated with seqlen_present_kv_cache >= total_seq_lens[b] for all b. +// - The kernel processes exactly batch_size * sequence_length * (Q+K+V hidden) elements, +// which matches the packed_qkv input size allocated by the model. +// +// RoPE Contiguity Requirement: +// - packed_qkv MUST be strictly contiguous with layout [B, S, (H_q + 2*H_kv) * D] +// - The half-split RoPE logic (RotaryDispatcher::apply) fetches pair elements at offset +// (h + rotary_dim/2) relative to the start of each head +// - If strided/non-contiguous inputs are ever supported, this pointer arithmetic must change +// +// Performance Optimization: +// Uses 3D grid layout to eliminate expensive integer divisions: +// - blockIdx.z = batch index (b) +// - blockIdx.y = sequence index (s) +// - blockIdx.x * blockDim.x + threadIdx.x = offset within QKV hidden dimension +// This removes 4 divisions (/, %) per thread that would otherwise be needed. +template +__global__ void UnpackQKVWithRoPEAndAppendKV( + const T* packed_qkv, // Input: packed QKV [B, S, (Q+K+V) hidden] + T* unpacked_q, // Output: rotated Q [B, S, Q_heads, H] (BSNH) + T* k_cache, // Output: K cache [B, N, MaxS, H] or [B, MaxS, N, H] + T* v_cache, // Output: V cache [B, N, MaxS, H] or [B, MaxS, N, H] + const int num_heads, + const int kv_num_heads, + const int head_size, + const int d, // QKV hidden stride = (num_heads + 2*kv_num_heads) * head_size + const int max_seqlen, // KV cache max sequence length + const int* past_seq_lens, + // RoPE params + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved, + const bool is_cache_bnsh) { + // Vectorized load/store using float4 (16 bytes) + using LoadT = float4; + constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); + + // 3D grid layout eliminates integer division: + // - blockIdx.z = batch index (b) - obtained from grid dimension, no division needed + // - blockIdx.y = sequence index (s) - obtained from grid dimension, no division needed + // - linear thread index within (b, s) gives offset directly + const int b = blockIdx.z; + const int s = blockIdx.y; + const int offset_vec_idx = blockIdx.x * blockDim.x + threadIdx.x; // Vector index within d + const int offset = offset_vec_idx * elements_per_thread; // Element offset within d + + // Bounds check: offset must be within the QKV hidden dimension + if (offset >= d) return; + + const int q_hidden = num_heads * head_size; + const int k_hidden = kv_num_heads * head_size; + const int sequence_length = gridDim.y; // Get from grid dimension + + // Calculate linear index for packed_qkv load + const int64_t packed_idx = static_cast(b) * sequence_length * d + + static_cast(s) * d + offset; + + // Load vector from packed buffer + LoadT val_vec = reinterpret_cast(packed_qkv)[packed_idx / elements_per_thread]; + + // Common RoPE Calculations + const int past_seq_len = past_seq_lens[b]; + int pos_id = 0; + if (position_ids != nullptr) { + pos_id = static_cast(position_ids[b * sequence_length + s]); + } else { + pos_id = past_seq_len + s; } -} -__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, - const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - int b = tid / seqlen; - int s = tid % seqlen; - if (b < batch_size) { - if (s < seqlens_k[b] + 1) { - position_ids[tid] = s; - } else { - position_ids[tid] = 1; + // Determine Q, K, or V based on offset + if (offset < q_hidden) { + // Q: Apply RoPE and write to unpacked_q buffer (BSNH format) + const int q_head_idx = offset / head_size; + const int h = offset % head_size; + const int h_idx = h / elements_per_thread; + + if (cos_cache != nullptr && rotary_dim > 0 && h < rotary_dim) { + // For half-split RoPE, pair values should be read relative to the START of the current Q head. + // Calculate offset to head start: (b, s, q_head_n, 0) in packed QKV. + const int64_t q_head_start_in_packed = static_cast(b) * sequence_length * d + + static_cast(s) * d + + static_cast(q_head_idx) * head_size; + RotaryDispatcher::apply(val_vec, + reinterpret_cast(cos_cache), + reinterpret_cast(sin_cache), + rotary_dim, h_idx, pos_id, interleaved, + reinterpret_cast(packed_qkv), + q_head_start_in_packed / elements_per_thread); } + + const int64_t q_idx = static_cast(b) * sequence_length * num_heads * head_size + + static_cast(s) * num_heads * head_size + offset; + // Vector store to unpacked_q + reinterpret_cast(unpacked_q)[q_idx / elements_per_thread] = val_vec; + + } else if (offset < q_hidden + k_hidden) { + // K: Apply RoPE and write DIRECTLY to K cache + const int k_offset = offset - q_hidden; + const int n = k_offset / head_size; + const int h = k_offset % head_size; + const int h_idx = h / elements_per_thread; + + if (cos_cache != nullptr && rotary_dim > 0 && h < rotary_dim) { + // For half-split RoPE, pair values should be read relative to the START of the current K head. + // Calculate offset to head start: (b, s, k_head_n, 0) in packed QKV. + const int64_t k_head_start_in_packed = static_cast(b) * sequence_length * d + + static_cast(s) * d + + q_hidden + + static_cast(n) * head_size; + RotaryDispatcher::apply(val_vec, + reinterpret_cast(cos_cache), + reinterpret_cast(sin_cache), + rotary_dim, h_idx, pos_id, interleaved, + reinterpret_cast(packed_qkv), + k_head_start_in_packed / elements_per_thread); + } + + const int cache_s = past_seq_len + s; + int64_t cache_idx; + if (is_cache_bnsh) { + cache_idx = static_cast(b) * kv_num_heads * max_seqlen * head_size + + static_cast(n) * max_seqlen * head_size + + static_cast(cache_s) * head_size + h; + } else { // BSNH + cache_idx = static_cast(b) * max_seqlen * kv_num_heads * head_size + + static_cast(cache_s) * kv_num_heads * head_size + + static_cast(n) * head_size + h; + } + // Vector store to k_cache + reinterpret_cast(k_cache)[cache_idx / elements_per_thread] = val_vec; + + } else { + // V: Write DIRECTLY to V cache (no rotation) + const int v_offset = offset - q_hidden - k_hidden; + const int n = v_offset / head_size; + const int h = v_offset % head_size; + + const int cache_s = past_seq_len + s; + int64_t cache_idx; + if (is_cache_bnsh) { + cache_idx = static_cast(b) * kv_num_heads * max_seqlen * head_size + + static_cast(n) * max_seqlen * head_size + + static_cast(cache_s) * head_size + h; + } else { // BSNH + cache_idx = static_cast(b) * max_seqlen * kv_num_heads * head_size + + static_cast(cache_s) * kv_num_heads * head_size + + static_cast(n) * head_size + h; + } + // Vector store to v_cache + reinterpret_cast(v_cache)[cache_idx / elements_per_thread] = val_vec; } } -__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < batch_size) { - position_ids[tid] = seqlens_k[tid]; +// Launcher for fused UnpackQKV + RoPE + KV Append +template +Status LaunchUnpackQKVWithRoPEAndAppendKV( + const T* packed_qkv, + T* unpacked_q, + T* k_cache, + T* v_cache, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int sequence_length, + const int batch_size, + const int max_seqlen, + const int* past_seq_lens, + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved, + const bool is_cache_bnsh, + cudaStream_t stream, + const int max_threads_per_block) { + // Determine vectorization factor (float4 is 16 bytes) + constexpr int vector_bytes = sizeof(float4); + constexpr int element_bytes = sizeof(T); + constexpr int elements_per_vector = vector_bytes / element_bytes; + + // Validate head_size alignment + if (head_size % elements_per_vector != 0) { + // If strict alignment is not met (unlikely given GQA constraints), we should fall back or fail. + // Typically GQA enforces head_size % 8 == 0. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by ", elements_per_vector, " for vectorized GQA kernel."); + } + + // Validate grid dimensions - CUDA limits gridDim.y to 65535 + if (sequence_length > 65535) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Sequence length ", sequence_length, + " exceeds CUDA grid dimension limit (65535) for fused UnpackQKV kernel."); } + +#ifndef NDEBUG + // Debug-mode alignment assertions for vectorized memory access + assert(reinterpret_cast(packed_qkv) % 16 == 0 && "packed_qkv must be 16-byte aligned"); + assert(reinterpret_cast(unpacked_q) % 16 == 0 && "unpacked_q must be 16-byte aligned"); + assert(reinterpret_cast(k_cache) % 16 == 0 && "k_cache must be 16-byte aligned"); + assert(reinterpret_cast(v_cache) % 16 == 0 && "v_cache must be 16-byte aligned"); + if (cos_cache != nullptr) { + assert(reinterpret_cast(cos_cache) % 16 == 0 && "cos_cache must be 16-byte aligned"); + assert(reinterpret_cast(sin_cache) % 16 == 0 && "sin_cache must be 16-byte aligned"); + } +#endif + + // QKV hidden dimension stride + const int d = (num_heads + 2 * kv_num_heads) * head_size; + const int d_vectors = d / elements_per_vector; // Number of vectors per (b, s) + + // 3D grid layout for eliminating integer divisions in kernel: + // grid.x = number of blocks needed to cover d_vectors with threads_per_block threads + // grid.y = sequence_length + // grid.z = batch_size + const int threads_per_block = std::min(max_threads_per_block, d_vectors); + const int blocks_x = (d_vectors + threads_per_block - 1) / threads_per_block; + const dim3 grid(blocks_x, sequence_length, batch_size); + const dim3 block(threads_per_block); + + UnpackQKVWithRoPEAndAppendKV<<>>( + packed_qkv, + unpacked_q, + k_cache, + v_cache, + num_heads, + kv_num_heads, + head_size, + d, + max_seqlen, + past_seq_lens, + cos_cache, + sin_cache, + rotary_dim, + position_ids, + interleaved, + is_cache_bnsh); + + return CUDA_CALL(cudaGetLastError()); } -// Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, cudaStream_t stream, - const int max_threads_per_block) { - const int seqlen = parameters.sequence_length; - const int batch_size = parameters.batch_size; - const int threads = max_threads_per_block; - const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_subsequent_prompt) { - SeqlensToPosIdsInteractive<<>>(seqlens_k, position_ids, seqlen, batch_size); - } else if (parameters.is_first_prompt) { - SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); - } else { - SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); +// Explicit template instantiations +template Status LaunchUnpackQKVWithRoPEAndAppendKV( + const half*, half*, half*, half*, + int, int, int, int, int, int, const int*, + const half*, const half*, int, const int64_t*, bool, bool, + cudaStream_t, int); + +template Status LaunchUnpackQKVWithRoPEAndAppendKV( + const BFloat16*, BFloat16*, BFloat16*, BFloat16*, + int, int, int, int, int, int, const int*, + const BFloat16*, const BFloat16*, int, const int64_t*, bool, bool, + cudaStream_t, int); + +// ============================================================================ +// GetSequenceLengths Kernel +// ============================================================================ +// PURPOSE: +// Computes derived sequence length buffers from input seqlens_k. +// Input seqlens_k contains (total_sequence_length - 1) for historical reasons. +// +// INPUTS: +// total_seq_lens_minus_one - Input from ONNX graph: total_len - 1 per batch [B] +// sequence_length - Current Q sequence length (new tokens) +// is_first_prompt - True if this is the first prompt (no past) +// +// OUTPUTS: +// past_seq_lens - Offset where new KV should be appended [B] +// First prompt: 0 +// Otherwise: total_len - sequence_length +// total_seq_lens - Total valid tokens including new ones [B] +// padded_seq_lens - Padded length for masking (first prompt only) [B] +// First prompt: sequence_length +// Otherwise: not set (undefined) +// +// THREAD MAPPING: +// One thread per batch element. +// +// USAGE: +// Called once per inference to derive all sequence length variants. +// ============================================================================ +__global__ void GetSequenceLengths(const int* total_seq_lens_minus_one, + int* past_seq_lens, + int* total_seq_lens, + int* padded_seq_lens, + const int batch_size, + const int sequence_length, + const bool is_first_prompt) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < batch_size) { + const int total_len = total_seq_lens_minus_one[i] + 1; + total_seq_lens[i] = total_len; + if (is_first_prompt) { + past_seq_lens[i] = 0; + padded_seq_lens[i] = sequence_length; + } else { + past_seq_lens[i] = total_len - sequence_length; + } } +} + +Status LaunchGetSequenceLengths( + const int* total_seq_lens_minus_one, + int* past_seq_lens, + int* total_seq_lens, + int* padded_seq_lens, + const int batch_size, + const int sequence_length, + const bool is_first_prompt, + cudaStream_t stream, + const int max_threads_per_block) { + int blocks = (batch_size + max_threads_per_block - 1) / max_threads_per_block; + GetSequenceLengths<<>>(total_seq_lens_minus_one, past_seq_lens, total_seq_lens, padded_seq_lens, batch_size, sequence_length, is_first_prompt); return CUDA_CALL(cudaGetLastError()); } -////////// Launch Kernels +////////// Kernels (supports right padding but not left padding) #if USE_FLASH_ATTENTION + +// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. +// Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. template -Status FlashAttention( +Status FlashAttentionDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { - const int max_threads_per_block = device_prop.maxThreadsPerBlock; + assert(!parameters.is_first_prompt && parameters.kv_share_buffer); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; @@ -435,30 +755,7 @@ Status FlashAttention( value = reinterpret_cast(key) + value_offset; } - void* seqlens_k = reinterpret_cast(data.seqlens_k); - if (parameters.is_subsequent_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlensInteractive(reinterpret_cast(data.seqlens_k), - reinterpret_cast(data.seqlens_k_buff), batch_size, - sequence_length, stream, max_threads_per_block)); - seqlens_k = reinterpret_cast(data.seqlens_k_buff); - } else if (parameters.is_first_prompt) { - // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value - // user should use seqlens_k to index into output to get new tokens - if (batch_size <= parameters.zeros_count) { - seqlens_k = parameters.zero_ptr; - } else { - // Launch kernel to create larger seqlen tensor when batch_size > 256 - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_buff, 0, batch_size); - seqlens_k = reinterpret_cast(data.seqlens_k_buff); - } - } - - if (!parameters.kv_share_buffer || parameters.is_first_prompt) { // copy past kv to present kv - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, - true)); - } + void* seqlens_k = reinterpret_cast(data.past_seq_lens); void* present_key = reinterpret_cast(const_cast(data.present_key)); void* present_value = reinterpret_cast(const_cast(data.present_value)); @@ -468,11 +765,6 @@ Status FlashAttention( bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("Q", reinterpret_cast(query), batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", reinterpret_cast(present_key), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - DUMP_TENSOR("V", reinterpret_cast(present_value), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, head_sink, /*block_table*/ nullptr, @@ -482,11 +774,293 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size - 1, parameters.rotary_interleaved, parameters.is_packed_qkv)); - // if (parameters.left_padding && parameters.is_first_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } + return Status::OK(); +} - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); +// Use extra kernel(s) for unpacking, rotary and kv append. +// Flash attention is used for attention only. +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + bool is_causal = parameters.is_unidirectional; + bool is_bf16 = std::is_same::value; + + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; + + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); + } else { + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } + +#if DUMP_TENSOR_LEVEL > 0 + printf("[GQA FlashAttention] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, kv_share_buffer: %d\n", + static_cast(parameters.is_packed_qkv), + static_cast(parameters.is_first_prompt), + static_cast(parameters.is_subsequent_prompt), + static_cast(parameters.kv_share_buffer)); +#endif + DUMP_TENSOR_INIT(); + + // Track whether we keep packed QKV for FA kernels + bool use_packed_for_fa = parameters.is_packed_qkv; + + // Track if we used the fully fused path (packed + share_buffer + rotary) + bool used_fused_packed_path = false; + + // ========================================================================= + // Handle Packed QKV Input + // ========================================================================= + if (parameters.is_packed_qkv) { + T* unpacked_buffer = reinterpret_cast(data.unpacked_qkv_buffer); + if (unpacked_buffer != nullptr) { + T* unpacked_q = unpacked_buffer; + + // Check if we can use the fully fused path + if (parameters.kv_share_buffer && parameters.do_rotary && !data.disable_fused_kv) { + // FULLY FUSED PATH: Unpack + RoPE Q + RoPE K + Append KV in single kernel + // This eliminates 4 kernel launches! + ORT_RETURN_IF_ERROR(LaunchUnpackQKVWithRoPEAndAppendKV( + reinterpret_cast(data.query), // packed QKV + unpacked_q, // Q output buffer (rotated) + data.present_key, // K cache (direct write) + data.present_value, // V cache (direct write) + num_heads, + kv_num_heads, + head_size, + sequence_length, + batch_size, + parameters.seqlen_present_kv_cache, + data.past_seq_lens, + data.cos_cache, + data.sin_cache, + parameters.rotary_dim, + data.position_ids, + parameters.rotary_interleaved, + !past_bsnh, // is_cache_bnsh + stream, + max_threads_per_block)); + + // Update query to point to rotated Q + query = unpacked_q; + use_packed_for_fa = false; + used_fused_packed_path = true; + + // Track buffer usage: Only Q is stored in unpacked_qkv_buffer (fused path writes K/V to cache) + size_t q_bytes = static_cast(batch_size) * sequence_length * num_heads * head_size * sizeof(T); + UpdateUnpackedQkvMaxUsed(data, q_bytes); + + // K and V are already in cache - no need to set key/value pointers + + } else { + // Standard path: Unpack first, then process K/V separately + size_t q_size = static_cast(batch_size) * sequence_length * num_heads * head_size; + T* unpacked_k = unpacked_buffer + q_size; + + size_t k_size = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; + T* unpacked_v = unpacked_k + k_size; + + // If we need Q rotation, we MUST unpack Q as well. + T* q_dst = parameters.do_rotary ? unpacked_q : nullptr; + + // Always unpack to BSNH as LaunchConcatNewToPastKV expects contiguous BSNH input + ORT_RETURN_IF_ERROR((LaunchUnpackQKV(reinterpret_cast(data.query), q_dst, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size, stream, max_threads_per_block))); + + // Update key/value to point to unpacked buffers + key = unpacked_k; + value = unpacked_v; + + if (parameters.do_rotary) { + query = unpacked_q; + use_packed_for_fa = false; + } + + // Track buffer usage: Q+K+V unpacked + size_t total_bytes = (q_size + 2 * k_size) * sizeof(T); + UpdateUnpackedQkvMaxUsed(data, total_bytes); + } + } + } + // ========================================================================= + // Handle Unpacked Q, K, V Input (with optional RoPE) + // ========================================================================= + else { + if (parameters.do_rotary) { + // For unpacked input, we need to rotate Q and K. + // The rotated Q and K will be stored in unpacked_qkv_buffer with layout [Q (B*S*H*D), K (B*S*H_kv*D)]. + T* unpacked_buffer = reinterpret_cast(data.unpacked_qkv_buffer); + if (unpacked_buffer != nullptr) { + query = unpacked_buffer; + // Do not update key here for Unpacked path. + // key must remain pointing to data.key (Input) for Explicit K Rotation (k_src). + // k_dst will be calculated from unpacked_buffer explicitly. + } + } + } + + const int64_t* position_ids = data.position_ids; + + // Explicit Q Rotation (skip if fused path already applied RoPE) + if (parameters.do_rotary && !used_fused_packed_path) { + // Rotate Q + // Q ptr is already set to the destination buffer (unpacked_buffer) above. + // Input for Rotation: + // If packed: we unpacked into `query` buffer. So Input==Output (In-place). + // If unpacked: we set `query = unpacked_buffer`. But Input is `data.query`. + const T* q_input_for_rope = parameters.is_packed_qkv ? reinterpret_cast(query) : reinterpret_cast(data.query); + T* q_output_for_rope = reinterpret_cast(query); // Destination + + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, + q_output_for_rope, + q_input_for_rope, + nullptr, // position_ids unused for format 2/3 + data.past_seq_lens, + data.cos_cache, + data.sin_cache, + batch_size, + sequence_length, + num_heads, + head_size, + parameters.rotary_dim, + parameters.max_sequence_length, + 2, // position_ids_format = 2 (Implicit: past_seq_lens[b] + s) + parameters.rotary_interleaved, + max_threads_per_block, + false // is_input_bnsh_format (Q is BSNH) + )); + DUMP_TENSOR("Rotated Q", q_output_for_rope, batch_size, sequence_length, num_heads, head_size); + + // Rotate K will be done later in fused kernel. + } + + // Skip KV append if we used the fully fused path (KV already in cache) + if (!used_fused_packed_path) { + if (parameters.kv_share_buffer && !parameters.is_first_prompt) { + constexpr bool is_new_kv_bnsh_format = false; + if (parameters.do_rotary) { + // Explicit K Rotation (replacing internal RoPE in fused kernel) + size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; + T* k_dst = reinterpret_cast(data.unpacked_qkv_buffer) + q_elements; + const T* k_src = reinterpret_cast(key); + + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, + k_dst, + k_src, + position_ids, + data.past_seq_lens, + data.cos_cache, + data.sin_cache, + batch_size, + sequence_length, + kv_num_heads, + head_size, + parameters.rotary_dim, + parameters.max_sequence_length, + position_ids != nullptr ? 1 : 2, + parameters.rotary_interleaved, + max_threads_per_block, + false)); + + if (!data.disable_fused_kv) { + // Use fused kernel for K (rotated) + V append + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlaceFused( + batch_size, + kv_num_heads, + head_size, + parameters.seqlen_present_kv_cache, + data.past_seq_lens, + data.total_seq_lens, + sequence_length, + k_dst, + reinterpret_cast(data.value), + data.present_key, + data.present_value, + !past_bsnh, + is_new_kv_bnsh_format, + stream, + max_threads_per_block)); + } else { + // Unfused Fallback: LaunchConcatKVInPlace + // We must pass the ROTATED K (k_dst) to it. + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( + parameters, data, k_dst, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); + } + + // Track buffer usage: Q + K rotated in unpacked_qkv_buffer + size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; + size_t total_bytes = (q_elements + k_elements) * sizeof(T); + UpdateUnpackedQkvMaxUsed(data, total_bytes); + } else { + // No RoPE - use original kernel + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); + } + } else { + // ORT MUST perform the append (using unpacked data for packed case) + bool skip_new_append = false; + // FUSED ROPE: Pass RoPE params to ConcatKV (applies RoPE to K as it is appended) + // IMPORTANT: For Fused RoPE with unpacked input, we must pass data.key (the original input), + // not the scratch buffer 'key' which is empty since explicit rotation was skipped. + const void* key_for_concat = parameters.is_packed_qkv ? key : data.key; + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKVHelper(parameters, data, key_for_concat, value, stream, max_threads_per_block, skip_new_append, + data.cos_cache, data.sin_cache, parameters.rotary_dim, nullptr, parameters.rotary_interleaved)); + } + } + + DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1); + DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1); + DUMP_TENSOR("Present Key", data.present_key, batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + DUMP_TENSOR("Present Value", data.present_value, batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Disable internal RoPE in Flash Attention (pass nullptr) + void* cos_cache = nullptr; + void* sin_cache = nullptr; + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); + + // We have already appended (and quantized if needed) the new tokens into present_key/value. + // Pass nullptr for new_k/new_v to disable flash attention kernel's internal Append_KV logic. + void* kernel_new_k = nullptr; + void* kernel_new_v = nullptr; + + // Use padded seq lens for first prompt since mha_fwd_kvcache assumes uniform seqlen_q. + // The causal mask offset (seqlen_k - seqlen_q) becomes negative when seqlen_k < seqlen_q, causing incorrect masking. + int* seq_lens = parameters.is_first_prompt ? data.padded_seq_lens : data.total_seq_lens; + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, + kernel_new_k, kernel_new_v, + data.output, reinterpret_cast(data.softmax_lse), seq_lens, + cos_cache, sin_cache, head_sink, /*block_table*/ nullptr, batch_size, + num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, + parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size - 1, + parameters.rotary_interleaved, use_packed_for_fa, 0, 1)); return Status::OK(); } @@ -509,6 +1083,14 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; +#if DUMP_TENSOR_LEVEL > 0 + printf("[GQA EfficientAttention] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, kv_share_buffer: %d\n", + static_cast(parameters.is_packed_qkv), + static_cast(parameters.is_first_prompt), + static_cast(parameters.is_subsequent_prompt), + static_cast(parameters.kv_share_buffer)); +#endif + const void* query; const void* key; const void* value; @@ -518,8 +1100,8 @@ Status EfficientAttention( key = reinterpret_cast(data.key); value = reinterpret_cast(data.value); } else { - size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); - size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + size_t q_size = static_cast(batch_size) * sequence_length * num_heads * head_size; + size_t k_size = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; auto q = reinterpret_cast(data.unpacked_qkv_buffer); auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); @@ -534,66 +1116,128 @@ Status EfficientAttention( query = reinterpret_cast(q); key = reinterpret_cast(k); value = reinterpret_cast(v); + + // Track buffer usage: Q+K+V unpacked + size_t total_bytes = (q_size + 2 * k_size) * sizeof(T); + UpdateUnpackedQkvMaxUsed(data, total_bytes); } + const int64_t* position_ids = data.position_ids; if (parameters.do_rotary) { - size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); - size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); auto q_buffer = reinterpret_cast(data.rotary_buffer); - auto k_buffer = q_buffer + q_size; - auto position_ids_buff = reinterpret_cast(k_buffer + k_size); - ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream, - max_threads_per_block)); - DUMP_TENSOR_INIT(); - DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length); - // Launch rotary embedding kernel - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, q_buffer, reinterpret_cast(query), - position_ids_buff, data.cos_cache, data.sin_cache, - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - device_prop.maxThreadsPerBlock, /*transposed*/ false)); - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), - position_ids_buff, data.cos_cache, data.sin_cache, - parameters.batch_size, parameters.sequence_length, - parameters.kv_num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - device_prop.maxThreadsPerBlock, /*transposed*/ false)); + + // Launch rotary embedding kernel for Q + if (position_ids != nullptr) { + // User provided explicit position_ids - Use Format 1 + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, q_buffer, reinterpret_cast(query), + position_ids, nullptr /*past_seq_lens not used in format 1*/, + data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.max_sequence_length, + 1, // Format 1: Explicit position_ids + parameters.rotary_interleaved, + max_threads_per_block, + false)); + } else { + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, q_buffer, reinterpret_cast(query), + nullptr, data.past_seq_lens, + data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.max_sequence_length, + 2, // Format 2: Implicit (past_seq_lens[b] + s) + parameters.rotary_interleaved, + max_threads_per_block, + false)); + } query = reinterpret_cast(q_buffer); - key = reinterpret_cast(k_buffer); - } - if (parameters.is_subsequent_prompt || !parameters.is_first_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlensTotal(data.seqlens_k, data.seqlens_k_buff, batch_size, stream, 256)); - } else { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_buff, parameters.sequence_length, - batch_size); + // For kv_share_buffer path, we use Fused RoPE in LaunchConcatKVInPlaceWithRoPE. + // For non-share-buffer path, we use Fused RoPE in LaunchConcatNewToPastKVHelper. + // No explicit K rotation needed here - handled by fused kernels. + + // key remains pointing to original source for use in fused kernel below + + // Track rotary buffer usage: Q rotated (K rotation is fused in KV append) + size_t q_bytes = static_cast(batch_size) * sequence_length * num_heads * head_size * sizeof(T); + size_t k_bytes = static_cast(batch_size) * sequence_length * kv_num_heads * head_size * sizeof(T); + // Note: rotary_buffer layout is [Q_rotated, K_rotated] - no position_ids here + UpdateRotaryMaxUsed(data, q_bytes + k_bytes); + + // Track position_ids_buffer usage + size_t pos_ids_bytes = static_cast(batch_size) * sequence_length * sizeof(int64_t); + UpdatePositionIdsMaxUsed(data, pos_ids_bytes); } - int* seqlens_k = data.seqlens_k_buff; if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } // Concatenate new kv in place constexpr bool is_new_kv_bnsh_format = false; - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } else { - // Not share buffer case - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); + + if (parameters.do_rotary) { + // Explicit K Rotation + size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; + T* k_dst = reinterpret_cast(data.rotary_buffer) + q_elements; + const T* k_src = reinterpret_cast(key); + + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, + k_dst, + k_src, + position_ids, + data.past_seq_lens, + data.cos_cache, + data.sin_cache, + batch_size, + sequence_length, + parameters.kv_num_heads, + parameters.head_size, + parameters.rotary_dim, + parameters.max_sequence_length, + position_ids != nullptr ? 1 : 2, + parameters.rotary_interleaved, + max_threads_per_block, + false)); + + if (!data.disable_fused_kv) { + // Use truly fused kernel for K (already rotated) + V append in single kernel + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlaceFused( + batch_size, + parameters.kv_num_heads, + parameters.head_size, + parameters.seqlen_present_kv_cache, + data.past_seq_lens, + data.total_seq_lens, + parameters.sequence_length, + k_dst, + reinterpret_cast(value), + data.present_key, + data.present_value, + past_kv_format != AttentionQkvFormat::Q_K_V_BSNH, // is_past_kv_bnsh_format + is_new_kv_bnsh_format, + stream, + max_threads_per_block)); + } else { + // Unfused Fallback + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( + parameters, data, k_dst, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); + } + + // Track rotary buffer usage: Q + K rotated (no position_ids in rotary_buffer) + size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; + UpdateRotaryMaxUsed(data, (q_elements + k_elements) * sizeof(T)); + } else { + // No RoPE - use original kernel + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( + parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); } + } else { // Copy past and concat new KV to present buffer - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, key, value, stream, max_threads_per_block)); + // FUSED ROPE: Pass RoPE params to ConcatKV + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKVHelper(parameters, data, key, value, stream, max_threads_per_block, false, + data.cos_cache, data.sin_cache, parameters.rotary_dim, nullptr, parameters.rotary_interleaved)); } // Ungroup if grouped, otherwise use present kv directly @@ -608,15 +1252,12 @@ Status EfficientAttention( float2* v_buff = reinterpret_cast(data.v); const float2* k_og = reinterpret_cast(data.present_key); const float2* v_og = reinterpret_cast(data.present_value); - ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, - present_sequence_length, is_bsnh, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); key = reinterpret_cast(data.k); value = reinterpret_cast(data.v); } - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", seqlens_k, batch_size, 1); - MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_bf16 = std::is_same::value; @@ -631,7 +1272,7 @@ Status EfficientAttention( p.causal = true; p.scale = scale; p.softcap = parameters.softcap; - p.seqlen_k_ptr = seqlens_k; // Note: seqlens_k is total sequence length for efficient + p.seqlen_k_ptr = parameters.is_first_prompt ? data.padded_seq_lens : data.total_seq_lens; p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; p.query = query; @@ -649,8 +1290,6 @@ Status EfficientAttention( p.local_window_size = parameters.local_window_size; run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); - return Status::OK(); } #endif @@ -668,6 +1307,10 @@ Status QkvToContext( const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; #if USE_FLASH_ATTENTION + if (data.use_flash_attention_fast_decode) { + return FlashAttentionDecoding(device_prop, stream, parameters, data, scale); + } + if (data.use_flash_attention) { return FlashAttention(device_prop, stream, parameters, data, scale); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 4ae4c450902f8..c42fe53e4b625 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -28,6 +28,141 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); +// ============================================================================ +// GQABufferRequirements: Centralized buffer size calculation +// ============================================================================ +// This struct provides a single source of truth for scratch buffer allocation. +// It ensures allocation logic in group_query_attention.cc stays in sync with +// kernel usage in group_query_attention_impl.cu. +// +// Usage: +// auto req = GQABufferRequirements::Compute(params, use_flash, fast_decode, use_mea, disable_fused); +// unpacked_qkv_buffer = GetScratchBuffer(req.unpacked_qkv_bytes, ...); +// rotary_buffer = GetScratchBuffer(req.rotary_buffer_bytes, ...); +// position_ids_buffer = GetScratchBuffer(req.position_ids_bytes, ...); +// ============================================================================ +struct GQABufferRequirements { + size_t unpacked_qkv_bytes = 0; + size_t rotary_buffer_bytes = 0; + size_t position_ids_bytes = 0; + + template + static GQABufferRequirements Compute( + const GroupQueryAttentionParameters& params, + bool use_flash_attention, + bool use_flash_attention_fast_decode, + bool use_memory_efficient_attention) { + GQABufferRequirements req; + + const size_t elem_size = sizeof(T); + const size_t batch_size = static_cast(params.batch_size); + const size_t seq_len = static_cast(params.sequence_length); + const size_t num_heads = static_cast(params.num_heads); + const size_t kv_num_heads = static_cast(params.kv_num_heads); + const size_t head_size = static_cast(params.head_size); + + // Fast decode path: Flash Attention handles everything internally + if (use_flash_attention_fast_decode) { + return req; // All zeros - no scratch buffers needed + } + + // Q, K, V element counts + const size_t q_elements = batch_size * seq_len * num_heads * head_size; + const size_t k_elements = batch_size * seq_len * kv_num_heads * head_size; + const size_t v_elements = k_elements; + + if (use_flash_attention) { + // Flash Attention path: + // - unpacked_qkv_buffer is used for: + // 1. Unpacking packed QKV input + // 2. Storing rotated Q (and K for non-fused path) + // - rotary_buffer is NOT used (rotations go to unpacked_qkv_buffer) + // - position_ids_buffer is NOT used (flash attention uses implicit position IDs) + + if (params.is_packed_qkv) { + // Need full Q+K+V for unpacking + req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements + v_elements); + } else if (params.do_rotary) { + // Unpacked input with RoPE: need Q+K for rotation output + req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements); + } + // Note: unpacked + no-RoPE case does NOT need unpacked_qkv_buffer + + } else if (use_memory_efficient_attention) { + // Memory Efficient Attention path: + // - unpacked_qkv_buffer: for unpacking packed QKV + // - rotary_buffer: for Q and K rotation output (separate from unpack buffer) + // - position_ids_buffer: for explicit position IDs if needed + + if (params.is_packed_qkv) { + req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements + v_elements); + } + + if (params.do_rotary) { + // Q rotation + K rotation + // Note: K uses kv_num_heads which may be less than num_heads + req.rotary_buffer_bytes = elem_size * (q_elements + k_elements); + // Position IDs space (always allocated for MEA + RoPE path) + req.position_ids_bytes = sizeof(int64_t) * batch_size * seq_len; + } + } + + return req; + } +}; + +// ============================================================================ +// Debug helper for tracking buffer usage +// ============================================================================ +// Call these after buffer access to record the maximum offset used. +// In release builds, these are no-ops. +// +// Example: +// T* unpacked_q = data.unpacked_qkv_buffer; +// // ... kernel writes to unpacked_q[0..Q_size-1] ... +// UpdateUnpackedQkvMaxUsed(data, Q_size * sizeof(T)); +// ============================================================================ +#ifndef NDEBUG +template +inline void UpdateUnpackedQkvMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { + if (bytes_used > data.unpacked_qkv_max_used) { + data.unpacked_qkv_max_used = bytes_used; + } +} + +template +inline void UpdateRotaryMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { + if (bytes_used > data.rotary_max_used) { + data.rotary_max_used = bytes_used; + } +} + +template +inline void UpdatePositionIdsMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { + if (bytes_used > data.position_ids_max_used) { + data.position_ids_max_used = bytes_used; + } +} +#else +template +inline void UpdateUnpackedQkvMaxUsed(GroupQueryAttentionData&, size_t) {} +template +inline void UpdateRotaryMaxUsed(GroupQueryAttentionData&, size_t) {} +template +inline void UpdatePositionIdsMaxUsed(GroupQueryAttentionData&, size_t) {} +#endif + +Status LaunchGetSequenceLengths( + const int* total_seq_lens_minus_one, + int* past_seq_lens, + int* total_seq_lens, + int* padded_seq_lens, + const int batch_size, + const int sequence_length, + const bool is_first_prompt, + cudaStream_t stream, + const int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh b/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh new file mode 100644 index 0000000000000..1cab81e83b2ef --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh @@ -0,0 +1,395 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// ============================================================================ +// rotary_common.cuh - Vectorized Rotary Position Embedding (RoPE) Dispatcher +// ============================================================================ +// +// PURPOSE: +// Provides a unified, vectorized interface for applying Rotary Position +// Embedding (RoPE) to K or Q tensors directly within fused kernels. +// This enables RoPE application during KV cache append operations without +// requiring separate kernel launches. +// +// USAGE: +// Called from ConcatNewToPastKVFused and UnpackQKVWithRoPEAndAppendKV kernels +// to apply in-place rotation to Key vectors as they are being appended to cache. +// +// SUPPORTED TYPES: +// - float2 + float: 2 fp32 elements (8 bytes) +// - float4 + float: 4 fp32 elements (16 bytes) +// - float2 + half: 4 fp16 elements (8 bytes) +// - float4 + half: 8 fp16 elements (16 bytes) +// - float2 + BFloat16: 4 bf16 elements (8 bytes) +// - float4 + BFloat16: 8 bf16 elements (16 bytes) +// +// ROTATION MODES: +// 1. INTERLEAVED: Adjacent pairs (x0,x1), (x2,x3), ... are rotated together +// Formula: (x, y) -> (x*cos - y*sin, x*sin + y*cos) +// +// 2. HALF-SPLIT (Non-Interleaved): First half pairs with second half +// Pairs: (x0, x_{d/2}), (x1, x_{d/2+1}), ... +// Formula: x_i -> x_i * cos + sign * x_{pair} * sin +// where sign = -1 if i < d/2, else +1 +// +// INPUT REQUIREMENTS: +// - cos_cache, sin_cache: [max_position, rotary_dim/2] +// - new_kv_base: Contiguous BSNH tensor for fetching pair values (half-split mode) +// - in_offset: Element offset into new_kv_base for current thread's data +// +// ============================================================================ + +#pragma once + +#include +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// ============================================================================ +// RotaryDispatcher Template +// ============================================================================ +// Template struct for vectorized RoPE application. +// +// TEMPLATE PARAMETERS: +// VectorT - Vector type used for memory access (float2, float4) +// ElementT - Underlying element type (float, half, BFloat16) +// +// STATIC METHOD: apply() +// Applies RoPE to 'val' in-place. +// +// PARAMETERS: +// val - [in/out] Vector of elements to rotate +// cos_cache - Precomputed cosine values [max_pos, rotary_dim/2] +// sin_cache - Precomputed sine values [max_pos, rotary_dim/2] +// rotary_dim - Number of dimensions with rotation (must be <= head_size) +// h_idx - Vector index within head (determines which rotary elements) +// pos_id - Position ID for this token (used to index cos/sin caches) +// interleaved - True for interleaved mode, false for half-split +// new_kv_base - Base pointer for fetching pair values (half-split mode only) +// in_offset - Offset to current element in new_kv_base (vector units) +// ============================================================================ +template +struct RotaryDispatcher { + __device__ static void apply(VectorT& val, const VectorT* cos_cache, const VectorT* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const VectorT* new_kv_base, const int64_t in_offset); +}; + +// ============================================================================ +// Specialization: float2 + float (2 fp32 elements per vector) +// ============================================================================ +// Handles 2 scalar float values per thread. +// For half-split mode, fetches pair values from new_kv_base at runtime. +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float2& val, const float2* cos_cache, const float2* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float2* new_kv_base, const int64_t in_offset) { + if (2 * h_idx >= rotary_dim) return; + + const float* cos_ptr = reinterpret_cast(cos_cache); + const float* sin_ptr = reinterpret_cast(sin_cache); + const float* kv_ptr = reinterpret_cast(new_kv_base); + + // Use int64_t for byte offsets if needed, but here we index float array + int64_t scalar_in_offset = in_offset * 2; + int scalar_h = h_idx * 2; + int half_rot = rotary_dim / 2; + + float c, s; + float x = val.x; + float y = val.y; + + if (interleaved) { + int cs_idx = pos_id * half_rot + h_idx; + c = cos_ptr[cs_idx]; + s = sin_ptr[cs_idx]; + val.x = x * c - y * s; + val.y = x * s + y * c; + } else { + // Half-Split Logic + // Process x (idx = scalar_h) + { + int idx = scalar_h; + if (idx < rotary_dim) { // Should be true given h_idx check + int pair_idx = (idx < half_rot) ? (idx + half_rot) : (idx - half_rot); + float sign = (idx < half_rot) ? -1.0f : 1.0f; + int cos_idx = idx % half_rot; + int cs_idx = pos_id * half_rot + cos_idx; + + c = cos_ptr[cs_idx]; + s = sin_ptr[cs_idx]; + // Potential gather from new_kv if we are doing fused append+rotate from a source + // The source is 'new_kv_base'. + float pair_val = kv_ptr[scalar_in_offset + pair_idx]; + val.x = x * c + sign * pair_val * s; + } + } + + // Process y (idx = scalar_h + 1) + { + int idx = scalar_h + 1; + if (idx < rotary_dim) { + int pair_idx = (idx < half_rot) ? (idx + half_rot) : (idx - half_rot); + float sign = (idx < half_rot) ? -1.0f : 1.0f; + int cos_idx = idx % half_rot; + int cs_idx = pos_id * half_rot + cos_idx; + + c = cos_ptr[cs_idx]; + s = sin_ptr[cs_idx]; + float pair_val = kv_ptr[scalar_in_offset + pair_idx]; + val.y = y * c + sign * pair_val * s; + } + } + } + } +}; + +// ============================================================================ +// Specialization: float4 + float (4 fp32 elements per vector) +// ============================================================================ +// Delegates to float2+float specialization for each half of the float4. +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float4& val, const float4* cos_cache, const float4* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float4* new_kv_base, const int64_t in_offset) { + float2 p1 = make_float2(val.x, val.y); + float2 p2 = make_float2(val.z, val.w); + const float2* c = reinterpret_cast(cos_cache); + const float2* s = reinterpret_cast(sin_cache); + const float2* b = reinterpret_cast(new_kv_base); + + // Update offsets for float2 components + RotaryDispatcher::apply(p1, c, s, rotary_dim, h_idx * 2, pos_id, interleaved, b, in_offset * 2); + RotaryDispatcher::apply(p2, c, s, rotary_dim, h_idx * 2 + 1, pos_id, interleaved, b, in_offset * 2); + + val.x = p1.x; + val.y = p1.y; + val.z = p2.x; + val.w = p2.y; + } +}; + +// ============================================================================ +// Specialization: float2 + half (4 fp16 elements packed in float2) +// ============================================================================ +// Uses half2 intrinsics for efficient fp16 computation. +// Each float2 contains 2 half2 values = 4 fp16 elements. +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float2& val, const float2* cos_cache, const float2* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float2* new_kv_base, const int64_t in_offset) { + if (2 * h_idx * 2 >= rotary_dim) return; + + // Vector layout: float2 = 8 bytes = 4 half values + // v0 contains elements [4*h_idx, 4*h_idx+1] as half2 (2 fp16 values) + // v1 contains elements [4*h_idx+2, 4*h_idx+3] as half2 (2 fp16 values) + half2* v_ptr = reinterpret_cast(&val); + half2 v0 = v_ptr[0]; + half2 v1 = v_ptr[1]; + const half2* cos_ptr = reinterpret_cast(cos_cache); + const half2* sin_ptr = reinterpret_cast(sin_cache); + int half_rot = rotary_dim / 2; + + if (interleaved) { + int f0 = 2 * h_idx; + int cs0 = pos_id * half_rot + f0; + + const half2 c_pair = cos_ptr[cs0 / 2]; + const half2 s_pair = sin_ptr[cs0 / 2]; + + const float2 c_f = __half22float2(c_pair); + const float2 s_f = __half22float2(s_pair); + + // Rotate v0 (pair 0) + const float2 e0 = __half22float2(v0); + v0 = __float22half2_rn(make_float2(e0.x * c_f.x - e0.y * s_f.x, e0.x * s_f.x + e0.y * c_f.x)); + + // Rotate v1 (pair 1) + const float2 e1 = __half22float2(v1); + v1 = __float22half2_rn(make_float2(e1.x * c_f.y - e1.y * s_f.y, e1.x * s_f.y + e1.y * c_f.y)); + } else { + // Half-Split Logic + // Elements i and i + H/2 are paired. + // We have 4 elements: 4*h_idx, +1, +2, +3. + // We need to fetch pairs from new_kv_base. + + const half* kv_ptr = reinterpret_cast(new_kv_base); + int base_idx = 4 * h_idx; + int64_t scalar_in_offset = in_offset * 4; // 4 halfs per float2 + + auto rotate_element = [&](int idx, half& val) { + if (idx >= rotary_dim) return; // Should be covered + int pair_idx = (idx < half_rot) ? (idx + half_rot) : (idx - half_rot); + float sign = (idx < half_rot) ? -1.0f : 1.0f; + int cos_idx = idx % half_rot; + int cs_idx = pos_id * half_rot + cos_idx; + + half c_val = reinterpret_cast(cos_ptr)[cs_idx]; + half s_val = reinterpret_cast(sin_ptr)[cs_idx]; + + float val_f = __half2float(val); + float pair_f = __half2float(kv_ptr[scalar_in_offset + pair_idx]); + float cf = __half2float(c_val); + float sf = __half2float(s_val); + + val = __float2half(val_f * cf + sign * pair_f * sf); + }; + + rotate_element(base_idx, v0.x); + rotate_element(base_idx + 1, v0.y); + rotate_element(base_idx + 2, v1.x); + rotate_element(base_idx + 3, v1.y); + } + v_ptr[0] = v0; + v_ptr[1] = v1; + } +}; + +// ============================================================================ +// Specialization: float2 + BFloat16 (4 bf16 elements packed in float2) +// ============================================================================ +// Uses __nv_bfloat162 intrinsics for bf16 computation. +// Requires SM 80+ (Ampere) for native bf16 support. +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float2& val, const float2* cos_cache, const float2* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float2* new_kv_base, const int64_t in_offset) { + if (2 * h_idx * 2 >= rotary_dim) return; + + using namespace onnxruntime::cuda; + // Vector layout: float2 = 8 bytes = 4 bf16 values + // v0 contains elements [4*h_idx, 4*h_idx+1] as bfloat162 (2 bf16 values) + // v1 contains elements [4*h_idx+2, 4*h_idx+3] as bfloat162 (2 bf16 values) + __nv_bfloat162* v_ptr = reinterpret_cast<__nv_bfloat162*>(&val); + __nv_bfloat162 v0 = v_ptr[0]; + __nv_bfloat162 v1 = v_ptr[1]; + const __nv_bfloat162* cos_ptr = reinterpret_cast(cos_cache); + const __nv_bfloat162* sin_ptr = reinterpret_cast(sin_cache); + int half_rot = rotary_dim / 2; + + if (interleaved) { + int f0 = 2 * h_idx; + int cs0 = pos_id * half_rot + f0; + + __nv_bfloat162 c_pair = cos_ptr[cs0 / 2]; + __nv_bfloat162 s_pair = sin_ptr[cs0 / 2]; + + // Process v0 (pair 1) + // v0.x, v0.y + float c0f = __bfloat162float(c_pair.x); + float s0f = __bfloat162float(s_pair.x); + float e0x = __bfloat162float(v0.x); + float e0y = __bfloat162float(v0.y); + v0.x = __float2bfloat16(e0x * c0f - e0y * s0f); + v0.y = __float2bfloat16(e0x * s0f + e0y * c0f); + + // Process v1 (pair 2) + // v1.x, v1.y + float c1f = __bfloat162float(c_pair.y); + float s1f = __bfloat162float(s_pair.y); + float e1x = __bfloat162float(v1.x); + float e1y = __bfloat162float(v1.y); + v1.x = __float2bfloat16(e1x * c1f - e1y * s1f); + v1.y = __float2bfloat16(e1x * s1f + e1y * c1f); + + } else { + // Half-Split Logic + const __nv_bfloat16* kv_ptr = reinterpret_cast(new_kv_base); + int base_idx = 4 * h_idx; + int64_t scalar_in_offset = in_offset * 4; + + auto rotate_element_bf16 = [&](int idx, __nv_bfloat16& val) { + if (idx >= rotary_dim) return; + int pair_idx = (idx < half_rot) ? (idx + half_rot) : (idx - half_rot); + float sign = (idx < half_rot) ? -1.0f : 1.0f; + int cos_idx = idx % half_rot; + int cs_idx = pos_id * half_rot + cos_idx; + + __nv_bfloat16 c_val = reinterpret_cast(cos_ptr)[cs_idx]; + __nv_bfloat16 s_val = reinterpret_cast(sin_ptr)[cs_idx]; + + float val_f = __bfloat162float(val); + float pair_f = __bfloat162float(kv_ptr[scalar_in_offset + pair_idx]); + float cf = __bfloat162float(c_val); + float sf = __bfloat162float(s_val); + + val = __float2bfloat16(val_f * cf + sign * pair_f * sf); + }; + + rotate_element_bf16(base_idx, v0.x); + rotate_element_bf16(base_idx + 1, v0.y); + rotate_element_bf16(base_idx + 2, v1.x); + rotate_element_bf16(base_idx + 3, v1.y); + } + v_ptr[0] = v0; + v_ptr[1] = v1; + } +}; + +// ============================================================================ +// Specialization: float4 + half (8 fp16 elements per vector) +// ============================================================================ +// Delegates to float2+half specialization for each half of the float4. +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float4& val, const float4* cos_cache, const float4* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float4* new_kv_base, const int64_t in_offset) { + float2 p1 = make_float2(val.x, val.y); + float2 p2 = make_float2(val.z, val.w); + const float2* c = reinterpret_cast(cos_cache); + const float2* s = reinterpret_cast(sin_cache); + const float2* b = reinterpret_cast(new_kv_base); + + RotaryDispatcher::apply(p1, c, s, rotary_dim, h_idx * 2, pos_id, interleaved, b, in_offset * 2); + RotaryDispatcher::apply(p2, c, s, rotary_dim, h_idx * 2 + 1, pos_id, interleaved, b, in_offset * 2); + + val.x = p1.x; + val.y = p1.y; + val.z = p2.x; + val.w = p2.y; + } +}; + +// ============================================================================ +// Specialization: float4 + BFloat16 (8 bf16 elements per vector) +// ============================================================================ +// Delegates to float2+BFloat16 specialization for each half of the float4. +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float4& val, const float4* cos_cache, const float4* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float4* new_kv_base, const int64_t in_offset) { + float2 p1 = make_float2(val.x, val.y); + float2 p2 = make_float2(val.z, val.w); + const float2* c = reinterpret_cast(cos_cache); + const float2* s = reinterpret_cast(sin_cache); + const float2* b = reinterpret_cast(new_kv_base); + + RotaryDispatcher::apply(p1, c, s, rotary_dim, h_idx * 2, pos_id, interleaved, b, in_offset * 2); + RotaryDispatcher::apply(p2, c, s, rotary_dim, h_idx * 2 + 1, pos_id, interleaved, b, in_offset * 2); + + val.x = p1.x; + val.y = p1.y; + val.z = p2.x; + val.w = p2.y; + } +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index eef33192e6e6b..0b38c6f0e5484 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -71,6 +71,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(output->template MutableData()), reinterpret_cast(input->template Data()), position_ids->Data(), + nullptr, // past_sequence_lengths reinterpret_cast(cos_cache->template Data()), reinterpret_cast(sin_cache->template Data()), parameters.batch_size, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index ad0a83c9cde65..ce6b4724af705 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -18,11 +18,12 @@ namespace contrib { namespace cuda { template -__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH - const T* input, // BxSxNxH - const T* cos_cache, // Mx(H/2) - const T* sin_cache, // Mx(H/2) - const int64_t* position_ids, // (1) or BxS +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int64_t* position_ids, // (1) or BxS + const int* past_sequence_lengths, // (B) for format 2 const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int position_ids_format, const bool interleaved, @@ -51,8 +52,17 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH // Cache is (M, H/2) const int half_rotary_embedding_dim = rotary_embedding_dim / 2; - const int position_id = (position_ids_format == 0) ? static_cast(position_ids[0]) + s - : static_cast(position_ids[b * sequence_length + s]); + int position_id = 0; + if (position_ids_format == 0) { + position_id = static_cast(position_ids[0]) + s; + } else if (position_ids_format == 1) { + position_id = static_cast(position_ids[b * sequence_length + s]); + } else if (position_ids_format == 2) { + // format 2: past_sequence_length + s + // used for Decoding (past_sequence_length = seqlens_k[b]) or First Prompt (past=0 if nullptr) + int past = (past_sequence_lengths == nullptr) ? 0 : past_sequence_lengths[b]; + position_id = past + s; + } const int cache_offset = position_id * half_rotary_embedding_dim; const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; @@ -74,6 +84,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const int* past_sequence_lengths, const T* cos_cache, const T* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, @@ -93,7 +104,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1}; } return LaunchRotaryEmbeddingKernel( - stream, output, input, position_ids, + stream, output, input, position_ids, past_sequence_lengths, cos_cache, sin_cache, batch_size, sequence_length, num_heads, head_size, rotary_embedding_dim, max_sequence_length, @@ -104,6 +115,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const int* past_sequence_lengths, const T* cos_cache, const T* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int /*max_sequence_length*/, @@ -125,7 +137,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu const dim3 grid(sequence_length, batch_size, num_heads); assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, + RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, num_heads, head_size, rotary_embedding_dim, position_ids_format, interleaved, in_strides, out_strides); @@ -133,7 +145,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu } template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, float* output, const float* input, - const int64_t* position_ids, const float* cos_cache, + const int64_t* position_ids, const int* past_sequence_lengths, const float* cos_cache, const float* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, @@ -141,7 +153,7 @@ template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, float* o const int max_threads_per_block, const bool is_input_bnsh_format); template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, half* output, const half* input, - const int64_t* position_ids, const half* cos_cache, + const int64_t* position_ids, const int* past_sequence_lengths, const half* cos_cache, const half* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, @@ -149,7 +161,7 @@ template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, half* out const int max_threads_per_block, const bool is_input_bnsh_format); template Status LaunchRotaryEmbeddingKernel( - cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids, + cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids, const int* past_sequence_lengths, const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index dd0ac6a6e3274..2d81e6b4067af 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -15,6 +15,7 @@ Status LaunchRotaryEmbeddingKernel( T* output, const T* input, const int64_t* position_ids, + const int* past_sequence_lengths, const T* cos_cache, const T* sin_cache, const int batch_size, @@ -34,6 +35,7 @@ Status LaunchRotaryEmbeddingKernel( T* output, const T* input, const int64_t* position_ids, + const int* past_sequence_lengths, const T* cos_cache, const T* sin_cache, const int batch_size, diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu index e55163186b505..b2a6eb89d4d23 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu @@ -178,14 +178,14 @@ Status QkvToContext( // Launch rotary embedding kernel. This requires separated Q, K and V ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, q_buffer, reinterpret_cast(query), - position_ids_buff, data.cos_cache, data.sin_cache, + position_ids_buff, nullptr, data.cos_cache, data.sin_cache, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size, parameters.rotary_dim, parameters.max_rotary_sequence_length, /*position_ids_format*/ 1, parameters.rotary_interleaved, max_threads_per_block, q_layout)); ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), - position_ids_buff, data.cos_cache, data.sin_cache, + position_ids_buff, nullptr, data.cos_cache, data.sin_cache, parameters.batch_size, parameters.sequence_length, parameters.kv_num_heads, parameters.head_size, parameters.rotary_dim, parameters.max_rotary_sequence_length, diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 280ec3f79c74e..b3a5c15718ffb 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -31,7 +31,7 @@ # Reduces number of tests to run for faster pipeline checks pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" -# Maximum number of values per parameter +# Number of values per parameter (compared to pipeline mode) param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 # When quick build is used, flash attention only supports fp16 and head_size=128 @@ -86,11 +86,12 @@ class GQAConfig: packed: bool = False softcap: float = 0.0 use_smooth_softmax: bool = False - # CPU-only parameters - has_position_ids: bool = False - has_attention_bias: bool = False has_head_sink: bool = False kv_cache_type: str = "" + share_buffer: bool = True + + has_position_ids: bool = False + has_attention_bias: bool = False # ################################################################################################# @@ -442,8 +443,9 @@ def gqa_prompt_func( bind_tensor(io_binding, "seqlens_k", seqlens_k.to(torch.int32), device, TensorProto.INT32) # total_sequence_length is INT32 [1] - tsl = torch.tensor([config.q_sequence_length], dtype=torch.int32, device=device) - bind_tensor(io_binding, "total_sequence_length", tsl, device, TensorProto.INT32) + # Schema requires this to be on CPU (OrtMemTypeCPUInput) + tsl = torch.tensor([config.q_sequence_length], dtype=torch.int32, device="cpu") + bind_tensor(io_binding, "total_sequence_length", tsl, "cpu", TensorProto.INT32) # 5. Optional inputs if cos is not None: @@ -564,14 +566,17 @@ def gqa_past_func( bind_tensor(io_binding, "past_key", k, device, cache_ort_type) bind_tensor(io_binding, "past_value", v, device, cache_ort_type) else: - # If not sharing buffer, 'k' and 'v' are the *past* states passed in (usually smaller?) - # Actually logic in test_gqa: k, v are the cache tensors. - # We bind them. - bind_tensor(io_binding, "past_key", k, device, cache_ort_type) - bind_tensor(io_binding, "past_value", v, device, cache_ort_type) + # If not sharing buffer, 'k' and 'v' are the *past* states passed in. + # We must slice the buffer to the valid past length expected by the graph. + past_len = config.past_kv_sequence_length + k_sliced = k[:, :, :past_len, :].contiguous() + v_sliced = v[:, :, :past_len, :].contiguous() + bind_tensor(io_binding, "past_key", k_sliced, device, cache_ort_type) + bind_tensor(io_binding, "past_value", v_sliced, device, cache_ort_type) # 4. Scalars - bind_tensor(io_binding, "seqlens_k", seqlens_k.to(torch.int32), device, TensorProto.INT32) + seqlens_k_int32 = seqlens_k.to(dtype=torch.int32, device=device) + bind_tensor(io_binding, "seqlens_k", seqlens_k_int32, device, TensorProto.INT32) tsl = torch.tensor([total_seq_len], dtype=torch.int32, device=device) bind_tensor(io_binding, "total_sequence_length", tsl, device, TensorProto.INT32) @@ -610,7 +615,7 @@ def gqa_past_func( if share_buffer: present_seqlen = config.buffer_sequence_length else: - present_seqlen = total_seq_len # For past_func, total seq len is accumulated + present_seqlen = total_seq_len present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] @@ -816,20 +821,22 @@ def parity_check_gqa_prompt( k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) position_ids, attention_bias = None, None - if ep == "CPUExecutionProvider": - if config.has_position_ids: - position_ids = ( - torch.arange(config.q_sequence_length, device=device).unsqueeze(0).expand(config.batch_size, -1) - ) - if config.has_attention_bias: - attention_bias = torch.zeros( - config.batch_size, - 1, - config.q_sequence_length, - config.kv_sequence_length, - device=device, - dtype=torch_type, - ) + if config.has_position_ids: + position_ids = ( + torch.arange(config.q_sequence_length, device=device) + .unsqueeze(0) + .expand(config.batch_size, -1) + .contiguous() + ) + if config.has_attention_bias: + attention_bias = torch.zeros( + config.batch_size, + 1, + config.q_sequence_length, + config.kv_sequence_length, + device=device, + dtype=torch_type, + ) arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") kv_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -881,7 +888,7 @@ def parity_check_gqa_prompt( head_sink=head_sink, ep=ep, device=device, - share_buffer=True, + share_buffer=config.share_buffer, ort_type=ort_type, ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -906,6 +913,10 @@ def parity_check_gqa_prompt( present_k_np = present_k.to(torch.float32).detach().cpu().numpy() present_v_np = present_v.to(torch.float32).detach().cpu().numpy() + if not config.share_buffer: + k_cache_ref_np = k_cache_ref_np[:, :, : config.kv_sequence_length, :] + v_cache_ref_np = v_cache_ref_np[:, :, : config.kv_sequence_length, :] + numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) @@ -955,6 +966,7 @@ def parity_check_gqa_past( ) v = torch.randn_like(k) * std + # Random past sequence lengths. This tests paddings in decoding. cache_seqlens = torch.randint( 0, config.past_kv_sequence_length - config.q_sequence_length + 1, @@ -1003,16 +1015,15 @@ def parity_check_gqa_past( position_ids, attention_bias = None, None total_seq_len = config.past_kv_sequence_length + config.q_sequence_length - if ep == "CPUExecutionProvider": - if config.has_position_ids: - position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long() - if config.has_attention_bias: - attention_bias = torch.zeros( - config.batch_size, 1, config.q_sequence_length, total_seq_len, device=device, dtype=torch_type - ) - for b in range(config.batch_size): - end_pos = cache_seqlens[b] + config.q_sequence_length - attention_bias[b, :, :, end_pos:] = float("-inf") + if config.has_position_ids: + position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long() + if config.has_attention_bias: + attention_bias = torch.zeros( + config.batch_size, 1, config.q_sequence_length, total_seq_len, device=device, dtype=torch_type + ) + for b in range(config.batch_size): + end_pos = cache_seqlens[b] + config.q_sequence_length + attention_bias[b, :, :, end_pos:] = float("-inf") arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -1060,7 +1071,7 @@ def parity_check_gqa_past( head_sink=head_sink, ep=ep, device=device, - share_buffer=True, + share_buffer=config.share_buffer, ort_type=ort_type, ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1074,12 +1085,161 @@ def parity_check_gqa_past( present_k_np = present_k.to(torch.float32).detach().cpu().numpy() present_v_np = present_v.to(torch.float32).detach().cpu().numpy() + if not config.share_buffer: + total_len = config.past_kv_sequence_length + config.q_sequence_length + k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] + v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] + numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) +def parity_test_gqa_padding_prompt(): + device = "cuda" + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + + # config + config = GQAConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + num_heads=8, + kv_num_heads=2, + head_size=128, + buffer_sequence_length=16, + share_buffer=True, + packed=False, + rotary=True, + ) + + # Inputs + torch.manual_seed(0) + std = 0.02 + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + k = ( + torch.randn( + config.batch_size, + config.kv_num_heads, + config.kv_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) * std + + new_k = k.transpose(1, 2).contiguous() + new_v = v.transpose(1, 2).contiguous() + + seqlens_k = torch.tensor([9, 15], dtype=torch.int32, device=device) + + # Generate Rotary Embeddings + rotary_dim = config.head_size + max_seq_len = config.buffer_sequence_length + cos = torch.randn(1, max_seq_len, 1, rotary_dim // 2, device=device, dtype=torch_type) + sin = torch.randn(1, max_seq_len, 1, rotary_dim // 2, device=device, dtype=torch_type) + + # Apply Rotary to inputs for Reference + rotary_op = LlamaMSRotaryEmbedding() + pos = torch.zeros(config.batch_size, device=device, dtype=torch.long) + + # In ORT, we pass raw Q/K and ORT applies rotary. + # For REF, we must apply rotary manually. + # But wait, ORT only rotates 'q' and 'k' inside the attention kernel. + # Wait, if `share_buffer=True`, `past_key` is used. + # In prompt mode, `new_k` is appended to `past_key`. + # ORT will apply rotary to Q. + # Does ORT apply rotary to K? Yes, if `do_rotary` is true. + # So we rotate Q and K for REF. + + q_ref = rotary_op.rotate_tensor(q, cos, sin, pos, False) + k_ref = rotary_op.rotate_tensor(new_k, cos, sin, pos, False) + v_ref = new_v + + # Run ONNX Runtime + out_ort, present_key_ort, present_value_ort = gqa_prompt_func( + q=q, + k=k, + v=v, + config=config, + new_k=new_k, + new_v=new_v, + cos=cos.squeeze(2).squeeze(0), + sin=sin.squeeze(2).squeeze(0), + seqlens_k=seqlens_k, + position_ids=None, + attention_bias=None, + head_sink=None, + ep="CUDAExecutionProvider", + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) + + # Compare present_key and present_value with reference + # ORT present_key is BNSH format: [batch, kv_num_heads, seq, head_size] + # k_ref is BSNH format: [batch, seq, kv_num_heads, head_size] + # Transpose k_ref to BNSH for comparison + k_ref_bnsh = k_ref.transpose(1, 2) # BSNH -> BNSH + v_ref_bnsh = v_ref.transpose(1, 2) # BSNH -> BNSH + + # Compare only valid positions (positions 0..9 for Batch 0, 0..15 for Batch 1) + torch.testing.assert_close(present_key_ort[0, :, :10, :], k_ref_bnsh[0, :, :10, :], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(present_key_ort[1, :, :16, :], k_ref_bnsh[1, :, :16, :], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(present_value_ort[0, :, :10, :], v_ref_bnsh[0, :, :10, :], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(present_value_ort[1, :, :16, :], v_ref_bnsh[1, :, :16, :], rtol=1e-3, atol=1e-3) + + # Run Reference + # key_padding_mask is a "Validity Mask" where True=Valid, False=Invalid + key_padding_mask = torch.zeros((config.batch_size, config.q_sequence_length), dtype=torch.bool, device=device) + + # Batch 0: Valid 0..9 (length 10) + key_padding_mask[0, :10] = True + + # Batch 1: Valid 0..15 (length 16) + key_padding_mask[1, :16] = True + + out_ref, _ = attention_ref( + q_ref, k_ref, v_ref, key_padding_mask=key_padding_mask, query_padding_mask=key_padding_mask, causal=True + ) + + # Compare + # Batch 0: 10..15 are padding + out_ort[0, 10:] = 0 + out_ref[0, 10:] = 0 + + # Reshape ref to match ORT + out_ref = out_ref.reshape(config.batch_size, config.q_sequence_length, -1) + + # Debugging + diff = (out_ort - out_ref).abs() + max_diff = diff.max() + # Check Batch 0 + b0_diff = diff[0].max() + # Check Batch 1 + b1_diff = diff[1].max() + + if not torch.allclose(out_ort, out_ref, rtol=1e-2, atol=1e-2): + msg = f"Mismatch! Max Diff: {max_diff}, Batch 0 Max: {b0_diff}, Batch 1 Max: {b1_diff}\n" + raise AssertionError(msg) + + torch.testing.assert_close(out_ort, out_ref, rtol=1e-2, atol=1e-2) + + # ################################################################################################# # Test Case Generators # ################################################################################################# @@ -1094,40 +1254,53 @@ def get_cpu_rotary_options(): def get_softmax_options(allow_head_sink: bool = True): - return [(False, False), (False, True), (True, False)] + options = [(False, False), (False, True), (True, False)] + if not allow_head_sink: + options = [opt for opt in options if not opt[1]] + return options def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): batches = [3, 1, 5] - seqs = [(35, 35), (64, 64), (128, 128), (240, 240), (2000, 2000)] - heads = [(6, 3), (9, 9), (32, 8)] + seqs = [(35, 35), (1, 1), (64, 64), (128, 128), (240, 240), (2000, 2000)] + heads = [(6, 3), (3, 1), (32, 8)] h_sizes = [128] if quick_build else [128, 32, 64, 256] smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) rotary_opts = list(get_cuda_rotary_options()) packed_opts = [False, True] + share_buffer_opts = [True, False] softcap_opts = [0.0, 50.0] + # Use new strategy for both modes: iterate over key code path parameters + # The difference between modes is the number of head_sizes tested + # Pipeline mode: h_sizes[:1] = [128] -> 12 combinations (fast) + # Comprehensive mode: all h_sizes -> 40+ combinations (thorough) + h_sizes_to_test = h_sizes[:1] if pipeline_mode else h_sizes + combo_index = 0 - for b in batches[:param_count]: - for sq, skv in seqs[:param_count]: - for n, n2 in heads[:param_count]: - for h in h_sizes[:param_count]: + for h in h_sizes_to_test: + for packed in packed_opts: + for rotary, rotary_interleaved in rotary_opts: + # Skip invalid: rotary requires head_size divisible by 16 + if rotary and h % 16 > 0: + continue + + for share_buffer in share_buffer_opts: + # Rotate secondary parameters + b = batches[combo_index % len(batches)] + sq, skv = seqs[combo_index % len(seqs)] + n, n2 = heads[combo_index % len(heads)] lws_opts = [-1, random.randint(1, skv)] lws = lws_opts[combo_index % len(lws_opts)] - - rotary, rotary_interleaved = rotary_opts[combo_index % len(rotary_opts)] - packed = packed_opts[combo_index % len(packed_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ combo_index % len(smmoth_softmax__head_sink) ] + has_position_ids = False if pipeline_mode else combo_index % 2 == 0 combo_index += 1 - if rotary and h % 16 > 0: - continue - if softcap > 0 and (use_smooth_softmax or has_head_sink): continue @@ -1144,48 +1317,66 @@ def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + share_buffer=share_buffer, softcap=softcap, use_smooth_softmax=use_smooth_softmax, has_head_sink=has_head_sink, + has_position_ids=has_position_ids, ) - name = f"b{b}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}" + name = f"b{b}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sb{share_buffer}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}_pid{has_position_ids}" yield name, config def gqa_cuda_past_test_cases(allow_head_sink: bool = True): batches = [2, 1, 3] # s: new sequence length, s2: past sequence length - seqs = [(1, 128), (3, 1024), (1, 2048), (1, 5000)] + seqs = [(1, 1), (1, 128), (1, 2048), (1, 5000)] + subsequent_prompt_seqs = [(3, 256)] heads = [(32, 8), (6, 3), (9, 9)] - h_sizes = [128] if quick_build else [128, 64, 256] + h_sizes = [128] if quick_build else [128, 40, 64, 256] smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) rotary_opts = list(get_cuda_rotary_options()) packed_opts = [False, True] + # For past test: pipeline tests share_buffer=True only, comprehensive tests both + share_buffer_opts = [True] if pipeline_mode else [True, False] softcap_opts = [0.0, 50.0] + # Use new strategy for both modes: iterate over key code path parameters + # The difference between modes is the number of head_sizes tested + # Pipeline mode: h_sizes[:1] = [128] -> 6 combinations (share_buffer=[True] only) + # Comprehensive mode: all h_sizes -> 36+ combinations + h_sizes_to_test = h_sizes[:1] if pipeline_mode else h_sizes + all_seqs = seqs + subsequent_prompt_seqs + combo_index = 0 - for b in batches[:param_count]: - for s, s2 in seqs[:param_count]: - if s > 1 and b > 1: - continue - for n, n2 in heads[:param_count]: - for h in h_sizes[:param_count]: + for h in h_sizes_to_test: + for packed in packed_opts: + for rotary, rotary_interleaved in rotary_opts: + # Skip invalid: rotary requires head_size divisible by 16 + if rotary and h % 16 > 0: + continue + + for share_buffer in share_buffer_opts: + # Rotate secondary parameters + b = batches[combo_index % len(batches)] + s, s2 = all_seqs[combo_index % len(all_seqs)] + + # Skip subsequent prompt for batch > 1 + if s > 1 and b > 1: + b = 1 # Force batch=1 for subsequent prompt + + n, n2 = heads[combo_index % len(heads)] lws_opts = [-1, random.randint(1, s2)] lws = lws_opts[combo_index % len(lws_opts)] - - rotary, rotary_interleaved = rotary_opts[combo_index % len(rotary_opts)] - packed = packed_opts[combo_index % len(packed_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ combo_index % len(smmoth_softmax__head_sink) ] + has_position_ids = False if pipeline_mode else s > 1 combo_index += 1 - if rotary and h % 16 > 0: - continue - if softcap > 0 and (use_smooth_softmax or has_head_sink): continue @@ -1202,11 +1393,13 @@ def gqa_cuda_past_test_cases(allow_head_sink: bool = True): rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + share_buffer=share_buffer, softcap=softcap, use_smooth_softmax=use_smooth_softmax, has_head_sink=has_head_sink, + has_position_ids=has_position_ids, ) - name = f"b{b}_s{s}_{s2}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}" + name = f"b{b}_s{s}_{s2}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sb{share_buffer}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}_pid{has_position_ids}" yield name, config @@ -1232,6 +1425,10 @@ def has_flash_attention(bf16: bool = False): return has_cuda_device(80) +rtol = {"fp16": 5e-3, "bf16": 5e-2} +atol = {"fp16": 5e-3, "bf16": 1e-2} + + @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFlashGQA(unittest.TestCase): @parameterized.expand(gqa_cuda_prompt_test_cases()) @@ -1244,8 +1441,8 @@ def test_gqa_prompt_flash_attention(self, name, config): torch_type=torch.float16, ort_type=TensorProto.FLOAT16, causal=True, - rtol=1e-1, - atol=1e-1, + rtol=rtol["fp16"], + atol=atol["fp16"], ) @parameterized.expand(gqa_cuda_past_test_cases()) @@ -1258,8 +1455,8 @@ def test_gqa_past_flash_attention(self, name, config): torch_type=torch.float16, ort_type=TensorProto.FLOAT16, causal=True, - rtol=1e-1, - atol=1e-1, + rtol=rtol["fp16"], + atol=atol["fp16"], ) @@ -1279,8 +1476,8 @@ def test_gqa_prompt_flash_attention_bf16(self, name, config): torch_type=torch.bfloat16, ort_type=TensorProto.BFLOAT16, causal=True, - rtol=1e-1, # Relaxed tolerance for BF16 - atol=2e-1, + rtol=rtol["bf16"], + atol=atol["bf16"], ) @parameterized.expand(gqa_cuda_past_test_cases()) @@ -1297,8 +1494,8 @@ def test_gqa_past_flash_attention_bf16(self, name, config): torch_type=torch.bfloat16, ort_type=TensorProto.BFLOAT16, causal=True, - rtol=1e-1, - atol=2e-1, + rtol=rtol["bf16"], + atol=atol["bf16"], ) @@ -1314,8 +1511,8 @@ def test_gqa_prompt_memory_efficient(self, name, config): torch_type=torch.float16, ort_type=TensorProto.FLOAT16, causal=True, - rtol=2e-2, - atol=2e-2, + rtol=rtol["fp16"], + atol=atol["fp16"], ) @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) @@ -1328,8 +1525,8 @@ def test_gqa_past_memory_efficient(self, name, config): torch_type=torch.float16, ort_type=TensorProto.FLOAT16, causal=True, - rtol=5e-3, - atol=2e-2, + rtol=rtol["fp16"], + atol=atol["fp16"], ) @@ -1345,10 +1542,172 @@ def test_gqa_past_memory_efficient_bf16(self, name, config): torch_type=torch.bfloat16, ort_type=TensorProto.BFLOAT16, causal=True, - rtol=5e-3, - atol=2e-2, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestFlashGQAPaddingPrompt(unittest.TestCase): + def test_gqa_padding_prompt_flash_attention(self): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_test_gqa_padding_prompt() + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +class TestMemoryEfficientGQAPaddingPrompt(unittest.TestCase): + def test_gqa_padding_prompt_memory_efficient_attention(self): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_test_gqa_padding_prompt() + + +# ################################################################################################# +# Fused Kernel Parity Tests (ORT_DISABLE_FUSED_KV and ORT_DISABLE_FLASH_DECODE) +# ################################################################################################# + + +def fused_kernel_test_cases(): + """Test cases specifically for fused vs unfused kernel parity.""" + configs = [ + # Decoding with RoPE and shared buffer + GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=16, + kv_num_heads=4, + head_size=128, + past_kv_sequence_length=128, + buffer_sequence_length=256, + rotary=True, + packed=False, + share_buffer=True, + ), + # Packed QKV decoding with RoPE + GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=8, + kv_num_heads=2, + head_size=128, + past_kv_sequence_length=64, + buffer_sequence_length=128, + rotary=True, + packed=True, + share_buffer=True, + ), + # Subsequent prompt with RoPE + GQAConfig( + batch_size=1, + q_sequence_length=4, + kv_sequence_length=4, + num_heads=8, + kv_num_heads=4, + head_size=128, + past_kv_sequence_length=32, + buffer_sequence_length=64, + rotary=True, + packed=False, + share_buffer=True, + ), + ] + for i, config in enumerate(configs): + yield f"fused_config_{i}", config + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestFusedKernelParity(unittest.TestCase): + """Tests that verify fused kernels produce the same results as unfused kernels.""" + + @parameterized.expand(fused_kernel_test_cases()) + def test_fused_kv_parity(self, name, config): + """Test ORT_DISABLE_FUSED_KV: fused vs unfused KV append kernels.""" + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + # Run with fused kernels (default) + if "ORT_DISABLE_FUSED_KV" in os.environ: + del os.environ["ORT_DISABLE_FUSED_KV"] + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # Run with unfused kernels + os.environ["ORT_DISABLE_FUSED_KV"] = "1" + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # Clean up + del os.environ["ORT_DISABLE_FUSED_KV"] + + def test_flash_decode_parity(self): + """Test ORT_DISABLE_FLASH_DECODE: fast decode vs standard path.""" + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + # Decoding config (seq_len=1, share_buffer=True) + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=16, + kv_num_heads=4, + head_size=128, + past_kv_sequence_length=128, + buffer_sequence_length=256, + rotary=True, + packed=False, + share_buffer=True, ) + # Run with flash decode enabled (default) + if "ORT_DISABLE_FLASH_DECODE" in os.environ: + del os.environ["ORT_DISABLE_FLASH_DECODE"] + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # Run with flash decode disabled + os.environ["ORT_DISABLE_FLASH_DECODE"] = "1" + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # Clean up + del os.environ["ORT_DISABLE_FLASH_DECODE"] + if __name__ == "__main__": unittest.main()