From 50dbc15043bf66bc134cab8ad3f5e1e3d8027f24 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 7 Jan 2026 16:09:28 -0800 Subject: [PATCH 1/3] Update Flash Attention Implementation and APIs --- cmake/onnxruntime_providers_cpu.cmake | 12 +- .../cuda/bert/flash_attention/block_info.h | 37 +- .../cuda/bert/flash_attention/flash.h | 11 +- .../cuda/bert/flash_attention/flash_api.cc | 81 ++-- .../cuda/bert/flash_attention/flash_api.h | 30 +- .../flash_fwd_hdim128_bf16_causal_sm80.cu | 15 + .../flash_fwd_hdim128_bf16_sm80.cu | 19 +- .../flash_fwd_hdim128_fp16_causal_sm80.cu | 15 + .../flash_fwd_hdim128_fp16_sm80.cu | 19 +- .../flash_fwd_hdim192_bf16_causal_sm80.cu | 15 + .../flash_fwd_hdim192_bf16_sm80.cu | 19 +- .../flash_fwd_hdim192_fp16_causal_sm80.cu | 15 + .../flash_fwd_hdim192_fp16_sm80.cu | 19 +- .../flash_fwd_hdim256_bf16_causal_sm80.cu | 15 + .../flash_fwd_hdim256_bf16_sm80.cu | 19 +- .../flash_fwd_hdim256_fp16_causal_sm80.cu | 15 + .../flash_fwd_hdim256_fp16_sm80.cu | 19 +- .../flash_fwd_hdim32_bf16_causal_sm80.cu | 15 + .../flash_fwd_hdim32_bf16_sm80.cu | 19 +- .../flash_fwd_hdim32_fp16_causal_sm80.cu | 15 + .../flash_fwd_hdim32_fp16_sm80.cu | 19 +- .../flash_fwd_hdim64_bf16_causal_sm80.cu | 15 + .../flash_fwd_hdim64_bf16_sm80.cu | 19 +- .../flash_fwd_hdim64_fp16_causal_sm80.cu | 15 + .../flash_fwd_hdim64_fp16_sm80.cu | 19 +- .../flash_fwd_hdim96_bf16_causal_sm80.cu | 15 + .../flash_fwd_hdim96_bf16_sm80.cu | 19 +- .../flash_fwd_hdim96_fp16_causal_sm80.cu | 15 + .../flash_fwd_hdim96_fp16_sm80.cu | 19 +- .../bert/flash_attention/flash_fwd_kernel.h | 325 +++++++-------- .../flash_fwd_launch_template.h | 392 +++++++----------- ...lash_fwd_split_hdim128_bf16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim128_bf16_sm80.cu | 17 +- ...lash_fwd_split_hdim128_fp16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim128_fp16_sm80.cu | 17 +- ...lash_fwd_split_hdim192_bf16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim192_bf16_sm80.cu | 17 +- ...lash_fwd_split_hdim192_fp16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim192_fp16_sm80.cu | 17 +- ...lash_fwd_split_hdim256_bf16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim256_bf16_sm80.cu | 17 +- ...lash_fwd_split_hdim256_fp16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim256_fp16_sm80.cu | 17 +- ...flash_fwd_split_hdim32_bf16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim32_bf16_sm80.cu | 17 +- ...flash_fwd_split_hdim32_fp16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim32_fp16_sm80.cu | 17 +- ...flash_fwd_split_hdim64_bf16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim64_bf16_sm80.cu | 17 +- ...flash_fwd_split_hdim64_fp16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim64_fp16_sm80.cu | 17 +- ...flash_fwd_split_hdim96_bf16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim96_bf16_sm80.cu | 17 +- ...flash_fwd_split_hdim96_fp16_causal_sm80.cu | 12 + .../flash_fwd_split_hdim96_fp16_sm80.cu | 17 +- .../cuda/bert/flash_attention/kernel_traits.h | 49 +-- .../cuda/bert/flash_attention/mask.h | 33 +- .../bert/flash_attention/namespace_config.h | 24 ++ .../cuda/bert/flash_attention/rotary.h | 7 +- .../cuda/bert/flash_attention/softmax.h | 49 ++- .../cuda/bert/flash_attention/static_switch.h | 2 + .../bert/flash_attention/update_kernels.py | 77 ++++ .../cuda/bert/flash_attention/utils.h | 42 +- .../cuda/bert/group_query_attention_impl.cu | 2 +- 64 files changed, 1102 insertions(+), 827 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_causal_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/namespace_config.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index 5f33b42c9b79a..f77a5dd78fcc5 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -25,17 +25,17 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh" ) -# Quick build mode: Filter out non-hdim128 flash attention kernels for faster development iteration +# Quick build mode: Filter flash attention kernels for faster development iteration. +# - We keep only hdim128 fp16 flash attention kernels in quick build mode. +# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256). +# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels. +# If new head dimensions are added or removed, update this list to match the supported set. if(onnxruntime_QUICK_BUILD) message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels") - # Filter non-hdim128 kernels - list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|160|192|224|256)") - # Filter all bfloat16 kernels (only keep fp16) + list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)") list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16") endif() - - file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h" "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc" diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h index dde6143153e8e..22d36bd281dbb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h @@ -1,10 +1,12 @@ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ + #pragma once -namespace onnxruntime { -namespace flash { +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +namespace FLASH_NAMESPACE { + //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -17,43 +19,40 @@ struct BlockInfo { : params.cu_seqlens_k[bidb]), actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q - : params.cu_seqlens_q[bidb + 1] - sum_s_q) + : params.cu_seqlens_q[bidb + 1] - sum_s_q), // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , - seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr - ? params.seqlen_k - : (params.is_seqlens_k_cumulative - ? params.cu_seqlens_k[bidb + 1] - sum_s_k - : params.cu_seqlens_k[bidb])), + leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]), + seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr + ? params.seqlen_k + : (params.is_seqlens_k_cumulative + ? params.cu_seqlens_k[bidb + 1] - sum_s_k + : params.cu_seqlens_k[bidb])) - + leftpad_k), actual_seqlen_k(params.seqused_k - ? params.seqused_k[bidb] + ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } template - __forceinline__ __device__ - index_t - q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template - __forceinline__ __device__ - index_t - k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; } const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int leftpad_k; const int seqlen_k_cache; const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 09ead61e7d80d..d3859912fbc70 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -80,12 +80,11 @@ struct Flash_fwd_params : public Qkv_params { // array of length b+1 holding starting offset of each sequence. int* __restrict__ cu_seqlens_q = nullptr; int* __restrict__ cu_seqlens_k = nullptr; + int* __restrict__ leftpad_k = nullptr; // If provided, the actual length of each k sequence. int* __restrict__ seqused_k = nullptr; - int* __restrict__ blockmask = nullptr; - // The K_new and V_new matrices. void* __restrict__ knew_ptr = nullptr; void* __restrict__ vnew_ptr = nullptr; @@ -131,15 +130,17 @@ struct Flash_fwd_params : public Qkv_params { void* __restrict__ alibi_slopes_ptr = nullptr; index_t alibi_slopes_batch_stride = 0; - bool unpadded_lse = false; + bool unpadded_lse = false; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. const cudaDeviceProp* dprops = nullptr; + bool seqlenq_ngroups_swapped = false; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); -template + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 76704b5b29fcd..e11137413a8cf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -43,7 +43,9 @@ void set_params_fprop(Flash_fwd_params& params, bool kv_bsnh = true, int window_size_left = -1, int window_size_right = -1, - const bool unpadded_lse = false) { + const bool unpadded_lse = false, + void* cache_batch_idx = nullptr, + void* leftpad_k = nullptr) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -147,6 +149,15 @@ void set_params_fprop(Flash_fwd_params& params, params.is_seqlens_k_cumulative = true; params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = false; + + params.leftpad_k = static_cast(leftpad_k); + params.cache_batch_idx = static_cast(cache_batch_idx); + params.rotary_cos_ptr = nullptr; + params.rotary_sin_ptr = nullptr; + params.is_rotary_interleaved = false; + params.alibi_slopes_ptr = nullptr; + params.alibi_slopes_batch_stride = 0; } size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { @@ -173,11 +184,13 @@ size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_splitkv_dispatch(params, stream); - } + BOOL_SWITCH(params.is_causal, Is_causal_const, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); }); }); } @@ -258,20 +271,6 @@ std::tuple get_num_splits_and_buffer_sizes(size_t batch_ } } -// void set_params_alibi(Flash_fwd_params ¶ms, void* alibi_slopes, int batch_size, int num_heads){ -// if (alibi_slopes != nullptr) { -// // TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); -// // CHECK_DEVICE(alibi_slopes); -// // TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); -// // TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) -// || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); -// params.alibi_slopes_ptr = alibi_slopes; -// params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? num_heads : 0; // TODO: flag for bool -// } else { -// params.alibi_slopes_ptr = nullptr; -// } -// } - Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -294,7 +293,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded bool kv_bsnh, - int local_window_size) { + int local_window_size, + void* cache_batch_idx, + void* leftpad_k) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -322,7 +323,10 @@ Status mha_fwd(const cudaDeviceProp& dprops, use_smooth_softmax, kv_bsnh, local_window_size, - is_causal ? 0 : -1); + is_causal ? 0 : -1, + /*unpadded_lse=*/false, + cache_batch_idx, + leftpad_k); params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; @@ -440,18 +444,20 @@ bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_hea // of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, - void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size - void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size - void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* out, // batch_size x seqlen_q x num_heads x head_size - void* softmax_lse, // batch_size x num_heads x seqlen_q - void* seqlens_k_, // batch_size - void* rotary_cos, // seqlen_ro x (rotary_dim / 2) - void* rotary_sin, // seqlen_ro x (rotary_dim / 2) - void* head_sink, // num_heads - int* block_table, // batch_size x max_num_blocks_per_seq + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* cache_batch_idx, // (optional) indices to index into the KV cache + void* leftpad_k, // (optional) batch_size + void* head_sink, // num_heads + int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, int num_heads_k, @@ -501,7 +507,10 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, use_smooth_softmax, past_bsnh, local_window_size, - is_causal ? 0 : -1); + is_causal ? 0 : -1, + /*unpadded_lse=*/false, + cache_batch_idx, + leftpad_k); params.dprops = &dprops; if (k_new != nullptr && v_new != nullptr) { @@ -573,7 +582,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr || cache_batch_idx != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index b18790f6de4e0..22b075d8533f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -61,7 +61,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded bool kv_bsnh = true, - int local_window_size = -1); + int local_window_size = -1, + void* cache_batch_idx = nullptr, + void* leftpad_k = nullptr); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -91,18 +93,20 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, - void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // batch_size x seqlen_k_new x num_heads_k x head_size - void* out, // batch_size x seqlen_q x num_heads x head_size - void* softmax_lse, // batch_size x num_heads x seqlen_q - void* seqlens_k_, // batch_size - void* rotary_cos, // seqlen_ro x (rotary_dim / 2) - void* rotary_sin, // seqlen_ro x (rotary_dim / 2) - void* head_sink, // num_heads - int* block_table, // batch_size x max_num_blocks_per_seq + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k x head_size, or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k x head_size, or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* cache_batch_idx, // (optional) indices to index into the KV cache + void* leftpad_k, // (optional) batch_size + void* head_sink, // num_heads + int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, int num_heads_k, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..a273073195bd6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu index 1ef1ce251ecba..a52cb042e2473 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..27b6a353a33ac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu index 44ea92e58c86e..12f41b7ee42cc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..26e0d28367bcc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu index 3bdc5e4b0443f..806c91bdefddf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..581b1f96a6f24 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu index 56fc04126ab12..eac2fe7a8b452 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..87c03f9594279 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu index 59568b0bb03ce..a5f7b0a5ce00d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..35304079e6cdb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu index 94d51e922d7cb..5bf5d13820f3b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..94eeffb3b4cc6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu index ad3d4df7dfc85..093b4232a1f12 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..bf184250fa960 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu index d32eec27634ce..9f2c46dd7b63a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..d1ea0af5218b9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu index 006416458c91b..82bc2a313e721 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..583824f9a2b6b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu index 65a2e42192532..b90a2acb27a1b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..795d80ad6bfa2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu index d5a273a3f4163..f91f07ff4e49d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..04798a07b535c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu index f37ee5005855a..c7f5099be18a1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu @@ -1,18 +1,15 @@ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template <> -void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 91104b8c3dfe0..e80d632c3b77a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -1,6 +1,7 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ + #pragma once #if defined(__GNUC__) @@ -21,6 +22,7 @@ #include #include +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/block_info.h" #include "contrib_ops/cuda/bert/flash_attention/kernel_traits.h" #include "contrib_ops/cuda/bert/flash_attention/utils.h" @@ -28,8 +30,8 @@ #include "contrib_ops/cuda/bert/flash_attention/mask.h" #include "contrib_ops/cuda/bert/flash_attention/rotary.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { + using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -39,14 +41,16 @@ __forceinline__ __device__ auto get_lse_tile(const Params& params, const int bid // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. // Otherwise, it's written as (h, b, seqlen_q). - const bool varlen_q = params.unpadded_lse; + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); - auto lse_stride = params.unpadded_lse - ? make_stride(params.h * params.total_q, params.total_q, 1) - : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1); + auto lse_stride = params.seqlenq_ngroups_swapped + ? make_stride(1, params.seqlen_q * params.b, params.b) + : (params.unpadded_lse + ? make_stride(params.h * params.total_q, params.total_q, 1) + : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)); auto lse_layout = make_layout(lse_shape, lse_stride); Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); @@ -70,64 +74,61 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - // constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - const int n_block_min = !Is_local - ? 0 - : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. - // Otherwise we might read OOB elements from gK and gV. - if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { - Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), - make_stride(params.o_row_stride, params.o_head_stride, _1{})); - Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, - make_coord(m_block, 0)); // (kBlockM, kHeadDim) - - Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_tensor(shape(tOgO)); - clear(tOrO); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); #pragma unroll - for (int m = 0; m < size<1>(tOgO); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSE(row) = std::numeric_limits::infinity(); - } + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = kInfinity; } - return; } + return; } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; // We move K and V to the last block. + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), @@ -155,7 +156,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -224,14 +225,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // Prologue // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M @@ -241,12 +242,12 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<1>(); + FLASH_NAMESPACE::cp_async_wait<1>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M @@ -254,13 +255,11 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } clear(acc_o); - flash::Softmax<2 * size<1>(acc_o)> softmax; - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr - ? 0.0f - : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, - params.window_size_left, params.window_size_right, alibi_slope); + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -272,41 +271,39 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : ((Is_even_MN && Is_causal) - ? cute::ceil_div(kBlockM, kBlockN) - : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap) { - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -318,11 +315,16 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + if (Return_softmax) { + cute::copy(rP, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -335,22 +337,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); if constexpr (Is_softcap) { - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -361,11 +363,16 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + if (Return_softmax) { + cute::copy(rP, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -375,7 +382,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, sink); // Convert acc_o from fp32 to fp16/bf16 - Tensor rO = flash::convert_type(acc_o); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); @@ -435,7 +442,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); } @@ -510,13 +517,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); #pragma unroll for (int m = 0; m < size<1>(tOgOaccum); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSEaccum(row) = Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + gLSEaccum(row) = Split ? -kInfinity : kInfinity; } } return; @@ -528,15 +535,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const int* block_table = params.block_table == nullptr - ? nullptr - : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr - ? 0 - : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr - ? 0 - : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const int* block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; @@ -562,7 +563,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -598,7 +599,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - // // PREDICATES // @@ -631,6 +631,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Prologue + // Copy from Knew to K, optionally apply rotary embedding. typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); @@ -640,7 +641,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); @@ -661,8 +662,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. // This maps to accessing the first 64 rows of knew_ptr. @@ -680,23 +683,23 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons auto tKgK_data = tKgK.data(); auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - flash::copy_w_min_idx( + FLASH_NAMESPACE::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { - flash::copy_w_min_idx( + FLASH_NAMESPACE::copy_w_min_idx( tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); } else { if (params.is_rotary_interleaved) { // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_interleaved( + FLASH_NAMESPACE::copy_rotary_interleaved( tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); } else { // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_contiguous( + FLASH_NAMESPACE::copy_rotary_contiguous( tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); @@ -729,10 +732,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), @@ -752,11 +755,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); if (params.is_rotary_interleaved) { - flash::copy_rotary_interleaved( + FLASH_NAMESPACE::copy_rotary_interleaved( tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim); } else { - flash::copy_rotary_contiguous( + FLASH_NAMESPACE::copy_rotary_contiguous( tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim); } @@ -764,22 +767,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); - // flash::cp_async_wait<0>(); + // FLASH_NAMESPACE::cp_async_wait<0>(); // __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); clear(acc_o); - flash::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; - const float alibi_slope = !Has_alibi ? 0.0f - : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -796,7 +798,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV @@ -810,26 +812,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap) { - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); @@ -845,7 +847,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -858,12 +860,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -876,7 +878,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (block_table == nullptr) { @@ -888,17 +890,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); if constexpr (Is_softcap) { - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { // Advance gK @@ -911,7 +913,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -921,18 +923,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue float sink = (params.head_sink_ptr != nullptr) ? reinterpret_cast(params.head_sink_ptr)[bidh] - : (params.smooth_softmax ? 0.0f : -std::numeric_limits::infinity()); + : (params.smooth_softmax ? 0.0f : -kInfinity); Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, sink); Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -943,7 +945,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons typename Kernel_traits::SmemCopyAtomOaccum>; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(acc_o); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) @@ -958,6 +960,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); @@ -1003,7 +1006,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); } @@ -1017,15 +1020,7 @@ inline __device__ void compute_attn(const Params& params) { // The block index for the head. const int bidh = blockIdx.z; - // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting - // them to have the same number of threads or have to traverse the attention matrix - // in the same order. - // In the Philox RNG, we use the offset to store the batch, head, and the lane id - // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within - // the attention matrix. This way, as long as we have the batch, head, and the location of - // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1039,7 +1034,7 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1072,12 +1067,15 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Shape, Int>{}, make_stride(lse_size, _1{})); + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. Layout flat_layout = make_layout(lse_size); Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); - auto transposed_stride = make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); @@ -1085,13 +1083,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - // Read the LSE values from gmem and store them in shared memory, then tranpose them. + // Read the LSE values from gmem and store them in shared memory, then transpose them. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -std::numeric_limits::infinity(); + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -kInfinity; if (row < kMaxSplits) { sLSE[row][col] = lse; } @@ -1112,7 +1110,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -std::numeric_limits::infinity(); + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -kInfinity; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } } @@ -1124,7 +1122,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = lse_max == -kInfinity ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { @@ -1132,9 +1130,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? std::numeric_limits::infinity() : logf(lse_sum) + lse_max; + // For the case where all local lse == -kInfinity, we want to set lse_logsum to kInfinity. Otherwise + // lse_logsum is log(0.0) = -kInfinity and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? kInfinity : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { @@ -1163,7 +1161,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Stride, _1>{}); constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; - using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; @@ -1186,7 +1184,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); #pragma unroll for (int m = 0; m < size<1>(tOrOaccum); ++m) { @@ -1205,7 +1203,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } // if (cute::thread0()) { print_tensor(tOrO); } - Tensor rO = flash::convert_type(tOrO); + Tensor rO = FLASH_NAMESPACE::convert_type(tOrO); // Write to gO #pragma unroll for (int m = 0; m < size<1>(rO); ++m) { @@ -1232,11 +1230,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } } -} // namespace flash -} // namespace onnxruntime - -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index a6a0133ae60c6..c9f53becb05b5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -1,14 +1,16 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ + #pragma once +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/static_switch.h" #include "contrib_ops/cuda/bert/flash_attention/flash.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -23,292 +25,210 @@ namespace flash { // Use a macro to clean up kernel definitions #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ - template \ - __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { -#if defined(ARCH_SUPPORTS_FLASH) - static_assert(!(Is_causal && Is_local)); // Enforce constraints - flash::compute_attn(params); -#else - FLASH_UNSUPPORTED_ARCH -#endif + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + FLASH_NAMESPACE::compute_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif } DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { -#if defined(ARCH_SUPPORTS_FLASH) - flash::compute_attn_splitkv(params); -#else - FLASH_UNSUPPORTED_ARCH -#endif + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn_splitkv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif } DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { - static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + static_assert(Log_max_splits >= 1); + FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template -void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.b, params.h); - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, false > ; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<(smem_size), stream>>>(params); - }); + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf(\"IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d\\n\", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + CUDA_CALL_THROW(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); + }); + }); }); - }); }); - }); } -template -void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { - static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); - static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); - constexpr size_t smem_size = Kernel_traits::kSmemSize; - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_Local_Const, [&] { - BOOL_SWITCH(params.num_splits > 1, SplitConst, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV_Const, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If Append_KV_Const, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_Local_Const, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, Is_Local_Const && !Is_causal, Has_alibi, - IsEvenMNConst && !Append_KV_Const && IsEvenKConst && !Is_Local_Const && Kernel_traits::kHeadDim <= 128, - IsEvenKConst, Is_softcap, SplitConst, Append_KV_Const >; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); - } - kernel<<(smem_size), stream>>>(params); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + CUDA_CALL_THROW(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + }); + }); + }); }); - }); }); - }); }); - }); }); - }); - if (params.num_splits > 1) { - // We want kBlockM to be as small as possible for more parallelism. - // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. - // If headdim is divisible by 64, then we set kBlockM = 8, etc. - constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); - dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } - }); - } + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } } -template -void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int kBlockM = 64; // Fixed for all head dimensions - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - // Also for headdim 160 with block size 64 x 128 after the rotary addition. - constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal>(params, stream); } -template -void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 32; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; run_flash_fwd, Is_causal>(params, stream); - }); } -template -void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 64; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - }); } -template -void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 96; - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { - if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); } -template -void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 128; - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. if (is_sm8x) { - if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } else { - run_flash_fwd, Is_causal>(params, stream); - } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy - }); -} - -template -void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 160; - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); } -template -void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 192; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); -} - -template -void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 224; - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; - // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. - // If we have N = 32, there are only 1024 elements to load at once, where each load - // is 8 elements. This means we can only use 128 threads and not 256 threads. - // run_flash_fwd, Is_causal>(params, stream); - }); } -template -void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr static int Headdim = 256; - size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; - // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // 64 KB - // run_flash_fwd, Is_causal>(params, stream); - // 96 KB - // run_flash_fwd, Is_causal>(params, stream); - }); } - -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..043e39fbf865a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu index 3ca416f6580c4..7bc28e28b9149 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..64d0d21530bc7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu index 68ae2ea759813..534f936a1e6ec 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..34b14f4cfb1bb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu index 79606fd05b4d8..b0194f3359005 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..98cf57ee7676b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu index ec9e9e738c5b3..16d08ddd12a04 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..22d2642e7ac36 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu index 8eb5c8f84544b..fe599d6604828 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..33ac6cefb771e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu index 552966852cdbe..440cf4d88a660 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..8ad3fd93b0839 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu index 0141f27aa199f..9cd950c63885d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..585dab607bf06 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu index e9f191a4828d6..a54e82f2c8d08 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..afdd82b8a658e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu index 489d2d47bc709..754e8ba273bed 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..31cef295cfd79 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu index d628a556680ad..2039998ad72b3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_causal_sm80.cu new file mode 100644 index 0000000000000..8637dd4072472 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu index bcfd47e76b99e..9ddb6eb467918 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_causal_sm80.cu new file mode 100644 index 0000000000000..8e05992607c07 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,12 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu index 88b6cc0fb1e22..47129e348913b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -1,15 +1,12 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#if USE_FLASH_ATTENTION +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash -} // namespace onnxruntime -#endif +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index cb08dbc853a91..ddbca756d0fd1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -1,18 +1,19 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ -#pragma once -#include +#pragma once -#include -#include +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" #include +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" + using namespace cute; -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { template struct Flash_kernel_traits { @@ -76,6 +77,7 @@ struct Flash_fwd_kernel_traits : public Base { typename Base::MMA_Atom_Arch, Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group Tile, _16, _16>>; + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, @@ -98,8 +100,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom; + using SmemCopyAtomO = Copy_Atom, Element>; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); @@ -122,11 +124,11 @@ struct Flash_fwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy>; + AutoVectorizingCopyWithAssumedAlignment<128>>; using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store @@ -136,14 +138,14 @@ struct Flash_fwd_kernel_traits : public Base { Stride<_8, _1>>, Layout, // Thread layout, 16 threads per row Stride<_16, _1>>>; - using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load - using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load }; @@ -235,7 +237,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdStransposed = decltype(composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - using SmemCopyAtomPdS = Copy_Atom; + using SmemCopyAtomPdS = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype(composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); @@ -246,7 +248,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); - using SmemCopyAtomdKV = Copy_Atom; + using SmemCopyAtomdKV = Copy_Atom, elem_type>; using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, Layout>, @@ -254,7 +256,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); - using SmemCopyAtomdQ = Copy_Atom; + using SmemCopyAtomdQ = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -283,17 +285,17 @@ struct Flash_bwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy>; + AutoVectorizingCopyWithAssumedAlignment<128>>; using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< @@ -302,15 +304,14 @@ struct Flash_bwd_kernel_traits : public Base { Stride<_8, _1>>, Layout, // Thread layout, 16 threads per row Stride<_16, _1>>>; - using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store - using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store }; -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h index 71434002f8df1..ec84fc7d07493 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h @@ -3,19 +3,18 @@ ******************************************************************************/ #pragma once +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" -#include #include -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; template __forceinline__ __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -29,7 +28,7 @@ __forceinline__ __device__ void apply_mask(Tensor& tensor, const // Without the "make_coord" we get wrong results #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); + tensor(mi, make_coord(j, nj)) = -kInfinity; } } } @@ -41,7 +40,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor& tensor, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride, const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -60,7 +59,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor& tensor, for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + tensor(make_coord(i, mi), make_coord(j, nj)) = -kInfinity; } } } @@ -86,7 +85,7 @@ template & tensor, Tensor const& idx_rowcol, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 2, "Only support 2D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); @@ -97,7 +96,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -std::numeric_limits::infinity(); + tensor(mi, ni) = -kInfinity; } } // if (cute::thread0()) { @@ -117,7 +116,8 @@ struct Mask { __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, const int window_size_left, const int window_size_right, const float alibi_slope = 0.f) - : max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q), window_size_left(window_size_left), window_size_right(window_size_right), alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {}; + : max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q), window_size_left(window_size_left), window_size_right(window_size_right), alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; // Causal_mask: whether this particular iteration needs causal masking template @@ -132,7 +132,7 @@ struct Mask { // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } if constexpr (Need_masking) { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; @@ -152,7 +152,7 @@ struct Mask { } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { - tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); + tensor(mi, make_coord(j, nj)) = -kInfinity; } } } @@ -182,18 +182,18 @@ struct Mask { } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + tensor(make_coord(i, mi), make_coord(j, nj)) = -kInfinity; } } if constexpr (Is_local) { if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + tensor(make_coord(i, mi), make_coord(j, nj)) = -kInfinity; } } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + tensor(make_coord(i, mi), make_coord(j, nj)) = -kInfinity; } } } @@ -205,5 +205,4 @@ struct Mask { }; }; -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/namespace_config.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/namespace_config.h new file mode 100644 index 0000000000000..dffb42d188530 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/namespace_config.h @@ -0,0 +1,24 @@ +/** + * @file flash_namespace_config.h + * @brief Configuration file for Flash namespace management and isolation + */ +#pragma once + +#ifndef FLASH_NAMESPACE_CONFIG_H +#define FLASH_NAMESPACE_CONFIG_H + +// Set default namespace to onnxruntime::flash +#ifndef FLASH_NAMESPACE +#define FLASH_NAMESPACE onnxruntime::flash +#endif + +#define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name + +#define FLASH_NAMESPACE_SCOPE(content) \ + namespace onnxruntime { \ + namespace flash { \ + content \ + } \ + } + +#endif // FLASH_NAMESPACE_CONFIG_H diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h index dfc14ab4b4406..cc528ebe9d601 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h @@ -6,12 +6,12 @@ #include +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/utils.h" //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -150,5 +150,4 @@ __forceinline__ __device__ void copy_rotary_contiguous(Tensor //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index c7a8476f5beae..d7ea7e6ab09fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -1,6 +1,7 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ + #pragma once #include @@ -10,10 +11,10 @@ #include +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" #include "contrib_ops/cuda/bert/flash_attention/utils.h" -namespace onnxruntime { -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -76,10 +77,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso const float max_scaled = max(mi) == -kInfinity ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. +// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - +// max * log_2(e)) This allows the compiler to use the ffma +// instruction instead of fadd and fmul separately. +// The following macro will disable the use of fma. +// See: https://github.com/pytorch/pytorch/issues/121558 for more details +// This macro is set in PyTorch and not FlashAttention +#ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); +#else tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); +#endif } } } @@ -96,21 +104,21 @@ struct Softmax { template __forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { - flash::template reduce_max(scores, row_max); - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - flash::reduce_sum(scores, row_sum); + FLASH_NAMESPACE::template reduce_max(scores, row_max); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); - flash::template reduce_max(scores, row_max); + FLASH_NAMESPACE::template reduce_max(scores, row_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int mi = 0; mi < size<0>(row_max); ++mi) { + for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -kInfinity ? 0.0f : row_max(mi)); @@ -121,11 +129,10 @@ struct Softmax { acc_o_rowcol(mi, ni) *= scores_scale; } } - - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. - flash::reduce_sum(scores, row_sum); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); } }; @@ -133,11 +140,10 @@ struct Softmax { __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, float sink) { // IMPORTANT: sink is a pre-scaled logit - SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); const bool use_sink = (sink != -kInfinity); @@ -177,7 +183,7 @@ struct Softmax { ? (Split ? -kInfinity : kInfinity) : max_unscaled * softmax_scale + __logf(sum); - float inv_sum = (sum == 0.f || !isfinite(sum)) ? 1.f : 1.f / sum; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= inv_sum; @@ -185,8 +191,7 @@ struct Softmax { } return lse; - } + }; }; -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h index 239365de8090b..eb1d6501a80f6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -1,5 +1,6 @@ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + #pragma once /// @param COND - a boolean expression to switch by @@ -12,6 +13,7 @@ /// some_function(...); /// }); /// ``` + #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py b/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py new file mode 100644 index 0000000000000..6c556cfc15a75 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py @@ -0,0 +1,77 @@ +import os + +base_dir = "/home/tlwu/onnxruntime/onnxruntime/contrib_ops/cuda/bert/flash_attention" +dims = [32, 64, 96, 128, 192, 256] +types = ["fp16", "bf16"] + +copyright_header = """/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +""" + +template_standard = ( + copyright_header + + """ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template<> +void run_mha_fwd_<{cpp_type}, {dim}, {is_causal}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_hdim{dim}<{cpp_type}, {is_causal}>(params, stream); +}} + +}} // namespace FLASH_NAMESPACE +""" +) + +template_split = ( + copyright_header + + """ +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template void run_mha_fwd_splitkv_dispatch<{cpp_type}, {dim}, {is_causal}>(Flash_fwd_params ¶ms, cudaStream_t stream); + +}} // namespace FLASH_NAMESPACE +""" +) + +for dim in dims: + for t in types: + cpp_type = "cutlass::half_t" if t == "fp16" else "cutlass::bfloat16_t" + + # Standard - Non-causal + filename = f"flash_fwd_hdim{dim}_{t}_sm80.cu" + filepath = os.path.join(base_dir, filename) + content = template_standard.format(cpp_type=cpp_type, dim=dim, is_causal="false") + with open(filepath, "w") as f: + f.write(content) + print(f"Updated {filename}") + + # Standard - Causal + filename_causal = f"flash_fwd_hdim{dim}_{t}_causal_sm80.cu" + filepath_causal = os.path.join(base_dir, filename_causal) + content_causal = template_standard.format(cpp_type=cpp_type, dim=dim, is_causal="true") + with open(filepath_causal, "w") as f: + f.write(content_causal) + print(f"Updated {filename_causal}") + + # Split - Non-causal + filename_split = f"flash_fwd_split_hdim{dim}_{t}_sm80.cu" + filepath_split = os.path.join(base_dir, filename_split) + content_split = template_split.format(cpp_type=cpp_type, dim=dim, is_causal="false") + with open(filepath_split, "w") as f: + f.write(content_split) + print(f"Updated {filename_split}") + + # Split - Causal + filename_split_causal = f"flash_fwd_split_hdim{dim}_{t}_causal_sm80.cu" + filepath_split_causal = os.path.join(base_dir, filename_split_causal) + content_split_causal = template_split.format(cpp_type=cpp_type, dim=dim, is_causal="true") + with open(filepath_split_causal, "w") as f: + f.write(content_split_causal) + print(f"Updated {filename_split_causal}") diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 76b1aaefebeff..872466b367b29 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -1,6 +1,7 @@ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ + #pragma once #include @@ -20,9 +21,11 @@ #include #include +#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace onnxruntime { -namespace flash { + +namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -34,18 +37,14 @@ __forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("max.f16x2 %0, %1, %2;\n" - : "=r"(res) - : "r"(x), "r"(zero)); + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); #else asm volatile( "{\n" "\t .reg .f16x2 sela;\n" "\t set.gtu.u32.f16x2 sela, %1, %2;\n" "\t and.b32 %0, sela, %1;\n" - "}\n" - : "=r"(res) - : "r"(x), "r"(zero)); + "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; } @@ -55,9 +54,7 @@ template <> __forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; - asm volatile("max.bf16x2 %0, %1, %2;\n" - : "=r"(res) - : "r"(x), "r"(zero)); + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); return res; } #endif @@ -74,9 +71,7 @@ __forceinline__ __device__ uint32_t convert_relu2(const float2 uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); - asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" - : "=r"(res) - : "r"(b), "r"(a)); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); return res; } @@ -85,9 +80,7 @@ __forceinline__ __device__ uint32_t convert_relu2(const flo uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); - asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" - : "=r"(res) - : "r"(b), "r"(a)); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); return res; } @@ -285,8 +278,8 @@ __forceinline__ __device__ auto convert_type_relu(Tensor const& } Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); #else - Tensor out = flash::convert_type(tensor); - flash::relu_(out); + Tensor out = FLASH_NAMESPACE::convert_type(tensor); + FLASH_NAMESPACE::relu_(out); #endif return out; } @@ -342,10 +335,10 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor -inline __device__ void copy_w_min_idx(Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, - const int max_MN = 0, const int min_MN = 0) { +__forceinline__ __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -379,5 +372,4 @@ __forceinline__ __device__ void apply_softcap(Tensor& tensor, co //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash -} // namespace onnxruntime +} // namespace FLASH_NAMESPACE diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index cfdaf2aa74837..83cd531d422b2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -475,7 +475,7 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, - reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, head_sink, /*block_table*/ nullptr, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, nullptr, nullptr, head_sink, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, From 4f0942fa2d01ce1fe952e35d9a2cbb12d2b41493 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 8 Jan 2026 01:48:44 -0800 Subject: [PATCH 2/3] update base_dir --- .../contrib_ops/cuda/bert/flash_attention/update_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py b/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py index 6c556cfc15a75..80a1f2956f303 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/update_kernels.py @@ -1,6 +1,6 @@ import os -base_dir = "/home/tlwu/onnxruntime/onnxruntime/contrib_ops/cuda/bert/flash_attention" +base_dir = os.path.dirname(os.path.realpath(__file__)) dims = [32, 64, 96, 128, 192, 256] types = ["fp16", "bf16"] From ee4205f35897f392993809c606649d29fca1e2b8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 8 Jan 2026 15:59:38 -0800 Subject: [PATCH 3/3] cleanup of redundant default assignments --- .../cuda/bert/flash_attention/flash_api.cc | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index e11137413a8cf..da140c270086f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -149,15 +149,9 @@ void set_params_fprop(Flash_fwd_params& params, params.is_seqlens_k_cumulative = true; params.unpadded_lse = unpadded_lse; - params.seqlenq_ngroups_swapped = false; params.leftpad_k = static_cast(leftpad_k); params.cache_batch_idx = static_cast(cache_batch_idx); - params.rotary_cos_ptr = nullptr; - params.rotary_sin_ptr = nullptr; - params.is_rotary_interleaved = false; - params.alibi_slopes_ptr = nullptr; - params.alibi_slopes_batch_stride = 0; } size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { @@ -328,26 +322,13 @@ Status mha_fwd(const cudaDeviceProp& dprops, cache_batch_idx, leftpad_k); params.dprops = &dprops; - params.knew_ptr = nullptr; - params.vnew_ptr = nullptr; - params.knew_batch_stride = 0; - params.vnew_batch_stride = 0; - params.knew_row_stride = 0; - params.vnew_row_stride = 0; - params.knew_head_stride = 0; - params.vnew_head_stride = 0; params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; params.oaccum_ptr = out_accum; - } else { - params.softmax_lseaccum_ptr = nullptr; - params.oaccum_ptr = nullptr; } - params.alibi_slopes_ptr = nullptr; - run_mha_fwd(params, stream); return Status::OK(); } @@ -412,12 +393,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, params.total_q = total_q; params.dprops = &dprops; - params.num_splits = 0; - params.softmax_lseaccum_ptr = nullptr; - params.oaccum_ptr = nullptr; - params.knew_ptr = nullptr; - params.vnew_ptr = nullptr; - params.alibi_slopes_ptr = nullptr; if (paged_KV) { params.block_table = block_table; params.block_table_batch_stride = max_num_blocks_per_seq; @@ -533,16 +508,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; - } else { - params.seqlen_knew = 0; - params.knew_ptr = nullptr; - params.vnew_ptr = nullptr; - params.knew_batch_stride = 0; - params.vnew_batch_stride = 0; - params.knew_row_stride = 0; - params.vnew_row_stride = 0; - params.knew_head_stride = 0; - params.vnew_head_stride = 0; } params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; @@ -566,7 +531,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.oaccum_ptr = nullptr; } - params.alibi_slopes_ptr = nullptr; if (paged_KV) { params.block_table = block_table; params.block_table_batch_stride = max_num_blocks_per_seq;