Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions csrc/flashmask_v2/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using namespace cute;

template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
bool Is_flashmask, bool Has_lt_end, bool Has_ut_start,
bool Has_lt_end, bool Has_ut_start,
int Stages_dO=2, int Stages_dS_or_QSm80=2,
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
Expand Down Expand Up @@ -94,7 +94,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
Arch >= 90,
flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
SdP_swapAB, dKV_swapAB, dQ_swapAB, Is_flashmask, Has_lt_end, Has_ut_start, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
SdP_swapAB, dKV_swapAB, dQ_swapAB, Has_lt_end, Has_ut_start, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
Expand All @@ -115,9 +115,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
>;

if constexpr (Is_flashmask) {
flash::flashmask::prepare_block_maxmin<kBlockN>(params, stream);
}
flash::flashmask::prepare_block_maxmin<kBlockN>(params, stream);

if constexpr (Arch >= 90) {
prepare_preemptive_scheduler(params, stream, params.num_sm);
Expand Down Expand Up @@ -351,7 +349,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
}

template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
bool Is_flashmask_, bool Has_lt_end_, bool Has_ut_start_,
bool Has_lt_end_, bool Has_ut_start_,
int Stages_dO=2, int Stages_dS_or_QSm80=2,
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
Expand All @@ -361,7 +359,7 @@ void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
// BOOL_SWITCH(params.deterministic, Deterministic, [&] {
// run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Is_flashmask_, Has_lt_end_, Has_ut_start_, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Has_lt_end_, Has_ut_start_, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
// });
});
});
Expand All @@ -372,25 +370,21 @@ template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// printf("point2-1\n");
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
if constexpr (Is_flashmask_ && !Is_causal) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 64, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
} else if constexpr (Is_causal && Has_softcap || Is_flashmask_) {
// register spill with 128 x 128
run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
if constexpr (!Is_causal) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 64, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
} else {
// With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
// register spill with 128 x 128
run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
}
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true, true>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false, true>(params, stream);
}
});
}
Expand All @@ -413,22 +407,21 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
if constexpr (Is_causal || Is_local || Has_softcap) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
} else {
if ((params.seqlen_q >= 1024 || params.seqlen_k >= 1024) && !(Has_lt_end && Has_ut_start)) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 64, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 64, 128, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
}
}
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false, true>(params, stream);
}
});
}
Expand Down
20 changes: 13 additions & 7 deletions csrc/flashmask_v2/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class FlashAttnFwdSm90 {
static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
static constexpr bool Varlen = CollectiveMainloop::Varlen;
static constexpr bool Split = CollectiveMainloop::Split;
static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
Expand All @@ -49,7 +48,6 @@ class FlashAttnFwdSm90 {
static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
static constexpr bool Is_flashmask = CollectiveMainloop::Is_flashmask;
static constexpr bool Use_Sch_Pipeline = TileScheduler_::pipelining;
static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
Expand Down Expand Up @@ -280,7 +278,6 @@ class FlashAttnFwdSm90 {
int const m_block = get<0>(block_coord);
int const bidh = get<1>(block_coord);
int const bidb = get<2>(block_coord);
int const split_idx = get<3>(block_coord);
SeqlenInfo_t seqlen_info{
get<2>(block_coord) /*bidb*/,
get<0>(params.mainloop.shape_Q),
Expand All @@ -290,11 +287,11 @@ class FlashAttnFwdSm90 {
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
};
auto [n_block_min, n_block_max] = CollectiveMainloop::BlockMN_t::get_n_block_min_max(
seqlen_info, m_block, bidb, split_idx, params.mainloop.num_splits,
seqlen_info, m_block, bidb, 0, 1,
params.mainloop.window_size_left, params.mainloop.window_size_right, params.mainloop.qhead_per_khead_divmod);

// It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access.
if constexpr (Is_causal || Is_local || Varlen || Split) {
if constexpr (Is_causal || Is_local || Varlen) {
if (n_block_max <= n_block_min) {
// skipping, don't forget to fetch us the next work!
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
Expand All @@ -305,6 +302,15 @@ class FlashAttnFwdSm90 {
// for padding 32 and padding 4: the num_chunk (pad_32) >= num_chunk (pad_4) is always true
const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; // umiswing: padding for int4 load
const int num_chunk = (nblock_seqlen + CollectiveMainloop::Flashmask_n_block_buffer_valid_length - 1) / CollectiveMainloop::Flashmask_n_block_buffer_valid_length;
const int reverse_chunk_start = [&] {
if constexpr (Is_causal) {
// if causal, the 'valid' sequence length is smaller. We don't need that many chunks, so we will start from chunks closer to the start
const int seqlen_offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q;
return std::max(num_chunk - 1 - (kBlockM * (m_block + 1) + seqlen_offset) / (kBlockN * CollectiveMainloop::Flashmask_n_block_buffer_valid_length), 0);
} else {
return 0;
}
} ();
// reverse_chunk_idx, start from right to left: [5, 4, 3, 2, 1, 0], and fwd kernel scans from right to left
bool valid_chunk = true;
const int cppl_stage = scheduler.template stage<true>(); // coarse pipeline stage (offset, 0 or 2)
Expand All @@ -314,11 +320,11 @@ class FlashAttnFwdSm90 {
reverse_chunk_idx, \
num_chunk, \
reverse_chunk_idx == num_chunk - 1 ? CollectiveMainloop::Flashmask_n_block_finish : CollectiveMainloop::Flashmask_n_block_chunk_end,\
n_block_min, n_block_max, seqlen_info.seqlen_q, \
n_block_min, n_block_max, nblock_seqlen, seqlen_info.seqlen_q, seqlen_info.seqlen_k, \
flashmask_maxmin_smem + 8 * CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage), \
n_block_smem + CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage), \
extra_flags + n_block_pipe_write.index() + cppl_stage)
for(int reverse_chunk_idx = 0; reverse_chunk_idx < num_chunk; reverse_chunk_idx++) {
for(int reverse_chunk_idx = reverse_chunk_start; reverse_chunk_idx < num_chunk; reverse_chunk_idx++) {
if (valid_chunk)
pipeline_n_block.producer_acquire(n_block_pipe_write);
mainloop.load_max_min(params.mainloop, seqlen_info, block_coord, reverse_chunk_idx, num_chunk, flashmask_maxmin_smem +
Expand Down
Loading