diff --git a/csrc/flashmask_v2/__init__.py b/csrc/flashmask_v2/__init__.py new file mode 100644 index 00000000000..2e33087c53f --- /dev/null +++ b/csrc/flashmask_v2/__init__.py @@ -0,0 +1 @@ +__version__ = "3.0.0.b1" diff --git a/csrc/flashmask_v2/flash_api.cpp b/csrc/flashmask_v2/flash_api.cpp new file mode 100644 index 00000000000..c1694f93a7e --- /dev/null +++ b/csrc/flashmask_v2/flash_api.cpp @@ -0,0 +1,1536 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" +#include "tile_size.h" +#include "heuristics.h" +#include "cuda_check.h" + +// 修改 begin +#include "flash_api.h" +// 修改 end + +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + +namespace { +inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) { + return at::cuda::CUDAGuard(static_cast(t.get_device())); +} +} // namespace + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + // int attention_chunk, + const cudaDeviceProp &dprops, + const float softcap=0.f, + const int sm_margin=0) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (cu_seqlens_k_d == nullptr) { + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + TORCH_CHECK(p_dropout < 1.f,"p_dropout must be less than 1.0"); + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0 && window_size_right >= 0) {window_size_left = seqlen_k - 1;} + if (window_size_left >= 0 && window_size_right < 0) {window_size_right = seqlen_q - 1;} + + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.arch = dprops.major * 10 + dprops.minor; + params.num_sm = dprops.multiProcessorCount - sm_margin; + + #ifdef FLASHMASK_V2_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &dout, + at::Tensor &dq, + at::Tensor &dk, + at::Tensor &dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const cudaDeviceProp &dprops, + const float softcap=0.f, + bool deterministic=false, + int const sm_margin=0) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_q, + seqused_k, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + dprops, + softcap, + sm_margin); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +// 修改 begin +void run_mha_fwd_entry(Flash_fwd_params ¶ms, cudaStream_t stream) { + flashmaskv2_run_mha_fwd(¶ms, stream); +} + +void run_mha_fwd_combine_entry(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { + #ifndef FLASHMASK_V2_DISABLE_SPLIT + flashmaskv2_run_mha_fwd_combine(¶ms, stream, enable_pdl); + #else + TORCH_CHECK(false, "This flash attention build does not support combine kernels."); + #endif +} + +// 修改 end + +inline bool get_pagedkv_tma_entry(Flash_fwd_params ¶ms) { + return flashmaskv2_get_pagedkv_tma(¶ms); +} + +inline bool get_pack_gqa_entry(Flash_fwd_params ¶ms) { + return flashmaskv2_get_pack_gqa(¶ms); +} + +inline int get_num_splits_entry(Flash_fwd_params& params) { + return flashmaskv2_get_num_splits(¶ms); +} + +inline int get_max_headdim() { + return 256; +} + +inline int round_up_headdim(int head_size) { + #ifndef FLASHMASK_V2_DISABLE_HDIM64 + if (head_size <= 64) { return 64; } + #endif + #ifndef FLASHMASK_V2_DISABLE_HDIM128 + if (head_size <= 128) { return 128; } + #endif + return 256; +} + +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + +// b: batch_size +// b_k: batch_size_k +// s_q: seqlen_q +// s_k: seqlen_k +// s_k_new: seqlen_k_new +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple +mha_fwd(const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + const std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + const std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + const std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + const std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + const std::optional &cu_seqlens_q_, // b+1 + const std::optional &cu_seqlens_k_, // b+1 + const std::optional &cu_seqlens_k_new_, // b+1 + const std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + const std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + const std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + const std::optional max_seqlen_k_, + const std::optional &page_table_, // (b_k, max_num_pages_per_seq) + const std::optional &kv_batch_idx_, // b. indices to index into the KV cache + const std::optional &leftpad_k_, // b + const std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + const std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + // std::optional seqlens_rotary_, // b + const std::optional &q_descale_, // (b, h_k), not (b, h) + const std::optional &k_descale_, // (b, h_k) + const std::optional &v_descale_, // (b, h_k) + const double softmax_scale, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + // int64_t attention_chunk, + const double softcap, + // 修改 begin + // === 新增参数 === + const std::optional &startend_row_indices_, // 新增:FlashMask 核心索引 + const std::optional &block_mask_, // 新增:BlockMask + // 修改 end + const bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + const std::optional &scheduler_metadata_, // (b + 1) + int64_t num_splits, + const std::optional pack_gqa_, + const int64_t sm_margin + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // 修改 begin + // bool is_sm8x = dprops->major >= 8; + // TORCH_CHECK(is_sm8x, "FLASHMASK_V2 only supports Ampere GPUs or newer."); + // Paddle: const bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention-3 (FlashMaskV2) only supports Hopper GPUs (e.g., H100/H800)."); + // 修改 end + + auto q_type = q.scalar_type(); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, + "FLASHMASK_V2 only supports fp16, bf16, and fp8_e4m3 data type"); + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor page_table; + const bool paged_KV = page_table_.has_value(); + if (paged_KV) { + page_table = page_table_.value(); + CHECK_DEVICE(page_table); + TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + } + + at::Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + at::Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + } + + auto const sizes = q.sizes(); + const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; + int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); + int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + int const max_headdim = get_max_headdim(); + TORCH_CHECK(head_size <= max_headdim, "FLASHMASK_V2 forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); + if (head_size <= 64 && head_size_v > 64) { + std::cout << "[Warning] 使用了备用 Kernel,性能可能受限..." << std::endl; + } + TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } + + // 修改 begin + bool const is_flashmask = startend_row_indices_.has_value(); + bool const is_blockmask = block_mask_.has_value(); + // 修改 end + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + // 修改 begin + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if (((head_size <= 64 || head_size > 128) || !paged_KV) && !is_flashmask) { + is_causal = false; + } + } + // 修改 end + if (is_causal) { window_size_right = 0; } + is_causal = window_size_left < 0 && window_size_right == 0; + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()) { + auto seqused_k = seqused_k_.value(); + TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHMASK_V2_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + } + } else { + out = !is_varlen_q + ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = round_up_headdim(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_cuda_guard_from_tensor(q); + + at::Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + } else { + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + } + + Flash_fwd_params params = {}; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + *dprops, + // attention_chunk, + softcap, + sm_margin); + // 疑惑 + // if (startend_row_indices_.has_value()) { + // params.flashmask_indices_ptr = startend_row_indices_.value().data_ptr(); + + // } else { + // params.flashmask_indices_ptr = nullptr; + // } + // if (block_mask_.has_value()) { + // params.block_mask_ptr = block_mask_.value().data_ptr(); + // } else { + // params.block_mask_ptr = nullptr; + // } + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } + if (paged_KV) { + params.page_table = page_table.data_ptr(); + params.page_table_batch_stride = page_table.stride(0); + } + params.page_size = page_size; + params.num_pages = num_pages; + + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + at::Tensor k_new, v_new; + TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + at::Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.value(); + CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); + TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); + } + k_new = k_new_.value(); + v_new = v_new_.value(); + TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); + TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); + CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); + TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + params.seqlen_knew = seqlen_k_new; + params.total_knew = total_k_new; + params.knew_ptr = k_new.data_ptr(); + params.vnew_ptr = v_new.data_ptr(); + // All stride are in elements, not bytes. + params.knew_row_stride = k_new.stride(-3); + params.vnew_row_stride = v_new.stride(-3); + params.knew_head_stride = k_new.stride(-2); + params.vnew_head_stride = v_new.stride(-2); + if (!is_varlen_k_new) { + params.knew_batch_stride = k_new.stride(0); + params.vnew_batch_stride = v_new.stride(0); + } + if (is_varlen_k_new) { + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + } + } + + // Paddle limit: 992 = 32 * 31 (max supported batch in prepare_varlen_num_blocks) + bool const use_dynamic_split = is_varlen && (params.b <= 992); + + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma_entry(params); + params.num_splits = num_splits <= 0 ? get_num_splits_entry(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + // 修改 begin + // 使用 FlashMask 特定的 GQA 打包策略函数 + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa_entry(params); + // 修改 end + + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? true // Hopper 上强制开启 + // ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + // params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + // params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_dynamic_split) { + // int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int metadata_size = int(scheduler_needs_semaphore) + (use_dynamic_split ? params.b : 0); + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + at::Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::empty({metadata_size }, opts.dtype(torch::kInt32)); + } + if (scheduler_needs_semaphore && !use_dynamic_split) { + tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing + } + + int* metadata_ptr = tile_count_semaphore.data_ptr(); + + // 1. 设置 Semaphore 指针 (对应 set_tile_count_semaphore) + // Paddle: scheduler_needs_semaphore ? ptr : nullptr + params.tile_count_semaphore = scheduler_needs_semaphore ? metadata_ptr : nullptr; + + // 2. 设置 Dynamic Split 指针 (对应 set_num_splits_dynamic_ptr) + // Paddle: use_dynamic_split ? (ptr + 1) : nullptr + // 关键点:这里显式偏移了 1 个 int 的位置 + params.num_splits_dynamic_ptr = use_dynamic_split ? (metadata_ptr + 1) : nullptr; + + } + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + // TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + if (paged_KV) { + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + at::Tensor out_accum, softmax_lse_accum; + auto outaccum_type = at::ScalarType::Float; + if (params.num_splits > 1) { + TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + if (!is_varlen_q) { + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + params.oaccum_batch_stride = out_accum.stride(1); + params.lseaccum_batch_stride = softmax_lse_accum.stride(1); + } else { + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); + } + params.is_fp32 = false; + params.oaccum_ptr = out_accum.data_ptr(); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_split_stride = out_accum.stride(0); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.lseaccum_split_stride = softmax_lse_accum.stride(0); + params.lseaccum_head_stride = softmax_lse_accum.stride(-2); + } + + if (q_type == at::ScalarType::Float8_e4m3fn) { + if (q_descale_.has_value()) { + auto q_descale = q_descale_.value(); + CHECK_DEVICE(q_descale); + CHECK_SHAPE(q_descale, batch_size, num_heads_k); + params.q_descale_ptr = q_descale.data_ptr(); + params.q_descale_batch_stride = q_descale.stride(0); + params.q_descale_head_stride = q_descale.stride(1); + } else { + params.q_descale_ptr = nullptr; + } + if (k_descale_.has_value()) { + auto k_descale = k_descale_.value(); + CHECK_DEVICE(k_descale); + CHECK_SHAPE(k_descale, batch_size, num_heads_k); + params.k_descale_ptr = k_descale.data_ptr(); + params.k_descale_batch_stride = k_descale.stride(0); + params.k_descale_head_stride = k_descale.stride(1); + } else { + params.k_descale_ptr = nullptr; + } + if (v_descale_.has_value()) { + auto v_descale = v_descale_.value(); + CHECK_DEVICE(v_descale); + CHECK_SHAPE(v_descale, batch_size, num_heads_k); + params.v_descale_ptr = v_descale.data_ptr(); + params.v_descale_batch_stride = v_descale.stride(0); + params.v_descale_head_stride = v_descale.stride(1); + } else { + params.v_descale_ptr = nullptr; + } + } + + #ifdef FLASHMASK_V2_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHMASK_V2_DISABLE_SOFTCAP + TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + #ifdef FLASHMASK_V2_DISABLE_SPLIT + TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); + #endif + #ifdef FLASHMASK_V2_DISABLE_PACKGQA + TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + #endif + #ifdef FLASHMASK_V2_DISABLE_PAGEDKV + TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); + #endif + #ifdef FLASHMASK_V2_DISABLE_APPENDKV + TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); + #endif + + at::Tensor startend_row_indices; + if (is_flashmask) { + startend_row_indices = startend_row_indices_.value(); + } + at::Tensor block_mask; + if (is_blockmask) { + block_mask = block_mask_.value(); + } + at::Tensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices, ut_start_row_indices, ut_end_row_indices; + if (is_flashmask) { + TORCH_CHECK(startend_row_indices.dim() == 4,"flashmask_attention receive startend_row_indices with dim [batch_size, num_heads, seq_len, mask_bounds]"); + int64_t mask_bounds = startend_row_indices.size(3); + TORCH_CHECK(mask_bounds == 1 || mask_bounds == 2 || mask_bounds == 4,"flashmask_attention startend_row_indices mask_bounds must in [1,2,4]"); + auto flashmask_maxmin_shape = startend_row_indices.sizes().vec(); + + int device_id = startend_row_indices.get_device(); + auto dprops = at::cuda::getDeviceProperties(device_id); + const bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + + if (is_sm90) { + // seqlen_k to nblock_seqlen, here we use kBlockN = 64 + flashmask_maxmin_shape[2] = ((flashmask_maxmin_shape[2] + 63) / 64 + 3) / 4 * 4; + + // make sure this is the same with FlashMaskV3 fwd main loop + static constexpr int flashmask_buffer_length = 16 * 1024; + static constexpr int chunk_padded_length = + ((flashmask_buffer_length + 63) / 64 + 31) & 0xffffffe0; + static constexpr int chunk_valid_length = + ((flashmask_buffer_length + 63) / 64 + 3) & 0xfffffffc; + + const int num_chunk = + (flashmask_maxmin_shape[2] + chunk_valid_length - 1) / chunk_valid_length; + + flashmask_maxmin_shape[2] = num_chunk * chunk_padded_length; + } else { + // seqlen_k to nblock_seqlen + flashmask_maxmin_shape[2] = ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4; + } + flashmask_maxmin_shape[3] = 8; + + flashmask_maxmin = at::empty( + flashmask_maxmin_shape, + startend_row_indices.options().dtype(at::kInt) + ); + + lt_start_row_indices = startend_row_indices.slice(3, 0, 1).contiguous(); + + if (mask_bounds == 2) { + if (!is_causal) { + ut_end_row_indices = startend_row_indices.slice(3, 1, 2).contiguous(); + } else { + lt_end_row_indices = startend_row_indices.slice(3, 1, 2).contiguous(); + } + } else if (mask_bounds == 4) { + ut_end_row_indices = startend_row_indices.slice(3, 3, 4).contiguous(); + lt_end_row_indices = startend_row_indices.slice(3, 1, 2).contiguous(); + ut_start_row_indices = startend_row_indices.slice(3, 2, 3).contiguous(); + } + } + + if (is_blockmask) { + TORCH_CHECK( + is_flashmask == true, + "blockmask should be used with flashmask at the same time " + ); + + TORCH_CHECK( + block_mask.dim() == 4, + "blockmask receive blockmask_indices with dim " + "[batch_size, num_heads, blocklen_q, blocklen_k]" + ); + + TORCH_CHECK( + block_mask.size(2) == (seqlen_q + 127) / 128, + "blockmask is now only support blockdim_q = 128 " + ); + + TORCH_CHECK( + block_mask.size(3) == (seqlen_k + 127) / 128, + "blockmask is now only support blockdim_k = 128 " + ); + + TORCH_CHECK( + block_mask.size(1) == startend_row_indices.size(1), + "blockmask is now only support same " + "dim num_heads with flashmask " + ); + } + + if (is_blockmask) { + params.m_block_dim = 128; + params.n_block_dim = 128; + params.block_mask_ptr = block_mask.data_ptr(); + } + + if (is_flashmask) { + if (lt_start_row_indices.defined()) + params.lt_start_ptr = lt_start_row_indices.data_ptr(); + else + params.lt_start_ptr = nullptr; + + if (lt_end_row_indices.defined()) + params.lt_end_ptr = lt_end_row_indices.data_ptr(); + else + params.lt_end_ptr = nullptr; + + if (ut_start_row_indices.defined()) + params.ut_start_ptr = ut_start_row_indices.data_ptr(); + else + params.ut_start_ptr = nullptr; + + if (ut_end_row_indices.defined()) + params.ut_end_ptr = ut_end_row_indices.data_ptr(); + else + params.ut_end_ptr = nullptr; + + if (flashmask_maxmin.defined()) + params.flashmask_maxmin_ptr = flashmask_maxmin.data_ptr(); + else + params.flashmask_maxmin_ptr = nullptr; + + params.h_flashmask = startend_row_indices.size(1); + params.h_h_flashmask_ratio = num_heads / startend_row_indices.size(1); + + } else { + params.lt_start_ptr = nullptr; + params.lt_end_ptr = nullptr; + params.ut_start_ptr = nullptr; + params.ut_end_ptr = nullptr; + params.flashmask_maxmin_ptr = nullptr; + params.h_flashmask = 0; + params.h_h_flashmask_ratio = 0; + } + + + + if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd_entry(params, stream); + if (params.num_splits > 1) { + if (out_type == at::ScalarType::BFloat16) { + // Since we want output in BF16. Otherwise fwd_combine will output to FP16 + params.is_bf16 = true; + } + // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 + // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. + // if (is_varlen_q && !seqused_q_.has_value()) { + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // This will zero out the semaphore if needed + run_mha_fwd_combine_entry(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + // tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); + tile_count_semaphore.zero_(); + } + } else if (total_q > 0 && num_heads_k > 0) { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + return {out, softmax_lse}; + // return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +#ifdef FLASHMASK_V2_DISABLE_BACKWARD +void run_mha_bwd_entry(Flash_bwd_params ¶ms, cudaStream_t stream) { + TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +void run_mha_bwd_entry(Flash_bwd_params ¶ms, cudaStream_t stream) { + flashmaskv2_run_mha_bwd(¶ms, stream); +} +#endif + + +// b: batch_size +// s_q: seqlen_q +// s_k: seqlen_k +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple mha_bwd( + const at::Tensor &dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + const at::Tensor &v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + const at::Tensor &out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + const std::optional &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const std::optional &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + const std::optional &dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + const std::optional &cu_seqlens_q_, // b+1 + const std::optional &cu_seqlens_k_, // b+1 + const std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + const std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + const std::optional &startend_row_indices_, + const std::optional &block_mask_, + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + const double softmax_scale, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + const double softcap, + const bool deterministic, + const int64_t sm_margin +) { + + #ifdef FLASHMASK_V2_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs."); + + auto q_type = q.dtype(); + TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16, + "FlashAttention-3 bwd only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + at::Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + at::Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + } + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); + #ifdef FLASHMASK_V2_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + auto const sizes = q.sizes(); + int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; + int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); + int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int const num_heads = q.size(-2); + int const head_size = q.size(-1); + // int const head_size_v = v.size(-1); + int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + // TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); + int const max_headdim = get_max_headdim(); + TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + if (is_causal) { window_size_right = 0; } + // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. + // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). + is_causal = window_size_left < 0 && window_size_right == 0; + + int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + int const head_size_rounded = round_up_headdim(head_size); + // int const head_size_v_rounded = head_size_rounded; + // TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); + // Very important that these match the kernel configs + bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + + // Flashmask + bool const is_flashmask = startend_row_indices_.has_value(); + at::Tensor startend_row_indices; + if (is_flashmask) startend_row_indices = startend_row_indices_.value(); + bool const has_softcap = softcap > 0.0; + + at::Tensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices, + ut_start_row_indices, ut_end_row_indices; + + if (is_flashmask) { + TORCH_CHECK(startend_row_indices.dtype() == torch::kInt32, + "flashmask_attention startend_row_indices must be INT32 type"); + TORCH_CHECK(startend_row_indices.dim() == 4, + "flashmask_attention receive startend_row_indices with dim " + "[batch_size, num_heads, seq_len, mask_bounds]"); + int64_t last_dim = startend_row_indices.size(3); + TORCH_CHECK(last_dim == 1 || last_dim == 2 || last_dim == 4, + "flashmask_attention startend_row_indices " + "mask_bounds must in [1,2,4]"); + + auto startend_sizes = startend_row_indices.sizes(); + std::vector flashmask_maxmin_shape(startend_sizes.begin(), startend_sizes.end()); + + flashmask_maxmin_shape[2] = ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4; + flashmask_maxmin_shape[3] = 8; + + flashmask_maxmin = at::empty(flashmask_maxmin_shape, + startend_row_indices.options().dtype(torch::kInt32)); + + lt_start_row_indices = startend_row_indices.slice(3, 0, 1).contiguous(); + + int64_t mask_bounds = startend_row_indices.size(3); + + if (mask_bounds == 2) { + if (!is_causal) { + // ut_end_row_indices + ut_end_row_indices = startend_row_indices.slice(3, 1, 2).contiguous(); + } else { + // lt_end_row_indices + lt_end_row_indices = startend_row_indices.slice(3, 1, 2).contiguous(); + } + } else if (mask_bounds == 4) { + // 对应 Paddle 的切片逻辑: + // ut_end: {3, 4} + ut_end_row_indices = startend_row_indices.slice(3, 3, 4).contiguous(); + // lt_end: {1, 2} + lt_end_row_indices = startend_row_indices.slice(3, 1, 2).contiguous(); + // ut_start: {2, 3} + ut_start_row_indices = startend_row_indices.slice(3, 2, 3).contiguous(); + } + } + + bool const is_blockmask = block_mask_.has_value(); + at::Tensor block_mask; + if (is_blockmask) block_mask = block_mask_.value(); + + if (is_blockmask) { + TORCH_CHECK(is_flashmask, + "blockmask should be used with flashmask at the same time "); + + TORCH_CHECK(block_mask.dim() == 4, + "blockmask receive blockmask_indices with dim " + "[batch_size, num_heads, blocklen_q, blocklen_k]"); + TORCH_CHECK(block_mask.size(2) == (seqlen_q + 127) / 128, + "blockmask only supports blockdim_q = 128 now"); + + TORCH_CHECK(block_mask.size(3) == (seqlen_k + 127) / 128, + "blockmask only supports blockdim_k = 128 now"); + + TORCH_CHECK(block_mask.size(1) == startend_row_indices.size(1), + "blockmask only supports same dim num_heads with flashmask now"); + + TORCH_CHECK(seqlen_k <= 1024 * 128, + "blockmask only supports seqlen <= 128k in bwd now"); + + TORCH_CHECK(seqlen_q <= 1024 * 128, + "blockmask only supports seqlen <= 128k in bwd now"); + } + + const bool has_lt_start = lt_start_row_indices.defined(); + const bool has_lt_end = lt_end_row_indices.defined(); + const bool has_ut_start = ut_start_row_indices.defined(); + const bool has_ut_end = ut_end_row_indices.defined(); + + + const auto [kBlockM_sm90, kBlockN_sm90] = [&]() -> std::pair { + if (head_size_rounded <= 64) { + if (is_flashmask && !is_causal) { + return {64, 96}; + } else if ((is_causal && has_softcap) || is_flashmask) { + return {96, 128}; + } else { + return {128, 128}; + } + } else if (head_size_rounded <= 128) { + if (is_causal || is_local || has_softcap) { + return {64, 128}; + } else { + if ((seqlen_q >= 1024 || seqlen_k >= 1024) && + !(has_lt_end && has_ut_start)) { + return {64, 128}; + } else { + return {64, 64}; + } + } + } else if (head_size_rounded <= 256) { + if (has_lt_end && has_ut_start) { + return {64, 32}; + } else { + return {64, 64}; + } + } else { + TORCH_CHECK(false, "head dim is rounded to ", head_size_rounded, + ", which is not supported in FlashMask V3 now."); + return {0, 0}; + } + }(); + int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; + int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; + int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); + // int const kBlockN_sm90 = head_size_rounded <= 128 + // ? 128 + // : (head_size_rounded <= 192 ? 96 : 80); + int const kBlockN_sm80 = + head_size_rounded <= 128 + ? 128 : + (head_size_rounded <= 192 ? 80 : 64); + int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 + : (head_size_rounded <= 96 ? 128 + : (head_size_rounded <= 128 ? 96 + : (head_size_rounded <= 192 ? 64 : 64))); + int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); + int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); + int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()){ + auto seqused_k = seqused_k_.value(); + TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } + } else { + dv = torch::empty_like(v); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_cuda_guard_from_tensor(q); + + auto opts = q.options(); + // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + at::Tensor softmax_d, softmax_lse_log2; + if (!is_varlen) { + // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + } else { + softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); + softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); + } + at::Tensor dq_accum, dk_accum, dv_accum; + if (!is_varlen) { + dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + } else { + dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + } + if (num_heads_k != num_heads) { // MQA / GQA + if (!is_varlen) { + dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + } else { + dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + } + } + + Flash_bwd_params params = {}; + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout, dq, dk, dv, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + dq_accum.data_ptr(), + num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, + num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + *dprops, + softcap, + deterministic, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + params.dv = head_size; + // params.dv_rounded = head_size_v_rounded; + + // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + // Will be zero'ed out in the backward preprocess kernel + + at::Tensor tile_count_semaphore; + if (arch >= 90) { + // Paddle 用 phi::Full(..., 0) -> Torch 用 torch::zeros + tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + } else { + params.tile_count_semaphore = nullptr; + } + + at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); + params.dq_semaphore = dq_semaphore.data_ptr(); + at::Tensor dk_semaphore, dv_semaphore; + if (num_heads_k != num_heads && params.deterministic) { + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + params.dk_semaphore = dk_semaphore.data_ptr(); + params.dv_semaphore = dv_semaphore.data_ptr(); + } + + if (is_flashmask) { + params.lt_start_ptr = lt_start_row_indices.defined() ? lt_start_row_indices.data_ptr() : nullptr; + params.lt_end_ptr = lt_end_row_indices.defined() ? lt_end_row_indices.data_ptr() : nullptr; + params.ut_start_ptr = ut_start_row_indices.defined() ? ut_start_row_indices.data_ptr() : nullptr; + params.ut_end_ptr = ut_end_row_indices.defined() ? ut_end_row_indices.data_ptr() : nullptr; + params.flashmask_maxmin_ptr = flashmask_maxmin.defined() ? flashmask_maxmin.data_ptr() : nullptr; + params.h_flashmask = startend_row_indices.size(1); + params.h_h_flashmask_ratio = num_heads / startend_row_indices.size(1); + } else { + params.lt_start_ptr = nullptr; + params.lt_end_ptr = nullptr; + params.ut_start_ptr = nullptr; + params.ut_end_ptr = nullptr; + params.flashmask_maxmin_ptr = nullptr; + params.h_flashmask = 0; + params.h_h_flashmask_ratio = 0; + } + + if (is_blockmask) { + params.m_block_dim = 128; + params.n_block_dim = 128; + params.block_mask_ptr = block_mask.data_ptr(); + } + + #ifdef FLASHMASK_V2_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHMASK_V2_DISABLE_SOFTCAP + TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + + if (total_q > 0 && total_k > 0 && num_heads_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_bwd_entry(params, stream); + } else if (total_k > 0 && num_heads_k > 0) { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk.zero_(); + dv.zero_(); + if (softmax_d.defined()) softmax_d.zero_(); + } else if (total_q > 0 && num_heads_k > 0) { + dq.zero_(); + if (softmax_d.defined()) softmax_d.zero_(); + } + + return { dq, dk, dv }; +} + +TORCH_LIBRARY(flashmask, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = 0," + "int? max_seqlen_k = 0," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + // "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float softmax_scale = 0.0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + // "int attention_chunk = 0," + "float softcap = 0.0," + // 修改 begin + // === 新增参数声明 === + "Tensor? startend_row_indices = None," + "Tensor? block_mask = None," + // 修改 end + "bool is_rotary_interleaved = True," + "Tensor? scheduler_metadata = None," + "int num_splits = 1," + "bool? pack_gqa = False," + "int sm_margin = 0) -> (Tensor(out!), Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor? dq = None," + "Tensor? dk = None," + "Tensor? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "Tensor? startend_row_indices = None, " + "Tensor? block_mask = None, " + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float softmax_scale = 0.0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor, Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(flashmask, CUDA, m) { + m.impl("fwd", &mha_fwd); + m.impl("bwd", &mha_bwd); +} diff --git a/csrc/flashmask_v2/flash_api.cu b/csrc/flashmask_v2/flash_api_cuda.cu similarity index 97% rename from csrc/flashmask_v2/flash_api.cu rename to csrc/flashmask_v2/flash_api_cuda.cu index 2b798e9dbd1..20bc283a8f4 100644 --- a/csrc/flashmask_v2/flash_api.cu +++ b/csrc/flashmask_v2/flash_api_cuda.cu @@ -44,12 +44,12 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHMASK_V2_DISABLE_HDIM64 if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); - } - else { + // if (params.dv > 64 && Arch == 90) { + // return run_mha_fwd_(params, stream); + // } + // else { return run_mha_fwd_(params, stream); - } + // } } #endif #ifndef FLASHMASK_V2_DISABLE_HDIM96 @@ -74,12 +74,12 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHMASK_V2_DISABLE_FP16 #ifndef FLASHMASK_V2_DISABLE_HDIM64 if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); - } - else { + // if (params.dv > 64 && Arch == 90) { + // return run_mha_fwd_(params, stream); + // } + // else { return run_mha_fwd_(params, stream); - } + // } } #endif #ifndef FLASHMASK_V2_DISABLE_HDIM96 diff --git a/csrc/flashmask_v2/flashmask_interface.py b/csrc/flashmask_v2/flashmask_interface.py new file mode 100644 index 00000000000..7eac5d8115b --- /dev/null +++ b/csrc/flashmask_v2/flashmask_interface.py @@ -0,0 +1,762 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Union, List, Tuple + +import torch +import torch.nn as nn + +# isort: off +# We need to import the CUDA kernels after importing torch +import flashmask._C # Registers operators with PyTorch + +# isort: on + +flashmask_cuda = torch.ops.flashmask + +class FlashMaskFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, startend_row_indices, block_mask, softmax_scale, causal, window_size_left, window_size_right, softcap, deterministic): + # 1. 调用 C++ 前向算子 + # 使用关键字参数调用,避免参数位置对不上的问题 + out, lse = flashmask_cuda.fwd( + q=q, + k=k, + v=v, + startend_row_indices=startend_row_indices, + block_mask=block_mask, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=window_size_left, + window_size_right=window_size_right, + softcap=softcap + # 其他参数 C++ 端有默认值 None/0,这里不用传 + ) + + # 2. 保存用于反向传播的 Tensor 和参数 + ctx.save_for_backward(q, k, v, out, lse, startend_row_indices, block_mask) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size_left = window_size_left + ctx.window_size_right = window_size_right + ctx.softcap = softcap + ctx.deterministic = deterministic + + return out, lse + + @staticmethod + def backward(ctx, dout, dlse): + # 1. 取出保存的 Tensor + q, k, v, out, lse, startend_row_indices, block_mask = ctx.saved_tensors + + # 2. 调用 C++ 反向算子 + # 注意:这里的参数名必须和 flash_api.cpp 中 m.def("bwd(...") 定义的一致 + dq, dk, dv = flashmask_cuda.bwd( + dout=dout, + q=q, + k=k, + v=v, + out=out, + softmax_lse=lse, + dq=None, # 可选,C++ 会自动分配 + dk=None, + dv=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + startend_row_indices=startend_row_indices, # 传入 mask + block_mask=block_mask, # 传入 block mask + max_seqlen_q=None, + max_seqlen_k=None, + softmax_scale=ctx.softmax_scale, + is_causal=ctx.causal, + window_size_left=ctx.window_size_left, + window_size_right=ctx.window_size_right, + softcap=ctx.softcap, + deterministic=ctx.deterministic, + sm_margin=0 + ) + + # 3. 返回梯度 + # 顺序必须对应 forward 的输入: + # (q, k, v, startend_row_indices, block_mask, softmax_scale, causal, window_size_left, window_size_right, softcap, deterministic) + return dq, dk, dv, None, None, None, None, None, None, None, None + +def flashmask_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + startend_row_indices: Optional[torch.Tensor] = None, + *, + dropout: float = 0.0, + causal: bool = False, + window_size: Optional[Union[int, Tuple[int, int]]] = None, + return_softmax_lse: bool = False, + return_seed_offset: bool = False, + fixed_seed_offset: Optional[torch.Tensor] = None, # Paddle 中通常传入 Tensor,PyTorch 中如果需要手动控制随机种子通常传 Generator 或 tuple + rng_name: str = "", # Paddle 特有:用于选择随机数生成器,PyTorch 中通常忽略或使用 generator 参数 + training: bool = True, + name: Optional[str] = None, # Paddle 特有:静态图命名,PyTorch 动态图不需要 + softmax_scale: Optional[float] = None, + block_mask: Optional[torch.Tensor] = None, +): + """ + FlashMask: Official Implementation (PyTorch Port) + + Args: + query (torch.Tensor): [batch_size, seq_len, num_heads, head_dim] + key (torch.Tensor): [batch_size, seq_len, num_heads, head_dim] + value (torch.Tensor): [batch_size, seq_len, num_heads, head_dim] + startend_row_indices (torch.Tensor): + A column-wise sparse attention mask row indices tensor. + Shape: [batch_size, num_heads, seq_len, {1, 2, 4}] + Dtype: torch.int32 + ... + """ + r""" + FlashMask: Official Implementation + + This module provides the official implementation of the FlashMask algorithm as described in the paper. For more details, please refer to the paper available at: https://arxiv.org/abs/2410.01359. + + The core equation utilized in FlashMask is as follows: + + .. math:: + + \text{result} = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d}} + M\right) \cdot V + + In this equation: + + - ``Q``, ``K``, and ``V`` are the input tensors to the attention module. + - All these tensors share the same dimensions. + - ``d`` denotes the size of the last dimension of these tensors. + - ``M`` represents the column-wise sparse mask introduced by FlashMask. + + Args: + query (torch.Tensor): The query tensor in the attention module. + A 4-D tensor with shape [batch_size, q_seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + key (torch.Tensor): The key tensor in the attention module. + A 4-D tensor with shape [batch_size, k_seq_len, k_num_heads, head_dim]. + The dtype can be float16 or bfloat16. + value (torch.Tensor): The value tensor in the attention module. + A 4-D tensor with shape [batch_size, k_seq_len, k_num_heads, head_dim]. + The dtype can be float16 or bfloat16. + startend_row_indices (torch.Tensor): + A column-wise sparse attention mask row indices tensor. + A 4-D tensor with shape [batch_size, k_num_heads, k_seq_len, {1, 2, 4}]. + The dtype must be int32. k_num_heads can be 1 or the same as key's num_heads. When num_heads is 1, it will be broadcast to match key's num_heads. + Depending on the value of the causal parameter, startend_row_indices can take different shapes and meanings. + + - When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 1], + indicating unidirectional attention. The value represents the starting row index of the left + lower triangular mask in the dense mask. The value startend_row_indices[..., 0] indicates that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked. + - When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 2], + indicating unidirectional attention. The values represent the starting and ending row indices of + the left lower triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1]-th row (exclusive) will be masked. + - When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 2], + indicating bidirectional attention. The values represent the starting row index of the left + lower triangular mask and the ending row index of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 1]-th row upwards (exclusive) will be masked. + - When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 4] , + indicating bidirectional attention. The values represent the start and end row indices of the + left lower triangular mask and the start and end row indices of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:4] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1] row (exclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 2]-th row downwards (inclusive) but above the startend_row_indices[..., 3] row (exclusive) will be masked. + + dropout (float): The dropout ratio. Default is 0.0. + causal (bool): Whether to enable causal mode. Default is False. + window_size (int|tuple, optional): Indicates the window size of sliding window local attention. + If causal mode is enabled, Query at position i will only attend to keys between [i - window_size, i] or [i - window_size[0], i]. + If causal mode is disabled, Query at position i will only attend to keys between [i - window_size, i + window_size] or [i - window_size[0], i + window_size[1]]. + return_softmax_lse (bool): Whether to return the log-sum-exp of the softmax. Default is False. + return_seed_offset (bool): Whether to return the random seed offset. Default is False. + fixed_seed_offset (torch.Tensor, optional): With fixed seed, offset for dropout mask. + rng_name (str): The name to select Generator. (Note: In PyTorch, this is typically unused or replaced by a torch.Generator object, kept here for interface compatibility). + training (bool): Whether the module is in training mode. Default is True. + name (str, optional): Name of the operation. Default is None. (Note: Unused in PyTorch). + block_mask (torch.Tensor, optional): + A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask. + The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where: + + blocklen_q = ceil(seqlen_q / 128), i.e., block_mask.shape[2] must be (seqlen_q + 127) // 128 + blocklen_k = ceil(seqlen_k / 128), i.e., block_mask.shape[3] must be (seqlen_k + 127) // 128 + block_mask.shape[1] (number of heads) must match the num_heads dimension of the flashmask + Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024 + The dtype should be int32, and each element should be either 0 or 1. + A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked. + + Usage Notes: + + Only supported when blockdim_q = blockdim_k = 128 now. + Only supported when headdim = 128 now. + This argument must be provided together with flashmask. + The mask will be applied at the block level: each [i, j] position in block_mask controls whether the corresponding [128 x 128] block in the attention matrix is masked. + Any mismatch in expected shape or head dimension will raise an error. + + + Returns + torch.Tensor. The computed attention result with the same shape as the input `query`. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + + Hint: + This API supports GQA. + + To convert FlashMask's `startend_row_indices` to `dense_mask`, use the code below: + + .. code-block:: python + + >>> import torch + >>> import numpy as np + >>> def flashmask_to_densemask(startend_row_indices, dtype, causal=True): + ... if startend_row_indices is None: + ... return None + ... bz, num_head, seq_len, bound_num = startend_row_indices.shape + ... m = torch.zeros((bz, num_head, seq_len, seq_len), dtype=dtype) + ... has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + ... for bi in range(bz): + ... for hi in range(num_head): + ... for j in range(seq_len): + ... downstart = startend_row_indices[bi, hi, j, 0] + ... if has_end: + ... downend = startend_row_indices[bi, hi, j, 1] + ... m[bi, hi, downstart:downend, j] = -np.inf + ... else: + ... m[bi, hi, downstart:, j] = -np.inf + ... if causal: + ... m[bi, hi, :j, j] = -np.inf + ... else: + ... if has_end: + ... upstart = startend_row_indices[bi, hi, j, 2] + ... upend = startend_row_indices[bi, hi, j, 3] + ... m[bi, hi, upstart:upend, j] = -np.inf + ... else: + ... upend = startend_row_indices[bi, hi, j, 1] + ... m[bi, hi, :upend, j] = -np.inf + ... return m + + For `Causal Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> startend_row_indices = torch.tensor([8]*10, dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> print(startend_row_indices) + tensor([[[[8], + [8], + [8], + [8], + [8], + [8], + [8], + [8], + [8], + [8]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + + For `Sliding Window Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> startend_row_indices = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> print(startend_row_indices) + tensor([[[[ 3], + [ 4], + [ 5], + [ 6], + [ 7], + [ 8], + [ 9], + [10], + [10], + [10]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Causal Document Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> startend_row_indices = torch.tensor([4, 4, 4, 4, 7, 7, 7, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> print(startend_row_indices) + tensor([[[[ 4], + [ 4], + [ 4], + [ 4], + [ 7], + [ 7], + [ 7], + [10], + [10], + [10]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Document Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> LTS = torch.tensor([4, 4, 4, 4, 7, 7, 7, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> UTE = torch.tensor([0, 0, 0, 0, 4, 4, 4, 7, 7, 7], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> startend_row_indices = torch.cat([LTS, UTE], dim=-1) + >>> print(startend_row_indices) + tensor([[[[ 4, 0], + [ 4, 0], + [ 4, 0], + [ 4, 0], + [ 7, 4], + [ 7, 4], + [ 7, 4], + [10, 7], + [10, 7], + [10, 7]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Share Question Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> startend_row_indices = torch.tensor([10, 10, 10, 10, 7, 7, 7, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> print(startend_row_indices) + tensor([[[[10], + [10], + [10], + [10], + [ 7], + [ 7], + [ 7], + [10], + [10], + [10]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Global + Sliding Window Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + >>> # doctest: +SKIP('Only example') + + [[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 0, 1, 1, 1, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 0, 0, 0, 0, 0, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 0, 0, 1, 1]]]]) + + >>> import torch + >>> LTS = torch.tensor([10, 10, 4, 5, 6, 7, 8, 9, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> LTE = torch.tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> UTS = torch.tensor([0, 0, 0, 0, 2, 2, 2, 2, 2, 2], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> UTE = torch.tensor([0, 0, 0, 0, 3, 4, 5, 6, 7, 8], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> startend_row_indices = torch.cat([LTS, LTE, UTS, UTE], dim=-1) + >>> print(startend_row_indices) + tensor([[[[10, 10, 0, 0], + [10, 10, 0, 0], + [ 4, 10, 0, 0], + [ 5, 10, 0, 0], + [ 6, 10, 2, 3], + [ 7, 10, 2, 4], + [ 8, 10, 2, 5], + [ 9, 10, 2, 6], + [10, 10, 2, 7], + [10, 10, 2, 8]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Causal Blockwise Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> LTS = torch.tensor([4, 4, 4, 4, 10, 10, 10, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> LTE = torch.tensor([7, 7, 7, 7, 10, 10, 10, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> startend_row_indices = torch.cat([LTS, LTE], dim=-1) + >>> print(startend_row_indices) + tensor([[[[ 4, 7], + [ 4, 7], + [ 4, 7], + [ 4, 7], + [10, 10], + [10, 10], + [10, 10], + [10, 10], + [10, 10], + [10, 10]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Prefix LM Document Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> LTS = torch.tensor([3, 3, 3, 5, 5, 10, 10, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> UTE = torch.tensor([0, 0, 2, 3, 3, 5, 5, 7, 8, 9], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> startend_row_indices = torch.cat([LTS, UTE], dim=-1) + >>> print(startend_row_indices) + tensor([[[[ 3, 0], + [ 3, 0], + [ 3, 2], + [ 5, 3], + [ 5, 3], + [10, 5], + [10, 5], + [10, 7], + [10, 8], + [10, 9]]]], device='cuda:0', dtype=torch.int32) + >>> # doctest: -SKIP + + For `Prefix LM Causal Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> LTS = torch.tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> UTE = torch.tensor([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> startend_row_indices = torch.cat([LTS, UTE], dim=-1) + >>> print(startend_row_indices) + tensor([[[[10, 0], + [10, 0], + [10, 0], + [10, 0], + [10, 0], + [10, 5], + [10, 6], + [10, 7], + [10, 8], + [10, 9]]]], device='cuda:0', dtype=torch.int32) + + For `QK-sparse Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import torch + >>> LTS = torch.tensor([10, 10, 2, 3, 4, 5, 6, 7, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> LTE = torch.tensor([10, 10, 5, 5, 5, 5, 8, 8, 10, 10], dtype=torch.int32).reshape(1, 1, 10, 1).cuda() + >>> startend_row_indices = torch.cat([LTS, LTE], dim=-1) + >>> print(startend_row_indices) + tensor([[[[10, 10], + [10, 10], + [ 2, 5], + [ 3, 5], + [ 4, 5], + [ 5, 5], + [ 6, 8], + [ 7, 8], + [10, 10], + [10, 10]]]], device='cuda:0', dtype=torch.int32) + + >>> # doctest: -SKIP + """ + + if window_size is not None: + if isinstance(window_size, int): + window_size = (window_size, window_size) + sq = query.shape[1] + bsz = query.shape[0] + assert startend_row_indices is None, ( + "can't use window_size with startend_row_indices" + ) + + # 关键点:获取输入 tensor 的 device,确保新建的 tensor 在同一设备上 + device = query.device + + if causal: + # paddle.arange -> torch.arange + startend_row_indices = torch.arange( + window_size[0] + 1, sq + window_size[0] + 1, dtype=torch.int32, device=device + ).reshape(1, 1, sq, 1) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp( + startend_row_indices, max=sq + ).repeat_interleave(bsz, dim=0) + + else: + # paddle.empty -> torch.empty + startend_row_indices = torch.empty((1, 1, sq, 2), dtype=torch.int32, device=device) + startend_row_indices[0, 0, :, 0] = torch.arange( + window_size[0] + 1, sq + window_size[0] + 1, dtype=torch.int32, device=device + ) + startend_row_indices[0, 0, :, 1] = torch.arange( + -window_size[1], sq - window_size[1], dtype=torch.int32, device=device + ) + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp( + startend_row_indices, min=0, max=sq + ).repeat_interleave(bsz, dim=0) + + if block_mask is not None: + # xhy: can set a full startend_row_indices for block_mask_attn when using block_mask_attn? + assert startend_row_indices is not None, ( + "must provide startend_row_indices when using block_mask_attn" + ) + if startend_row_indices is None: + raise ValueError( + "startend_row_indices cannot be None when calling flashmask_attention. " + "This API is dedicated to FlashMask functionality. " + "If you intended to use standard FlashAttention (without sparse mask), " + "please import and use 'flash_attn_func' from the 'flash_attn' library directly." + ) + + else: + assert startend_row_indices.dtype == torch.int32, ( + f"startend_row_indices.dtype must be torch.int32, but got {startend_row_indices.dtype}" + ) + assert len(startend_row_indices.shape) == 4, ( + f"startend_row_indices rank must be 4, but got {startend_row_indices.shape}" + ) + assert startend_row_indices.shape[0] == key.shape[0], ( + f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}" + ) + assert startend_row_indices.shape[2] == key.shape[1], ( + f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}" + ) + assert startend_row_indices.shape[1] in [1, key.shape[2]], ( + "startend_row_indices head_num must be equal to 1(broadcast) or head_num_k." + ) + + if block_mask is not None: + assert block_mask.dtype == torch.int32, ( + f"block_mask.dtype must be torch.int32, but got {block_mask.dtype}" + ) + assert block_mask.shape[0] == key.shape[0], ( + f"block_mask.shape[0] must be equal to batch_size, but got {block_mask.shape[0]} and {key.shape[0]}" + ) + assert block_mask.shape[1] == startend_row_indices.shape[1], ( + f"block_mask.shape[1] must be equal to startend_row_indices.shape[1], but got {block_mask.shape[1]} and {key.shape[2]}" + ) + assert block_mask.shape[2] == (query.shape[1] + 127) // 128, ( + "block_size must be 128 when using block_mask_attn" + ) + assert block_mask.shape[3] == (key.shape[1] + 127) // 128, ( + "block_size must be 128 when using block_mask_attn" + ) + assert key.shape[3] == 128, ( + "headdim must be 128 when using block_mask_attn" + ) + + if causal: + if startend_row_indices.shape[-1] == 1: + has_end = False + elif startend_row_indices.shape[-1] == 2: + has_end = True + else: + raise ValueError( + f"Invalid shape of startend_row_indices, when causal is True, the last dimension should be either 1 or 2 but got {startend_row_indices.shape[-1]}" + ) + else: + if startend_row_indices.shape[-1] == 2: + has_end = False + elif startend_row_indices.shape[-1] == 4: + has_end = True + else: + raise ValueError( + f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}" + ) + + current_device_type = query.device.type + flag_cudnn_deterministic = torch.are_deterministic_algorithms_enabled() + + # 疑惑 + if current_device_type == "cuda": + major, _ = torch.cuda.get_device_capability(query.device) + flag_flash_attn_version = 3 if major >= 9 else 2 + else: + flag_flash_attn_version = 2 + + if ( + "xpu" not in current_device_type + and flag_cudnn_deterministic + ): + assert block_mask is None, ( + " blockmask attention no supports deterministic now ." + ) + + if "xpu" in current_device_type: + fa_version = 2 + elif ( + flag_flash_attn_version == 3 + and flag_cudnn_deterministic + and query.shape[3] > 128 + ): + fa_version = 2 + else: + fa_version = flag_flash_attn_version + + if fa_version == 2: + raise NotImplementedError("FlashMask v1 is not supported. Please use FlashMask v2.") + + elif fa_version == 3: + assert dropout == 0.0, ( + "flashmask_attention_v2 does not support dropout" + ) + assert not return_seed_offset, ( + "flashmask_attention_v2 does not support return seed_offset" + ) + assert fixed_seed_offset is None, ( + "flashmask_attention_v2 does not support setting seed_offset" + ) + # 在 PyTorch 接口中 rng_name 通常默认为 "" 或 None,保留此检查以确保兼容性 + assert rng_name == "", ( + "flashmask_attention_v2 does not support setting rng_name" + ) + assert training, ( + "flashmask_attention_v2 does not support setting training to False" + ) + + assert name is None, ( + "flashmask_attention_v2 does not support setting name" + ) + + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + + window_size_left = -1 + window_size_right = -1 + softcap = 0.0 + deterministic = flag_cudnn_deterministic + + # 调用 PyTorch 注册的算子 (flashmask_cuda.flashmask_attention_v2) + ( + out, + result_softmax_lse, + ) = FlashMaskFunc.apply( + query, + key, + value, + startend_row_indices, + block_mask, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + deterministic + ) + else: + raise ValueError(f"Invalid flash attention version: {fa_version}") + + outputs = [out] + if return_softmax_lse: + outputs += [result_softmax_lse] + if return_seed_offset: + outputs += [result_seed_offset] + + if len(outputs) == 1: + return outputs[0] + else: + return tuple(outputs) + + \ No newline at end of file diff --git a/csrc/flashmask_v2/setup.py b/csrc/flashmask_v2/setup.py new file mode 100644 index 00000000000..45be66591ff --- /dev/null +++ b/csrc/flashmask_v2/setup.py @@ -0,0 +1,781 @@ +# Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import sys +import warnings +import os +import stat +import re +import shutil +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform +import sysconfig +import tarfile +import itertools + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME + + +# with open("../README.md", "r", encoding="utf-8") as fh: +# with open("../README.md", "r", encoding="utf-8") as fh: +# long_description = fh.read() +long_description="" +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "flashmask" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_MASK_V2_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_MASK_V2_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_MASK_V2_FORCE_CXX11_ABI", "FALSE") == "TRUE" + +DISABLE_BACKWARD = os.getenv("FLASH_MASK_V2_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_MASK_V2_DISABLE_SPLIT", "FALSE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_MASK_V2_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_MASK_V2_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_MASK_V2_DISABLE_LOCAL", "FALSE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_MASK_V2_DISABLE_SOFTCAP", "FALSE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_MASK_V2_DISABLE_PACKGQA", "FALSE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_MASK_V2_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_MASK_V2_DISABLE_FP8", "FALSE") == "TRUE" +DISABLE_VARLEN = os.getenv("FLASH_MASK_V2_DISABLE_VARLEN", "FALSE") == "TRUE" +DISABLE_CLUSTER = os.getenv("FLASH_MASK_V2_DISABLE_CLUSTER", "FALSE") == "TRUE" +DISABLE_HDIM64 = os.getenv("FLASH_MASK_V2_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_MASK_V2_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_MASK_V2_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_MASK_V2_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_MASK_V2_DISABLE_HDIM256", "FALSE") == "TRUE" +DISABLE_SM8x = os.getenv("FLASH_MASK_V2_DISABLE_SM80", "FALSE") == "TRUE" +DISABLE_SM8X = os.getenv("FLASH_MASK_V2_DISABLE_SM80", "FALSE") == "TRUE" + +ENABLE_VCOLMAJOR = os.getenv("FLASH_MASK_V2_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" + +DISABLE_HDIMDIFF64 = os.getenv("FLASH_MASK_V2_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_MASK_V2_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" + +# HACK: we monkey patch pytorch's _write_ninja_file to pass +# "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', +# and pass "-gencode arch=compute_sm80,code=sm_80" to files ending in '_sm80.cu' +from torch.utils.cpp_extension import ( + IS_HIP_EXTENSION, + COMMON_HIP_FLAGS, + SUBPROCESS_DECODE_ARGS, + IS_WINDOWS, + get_cxx_compiler, + _join_rocm_home, + _join_cuda_home, + _is_cuda_file, + _maybe_write, +) + +def create_build_config_file(): + CONFIG = { + "build_flags": { + "FLASHMASK_V2_DISABLE_BACKWARD": DISABLE_BACKWARD, + "FLASHMASK_V2_DISABLE_SPLIT": DISABLE_SPLIT, + "FLASHMASK_V2_DISABLE_PAGEDKV": DISABLE_PAGEDKV, + "FLASHMASK_V2_DISABLE_APPENDKV": DISABLE_APPENDKV, + "FLASHMASK_V2_DISABLE_LOCAL": DISABLE_LOCAL, + "FLASHMASK_V2_DISABLE_SOFTCAP": DISABLE_SOFTCAP, + "FLASHMASK_V2_DISABLE_PACKGQA": DISABLE_PACKGQA, + "FLASHMASK_V2_DISABLE_FP16": DISABLE_FP16, + "FLASHMASK_V2_DISABLE_FP8": DISABLE_FP8, + "FLASHMASK_V2_DISABLE_VARLEN": DISABLE_VARLEN, + "FLASHMASK_V2_DISABLE_CLUSTER": DISABLE_CLUSTER, + "FLASHMASK_V2_DISABLE_HDIM64": DISABLE_HDIM64, + "FLASHMASK_V2_DISABLE_HDIM96": DISABLE_HDIM96, + "FLASHMASK_V2_DISABLE_HDIM128": DISABLE_HDIM128, + "FLASHMASK_V2_DISABLE_HDIM192": DISABLE_HDIM192, + "FLASHMASK_V2_DISABLE_HDIM256": DISABLE_HDIM256, + "FLASHMASK_V2_DISABLE_SM8x": DISABLE_SM8x, + "FLASHMASK_V2_DISABLE_SM8X": DISABLE_SM8X, + "FLASHMASK_V2_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, + "FLASH_MASK_V2_DISABLE_HDIMDIFF64": DISABLE_HDIMDIFF64, + "FLASH_MASK_V2_DISABLE_HDIMDIFF192": DISABLE_HDIMDIFF192, + } + } + + with open("flash_attn_config.py", "w") as f: + f.write("# Auto-generated by flash attention 3 setup.py\n") + f.write(f"CONFIG = {repr(CONFIG)}\n") + f.write("\n") + + f.write("def show():\n") + f.write(" from pprint import pprint\n") + f.write(" pprint(CONFIG)\n") + f.write("\n") + +def _write_ninja_file(path, + cflags, + post_cflags, + cuda_cflags, + cuda_post_cflags, + cuda_dlink_post_cflags, + sources, + objects, + ldflags, + library_target, + with_cuda, + **kwargs, # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension + ) -> None: + r"""Write a ninja file that does the desired compiling and linking. + + `path`: Where to write this file + `cflags`: list of flags to pass to $cxx. Can be None. + `post_cflags`: list of flags to append to the $cxx invocation. Can be None. + `cuda_cflags`: list of flags to pass to $nvcc. Can be None. + `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None. + `sources`: list of paths to source files + `objects`: list of desired paths to objects, one per source. + `ldflags`: list of flags to pass to linker. Can be None. + `library_target`: Name of the output library. Can be None; in that case, + we do no linking. + `with_cuda`: If we should be compiling with CUDA. + """ + def sanitize_flags(flags): + if flags is None: + return [] + else: + return [flag.strip() for flag in flags] + + cflags = sanitize_flags(cflags) + post_cflags = sanitize_flags(post_cflags) + cuda_cflags = sanitize_flags(cuda_cflags) + cuda_post_cflags = sanitize_flags(cuda_post_cflags) + cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags) + ldflags = sanitize_flags(ldflags) + + # Sanity checks... + assert len(sources) == len(objects) + assert len(sources) > 0 + + compiler = get_cxx_compiler() + + # Version 1.3 is required for the `deps` directive. + config = ['ninja_required_version = 1.3'] + config.append(f'cxx = {compiler}') + if with_cuda or cuda_dlink_post_cflags: + if IS_HIP_EXTENSION: + nvcc = _join_rocm_home('bin', 'hipcc') + else: + nvcc = _join_cuda_home('bin', 'nvcc') + if "PYTORCH_NVCC" in os.environ: + nvcc_from_env = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here + else: + nvcc_from_env = nvcc + config.append(f'nvcc_from_env = {nvcc_from_env}') + config.append(f'nvcc = {nvcc}') + + if IS_HIP_EXTENSION: + post_cflags = COMMON_HIP_FLAGS + post_cflags + flags = [f'cflags = {" ".join(cflags)}'] + flags.append(f'post_cflags = {" ".join(post_cflags)}') + if with_cuda: + flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') + flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') + cuda_post_cflags_sm80 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_80,code=sm_80' for s in cuda_post_cflags] + flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}') + cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80'] + flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}') + cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags] + flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}') + flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') + flags.append(f'ldflags = {" ".join(ldflags)}') + + # Turn into absolute paths so we can emit them into the ninja build + # file wherever it is. + sources = [os.path.abspath(file) for file in sources] + + # See https://ninja-build.org/build.ninja.html for reference. + compile_rule = ['rule compile'] + if IS_WINDOWS: + compile_rule.append( + ' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags') + compile_rule.append(' deps = msvc') + else: + compile_rule.append( + ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags') + compile_rule.append(' depfile = $out.d') + compile_rule.append(' deps = gcc') + + if with_cuda: + cuda_compile_rule = ['rule cuda_compile'] + nvcc_gendeps = '' + # --generate-dependencies-with-compile is not supported by ROCm + # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time. + if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1': + cuda_compile_rule.append(' depfile = $out.d') + cuda_compile_rule.append(' deps = gcc') + # Note: non-system deps with nvcc are only supported + # on Linux so use --generate-dependencies-with-compile + # to make this work on Windows too. + nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' + cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' + ] + cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + ] + cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100' + ] + cuda_compile_rule.append( + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') + + # Emit one build rule per source to enable incremental build. + build = [] + for source_file, object_file in zip(sources, objects): + is_cuda_source = _is_cuda_file(source_file) and with_cuda + if is_cuda_source: + if source_file.endswith('_sm90.cu'): + rule = 'cuda_compile' + elif source_file.endswith('_sm80.cu'): + rule = 'cuda_compile_sm80' + elif source_file.endswith('_sm100.cu'): + rule = 'cuda_compile_sm100' + else: + rule = 'cuda_compile_sm80_sm90' + else: + rule = 'compile' + if IS_WINDOWS: + source_file = source_file.replace(':', '$:') + object_file = object_file.replace(':', '$:') + source_file = source_file.replace(" ", "$ ") + object_file = object_file.replace(" ", "$ ") + build.append(f'build {object_file}: {rule} {source_file}') + + if cuda_dlink_post_cflags: + devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o') + devlink_rule = ['rule cuda_devlink'] + devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags') + devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}'] + objects += [devlink_out] + else: + devlink_rule, devlink = [], [] + + if library_target is not None: + link_rule = ['rule link'] + if IS_WINDOWS: + cl_paths = subprocess.check_output(['where', + 'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n') + if len(cl_paths) >= 1: + cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') + else: + raise RuntimeError("MSVC is required to load C++ extensions") + link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') + else: + link_rule.append(' command = $cxx $in $ldflags -o $out') + + link = [f'build {library_target}: link {" ".join(objects)}'] + + default = [f'default {library_target}'] + else: + link_rule, link, default = [], [], [] + + # 'Blocks' should be separated by newlines, for visual benefit. + blocks = [config, flags, compile_rule] + if with_cuda: + blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined] + blocks += [devlink_rule, link_rule, build, devlink, link, default] + content = "\n\n".join("\n".join(b) for b in blocks) + # Ninja requires a new lines at the end of the .ninja file + content += "\n" + _maybe_write(path, content) + + +# Monkey patching +torch.utils.cpp_extension._write_ninja_file = _write_ninja_file + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py +def check_env_flag(name: str, default: str = "") -> bool: + return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] + +# ===== flashmask_v2 feature toggles (对应 CMake 里的 option) ===== +# DISABLE_FP16 = check_env_flag("DISABLE_FP16", "ON") +# DISABLE_FP8 = check_env_flag("DISABLE_FP8", "ON") +# DISABLE_HDIM64 = check_env_flag("DISABLE_HDIM64", "OFF") +# DISABLE_HDIM96 = check_env_flag("DISABLE_HDIM96", "ON") +# DISABLE_HDIM128 = check_env_flag("DISABLE_HDIM128", "OFF") +# DISABLE_HDIM192 = check_env_flag("DISABLE_HDIM192", "ON") +# DISABLE_HDIM256 = check_env_flag("DISABLE_HDIM256", "OFF") +# DISABLE_SPLIT = check_env_flag("DISABLE_SPLIT", "ON") +# DISABLE_PAGEDKV = check_env_flag("DISABLE_PAGEDKV", "ON") +# DISABLE_SOFTCAP = check_env_flag("DISABLE_SOFTCAP", "ON") +# DISABLE_PACKGQA = check_env_flag("DISABLE_PACKGQA", "ON") +# DISABLE_BACKWARD= check_env_flag("DISABLE_BACKWARD", "OFF") +# DISABLE_SM8x = check_env_flag("DISABLE_SM8x", "ON") +# DISABLE_SM8X = check_env_flag("DISABLE_SM8X", "ON") + +DISABLE_FP16 = True # 禁用 FP16 +DISABLE_FP8 = True # 禁用 FP8 +DISABLE_SM8x = True # 禁用 SM8x +DISABLE_SM8X = True +DISABLE_BACKWARD= False # 禁用 BWD (为了解决报错) + +DISABLE_HDIM96 = True # 禁用 96 +DISABLE_HDIM192 = True # 禁用 192 +DISABLE_SPLIT = True # 禁用 Split +DISABLE_PAGEDKV = True # 禁用 PagedKV +DISABLE_SOFTCAP = True # 禁用 Softcap +DISABLE_PACKGQA = True # 禁用 PackGQA + +# 2. 想要启用的 (设为 False = 不禁用) +DISABLE_HDIM64 = False # 启用 64 +DISABLE_HDIM128 = False # 启用 128 +DISABLE_HDIM256 = False # 启用 256 +# ================================================================ + + +# Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py +def is_offline_build() -> bool: + """ + Downstream projects and distributions which bootstrap their own dependencies from scratch + and run builds in offline sandboxes + may set `FLASH_MASK_V2_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading + pinned dependencies from the internet or at using dependencies vendored in-tree. + + Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`). + Missing dependencies lead to an early abortion. + Dependencies' compatibility is not verified. + + Note that this flag isn't tested by the CI and does not provide any guarantees. + """ + return check_env_flag("FLASH_MASK_V2_OFFLINE_BUILD", "") + + +# Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py +def get_flashattn_cache_path(): + user_home = os.getenv("FLASH_MASK_V2_HOME") + if not user_home: + user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None + if not user_home: + raise RuntimeError("Could not find user home directory") + return os.path.join(user_home, ".flashattn") + + +def open_url(url): + user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0' + headers = { + 'User-Agent': user_agent, + } + request = urllib.request.Request(url, None, headers) + # Set timeout to 300 seconds to prevent the request from hanging forever. + return urllib.request.urlopen(request, timeout=300) + + +def download_and_copy(name, src_func, dst_path, version, url_func): + if is_offline_build(): + return + flashattn_cache_path = get_flashattn_cache_path() + base_dir = os.path.dirname(__file__) + system = platform.system() + arch = platform.machine() + arch = {"arm64": "aarch64"}.get(arch, arch) + supported = {"Linux": "linux", "Darwin": "linux"} + url = url_func(supported[system], arch, version) + src_path = src_func(supported[system], arch, version) + tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download + dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path + src_path = os.path.join(tmp_path, src_path) + download = not os.path.exists(src_path) + if download: + print(f'downloading and extracting {url} ...') + file = tarfile.open(fileobj=open_url(url), mode="r|*") + file.extractall(path=tmp_path) + os.makedirs(os.path.split(dst_path)[0], exist_ok=True) + print(f'copy {src_path} to {dst_path} ...') + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy(src_path, dst_path) + + +def nvcc_threads_args(): + nvcc_threads = os.getenv("NVCC_THREADS") or "2" + return ["--threads", nvcc_threads] + + +# NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.93"} + +exe_extension = sysconfig.get_config_var("EXE") + + +cmdclass = {} +ext_modules = [] + +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +# subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) +subprocess.run(["git", "submodule", "update", "--init", "cutlass"]) + +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + create_build_config_file() + check_if_cuda_home_none(PACKAGE_NAME) + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("12.3"): + raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") + elif bare_metal_version >= Version("13.0"): + # CUDA 13.0+ uses system nvcc and CCCL headers are in /usr/local/cuda/include/cccl/ + cccl_include = os.path.join(CUDA_HOME, "include", "cccl") + for env_var in ["CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"]: + current = os.environ.get(env_var, "") + os.environ[env_var] = cccl_include + (":" + current if current else "") + + # ptxas 12.8 gives the best perf currently + # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 + # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. + # For CUDA 13.0+, use system nvcc instead of downloading CUDA 12.x toolchain + if bare_metal_version >= Version("12.3") and bare_metal_version < Version("13.0") and bare_metal_version != Version("12.8"): + download_and_copy( + name="nvcc", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + download_and_copy( + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + download_and_copy( + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", + dst_path="nvvm/bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + base_dir = os.path.dirname(__file__) + ctk_path_new = os.path.abspath(os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin")) + nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") + # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc + # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc + os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] + os.environ["PYTORCH_NVCC"] = nvcc_path_new + # Make nvcc executable, sometimes after the copy it loses its permissions + os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC) + + cc_flag = [] + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90a,code=sm_90a") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + # repo_dir = Path(this_dir).parent + cutlass_dir = Path(this_dir) / "cutlass" + + feature_args = ( + [] + + (["-DFLASHMASK_V2_DISABLE_BACKWARD" ,"-DFLASHATTENTION_DISABLE_BACKWARD"] if DISABLE_BACKWARD else []) + + (["-DFLASHMASK_V2_DISABLE_PAGEDKV" ,"-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else []) + + (["-DFLASHMASK_V2_DISABLE_SPLIT" ,"-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else []) + + (["-DFLASHMASK_V2_DISABLE_APPENDKV" ,"-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else []) + + (["-DFLASHMASK_V2_DISABLE_LOCAL" ,"-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else []) + + (["-DFLASHMASK_V2_DISABLE_SOFTCAP" ,"-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else []) + + (["-DFLASHMASK_V2_DISABLE_PACKGQA" ,"-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else []) + + (["-DFLASHMASK_V2_DISABLE_FP16" ,"-DFLASHATTENTION_DISABLE_FP16"] if DISABLE_FP16 else []) + + (["-DFLASHMASK_V2_DISABLE_FP8" ,"-DFLASHATTENTION_DISABLE_FP8"] if DISABLE_FP8 else []) + + (["-DFLASHMASK_V2_DISABLE_VARLEN" ,"-DFLASHATTENTION_DISABLE_VARLEN"] if DISABLE_VARLEN else []) + + (["-DFLASHMASK_V2_DISABLE_CLUSTER" ,"-DFLASHATTENTION_DISABLE_CLUSTER"] if DISABLE_CLUSTER else []) + + (["-DFLASHMASK_V2_DISABLE_HDIM64" ,"-DFLASHATTENTION_DISABLE_HDIM64"] if DISABLE_HDIM64 else []) + + (["-DFLASHMASK_V2_DISABLE_HDIM96" ,"-DFLASHATTENTION_DISABLE_HDIM96"] if DISABLE_HDIM96 else []) + + (["-DFLASHMASK_V2_DISABLE_HDIM128" ,"-DFLASHATTENTION_DISABLE_HDIM128"] if DISABLE_HDIM128 else []) + + (["-DFLASHMASK_V2_DISABLE_HDIM192" ,"-DFLASHATTENTION_DISABLE_HDIM192"] if DISABLE_HDIM192 else []) + + (["-DFLASHMASK_V2_DISABLE_HDIM256" ,"-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + + (["-DFLASHMASK_V2_DISABLE_SM8x" ,"-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + + (["-DFLASHMASK_V2_DISABLE_SM8X" ,"-DFLASHATTENTION_DISABLE_SM8X"] if DISABLE_SM8X else []) + + (["-DFLASHMASK_V2_ENABLE_VCOLMAJOR" ,"-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHMASK_V2_DISABLE_HDIMDIFF64" ,"-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHMASK_V2_DISABLE_HDIMDIFF192","-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) + ) + + DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + HEAD_DIMENSIONS_BWD = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + # + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + # + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) + ) + # build will now explode with this compilation grouping given all our templating + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + # + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + # + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) + HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD + SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) + PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) + SOFTCAP = [""] + (["_softcap"] if not DISABLE_SOFTCAP else []) + SOFTCAP_ALL = [""] if DISABLE_SOFTCAP else ["_softcapall"] + PACKGQA = [""] + (["_packgqa"] if not DISABLE_PACKGQA else []) + # We already always hard-code PackGQA=true for Sm8x + sources_fwd_sm80 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}_sm80.cu" + for hdim, dtype, split, paged, softcap in itertools.product(HEAD_DIMENSIONS_FWD_SM80, DTYPE_FWD_SM80, SPLIT, PAGEDKV, SOFTCAP_ALL)] + # We already always hard-code PackGQA=true for Sm9x if PagedKV or Split + sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" + for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] + CAUSAL_FLAGS = ["", "_causal"] + DETERM_FLAGS = ["", "_determ"] + sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{causal}{determ}{softcap}_sm90.cu" + for hdim, dtype, causal, determ, softcap in + itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, CAUSAL_FLAGS, DETERM_FLAGS, SOFTCAP_ALL)] + if DISABLE_BACKWARD: + sources_bwd_sm90 = [] + sources_bwd_sm80 = [] + + # Choose between flash_api.cpp and flash_api_stable.cpp based on torch version + torch_version = parse(torch.__version__) + target_version = parse("2.9.0.dev20250830") + stable_args = [] + + if torch_version >= target_version: + flash_api_source = "flash_api_stable.cpp" + stable_args = ["-DTORCH_STABLE_ONLY"] # Checks against including unstable Tensor APIs + else: + flash_api_source = "flash_api.cpp" + + sources = ( + [flash_api_source,"flash_api_cuda.cu"] + + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90 + + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90 + ) + sources = [s for s in sources if "hdim64_512" not in s] + if not DISABLE_SPLIT: + sources += ["flash_fwd_combine.cu"] + sources += ["flash_prepare_scheduler.cu"] + nvcc_flags = [ + "-O3", + "-std=c++17", + "--ftemplate-backtrace-limit=0", # To debug template code + "--use_fast_math", + # "--keep", + # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers + "--resource-usage", # printing out number of registers + # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster + "-lineinfo", # TODO: disable this for release to reduce binary size + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use + "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging + "-DNDEBUG", # Important, otherwise performance is severely impacted + ] + if get_platform() == "win_amd64": + nvcc_flags.extend( + [ + "-D_USE_MATH_DEFINES", # for M_LN2 + "-Xcompiler=/Zc:__cplusplus", # sets __cplusplus correctly, CUTLASS_CONSTEXPR_IF_CXX17 needed for cutlass::gcd + ] + ) + include_dirs = [ + Path(this_dir), + cutlass_dir / "include", + ] + + ext_modules.append( + CUDAExtension( + name=f"{PACKAGE_NAME}._C", + sources=sources, + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + stable_args + feature_args, + "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, + }, + include_dirs=include_dirs, + py_limited_api=True, + ) + ) + + +def get_package_version(): + with open(Path(this_dir) / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +def get_wheel_url(): + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + torch_version_raw = parse(torch.__version__) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + package_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{package_version}", wheel_name=wheel_filename) + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + shutil.move(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + ) + ), + py_modules=["flashmask_interface", "flash_attn_config"], + description="FlashAttention-3", + long_description=long_description, + long_description_content_type="text/markdown", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} + if ext_modules + else { + "bdist_wheel": CachedWheelsCommand, + }, + python_requires=">=3.8", + install_requires=[ + "torch", + "einops", + "packaging", + "ninja", + ], + options={"bdist_wheel": {"py_limited_api": "cp39"}}, +)