diff --git a/csrc/flashmask_v2/flash_bwd_launch_template.h b/csrc/flashmask_v2/flash_bwd_launch_template.h index 521c80cf00d..058426c3872 100644 --- a/csrc/flashmask_v2/flash_bwd_launch_template.h +++ b/csrc/flashmask_v2/flash_bwd_launch_template.h @@ -26,7 +26,7 @@ using namespace cute; template = 90, flash::CollectiveMainloopBwdSm90, + SdP_swapAB, dKV_swapAB, dQ_swapAB, Has_lt_end, Has_ut_start, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>, flash::CollectiveMainloopBwdSm80 @@ -115,9 +115,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { flash::enable_sm80_to_sm89> >; - if constexpr (Is_flashmask) { - flash::flashmask::prepare_block_maxmin(params, stream); - } + flash::flashmask::prepare_block_maxmin(params, stream); if constexpr (Arch >= 90) { prepare_preemptive_scheduler(params, stream, params.num_sm); @@ -351,7 +349,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { } template(params, stream); - run_flash_bwd(params, stream); + run_flash_bwd(params, stream); // }); }); }); @@ -372,25 +370,21 @@ template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { // printf("point2-1\n"); static constexpr bool Is_local = false; - static constexpr bool Is_flashmask_ = true; FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { if constexpr (Arch >= 90) { - if constexpr (Is_flashmask_ && !Is_causal) { - run_mha_bwd_dispatch(params, stream); - } else if constexpr (Is_causal && Has_softcap || Is_flashmask_) { - // register spill with 128 x 128 - run_mha_bwd_dispatch(params, stream); + if constexpr (!Is_causal) { + run_mha_bwd_dispatch(params, stream); } else { - // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block. - run_mha_bwd_dispatch(params, stream); + // register spill with 128 x 128 + run_mha_bwd_dispatch(params, stream); } } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); } else { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } }); } @@ -413,22 +407,21 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { static constexpr bool Is_local = false; - static constexpr bool Is_flashmask_ = true; FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { if constexpr (Arch >= 90) { if constexpr (Is_causal || Is_local || Has_softcap) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } else { if ((params.seqlen_q >= 1024 || params.seqlen_k >= 1024) && !(Has_lt_end && Has_ut_start)) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } else { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } } } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } else { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } }); } diff --git a/csrc/flashmask_v2/flash_fwd_kernel_sm90.h b/csrc/flashmask_v2/flash_fwd_kernel_sm90.h index 3ba512689a8..fc71db7b08b 100644 --- a/csrc/flashmask_v2/flash_fwd_kernel_sm90.h +++ b/csrc/flashmask_v2/flash_fwd_kernel_sm90.h @@ -37,7 +37,6 @@ class FlashAttnFwdSm90 { static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; static constexpr bool Varlen = CollectiveMainloop::Varlen; - static constexpr bool Split = CollectiveMainloop::Split; static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; @@ -49,7 +48,6 @@ class FlashAttnFwdSm90 { static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; - static constexpr bool Is_flashmask = CollectiveMainloop::Is_flashmask; static constexpr bool Use_Sch_Pipeline = TileScheduler_::pipelining; static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; @@ -280,7 +278,6 @@ class FlashAttnFwdSm90 { int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); SeqlenInfo_t seqlen_info{ get<2>(block_coord) /*bidb*/, get<0>(params.mainloop.shape_Q), @@ -290,11 +287,11 @@ class FlashAttnFwdSm90 { params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, }; auto [n_block_min, n_block_max] = CollectiveMainloop::BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.mainloop.num_splits, + seqlen_info, m_block, bidb, 0, 1, params.mainloop.window_size_left, params.mainloop.window_size_right, params.mainloop.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. - if constexpr (Is_causal || Is_local || Varlen || Split) { + if constexpr (Is_causal || Is_local || Varlen) { if (n_block_max <= n_block_min) { // skipping, don't forget to fetch us the next work! scheduler.prefetch_next_work(params.scheduler, work_tile_info); @@ -305,6 +302,15 @@ class FlashAttnFwdSm90 { // for padding 32 and padding 4: the num_chunk (pad_32) >= num_chunk (pad_4) is always true const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; // umiswing: padding for int4 load const int num_chunk = (nblock_seqlen + CollectiveMainloop::Flashmask_n_block_buffer_valid_length - 1) / CollectiveMainloop::Flashmask_n_block_buffer_valid_length; + const int reverse_chunk_start = [&] { + if constexpr (Is_causal) { + // if causal, the 'valid' sequence length is smaller. We don't need that many chunks, so we will start from chunks closer to the start + const int seqlen_offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q; + return std::max(num_chunk - 1 - (kBlockM * (m_block + 1) + seqlen_offset) / (kBlockN * CollectiveMainloop::Flashmask_n_block_buffer_valid_length), 0); + } else { + return 0; + } + } (); // reverse_chunk_idx, start from right to left: [5, 4, 3, 2, 1, 0], and fwd kernel scans from right to left bool valid_chunk = true; const int cppl_stage = scheduler.template stage(); // coarse pipeline stage (offset, 0 or 2) @@ -314,11 +320,11 @@ class FlashAttnFwdSm90 { reverse_chunk_idx, \ num_chunk, \ reverse_chunk_idx == num_chunk - 1 ? CollectiveMainloop::Flashmask_n_block_finish : CollectiveMainloop::Flashmask_n_block_chunk_end,\ - n_block_min, n_block_max, seqlen_info.seqlen_q, \ + n_block_min, n_block_max, nblock_seqlen, seqlen_info.seqlen_q, seqlen_info.seqlen_k, \ flashmask_maxmin_smem + 8 * CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage), \ n_block_smem + CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage), \ extra_flags + n_block_pipe_write.index() + cppl_stage) - for(int reverse_chunk_idx = 0; reverse_chunk_idx < num_chunk; reverse_chunk_idx++) { + for(int reverse_chunk_idx = reverse_chunk_start; reverse_chunk_idx < num_chunk; reverse_chunk_idx++) { if (valid_chunk) pipeline_n_block.producer_acquire(n_block_pipe_write); mainloop.load_max_min(params.mainloop, seqlen_info, block_coord, reverse_chunk_idx, num_chunk, flashmask_maxmin_smem + diff --git a/csrc/flashmask_v2/flash_fwd_launch_template.h b/csrc/flashmask_v2/flash_fwd_launch_template.h index 31bd52b85f4..2646709ad06 100644 --- a/csrc/flashmask_v2/flash_fwd_launch_template.h +++ b/csrc/flashmask_v2/flash_fwd_launch_template.h @@ -27,7 +27,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, bool short_seqlen = false> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -42,7 +42,6 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); - static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); @@ -57,7 +56,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -72,11 +71,11 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // in headdim = 64 case, I suspect I've fixed it, but there is no testing facility (9.30 EB5 occupied) // The current logic: only headdim=128 will use Dual PPTX using Scheduler = std::conditional_t< - Arch >= 90, + Arch >= 90 && !short_seqlen, std::conditional_t< (Predicate_for_Headdim && (kHeadDimV != 128 || kHeadDim != 128)) || No_Scheduler_Pipeline, - flash::PreemptivePersistentTileScheduler<_NumConsumerThreads, _NumProducerThreads, Split>, - flash::DualPreemptivePersistentTileExecutionScheduler<_NumConsumerThreads, _NumProducerThreads, Split> + flash::PreemptivePersistentTileScheduler<_NumConsumerThreads, _NumProducerThreads>, + flash::DualPreemptivePersistentTileExecutionScheduler<_NumConsumerThreads, _NumProducerThreads> >, flash::StaticPersistentTileScheduler >; @@ -98,9 +97,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0), make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0)); - if constexpr (Is_flashmask) { - flash::flashmask::prepare_block_maxmin(params, stream, true); - } + flash::flashmask::prepare_block_maxmin(params, stream, true); typename CollectiveMainloop::Arguments mainloop_args { static_cast(params.q_ptr), @@ -231,22 +228,20 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - BOOL_SWITCH(params.lt_start_ptr != nullptr, Is_flashmask, [&] { - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen && !Is_flashmask; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - BOOL_SWITCH(params.seqlen_k < 128 && params.seqlen_q < 128, ShortSeqlen, [&] { - // If the sequence length is (extremely) short, we should cut down the tile size - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, - sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, ShortSeqlen)) : 128; - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); - }); + static constexpr bool Enable_cluster = false; // disable cluster for now, may be of use in the future + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + BOOL_SWITCH(params.seqlen_k < 128 && params.seqlen_q < 128, ShortSeqlen, [&] { + // If the sequence length is (extremely) short, we should cut down the tile size + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, + sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, ShortSeqlen)) : 128; + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); }); }); }); diff --git a/csrc/flashmask_v2/mainloop_bwd_sm90_tma_gmma_ws.hpp b/csrc/flashmask_v2/mainloop_bwd_sm90_tma_gmma_ws.hpp index 253938dc562..431ec6c772e 100644 --- a/csrc/flashmask_v2/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/csrc/flashmask_v2/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -31,7 +31,7 @@ using namespace cute; template struct CollectiveMainloopBwdSm90 { diff --git a/csrc/flashmask_v2/mainloop_fwd_sm90_tma_gmma_ws.hpp b/csrc/flashmask_v2/mainloop_fwd_sm90_tma_gmma_ws.hpp index bce5acf0e94..74391b68409 100644 --- a/csrc/flashmask_v2/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/csrc/flashmask_v2/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -32,7 +32,7 @@ using namespace cute; template + bool MmaPV_is_RS, bool PackGQA_, bool V_colmajor_> struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; @@ -53,9 +53,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool AppendKV = AppendKV_; static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; - static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; - static constexpr bool Is_flashmask = Is_flashmask_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; static constexpr bool Use_TMA_Q = !PackGQA; static constexpr bool Use_TMA_KV = !PagedKVNonTMA; @@ -88,7 +86,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int Flashmask_n_block_finish = INT_MIN; // 0x80000000 using SeqlenInfo_t = flash::SeqlenInfoQKNewK; - using BlockMN_t = flash::BlockMN; + using BlockMN_t = flash::BlockMN; static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); @@ -369,11 +367,10 @@ struct CollectiveMainloopFwdSm90 { using TensorStorage = std::conditional_t; // These are tuned for speed. They don't affect correctness. - static constexpr bool UseSchedulerBarrier = (IntraWGOverlap - ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) - : NumMmaWarpGroups == 2) + static constexpr bool UseSchedulerBarrier = + ((NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128)) && !LargeHeadDimV; - static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); // Host side kernel arguments struct Arguments { @@ -623,8 +620,7 @@ struct CollectiveMainloopFwdSm90 { args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.window_size_left, args.window_size_right, - !Split ? 1 : args.num_splits, - args.kv_batch_idx, + 1 /* Split is always 1 */, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, args.seqused_q, args.seqused_k, args.leftpad_k, args.h_flashmask, args.h_h_flashmask_ratio, @@ -674,7 +670,20 @@ struct CollectiveMainloopFwdSm90 { int32_t bidh = get<1>(block_coord); int32_t bidb = get<2>(block_coord); // pad for fully 128B aligned load - const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + + // Note(heqianyue): compute the part that actually needs loading and computation for causal masks + const int nblock_seqlen = [&] { + if constexpr (Is_causal) { + const int m_block = get<0>(block_coord); + // Warn(heqianyue): abs_coord might be greater than unpadded nblock_seqlen (this is mathematically prove-able) + // but we don't need to clip to nblock_seqlen, since loading invalid data is fine, we can choose not to use it + // (mblock_id + 1) * kBlockM + diff_seqlen is the block ID of the rightmost valid block in causal, so + 1 for length + const int valid_nblock_seqlen = ((m_block + 1) * kBlockM + seqlen_info.seqlen_k - seqlen_info.seqlen_q) / kBlockN + 1; + return (valid_nblock_seqlen + 3) & 0xfffffffc; + } else { // original nblock_seqlen + return ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + } + } (); // change this to num_chunk * chunk_size (should be Flashmask_n_block_buffer_length) const int chunks_size = total_num_chunks * Flashmask_n_block_buffer_length; @@ -755,7 +764,9 @@ struct CollectiveMainloopFwdSm90 { int32_t const end_flag, int32_t const n_block_min, int32_t const n_block_max, + int32_t const nblock_seqlen, int32_t const seqlen_q, + int32_t const seqlen_k, int32_t* const __restrict__ flashmask_maxmin_smem, int32_t* const __restrict__ mask_encode_n_block_smem_, int32_t* const __restrict__ extra_flags) { @@ -804,6 +815,19 @@ struct CollectiveMainloopFwdSm90 { int32_t valid_n_block_num = 0; const int32_t base_offset = (total_num_chunks - 1 - reverse_chunk_idx) * Flashmask_n_block_buffer_valid_length; + constexpr int valid_buffer_length = Flashmask_n_block_buffer_valid_length; + const int nblock_start = [&]() { + if constexpr (Is_causal) { + // Warn(heqianyue): abs_coord might be greater than unpadded nblock_seqlen (this is mathematically prove-able) + // in order not to load invalid data (or even OOR, which is highly not possible), we clip it to nblock_seqlen + const int abs_coord = std::min(((m_block + 1) * kBlockM + seqlen_k - seqlen_q) / kBlockN, nblock_seqlen - 1); + return std::min(abs_coord - base_offset + 1, valid_buffer_length); + } else { + // Note(heqianyue): even for non-causal masks, we can skip some of the computation, for example: + // since buffer is 16k, if seqlen is 8k, we can skip 50% of the computation here, legit save! + return std::min(valid_buffer_length, nblock_seqlen); + } + } (); // explanation for the loop condition: // -2, -1, 0, 1, 2 @@ -813,8 +837,8 @@ struct CollectiveMainloopFwdSm90 { // Note(heqianyue): ute/lte will be seqlen_q (at most). Yet if m_block_e > seqlen_q, even if ute/lte are seqlen_q (masked to the end) // we will still consider the block as partially masked, adding unnecessary computation for those fully-masked blocks const int m_block_e = __viaddmin_s32(m_block_s, kBlockM, seqlen_q); // min(a + b, c) - for(int32_t idx = Flashmask_n_block_buffer_valid_length - 1 - thread_idx; // make sure thread_idx is in range [0, ProducerThreadNum) - idx >= (0 - (ProducerThreadNum - Flashmask_n_block_buffer_valid_length % ProducerThreadNum)); idx -= ProducerThreadNum + for(int32_t idx = nblock_start - 1 - thread_idx; // make sure thread_idx is in range [0, ProducerThreadNum) + idx >= (0 - (ProducerThreadNum - nblock_start % ProducerThreadNum)); idx -= ProducerThreadNum ) { int32_t n_block = base_offset + idx; int prefix_sum = 0; @@ -919,12 +943,11 @@ struct CollectiveMainloopFwdSm90 { int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, + seqlen_info, m_block, bidb, 0, 1, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. - if constexpr (Is_causal || Is_local || Varlen || Split) { + if constexpr (Is_causal || Is_local || Varlen) { if (n_block_max <= n_block_min) { return; } @@ -1005,7 +1028,7 @@ struct CollectiveMainloopFwdSm90 { // This is used to index into the batch dimension of mK and mV int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, @@ -1112,69 +1135,67 @@ struct CollectiveMainloopFwdSm90 { const int32_t* extra_flags_smem = extra_flags + n_block_pipe_read.index(); auto load_flashmask = [&] (auto const& smem_pipe_write) { - if constexpr (Is_flashmask) { - pipeline_flashmask_apply.producer_acquire(smem_pipe_write); + pipeline_flashmask_apply.producer_acquire(smem_pipe_write); + if(n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { int32_t* const flashmask_base_addr = flashmask_smem_ + smem_pipe_write.index() * 4 * kBlockN; - if(n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { - const int row_offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * seqlen_info.seqlen_k; - const int nb_mul_kBN = n_block * kBlockN; - const int loop_ub = std::min(kBlockN, seqlen_info.seqlen_k - nb_mul_kBN); - if (params.ut_start_ptr != nullptr) { - for(int idx = thread_idx; idx < loop_ub; idx += NumProducerThreads) { - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + idx)), - "l"(params.lt_start_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN + idx)), - "l"(params.lt_end_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN * 2 + idx)), - "l"(params.ut_start_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN * 3 + idx)), - "l"(params.ut_end_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - } - } else { + const int row_offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * seqlen_info.seqlen_k; + const int nb_mul_kBN = n_block * kBlockN; + const int loop_ub = std::min(kBlockN, seqlen_info.seqlen_k - nb_mul_kBN); + if (params.ut_start_ptr != nullptr) { + for(int idx = thread_idx; idx < loop_ub; idx += NumProducerThreads) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + idx)), + "l"(params.lt_start_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN + idx)), + "l"(params.lt_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN * 2 + idx)), + "l"(params.ut_start_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN * 3 + idx)), + "l"(params.ut_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + } + } else { + + for(int idx = thread_idx; idx < loop_ub; idx += NumProducerThreads) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + idx)), + "l"(params.lt_start_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); - for(int idx = thread_idx; idx < loop_ub; idx += NumProducerThreads) { - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + idx)), - "l"(params.lt_start_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - - if constexpr (Is_causal) { - if(params.lt_end_ptr != nullptr) { - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN + idx)), - "l"(params.lt_end_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - } - } else { - if(params.ut_end_ptr != nullptr) { - asm volatile( - "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + 3 * kBlockN + idx)), - "l"(params.ut_end_ptr + row_offset + nb_mul_kBN + idx), - "n"(4)); - } + if constexpr (Is_causal) { + if(params.lt_end_ptr != nullptr) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN + idx)), + "l"(params.lt_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + } + } else { + if(params.ut_end_ptr != nullptr) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + 3 * kBlockN + idx)), + "l"(params.ut_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); } } - } - asm volatile("cp.async.commit_group;\n" ::); - asm volatile("cp.async.wait_group 0;\n" ::); } - pipeline_flashmask_apply.producer_commit(smem_pipe_write); + } + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_group 0;\n" ::); } + pipeline_flashmask_apply.producer_commit(smem_pipe_write); }; auto n_block_wait = [](auto& pipeline, auto& smem_pipe_read) { @@ -1227,11 +1248,11 @@ struct CollectiveMainloopFwdSm90 { if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, TMA::CacheHintSm90::EVICT_FIRST), tQgQ, tQsQ); if constexpr (HasQv) { shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); - copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, TMA::CacheHintSm90::EVICT_FIRST), tQvgQv, tQvsQv); } } @@ -1262,11 +1283,9 @@ struct CollectiveMainloopFwdSm90 { shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");} + // we do not what to load flashmask unconditionally, since it would be of waste load_flashmask(smem_pipe_write); - if constexpr (!Transpose_V && !IntraWGOverlap) { - if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } - } int n_block_prev = n_block; n_block = n_block_getter(++n_block_idx); @@ -1285,18 +1304,13 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); if constexpr (!Transpose_V) { - if constexpr (IntraWGOverlap) { - load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); - } else { - load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - } + load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); } } n_block_prev = n_block; if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } load_flashmask(smem_pipe_write); - } if (n_block == Flashmask_n_block_chunk_end) { @@ -1313,7 +1327,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_n_block.consumer_release(n_block_pipe_read); ++n_block_pipe_read; - if constexpr (!Transpose_V && IntraWGOverlap) { + if constexpr (!Transpose_V) { if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } } if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); } @@ -1491,13 +1505,12 @@ struct CollectiveMainloopFwdSm90 { int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, + seqlen_info, m_block, bidb, 0, 1, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier - if constexpr (Is_causal || Is_local || Varlen || Split) { + if constexpr (Is_causal || Is_local || Varlen) { if (n_block_max <= n_block_min) { return false; } } @@ -1678,139 +1691,134 @@ struct CollectiveMainloopFwdSm90 { cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } - if constexpr (IntraWGOverlap) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } + scoremod_premask_fn(tSrS); + mask.template apply(tSrS, m_block, n_block); + + consumer_wait(pipeline_flashmask_apply, smem_pipe_read); + if (n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { + if (params.ut_start_ptr) { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, params.lt_end_ptr, + params.ut_start_ptr, params.ut_end_ptr); + } else if (params.lt_end_ptr || params.ut_end_ptr) { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, params.lt_end_ptr, + nullptr, params.ut_end_ptr); + } else { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, nullptr, nullptr, nullptr); + } + } + pipeline_flashmask_apply.consumer_release(smem_pipe_read); + + Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + + softmax.template online_softmax(tSrS); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } + n_block = n_block_getter(++n_block_idx); + + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. + auto fwd_step = [&](int const n_block, auto mask_fn) { + static constexpr bool Check_inf = true; + PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); + ++smem_pipe_read; + // PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); + // ++smem_pipe_read; Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read); + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } + warp_scheduler_barrier_sync(); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read); + if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read); // release K if constexpr (HasQv) { - shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V consumer_wait(pipeline_v, smem_pipe_read); flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); } scoremod_premask_fn(tSrS); - mask.template apply(tSrS, m_block, n_block); + mask_fn(tSrS, n_block); - if constexpr(Is_flashmask) { - consumer_wait(pipeline_flashmask_apply, smem_pipe_read); - if (n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { + consumer_wait(pipeline_flashmask_apply, smem_pipe_read); + if (n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { if (params.ut_start_ptr) { flashmask_apply( tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, - params.lt_start_ptr, params.lt_end_ptr, - params.ut_start_ptr, params.ut_end_ptr); + params.lt_start_ptr, params.lt_end_ptr, params.ut_start_ptr, params.ut_end_ptr); } else if (params.lt_end_ptr || params.ut_end_ptr) { flashmask_apply( tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, - params.lt_start_ptr, params.lt_end_ptr, - nullptr, params.ut_end_ptr); + params.lt_start_ptr, params.lt_end_ptr, nullptr, params.ut_end_ptr); } else { flashmask_apply( tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, params.lt_start_ptr, nullptr, nullptr, nullptr); } - } - pipeline_flashmask_apply.consumer_release(smem_pipe_read); } + pipeline_flashmask_apply.consumer_release(smem_pipe_read); - Tensor scores_scale = softmax.template max_get_scale(tSrS); - // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f - - softmax.template online_softmax(tSrS); + cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } + softmax.template online_softmax(tSrS); + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); - Tensor tOrP = make_tensor_like(tOrP_acc); - convert_type_out(tOrP_acc, tOrP); + convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } - n_block = n_block_getter(++n_block_idx); - - // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - clear(tOrO); - // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; - - // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. - auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { - static constexpr bool Check_inf = decltype(check_inf_type)::value; - PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); - ++smem_pipe_read; - // PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); - // ++smem_pipe_read; - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } - warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr(!HasQv) { - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } - } - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - warp_scheduler_barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read); // release K - if constexpr (HasQv) { - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); - } - scoremod_premask_fn(tSrS); - mask_fn(tSrS, n_block); - - if constexpr (Is_flashmask) { - consumer_wait(pipeline_flashmask_apply, smem_pipe_read); - if (n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { - if (params.ut_start_ptr) { - flashmask_apply( - tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, - params.lt_start_ptr, params.lt_end_ptr, params.ut_start_ptr, params.ut_end_ptr); - } else if (params.lt_end_ptr || params.ut_end_ptr) { - flashmask_apply( - tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, - params.lt_start_ptr, params.lt_end_ptr, nullptr, params.ut_end_ptr); - } else { - flashmask_apply( - tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, - params.lt_start_ptr, nullptr, nullptr, nullptr); - } - } - pipeline_flashmask_apply.consumer_release(smem_pipe_read); - } + }; - cute::copy(softmax.template max_get_scale(tSrS), scores_scale); - if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } - softmax.template online_softmax(tSrS); - if constexpr (!HasQv) { - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + int const n_block_min_before_local_mask = !Is_local + ? n_block_min + : std::max(n_block_min, + cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + + if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking + auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); + int const n_block_min_causal_local_mask = + std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + for(; n_block >= n_block_min_causal_local_mask || n_block == Flashmask_n_block_chunk_end;) { + #pragma unroll 1 + for (; n_block >= n_block_min_causal_local_mask; n_block = n_block_getter(++n_block_idx)) { + fwd_step(n_block, mask_fn); } - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); - if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } - if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } - }; - - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); - - if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking - auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); - for(; n_block >= n_block_min_causal_local_mask || n_block == Flashmask_n_block_chunk_end;) { - #pragma unroll 1 - for (; n_block >= n_block_min_causal_local_mask; n_block = n_block_getter(++n_block_idx)) { - fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); - } - if (n_block == Flashmask_n_block_chunk_end) { + if (n_block == Flashmask_n_block_chunk_end) { pipeline_n_block.consumer_release(n_block_pipe_read); ++n_block_pipe_read; consumer_wait(pipeline_n_block, n_block_pipe_read); @@ -1818,17 +1826,17 @@ struct CollectiveMainloopFwdSm90 { extra_flags_smem = extra_flags + n_block_pipe_read.index(); n_block_idx = 0; n_block = n_block_getter(0); - } } } + } - auto no_mask_fn = [](auto& tSrS, int n_block) { }; - for(; n_block >= n_block_min_before_local_mask || n_block == Flashmask_n_block_chunk_end;) { - #pragma unroll 1 - for (; n_block >= n_block_min_before_local_mask; n_block = n_block_getter(++n_block_idx)) { - fwd_step(n_block, no_mask_fn, cute::bool_constant{} /*check_inf*/); - } - if (n_block == Flashmask_n_block_chunk_end) { + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + for(; n_block >= n_block_min_before_local_mask || n_block == Flashmask_n_block_chunk_end;) { + #pragma unroll 1 + for (; n_block >= n_block_min_before_local_mask; n_block = n_block_getter(++n_block_idx)) { + fwd_step(n_block, no_mask_fn); + } + if (n_block == Flashmask_n_block_chunk_end) { pipeline_n_block.consumer_release(n_block_pipe_read); ++n_block_pipe_read; consumer_wait(pipeline_n_block, n_block_pipe_read); @@ -1836,18 +1844,18 @@ struct CollectiveMainloopFwdSm90 { extra_flags_smem = extra_flags + n_block_pipe_read.index(); n_block_idx = 0; n_block = n_block_getter(0); - } } + } - // Separate masking iterations on the left for local attention - if constexpr (Is_local) { - auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - for(; n_block >= n_block_min || n_block == Flashmask_n_block_chunk_end;) { - #pragma unroll 1 - for (; n_block >= n_block_min; n_block = n_block_getter(++n_block_idx)) { - fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); - } - if (n_block == Flashmask_n_block_chunk_end) { + // Separate masking iterations on the left for local attention + if constexpr (Is_local) { + auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + for(; n_block >= n_block_min || n_block == Flashmask_n_block_chunk_end;) { + #pragma unroll 1 + for (; n_block >= n_block_min; n_block = n_block_getter(++n_block_idx)) { + fwd_step(n_block, local_mask_fn); + } + if (n_block == Flashmask_n_block_chunk_end) { pipeline_n_block.consumer_release(n_block_pipe_read); ++n_block_pipe_read; consumer_wait(pipeline_n_block, n_block_pipe_read); @@ -1855,31 +1863,28 @@ struct CollectiveMainloopFwdSm90 { extra_flags_smem = extra_flags + n_block_pipe_read.index(); n_block_idx = 0; n_block = n_block_getter(0); - } } } - pipeline_n_block.consumer_release(n_block_pipe_read); - ++n_block_pipe_read; - // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); - float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; - cute::copy(softmax.finalize(v_descale), scores_scale); - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - store_scales(scores_scale, smem_pipe_read.index()); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang - softmax.rescale_o(tOrO, scores_scale); - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } - ++smem_pipe_read; - } else { // No intra-WG overlap - static_assert(!Is_flashmask, "flashmaskv2 does not support no intra-wg overlap"); } + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + // Tell producers that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; ++work_idx; return true; } @@ -1903,12 +1908,11 @@ struct CollectiveMainloopFwdSm90 { int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, + seqlen_info, m_block, bidb, 0, 1, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier - if constexpr (Is_causal || Is_local || Varlen || Split) { + if constexpr (Is_causal || Is_local || Varlen) { if (n_block_max <= n_block_min) { return false; } } @@ -1995,9 +1999,9 @@ struct CollectiveMainloopFwdSm90 { int const work_idx ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [m_block, bidh, bidb, __split] = block_coord; auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, + seqlen_info, m_block, bidb, 0, 1, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } @@ -2097,9 +2101,9 @@ struct CollectiveMainloopFwdSm90 { SeqlenInfo_t const& seqlen_info, cute::tuple block_coord ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [m_block, bidh, bidb, __split] = block_coord; auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, + seqlen_info, m_block, bidb, 0, 1, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } diff --git a/csrc/flashmask_v2/tile_scheduler.hpp b/csrc/flashmask_v2/tile_scheduler.hpp index 6845a9de012..567c9641fbd 100644 --- a/csrc/flashmask_v2/tile_scheduler.hpp +++ b/csrc/flashmask_v2/tile_scheduler.hpp @@ -14,14 +14,6 @@ namespace flash { /////////////////////////////////////////////////////////////////////////////// -#define DEFINE_DUMMY_NOTIFY_FUNCS \ - CUTLASS_DEVICE \ - void \ - producer_notify() const {} \ - CUTLASS_DEVICE \ - void \ - consumer_notify() const {} - // Host side kernel arguments struct TileSchedulerArguments { // num_head is num_head_q if not PackGQA, else num_head_k @@ -36,13 +28,34 @@ struct TileSchedulerArguments { int const* const num_splits_dynamic_ptr = nullptr; }; +// method / static vars needed for every scheduler +// overwrite some of the methods / vars in the derived class, if needed +class TileSchedulerBase { +public: + static constexpr bool pipelining = false; + static constexpr int stride = 1; +public: + + CUTLASS_DEVICE + TileSchedulerBase() {} + + CUTLASS_DEVICE void producer_notify() const {} + CUTLASS_DEVICE void consumer_notify() const {} + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } + + CUTLASS_DEVICE + void + init_consumer() const {} +}; + /////////////////////////////////////////////////////////////////////////////// template -class SingleTileScheduler { +class SingleTileScheduler: public TileSchedulerBase { public: - static constexpr bool pipelining = false; - using SharedStorage = int; // Device side kernel params @@ -72,7 +85,6 @@ class SingleTileScheduler { return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; } - DEFINE_DUMMY_NOTIFY_FUNCS struct WorkTileInfo { int block_idx = 0; @@ -125,10 +137,6 @@ class SingleTileScheduler { return work_info; } - CUTLASS_DEVICE - void - init_consumer() const {} - CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} @@ -139,19 +147,14 @@ class SingleTileScheduler { get_next_work(Params const& params, WorkTileInfo const& current_work) const { return {0, 0, -1, 0}; } - - template - CUTLASS_DEVICE - constexpr uint32_t stage() const noexcept { return 0; } }; /////////////////////////////////////////////////////////////////////////////// template -class StaticPersistentTileScheduler { +class StaticPersistentTileScheduler: public TileSchedulerBase { public: - static constexpr bool pipelining = false; using SharedStorage = int; // Device side kernel params @@ -173,7 +176,6 @@ class StaticPersistentTileScheduler { return {uint32_t(num_sm)}; } - DEFINE_DUMMY_NOTIFY_FUNCS struct WorkTileInfo { int tile_idx; @@ -208,10 +210,6 @@ class StaticPersistentTileScheduler { return {int(blockIdx.x)}; } - CUTLASS_DEVICE - void - init_consumer() const {} - CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} @@ -222,14 +220,10 @@ class StaticPersistentTileScheduler { get_next_work(Params const& params, WorkTileInfo const& current_work) const { return {current_work.tile_idx + int(gridDim.x)}; } - - template - CUTLASS_DEVICE - constexpr uint32_t stage() const noexcept { return 0; } }; -template -class PreemptivePersistentTileScheduler { +template +class PreemptivePersistentTileScheduler: public TileSchedulerBase { // **PPT** scheduler: performs correct synchronization for producer (generate_n_block) and consumer (KV load and computation pipeline) // This scheduler has the same coordinate computation logic as StaticPersistentTileSch, the difference is that // we employ a preemptive scheduling strategy based on a rough estimation of the workload for the consumer @@ -239,7 +233,7 @@ class PreemptivePersistentTileScheduler { static constexpr int NumThreads = NumConsumerThreads + NumProducerThreads; public: using SharedStorage = int; - static constexpr bool pipelining = false; + static constexpr int stride = Stride; protected: SharedStorage* const tile_count_smem; @@ -248,18 +242,18 @@ class PreemptivePersistentTileScheduler { // Device side kernel params struct Params { - int total_blocks; - cutlass::FastDivmod m_block_divmod, head_divmod; - cutlass::FastDivmod nsplits_divmod; + const int total_blocks; + const cutlass::FastDivmod m_block_divmod, head_divmod; int* const tile_count_semaphore; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { assert(args.tile_count_semaphore != nullptr); - return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), - cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), - cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore}; + return {args.num_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), + cutlass::FastDivmod(args.num_head), + args.tile_count_semaphore}; } static dim3 @@ -281,16 +275,11 @@ class PreemptivePersistentTileScheduler { get_block_coord(Params const& params) const { int block, bidh, bidb; bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); - int split_idx = 0; - if constexpr (Split) { - bidh = params.nsplits_divmod.divmod(split_idx, bidh); - } - return {block, bidh, bidb, split_idx}; + return {block, bidh, bidb, 0}; } }; - DEFINE_DUMMY_NOTIFY_FUNCS CUTLASS_DEVICE PreemptivePersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; @@ -308,7 +297,7 @@ class PreemptivePersistentTileScheduler { // prefetch_next_work even before SM2 calls get_initial_work, then SM1 will risk computing the same block as SM2. // for the initial work: assign deterministically - return {int(blockIdx.x)}; + return {int(blockIdx.x) * stride}; } CUTLASS_DEVICE @@ -325,8 +314,8 @@ class PreemptivePersistentTileScheduler { prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { // only producer will call this method if (threadIdx.x == 96) { // hard-coded, since n_block producer threads are in [32, 128) - // the next job we are going to process: number of currently blocks done - current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1); + // the next job we are going to process: number of blocks currently done + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, stride); } } @@ -352,18 +341,14 @@ class PreemptivePersistentTileScheduler { return {tile_idx}; } } - - template - CUTLASS_DEVICE - constexpr uint32_t stage() const noexcept { return 0; } }; + template -class BwdPreemptivePersistentTileScheduler { +class BwdPreemptivePersistentTileScheduler: public TileSchedulerBase { static constexpr int NumThreads = NumConsumerThreads + NumProducerThreads; public: using SharedStorage = int; - static constexpr bool pipelining = false; protected: SharedStorage* const tile_count_smem; @@ -372,8 +357,8 @@ class BwdPreemptivePersistentTileScheduler { // Device side kernel params struct Params { - int total_blocks; - cutlass::FastDivmod m_block_divmod, head_divmod; + const int total_blocks; + const cutlass::FastDivmod m_block_divmod, head_divmod; int* const tile_count_semaphore; }; @@ -422,12 +407,6 @@ class BwdPreemptivePersistentTileScheduler { return {int(blockIdx.x)}; } - CUTLASS_DEVICE - void - init_consumer() const { - // flash::named_barrier_arrive(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemEmpty) /*id*/); - } - CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} @@ -463,15 +442,11 @@ class BwdPreemptivePersistentTileScheduler { // how to make sure consumers can actually get this? return {*tile_count_smem}; } - - template - CUTLASS_DEVICE - constexpr uint32_t stage() const noexcept { return 0; } }; -template -class DualPreemptivePersistentTileExecutionScheduler { +template +class DualPreemptivePersistentTileExecutionScheduler: public TileSchedulerBase { // **PPT** scheduler: performs correct synchronization for producer (generate_n_block) and consumer (KV load and computation pipeline) // This scheduler has the same coordinate computation logic as StaticPersistentTileSch, the difference is that // we employ a preemptive scheduling strategy based on a rough estimation of the workload for the consumer @@ -491,18 +466,18 @@ class DualPreemptivePersistentTileExecutionScheduler { // Device side kernel params struct Params { - int total_blocks; - cutlass::FastDivmod m_block_divmod, head_divmod; - cutlass::FastDivmod nsplits_divmod; + const int total_blocks; + const cutlass::FastDivmod m_block_divmod, head_divmod; int* const tile_count_semaphore; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { assert(args.tile_count_semaphore != nullptr); - return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), - cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), - cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore}; + return {args.num_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), + cutlass::FastDivmod(args.num_head), + args.tile_count_semaphore}; } static dim3 @@ -524,11 +499,7 @@ class DualPreemptivePersistentTileExecutionScheduler { get_block_coord(Params const& params) const { int block, bidh, bidb; bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); - int split_idx = 0; - if constexpr (Split) { - bidh = params.nsplits_divmod.divmod(split_idx, bidh); - } - return {block, bidh, bidb, split_idx}; + return {block, bidh, bidb, 0}; } }; @@ -559,12 +530,6 @@ class DualPreemptivePersistentTileExecutionScheduler { return {int(blockIdx.x)}; } - DEFINE_DUMMY_NOTIFY_FUNCS - - CUTLASS_DEVICE - void - init_consumer() const { /* Init is done in get_initial work, therefore no need to repeat. */ } - CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { @@ -615,399 +580,4 @@ class DualPreemptivePersistentTileExecutionScheduler { } }; -template -class DynamicPersistentTileScheduler { - - // This scheduler targets the causal (or local) case where each tile takes different - // amount of time. We use longest-processing-time-first scheduling: - // the longest remaining tile is assigned to the first SM that's free. - // SM indicates they are free by incrementing a semaphore. - // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling - // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads. - // This is the L2 swizzling part. The size of each section is precomputed based on the - // size of K & V and the L2 cache size. - - static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); - static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + (Is_flashmask ? 128 : NumProducerThreads) : NumMmaThreads; - -public: - using SharedStorage = int; - static constexpr bool pipelining = false; -protected: - SharedStorage* const tile_count_smem; - -public: - - // Device side kernel params - struct Params { - int const total_blocks; - cutlass::FastDivmod const m_block_divmod, head_divmod; - cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; - cutlass::FastDivmod const l2_minor_residual_divmod; - int const num_hb_quotient; - int* const tile_count_semaphore; - }; - - static Params - to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; - int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V - // Swizzle is the size of each "section". Round swizzle to a power of 2 - // If not PackGQA already, the size of each section can increase by qhead_per_khead - // Need to be careful about the case where only one head will fit - int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << cutlass::find_log2(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; - int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); - assert(args.tile_count_semaphore != nullptr); - return {num_split_blocks * args.num_head * args.num_batch, - cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), - cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), - // don't divide by 0 - cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), - (args.num_head * args.num_batch) / swizzle, - args.tile_count_semaphore}; - } - - static dim3 - get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(num_sm)}; - } - - struct WorkTileInfo { - int tile_idx; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return tile_idx < params.total_blocks; - } - - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - int block, bidh, bidb; - int l2_mod, bidhb, bidhb_residual; - bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - if (bidhb < params.num_hb_quotient) { - block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); - } else { - block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); - } - bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); - int split_idx = 0; - if constexpr (Split) { - split_idx = params.m_block_divmod.divmod(block, block); - } - // Longest-processing-time-first - block = params.m_block_divmod.divisor - 1 - block; - return {block, bidh, bidb, split_idx}; - } - - }; - - CUTLASS_DEVICE - DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; - - template - CUTLASS_DEVICE - WorkTileInfo - get_initial_work(Params const& params) const { - return {int(blockIdx.x)}; - } - - DEFINE_DUMMY_NOTIFY_FUNCS - - CUTLASS_DEVICE - void - init_consumer() const { - if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); - } - } - - CUTLASS_DEVICE - void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { - if (threadIdx.x % NumProducerThreads == 0) { - current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); - } - } - - template - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - if constexpr (IsProducerWarp) { - // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 - int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); - if (threadIdx.x % NumProducerThreads == 0) { - *tile_count_smem = current_work.tile_idx; - } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); - return {new_tile_idx}; - } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); - int tile_idx = *tile_count_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); - return {tile_idx}; - } - } - - template - CUTLASS_DEVICE - constexpr uint32_t stage() const noexcept { return 0; } -}; - -template -class VarlenDynamicPersistentTileScheduler { - - static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); - static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; - -public: - using SharedStorage = int4; - static constexpr bool pipelining = false; -protected: - SharedStorage* const work_info_smem; - -public: - - // Device side kernel params - struct Params { - int num_head, num_batch; - int const qhead_per_khead; - int const seqlen; - cutlass::FastDivmod head_divmod; - cutlass::FastDivmod nsplits_divmod; - int* const tile_count_semaphore; - int const* const cu_seqlens; - int const* const seqused; - // int* const num_m_blocks_ptr; - int const* const num_splits_dynamic_ptr; - }; - - static Params - to_underlying_arguments(TileSchedulerArguments const& args) { - // If Split, for the purpose of scheduling, we pretend that instead there are - // (args.num_splits * args.num_head) number of heads. - assert(args.tile_count_semaphore != nullptr); - assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx - assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits - return {args.num_head, args.num_batch, - args.qhead_per_khead, args.seqlen, - cutlass::FastDivmod(args.num_head), - cutlass::FastDivmod(!Split ? 1 : args.num_splits), - args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; - } - - static dim3 - get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(num_sm)}; - } - - struct WorkTileInfo { - int tile_idx, block, bidh, bidb; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } - return bidb < params.num_batch; - } - - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; - } else { - // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh); - uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; - int bidh_actual = reinterpret_cast(bidh_actual_u); - // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx - uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); - int split_idx = reinterpret_cast(split_idx_u); - // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - // if (threadIdx.x == 128) { - // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); - // } - return {block, bidh_actual, bidb, split_idx}; - } - } - }; - - CUTLASS_DEVICE - VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; - - DEFINE_DUMMY_NOTIFY_FUNCS - - CUTLASS_DEVICE - WorkTileInfo - tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks = [&] (int bidb_start) { - int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; - } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - } - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; - // ? params.num_m_blocks_ptr[batch_idx] : 0; - }; - - auto get_num_splits = [&] (int bidb_start) { - int batch_idx = lane + bidb_start; - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : (params.num_splits_dynamic_ptr - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor)) - : 0; - }; - - int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane - int num_splits = get_num_splits(current_work.bidb); - int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; - // Cumulative number of blocks for the next 31 batches - int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); - // Total number of blocks for the next 31 batches - int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - // Only the lower 16 bits are the actual bidh - int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); - int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes - if constexpr (Split) { - int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; - group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); - } - int bidb = current_work.bidb; - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); - // } - // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } - while (group_end_tile <= next_tile_idx) { - bidb += cutlass::NumThreadsPerWarp - 1; - if (bidb >= params.num_batch) { - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); - // } - return {next_tile_idx, 0, 0, params.num_batch}; - } - num_m_blocks = get_num_m_blocks(bidb); - num_splits = get_num_splits(bidb); - num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; - num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); - m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - group_end_tile += m_blocks_in_group * params.num_head; - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); - // } - } - int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; - // The next problem to process is the first one that does not have ending tile position - // that is greater than or equal to tile index. - int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); - // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } - bidb += batch_idx_in_group; - num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); - if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - if constexpr (Split) { - int bidh_actual = bidh / num_splits; - int split_idx = bidh - bidh_actual * num_splits; - // TODO: idk why this gives wrong answer nondeterministically - // int bidh_actual, split_idx; - // split_idx = params.head_divmod.divmod(bidh_actual, bidh); - // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); - // } - bidh = reinterpret_cast(bidh_packed); - } - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; - } - - template - CUTLASS_DEVICE - WorkTileInfo - get_initial_work(Params const& params) const { - if constexpr (IsProducerWarp) { - WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); - if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); - } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); - return work_info; - } else { - return get_next_work(params, {0, 0, 0, 0}); - } - } - - CUTLASS_DEVICE - void - init_consumer() const { - // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that - } - - CUTLASS_DEVICE - void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { - if (threadIdx.x % NumProducerThreads == 0) { - current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); - } - } - - template - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - if constexpr (IsProducerWarp) { - // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 - int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; - work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); - if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); - } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); - return work_info; - } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); - int4 work_info = *work_info_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); - return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; - } - } - - template - CUTLASS_DEVICE - constexpr uint32_t stage() const noexcept { return 0; } -}; - } // flash diff --git a/csrc/flashmask_v2/tile_size.h b/csrc/flashmask_v2/tile_size.h index 30c7e7086a5..1e77ff3abd7 100644 --- a/csrc/flashmask_v2/tile_size.h +++ b/csrc/flashmask_v2/tile_size.h @@ -22,7 +22,8 @@ constexpr std::tuple tile_size_fwd_sm90( // Switch to tile size 192 x 192 for now bool const use_blockN_128 = is_causal || is_local; // return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; - return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim}; + // FlashMask does not support IntraWGOverlap = false + return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, true}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) {