Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh"
)

# Quick build mode: Filter out non-hdim128 flash attention kernels for faster development iteration
# Quick build mode: Filter flash attention kernels for faster development iteration.
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels.
# If new head dimensions are added or removed, update this list to match the supported set.
if(onnxruntime_QUICK_BUILD)
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
# Filter non-hdim128 kernels
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|160|192|224|256)")
# Filter all bfloat16 kernels (only keep fp16)
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16")
endif()



file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
Expand Down
37 changes: 18 additions & 19 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

namespace onnxruntime {
namespace flash {
#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h"
namespace FLASH_NAMESPACE {

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Varlen = true>
Expand All @@ -17,43 +19,40 @@ struct BlockInfo {
: params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
? params.seqlen_q
: params.cu_seqlens_q[bidb + 1] - sum_s_q)
: params.cu_seqlens_q[bidb + 1] - sum_s_q),
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
? params.seqlen_k
: (params.is_seqlens_k_cumulative
? params.cu_seqlens_k[bidb + 1] - sum_s_k
: params.cu_seqlens_k[bidb])),
leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]),
seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr
? params.seqlen_k
: (params.is_seqlens_k_cumulative
? params.cu_seqlens_k[bidb + 1] - sum_s_k
: params.cu_seqlens_k[bidb])) -
leftpad_k),
actual_seqlen_k(params.seqused_k
? params.seqused_k[bidb]
? params.seqused_k[bidb] - leftpad_k
: seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
}

template <typename index_t>
__forceinline__ __device__
index_t
q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template <typename index_t>
__forceinline__ __device__
index_t
k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}

const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache;
const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
} // namespace onnxruntime
} // namespace FLASH_NAMESPACE
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q = nullptr;
int* __restrict__ cu_seqlens_k = nullptr;
int* __restrict__ leftpad_k = nullptr;

// If provided, the actual length of each k sequence.
int* __restrict__ seqused_k = nullptr;

int* __restrict__ blockmask = nullptr;

// The K_new and V_new matrices.
void* __restrict__ knew_ptr = nullptr;
void* __restrict__ vnew_ptr = nullptr;
Expand Down Expand Up @@ -131,15 +130,17 @@ struct Flash_fwd_params : public Qkv_params {
void* __restrict__ alibi_slopes_ptr = nullptr;
index_t alibi_slopes_batch_stride = 0;

bool unpadded_lse = false;
bool unpadded_lse = false; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
const cudaDeviceProp* dprops = nullptr;
bool seqlenq_ngroups_swapped = false; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int Headdim>
template <typename T, int Headdim, bool Is_causal>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
template <typename T, int Headdim>

template <typename T, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);

} // namespace flash
Expand Down
108 changes: 39 additions & 69 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ void set_params_fprop(Flash_fwd_params& params,
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1,
const bool unpadded_lse = false) {
const bool unpadded_lse = false,
void* cache_batch_idx = nullptr,
void* leftpad_k = nullptr) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
Expand Down Expand Up @@ -147,6 +149,9 @@ void set_params_fprop(Flash_fwd_params& params,

params.is_seqlens_k_cumulative = true;
params.unpadded_lse = unpadded_lse;

params.leftpad_k = static_cast<int*>(leftpad_k);
params.cache_batch_idx = static_cast<int*>(cache_batch_idx);
}

size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) {
Expand All @@ -173,11 +178,13 @@ size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) {
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
}
BOOL_SWITCH(params.is_causal, Is_causal_const, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim, Is_causal_const>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal_const>(params, stream);
}
});
});
});
}
Expand Down Expand Up @@ -258,20 +265,6 @@ std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_
}
}

// void set_params_alibi(Flash_fwd_params &params, void* alibi_slopes, int batch_size, int num_heads){
// if (alibi_slopes != nullptr) {
// // TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
// // CHECK_DEVICE(alibi_slopes);
// // TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
// // TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads})
// || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
// params.alibi_slopes_ptr = alibi_slopes;
// params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? num_heads : 0; // TODO: flag for bool
// } else {
// params.alibi_slopes_ptr = nullptr;
// }
// }

