diff --git a/csrc/flashmask_v2/flash_fwd_kernel_sm90.h b/csrc/flashmask_v2/flash_fwd_kernel_sm90.h index d94d8879c3d..23f761d91d3 100644 --- a/csrc/flashmask_v2/flash_fwd_kernel_sm90.h +++ b/csrc/flashmask_v2/flash_fwd_kernel_sm90.h @@ -87,8 +87,32 @@ class FlashAttnFwdSm90 { // static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32); // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160); - static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32); - static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160); + // static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32); + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160); + + static constexpr int kHeadDim = CollectiveMainloop::kHeadDim; + + static constexpr uint32_t NBlockRegisterRequirement = [] { + if constexpr (kHeadDim <= 64) { + return 56; + } else { + return NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32); + } + }(); + static constexpr uint32_t LoadRegisterRequirement = [] { + if constexpr (kHeadDim <= 64) { + return 32; + } else { + return NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32); + } + }(); + static constexpr uint32_t MmaRegisterRequirement = [] { + if constexpr (kHeadDim <= 64) { + return 224; + } else { + return NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160); + } + }(); // If you want to print from the producer warp, you'd need to increase the number of registers // Otherwise you'll get CUDA error. @@ -272,7 +296,7 @@ class FlashAttnFwdSm90 { TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); if (warp_group_idx == 0 && warp_idx_in_warpgroup != 0) { // n_block generator - cutlass::arch::warpgroup_reg_dealloc(); + cutlass::arch::warpgroup_reg_dealloc(); cutlass::PipelineState n_block_pipe_write = cutlass::make_producer_start_state(); // Manually specify the scheduler role: producer. For StaticPersistentTileSch, passing template args won't change the behavior for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); @@ -556,4 +580,4 @@ class FlashAttnFwdSm90 { }; -} // namespace flash +} // namespace flash \ No newline at end of file diff --git a/csrc/flashmask_v2/tile_size.h b/csrc/flashmask_v2/tile_size.h index 30c7e7086a5..08d1c140a54 100644 --- a/csrc/flashmask_v2/tile_size.h +++ b/csrc/flashmask_v2/tile_size.h @@ -15,16 +15,17 @@ constexpr std::tuple tile_size_fwd_sm90( return {64, 64, false, true}; } if (headdim <= 64) { - bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 + // bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; + // bool const use_blockN_128 = is_causal || is_local; // return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; - return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim}; + // return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; + return {128, 128, true, true}; } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) {