Status mha_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
Expand All @@ -294,7 +287,9 @@ Status mha_fwd(const cudaDeviceProp& dprops,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
bool kv_bsnh,
int local_window_size) {
int local_window_size,
void* cache_batch_idx,
void* leftpad_k) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
Expand Down Expand Up @@ -322,28 +317,18 @@ Status mha_fwd(const cudaDeviceProp& dprops,
use_smooth_softmax,
kv_bsnh,
local_window_size,
is_causal ? 0 : -1);
is_causal ? 0 : -1,
/*unpadded_lse=*/false,
cache_batch_idx,
leftpad_k);
params.dprops = &dprops;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
params.knew_batch_stride = 0;
params.vnew_batch_stride = 0;
params.knew_row_stride = 0;
params.vnew_row_stride = 0;
params.knew_head_stride = 0;
params.vnew_head_stride = 0;

params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
params.oaccum_ptr = out_accum;
} else {
params.softmax_lseaccum_ptr = nullptr;
params.oaccum_ptr = nullptr;
}

params.alibi_slopes_ptr = nullptr;

run_mha_fwd(params, stream);
return Status::OK();
}
Expand Down Expand Up @@ -408,12 +393,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,

params.total_q = total_q;
params.dprops = &dprops;
params.num_splits = 0;
params.softmax_lseaccum_ptr = nullptr;
params.oaccum_ptr = nullptr;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
params.alibi_slopes_ptr = nullptr;
if (paged_KV) {
params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
Expand All @@ -440,18 +419,20 @@ bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_hea
// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_.
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* cache_batch_idx, // (optional) indices to index into the KV cache
void* leftpad_k, // (optional) batch_size
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
int batch_size,
int num_heads,
int num_heads_k,
Expand Down Expand Up @@ -501,7 +482,10 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
use_smooth_softmax,
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
is_causal ? 0 : -1,
/*unpadded_lse=*/false,
cache_batch_idx,
leftpad_k);
params.dprops = &dprops;

if (k_new != nullptr && v_new != nullptr) {
Expand Down Expand Up @@ -544,19 +528,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.q_head_stride = head_size;
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;

params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
} else {
params.seqlen_knew = 0;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
params.knew_batch_stride = 0;
params.vnew_batch_stride = 0;
params.knew_row_stride = 0;
params.vnew_row_stride = 0;
params.knew_head_stride = 0;
params.vnew_head_stride = 0;
}

params.is_seqlens_k_cumulative = seqlens_k_ == nullptr;
Expand All @@ -581,7 +552,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.oaccum_ptr = nullptr;
}

params.alibi_slopes_ptr = nullptr;
if (paged_KV) {
params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
Expand All @@ -600,7 +570,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
// or if using packed QKV (to ensure correct handling of strided inputs which might be better supported or isolated in split kernel logic).
// Note: if the fused kernel handles packing/rotary/appending, it should pass is_packed_qkv=false to this API (via use_packed_for_fa=false),
// effectively bypassing this check and allowing standard kernels if otherwise eligible.
bool force_split = (k_new != nullptr) || is_packed_qkv;
bool force_split = (k_new != nullptr) || is_packed_qkv || cache_batch_idx != nullptr;

run_mha_fwd(params, stream, force_split);

Expand Down
30 changes: 17 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ Status mha_fwd(const cudaDeviceProp& dprops,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
bool kv_bsnh = true,
int local_window_size = -1);
int local_window_size = -1,
void* cache_batch_idx = nullptr,
void* leftpad_k = nullptr);

Status mha_varlen_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
Expand Down Expand Up @@ -91,18 +93,20 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,

Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k x head_size, or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k x head_size, or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* cache_batch_idx, // (optional) indices to index into the KV cache
void* leftpad_k, // (optional) batch_size
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
int batch_size,
int num_heads,
int num_heads_k,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template <>
void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params& params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
Loading
Loading