diff --git a/.gitmodules b/.gitmodules index 52fe5359e0d..9e7b7c207a6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "csrc/flashmask_v2/cutlass"] path = csrc/flashmask_v2/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "flashmask/flash_mask/flashmask_attention_v3/cutlass"] + path = flashmask/flash_mask/flashmask_attention_v3/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/flashmask/flash_mask/CMakeLists.txt b/flashmask/flash_mask/CMakeLists.txt new file mode 100644 index 00000000000..e21494ef841 --- /dev/null +++ b/flashmask/flash_mask/CMakeLists.txt @@ -0,0 +1,292 @@ +cmake_minimum_required(VERSION 3.9 FATAL_ERROR) +project(flash-attention LANGUAGES CXX CUDA) + +find_package(Git REQUIRED) + +# 会触发所有submodule的下载 +execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE GIT_SUBMOD_RESULT) + +option(SKIP_BUILD_FA "Enable compile with FA3" OFF) +option(WITH_FLASHATTN_V3 "Enable compile with FA3" OFF) + +if(NOT GIT_SUBMOD_RESULT EQUAL 0) + message(FATAL_ERROR "Failed to update Git submodules") +endif() + +if(NOT SKIP_BUILD_FA) + + set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) + + + find_package(PythonInterp REQUIRED) + + if(WITH_FLASHATTN_V3) + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" flashmask_attention_v3/generate_kernels.py -o flashmask_attention_v3/instantiations + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" + RESULT_VARIABLE result + OUTPUT_VARIABLE output + ERROR_VARIABLE error + ) + + option(DISABLE_FLASHMASK_V3_FP16 "Disable FP16 for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_FP8 "Disable FP8 for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_HDIM64 "Disable HDIM64 for flashmask_v3" OFF) + option(DISABLE_FLASHMASK_V3_HDIM96 "Disable HDIM96 for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_HDIM128 "Disable HDIM128 for flashmask_v3" OFF) + option(DISABLE_FLASHMASK_V3_HDIM192 "Disable HDIM192 for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_HDIM256 "Disable HDIM256 for flashmask_v3" OFF) + option(DISABLE_FLASHMASK_V3_SPLIT "Disable Split for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_PAGEDKV "Disable PagedKV for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_SOFTCAP "Disable Softcap for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_PACKGQA "Disable PackGQA for flashmask_v3" ON) + option(DISABLE_FLASHMASK_V3_BACKWARD "Disable Backward for flashmask_v3" OFF) + option(DISABLE_FLASHMASK_V3_SM8X "Disable SM8x for flashmask_v3" ON) + + if(DISABLE_FLASHMASK_V3_FP16) + add_compile_definitions(FLASHMASK_V3_DISABLE_FP16) + endif() + + if(DISABLE_FLASHMASK_V3_FP8) + add_compile_definitions(FLASHMASK_V3_DISABLE_FP8) + endif() + + if(DISABLE_FLASHMASK_V3_HDIM64) + add_compile_definitions(FLASHMASK_V3_DISABLE_HDIM64) + endif() + + if(DISABLE_FLASHMASK_V3_HDIM96) + add_compile_definitions(FLASHMASK_V3_DISABLE_HDIM96) + endif() + + if(DISABLE_FLASHMASK_V3_HDIM128) + add_compile_definitions(FLASHMASK_V3_DISABLE_HDIM128) + endif() + + if(DISABLE_FLASHMASK_V3_HDIM192) + add_compile_definitions(FLASHMASK_V3_DISABLE_HDIM192) + endif() + + if(DISABLE_FLASHMASK_V3_HDIM256) + add_compile_definitions(FLASHMASK_V3_DISABLE_HDIM256) + endif() + + if(DISABLE_FLASHMASK_V3_SPLIT) + add_compile_definitions(FLASHMASK_V3_DISABLE_SPLIT) + endif() + + if(DISABLE_FLASHMASK_V3_PAGEDKV) + add_compile_definitions(FLASHMASK_V3_DISABLE_PAGEDKV) + endif() + + if(DISABLE_FLASHMASK_V3_SOFTCAP) + add_compile_definitions(FLASHMASK_V3_DISABLE_SOFTCAP) + endif() + + if(DISABLE_FLASHMASK_V3_PACKGQA) + add_compile_definitions(FLASHMASK_V3_DISABLE_PACKGQA) + endif() + + if(DISABLE_FLASHMASK_V3_BACKWARD) + add_compile_definitions(FLASHMASK_V3_DISABLE_BACKWARD) + endif() + + if(DISABLE_FLASHMASK_V3_SM8X) + add_compile_definitions(FLASHMASK_V3_DISABLE_SM8X) + endif() + + if(NOT result EQUAL 0) + message(FATAL_ERROR "Generating flashmask_v3 Python script execution failed with exit code ${result}: ${error}") + endif() + + # 以 flashmask_v3 为前缀的变量逻辑 + set(FLASHMASKV3_DTYPE_FWD_SM80 "bf16") + if(NOT DISABLE_FLASHMASK_V3_FP16) + list(APPEND FLASHMASKV3_DTYPE_FWD_SM80 "fp16") + endif() + + set(FLASHMASKV3_DTYPE_FWD_SM90 "bf16") + if(NOT DISABLE_FLASHMASK_V3_FP16) + list(APPEND FLASHMASKV3_DTYPE_FWD_SM90 "fp16") + endif() + if(NOT DISABLE_FLASHMASK_V3_FP8) + list(APPEND FLASHMASKV3_DTYPE_FWD_SM90 "e4m3") + endif() + + set(FLASHMASKV3_DTYPE_BWD "bf16") + if(NOT DISABLE_FLASHMASK_V3_FP16) + list(APPEND FLASHMASKV3_DTYPE_BWD "fp16") + endif() + + set(FLASHMASKV3_HEAD_DIMENSIONS_BWD) + if(NOT DISABLE_FLASHMASK_V3_HDIM64) + list(APPEND FLASHMASKV3_HEAD_DIMENSIONS_BWD 64) + endif() + if(NOT DISABLE_FLASHMASK_V3_HDIM96) + list(APPEND FLASHMASKV3_HEAD_DIMENSIONS_BWD 96) + endif() + if(NOT DISABLE_FLASHMASK_V3_HDIM128) + list(APPEND FLASHMASKV3_HEAD_DIMENSIONS_BWD 128) + endif() + if(NOT DISABLE_FLASHMASK_V3_HDIM192) + list(APPEND FLASHMASKV3_HEAD_DIMENSIONS_BWD 192) + endif() + if(NOT DISABLE_FLASHMASK_V3_HDIM256) + list(APPEND FLASHMASKV3_HEAD_DIMENSIONS_BWD 256) + endif() + + # Disable diff, not support headdim != headdim_v + set(FLASHMASKV3_HEAD_DIMENSIONS_FWD ${FLASHMASKV3_HEAD_DIMENSIONS_BWD}) + set(FLASHMASKV3_HEAD_DIMENSIONS_FWD_SM80 ${FLASHMASKV3_HEAD_DIMENSIONS_BWD}) + + set(FLASHMASKV3_SPLIT "__EMPTY__") + if(NOT DISABLE_FLASHMASK_V3_SPLIT) + list(APPEND FLASHMASKV3_SPLIT "_split") + endif() + + set(FLASHMASKV3_PAGEDKV "__EMPTY__") + if(NOT DISABLE_FLASHMASK_V3_PAGEDKV) + list(APPEND FLASHMASKV3_PAGEDKV "_paged") + endif() + + set(FLASHMASKV3_SOFTCAP "__EMPTY__") + if(NOT DISABLE_FLASHMASK_V3_SOFTCAP) + list(APPEND FLASHMASKV3_SOFTCAP "_softcap") + endif() + + set(FLASHMASKV3_SOFTCAP_ALL) + if(DISABLE_FLASHMASK_V3_SOFTCAP) + set(FLASHMASKV3_SOFTCAP_ALL "__EMPTY__") + else() + set(FLASHMASKV3_SOFTCAP_ALL "_softcapall") + endif() + + set(FLASHMASKV3_PACKGQA "__EMPTY__") + if(NOT DISABLE_FLASHMASK_V3_PACKGQA) + list(APPEND FLASHMASKV3_PACKGQA "_packgqa") + endif() + + set(flashmaskv3_sources_fwd_sm80) + foreach(hdim ${FLASHMASKV3_HEAD_DIMENSIONS_FWD_SM80}) + foreach(dtype ${FLASHMASKV3_DTYPE_FWD_SM80}) + foreach(split ${FLASHMASKV3_SPLIT}) + foreach(paged ${FLASHMASKV3_PAGEDKV}) + foreach(softcap ${FLASHMASKV3_SOFTCAP_ALL}) + set(name "flashmask_attention_v3/instantiations/flash_fwd_hdim${hdim}_${dtype}${paged}${split}${softcap}_sm80.cu") + string(REPLACE "__EMPTY__" "" refine_name "${name}") + list(APPEND flashmaskv3_sources_fwd_sm80 "${refine_name}") + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + + set(flashmaskv3_sources_fwd_sm90) + foreach(hdim ${FLASHMASKV3_HEAD_DIMENSIONS_FWD}) + foreach(dtype ${FLASHMASKV3_DTYPE_FWD_SM90}) + foreach(split ${FLASHMASKV3_SPLIT}) + foreach(paged ${FLASHMASKV3_PAGEDKV}) + foreach(softcap ${FLASHMASKV3_SOFTCAP}) + foreach(packgqa ${FLASHMASKV3_PACKGQA}) + if(packgqa STREQUAL "__EMPTY__" OR (paged STREQUAL "__EMPTY__" AND split STREQUAL "__EMPTY__")) + set(name "flashmask_attention_v3/instantiations/flash_fwd_hdim${hdim}_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu") + string(REPLACE "__EMPTY__" "" refine_name "${name}") + list(APPEND flashmaskv3_sources_fwd_sm90 "${refine_name}") + endif() + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + + set(flashmaskv3_sources_bwd_sm80) + foreach(hdim ${FLASHMASKV3_HEAD_DIMENSIONS_BWD}) + foreach(dtype ${FLASHMASKV3_DTYPE_BWD}) + foreach(softcap ${FLASHMASKV3_SOFTCAP}) + set(name "flashmask_attention_v3/instantiations/flash_bwd_hdim${hdim}_${dtype}${softcap}_sm80.cu") + string(REPLACE "__EMPTY__" "" refine_name "${name}") + list(APPEND flashmaskv3_sources_bwd_sm80 "${refine_name}") + endforeach() + endforeach() + endforeach() + + set(flashmaskv3_sources_bwd_sm90) + foreach(hdim ${FLASHMASKV3_HEAD_DIMENSIONS_BWD}) + foreach(dtype ${FLASHMASKV3_DTYPE_BWD}) + foreach(softcap ${FLASHMASKV3_SOFTCAP_ALL}) + foreach(causal IN ITEMS "" "_causal") + foreach(determ IN ITEMS "" "_determ") + set(name "flashmask_attention_v3/instantiations/flash_bwd_hdim${hdim}_${dtype}${causal}${determ}${softcap}_sm90.cu") + string(REPLACE "__EMPTY__" "" refine_name "${name}") + list(APPEND flashmaskv3_sources_bwd_sm90 "${refine_name}") + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + + if(DISABLE_FLASHMASK_V3_BACKWARD) + set(flashmaskv3_sources_bwd_sm80 "") + set(flashmaskv3_sources_bwd_sm90 "") + endif() + + + set(FLASHMASKV3_SOURCES_CU_SOURCES "flashmask_attention_v3/flash_api.cu") + if(NOT DISABLE_FLASHMASK_V3_SM8X) + list(APPEND FLASHMASKV3_SOURCES_CU_SOURCES ${flashmaskv3_sources_fwd_sm80}) + endif() + list(APPEND FLASHMASKV3_SOURCES_CU_SOURCES ${flashmaskv3_sources_fwd_sm90}) + if(NOT DISABLE_FLASHMASK_V3_SM8X) + list(APPEND FLASHMASKV3_SOURCES_CU_SOURCES ${flashmaskv3_sources_bwd_sm80}) + endif() + list(APPEND FLASHMASKV3_SOURCES_CU_SOURCES ${flashmaskv3_sources_bwd_sm90}) + + if(NOT DISABLE_FLASHMASK_V3_SPLIT) + list(APPEND FLASHMASKV3_SOURCES_CU_SOURCES "flashmask_attention_v3/flash_fwd_combine.cu") + endif() + + list(APPEND FLASHMASKV3_SOURCES_CU_SOURCES "flashmask_attention_v3/flash_prepare_scheduler.cu") + message(STATUS "Auto generated CUDA source files for flashmask_v3: ${FLASHMASKV3_SOURCES_CU_SOURCES}") + + # 3. 添加动态库 + add_library(flashmaskv3 SHARED + ${FLASHMASKV3_SOURCES_CU_SOURCES} + ) + + target_include_directories(flashmaskv3 PRIVATE + flashmask_attention_v3 + flashmask_attention_v3/cutlass/include + ) + + target_compile_options(flashmaskv3 PRIVATE $<$: + -Xcompiler="-fPIC" + -Xcompiler="-O3" + -std=c++17 + --ftemplate-backtrace-limit=0 + --use_fast_math + --resource-usage + -lineinfo + -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED # Necessary for the WGMMA shapes that we use + -DCUTLASS_ENABLE_GDC_FOR_SM90 # For PDL + -DCUTLASS_DEBUG_TRACE_LEVEL=0 # Can toggle for debugging + -DNDEBUG # Important, otherwise performance is severely impacted + -gencode arch=compute_90a,code=sm_90a + --expt-relaxed-constexpr + >) + + INSTALL(TARGETS flashmaskv3 LIBRARY DESTINATION "lib") + + INSTALL(FILES flashmask_attention_v3/flash_api.h DESTINATION "include" RENAME flashmaskv3_api.h) + + endif() + +else() + INSTALL(FILES flashmask_attention_v3/flash_api.h DESTINATION "include" RENAME flashmaskv3_api.h) + +endif() +#SKIP_BUILD_FA + + diff --git a/flashmask/flash_mask/__init__.py b/flashmask/flash_mask/__init__.py index 290f972cf31..2bca3c54ad2 100644 --- a/flashmask/flash_mask/__init__.py +++ b/flashmask/flash_mask/__init__.py @@ -11,3 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# [BQW_CHANGE] 在 import 前先加载 flash_mask_pd_.so 并注册自定义算子 +# Paddle CUDAExtension 生成 flash_mask_pd_.so,需要手动加载注册 +import os +import paddle + +_curr_dir = os.path.dirname(os.path.abspath(__file__)) +_parent_dir = os.path.dirname(_curr_dir) +_so_path = os.path.join(_parent_dir, "flash_mask_pd_.so") + +if os.path.exists(_so_path): + paddle.utils.cpp_extension.load_op_meta_info_and_register_op(_so_path) +else: + print(f"[WARNING] flash_mask_pd_.so not found at {_so_path}, custom ops may not be available") + +from .flashmask_attention_v3.interface import flashmask_attention + +__all__ = ["flashmask_attention"] diff --git a/flashmask/flash_mask/flashmask_attention_v3/__init__.py b/flashmask/flash_mask/flashmask_attention_v3/__init__.py new file mode 100644 index 00000000000..21d5b5efa13 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/__init__.py @@ -0,0 +1,9 @@ + +""" +flashmask_attention - 带 mask 的 FlashAttention V3 +""" + +from .interface import flashmask_attention + +__all__ = ["flashmask_attention"] + diff --git a/flashmask/flash_mask/flashmask_attention_v3/block.h b/flashmask/flash_mask/flashmask_attention_v3/block.h new file mode 100644 index 00000000000..e9227168324 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/block.h @@ -0,0 +1,108 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +namespace flash { + +template +struct BlockMN { + + static + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + int const seqlen_k = seqlen_info.seqlen_k; + int const seqlen_q = seqlen_info.seqlen_q; + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal || Is_local) { + int m_idx_max = (m_block + 1) * kBlockM; + // TODO: check off-by-1 error + if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } + n_block_max = std::min(n_block_max, + cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN)); + } + int n_block_min = 0; + if constexpr (Is_local) { + int m_idx_min = m_block * kBlockM; + if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } + n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN); + } + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + if constexpr (Split) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + int split_idx_actual = split_idx & 0x0000FFFF; + int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); + n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; + n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } + } + // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + return {n_block_min, n_block_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_n_block_k_new_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + auto [n_block_min, n_block_max] = get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, num_splits, + window_size_left, window_size_right, qhead_per_khead_divmod); + int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); + int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); + int const n_block_new_min = idx_k_new_min / kBlockN; + int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; + // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} + return {n_block_new_min, n_block_new_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_m_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const n_block, int const bidb, + int const window_size_left, int const window_size_right, int const sink_token_length) { + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + if constexpr (Is_local) { + if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM)); + } + } + int m_block_min = 0; + if constexpr (Is_causal || Is_local) { + m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM); + } + return {m_block_min, m_block_max}; + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/copy_sm90_bulk_reduce.hpp b/flashmask/flash_mask/flashmask_attention_v3/copy_sm90_bulk_reduce.hpp new file mode 100644 index 00000000000..4d78d5f0865 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/copy_sm90_bulk_reduce.hpp @@ -0,0 +1,63 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_REDUCE_ADD +{ + CUTE_HOST_DEVICE static void + copy(float const* smem_ptr, + float * gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + CUTE_HOST_DEVICE static void + copy(float const* smem_ptr, + float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/flashmask/flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.cu b/flashmask/flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.cu new file mode 100644 index 00000000000..0b9f749db0d --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.cu @@ -0,0 +1,231 @@ +/****************************************************************************** + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#include "flash_attn_v3_utils.h" +#include + +#ifdef PADDLE_WITH_FLASHATTN_V3 + +static void +destroy_flashmask_fwd_params_handle(Flash_fwd_params *params_handle) { + flashmaskv3_destroy_fwd_params_handle(params_handle); +} + +static void +destroy_flashmask_bwd_params_handle(Flash_bwd_params *params_handle) { + flashmaskv3_destroy_bwd_params_handle(params_handle); +} + +FlashMask_fwd_params *get_flashmask_fwd_params_handle() { + static std::unique_ptr + params_handle(flashmaskv3_create_fwd_params_handle(), + &destroy_flashmask_fwd_params_handle); + return params_handle.get(); +} + +FlashMask_bwd_params *get_flashmask_bwd_params_handle() { + static std::unique_ptr + params_handle(flashmaskv3_create_bwd_params_handle(), + &destroy_flashmask_bwd_params_handle); + return params_handle.get(); +} + +void set_flashmaskv3_params_fprop( + Flash_fwd_params *params_handle, + // sizes + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, + // device pointers + const paddle::Tensor &q, const paddle::Tensor &k, const paddle::Tensor &v, + const paddle::Tensor *out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, void *seqused_k, void *softmax_lse_d, float p_dropout, + float softmax_scale, int window_size_left, int window_size_right, + const cudaDeviceProp &dprops, const float softcap, const int sm_margin) { + flashmaskv3_fwd_params_set_is_bf16(params_handle, + q.dtype() == paddle::DataType::BFLOAT16); + flashmaskv3_fwd_params_set_is_e4m3( + params_handle, q.dtype() == paddle::DataType::FLOAT8_E4M3FN); + + // Set the pointers and strides. + flashmaskv3_fwd_params_set_q_ptr(params_handle, const_cast(q.data())); + flashmaskv3_fwd_params_set_k_ptr(params_handle, const_cast(k.data())); + flashmaskv3_fwd_params_set_v_ptr(params_handle, const_cast(v.data())); + // All stride are in elements, not bytes. + flashmaskv3_fwd_params_set_q_row_stride(params_handle, + q.strides()[q.strides().size() - 3]); + flashmaskv3_fwd_params_set_k_row_stride(params_handle, + k.strides()[k.strides().size() - 3]); + flashmaskv3_fwd_params_set_v_row_stride(params_handle, + v.strides()[v.strides().size() - 3]); + flashmaskv3_fwd_params_set_q_head_stride(params_handle, + q.strides()[q.strides().size() - 2]); + flashmaskv3_fwd_params_set_k_head_stride(params_handle, + k.strides()[k.strides().size() - 2]); + flashmaskv3_fwd_params_set_v_head_stride(params_handle, + v.strides()[v.strides().size() - 2]); + flashmaskv3_fwd_params_set_v_dim_stride(params_handle, + v.strides()[v.strides().size() - 1]); + flashmaskv3_fwd_params_set_o_ptr(params_handle, + const_cast(out->data())); + flashmaskv3_fwd_params_set_o_row_stride( + params_handle, out->strides()[out->strides().size() - 3]); + flashmaskv3_fwd_params_set_o_head_stride( + params_handle, out->strides()[out->strides().size() - 2]); + + if (cu_seqlens_q_d == nullptr) { + flashmaskv3_fwd_params_set_q_batch_stride(params_handle, q.strides()[0]); + flashmaskv3_fwd_params_set_o_batch_stride(params_handle, out->strides()[0]); + } + if (cu_seqlens_k_d == nullptr) { + flashmaskv3_fwd_params_set_k_batch_stride(params_handle, k.strides()[0]); + flashmaskv3_fwd_params_set_v_batch_stride(params_handle, v.strides()[0]); + } + + flashmaskv3_fwd_params_set_cu_seqlens_q(params_handle, + static_cast(cu_seqlens_q_d)); + flashmaskv3_fwd_params_set_cu_seqlens_k(params_handle, + static_cast(cu_seqlens_k_d)); + flashmaskv3_fwd_params_set_seqused_q(params_handle, + static_cast(seqused_q)); + flashmaskv3_fwd_params_set_seqused_k(params_handle, + static_cast(seqused_k)); + + // Softmax sum + flashmaskv3_fwd_params_set_softmax_lse_ptr(params_handle, softmax_lse_d); + + // Set the dimensions. + flashmaskv3_fwd_params_set_b(params_handle, b); + flashmaskv3_fwd_params_set_h(params_handle, h); + flashmaskv3_fwd_params_set_h_k(params_handle, h_k); + flashmaskv3_fwd_params_set_seqlen_q(params_handle, seqlen_q); + flashmaskv3_fwd_params_set_seqlen_k(params_handle, seqlen_k); + flashmaskv3_fwd_params_set_seqlen_q_rounded(params_handle, seqlen_q_rounded); + flashmaskv3_fwd_params_set_seqlen_k_rounded(params_handle, seqlen_k_rounded); + flashmaskv3_fwd_params_set_d(params_handle, d); + flashmaskv3_fwd_params_set_d_rounded(params_handle, d_rounded); + + // Set the different scale values. + flashmaskv3_fwd_params_set_scale_softmax(params_handle, softmax_scale); + flashmaskv3_fwd_params_set_softcap(params_handle, softcap); + + // Set this to probability of keeping an element to simplify things. + flashmaskv3_fwd_params_set_p_dropout(params_handle, 1.f - p_dropout); + flashmaskv3_fwd_params_set_p_dropout_in_uint8_t( + params_handle, + uint8_t(std::floor(flashmaskv3_fwd_params_get_p_dropout(params_handle) * + 255.0))); + flashmaskv3_fwd_params_set_rp_dropout( + params_handle, 1.f / flashmaskv3_fwd_params_get_p_dropout(params_handle)); + PADDLE_ENFORCE_LT( + p_dropout, 1.f, + common::errors::InvalidArgument("p_dropout must less than 1")); + + PADDLE_ENFORCE_EQ( + p_dropout, 0.0f, + common::errors::InvalidArgument( + "This flash attention build does not support dropout.")); + + // Causal is the special case where window_size_right == 0 and + // window_size_left < 0. Local is the more general case where + // window_size_right >= 0 or window_size_left >= 0. + flashmaskv3_fwd_params_set_is_causal( + params_handle, window_size_left < 0 && window_size_right == 0); + flashmaskv3_fwd_params_set_is_local( + params_handle, (window_size_left >= 0 || window_size_right >= 0) && + !flashmaskv3_fwd_params_get_is_causal(params_handle)); + + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k - 1; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_q - 1; + } + flashmaskv3_fwd_params_set_window_size_left(params_handle, window_size_left); + flashmaskv3_fwd_params_set_window_size_right(params_handle, + window_size_right); + + int arch = dprops.major * 10 + dprops.minor; + int num_sm = dprops.multiProcessorCount - sm_margin; + + flashmaskv3_fwd_params_set_arch(params_handle, arch); + flashmaskv3_fwd_params_set_num_sm(params_handle, num_sm); +} + +void set_flashmaskv3_params_dgrad( + Flash_bwd_params *params_handle, + // sizes + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, + // device pointers + const paddle::Tensor &q, const paddle::Tensor &k, const paddle::Tensor &v, + const paddle::Tensor &out, const paddle::Tensor &dout, paddle::Tensor *dq, + paddle::Tensor *dk, paddle::Tensor *dv, void *cu_seqlens_q_d, + void *cu_seqlens_k_d, void *seqused_q, void *seqused_k, void *dq_accum_d, + void *dk_accum_d, void *dv_accum_d, void *softmax_lse_d, + void *dsoftmax_sum_d, float p_dropout, float softmax_scale, + int window_size_left, int window_size_right, const cudaDeviceProp &dprops, + const float softcap, bool deterministic, int const sm_margin) { + set_flashmaskv3_params_fprop( + flashmaskv3_cast_to_fwd_params_handle(params_handle), b, seqlen_q, + seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, q, k, + v, &out, cu_seqlens_q_d, cu_seqlens_k_d, seqused_q, seqused_k, + softmax_lse_d, p_dropout, softmax_scale, window_size_left, + window_size_right, dprops, softcap, sm_margin); + + // Set the pointers and strides. + flashmaskv3_bwd_params_set_do_ptr(params_handle, + const_cast(dout.data())); + flashmaskv3_bwd_params_set_do_row_stride( + params_handle, dout.strides()[dout.strides().size() - 3]); + flashmaskv3_bwd_params_set_do_head_stride( + params_handle, dout.strides()[dout.strides().size() - 2]); + flashmaskv3_bwd_params_set_dq_ptr(params_handle, dq->data()); + flashmaskv3_bwd_params_set_dk_ptr(params_handle, dk->data()); + flashmaskv3_bwd_params_set_dv_ptr(params_handle, dv->data()); + flashmaskv3_bwd_params_set_dq_row_stride( + params_handle, dq->strides()[dq->strides().size() - 3]); + flashmaskv3_bwd_params_set_dk_row_stride( + params_handle, dk->strides()[dk->strides().size() - 3]); + flashmaskv3_bwd_params_set_dv_row_stride( + params_handle, dv->strides()[dv->strides().size() - 3]); + flashmaskv3_bwd_params_set_dq_head_stride( + params_handle, dq->strides()[dq->strides().size() - 2]); + flashmaskv3_bwd_params_set_dk_head_stride( + params_handle, dk->strides()[dk->strides().size() - 2]); + flashmaskv3_bwd_params_set_dv_head_stride( + params_handle, dv->strides()[dv->strides().size() - 2]); + + if (cu_seqlens_q_d == nullptr) { + flashmaskv3_bwd_params_set_do_batch_stride(params_handle, + dout.strides()[0]); + flashmaskv3_bwd_params_set_dq_batch_stride(params_handle, dq->strides()[0]); + flashmaskv3_bwd_params_set_dk_batch_stride(params_handle, dk->strides()[0]); + flashmaskv3_bwd_params_set_dv_batch_stride(params_handle, dv->strides()[0]); + } + + flashmaskv3_bwd_params_set_dq_accum_ptr(params_handle, dq_accum_d); + flashmaskv3_bwd_params_set_dk_accum_ptr(params_handle, dk_accum_d); + flashmaskv3_bwd_params_set_dv_accum_ptr(params_handle, dv_accum_d); + + // Softmax sum + flashmaskv3_bwd_params_set_dsoftmax_sum(params_handle, dsoftmax_sum_d); + + flashmaskv3_bwd_params_set_deterministic(params_handle, deterministic); +} +#endif diff --git a/flashmask/flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.h b/flashmask/flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.h new file mode 100644 index 00000000000..eaae65bafa1 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.h @@ -0,0 +1,62 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#ifdef PADDLE_WITH_FLASHATTN_V3 +#include "../flash_api.h" +#include "paddle/extension.h" +#include + +FlashMask_fwd_params *get_flashmask_fwd_params_handle(); + +FlashMask_bwd_params *get_flashmask_bwd_params_handle(); + +inline int flashmaskv3_get_max_headdim() { return 256; } + +inline int flashmaskv3_round_up_headdim(int head_size) { +#ifndef FLASHATTENTION_DISABLE_HDIM64 + if (head_size <= 64) { + return 64; + } +#endif +#ifndef FLASHATTENTION_DISABLE_HDIM128 + if (head_size <= 128) { + return 128; + } +#endif + return 256; +} + +void set_flashmaskv3_params_fprop( + FlashMask_fwd_params *params_handle, + // sizes + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, + // device pointers + const paddle::Tensor &q, const paddle::Tensor &k, const paddle::Tensor &v, + const paddle::Tensor *out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, void *seqused_k, void *softmax_lse_d, float p_dropout, + float softmax_scale, int window_size_left, int window_size_right, + const cudaDeviceProp &dprops, const float softcap = 0.f, + const int sm_margin = 0); + +void set_flashmaskv3_params_dgrad( + FlashMask_bwd_params *params_handle, + // sizes + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, + // device pointers + const paddle::Tensor &q, const paddle::Tensor &k, const paddle::Tensor &v, + const paddle::Tensor &out, const paddle::Tensor &dout, paddle::Tensor *dq, + paddle::Tensor *dk, paddle::Tensor *dv, void *cu_seqlens_q_d, + void *cu_seqlens_k_d, void *seqused_q, void *seqused_k, void *dq_accum_d, + void *dk_accum_d, void *dv_accum_d, void *softmax_lse_d, + void *dsoftmax_sum_d, float p_dropout, float softmax_scale, + int window_size_left, int window_size_right, const cudaDeviceProp &dprops, + const float softcap = 0.f, bool deterministic = false, + int const sm_margin = 0); +#endif diff --git a/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3.cpp b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3.cpp new file mode 100644 index 00000000000..3bc03310a0b --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3.cpp @@ -0,0 +1,232 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#include "flashmask_v3.h" +#include "paddle/extension.h" + +static void RaiseNotSupportedError(int version = 2) { + PADDLE_THROW(common::errors::Unimplemented( + "FlashAttention %d is unsupported, please check " + "the GPU compatibility and CUDA Version.", + version)); +} + +std::vector +FlashMaskV3Forward(const paddle::Tensor &query, const paddle::Tensor &key, + const paddle::Tensor &value, + const paddle::Tensor &startend_row_indices, + const paddle::optional &block_mask, + float softmax_scale, bool is_causal) { +#ifdef PADDLE_WITH_FLASHATTN_V3 + + paddle::Tensor out; + paddle::Tensor softmax_lse; + paddle::Tensor out_accum; + paddle::Tensor softmax_lse_accum; + +#define CALL_FLASHMASK_V3_BASE_KERNEL(DType) \ + FlashMaskV3BaseKernel( \ + query, key, value, paddle::none, paddle::none, paddle::none, \ + paddle::none, paddle::none, paddle::none, paddle::none, paddle::none, \ + paddle::none, paddle::none, paddle::none, paddle::none, paddle::none, \ + paddle::none, paddle::none, paddle::none, paddle::none, paddle::none, \ + startend_row_indices, block_mask, 0, 0, softmax_scale, is_causal, -1, \ + -1, float{0}, true, 1, false, false, 0, &out, &softmax_lse, &out_accum, \ + &softmax_lse_accum) + + switch (query.dtype()) { + case paddle::DataType::FLOAT16: { + CALL_FLASHMASK_V3_BASE_KERNEL(paddle::float16); + break; + } + case paddle::DataType::BFLOAT16: { + CALL_FLASHMASK_V3_BASE_KERNEL(paddle::bfloat16); + break; + } + default: { + PADDLE_THROW(phi::errors::InvalidArgument( + "FlashMaskV3BaseKernel only support bfloat16 and float16, but got %d", + static_cast(query.dtype()))); + } + } + + return {out, softmax_lse}; +#else + RaiseNotSupportedError(); +#endif +} + +std::vector FlashMaskV3GradKernel( + const paddle::Tensor &query, const paddle::Tensor &key, + const paddle::Tensor &value, const paddle::Tensor &out, + const paddle::Tensor &softmax_lse, + const paddle::Tensor &startend_row_indices, // TODO(xiehaoyang): remove this + const paddle::optional &block_mask, + const paddle::Tensor &out_grad, float const softmax_scale, bool is_causal) { +#ifdef PADDLE_WITH_FLASHATTN_V3 + + PADDLE_ENFORCE_EQ( + query.dims()[query.dims().size() - 1], + value.dims()[value.dims().size() - 1], + common::errors::InvalidArgument("head_dim_q != head_dim_v (%d != %d)", + query.dims()[query.dims().size() - 1], + value.dims()[value.dims().size() - 1])); + + // umiswing: fake grad tensor for FlashAttnV3GradBaseKernel + paddle::Tensor softmax_d; + paddle::Tensor softmax_lse_log2; + paddle::Tensor dq_accum; + paddle::Tensor dk_accum; + paddle::Tensor dv_accum; + + paddle::Tensor dq; + paddle::Tensor dk; + paddle::Tensor dv; + +#define CALL_FLASHMASK_V3_BASE_GRAD_KERNEL(DType) \ + FlashMaskV3GradBaseKernel( \ + out_grad, query, key, value, out, softmax_lse, paddle::none, \ + paddle::none, paddle::none, paddle::none, paddle::none, paddle::none, \ + paddle::none, startend_row_indices, block_mask, 0, 0, softmax_scale, \ + is_causal, -1, -1, 0, FLAGS_cudnn_deterministic, 0, &dq, &dk, &dv, \ + &softmax_d, &softmax_lse_log2, &dq_accum, &dk_accum, &dv_accum); + + static const char *env_val = std::getenv("FLAGS_cudnn_deterministic"); + static bool FLAGS_cudnn_deterministic = + (env_val != nullptr && std::string(env_val) == "1"); + + switch (query.dtype()) { + case paddle::DataType::FLOAT16: { + CALL_FLASHMASK_V3_BASE_GRAD_KERNEL(paddle::float16); + break; + } + case paddle::DataType::BFLOAT16: { + CALL_FLASHMASK_V3_BASE_GRAD_KERNEL(paddle::bfloat16); + break; + } + default: { + PADDLE_THROW( + phi::errors::InvalidArgument("FlashMaskV3GradBaseKernel only support " + "bfloat16 and float16, but got %d", + static_cast(query.dtype()))); + } + } + + // umiswing: some branch in upstream fa3 could have padded the head dimension + PADDLE_ENFORCE_EQ( + dq.dims()[dq.dims().size() - 1], + out_grad.dims()[out_grad.dims().size() - 1], + common::errors::InvalidArgument( + "head dimension of dq != head dimension of out_grad (%d != %d)", + dq.dims()[dq.dims().size() - 1], + out_grad.dims()[out_grad.dims().size() - 1])); + + PADDLE_ENFORCE_EQ( + dk.dims()[dk.dims().size() - 1], + out_grad.dims()[out_grad.dims().size() - 1], + common::errors::InvalidArgument( + "head dimension of dk != head dimension of out_grad (%d != %d)", + dk.dims()[dk.dims().size() - 1], + out_grad.dims()[out_grad.dims().size() - 1])); + + PADDLE_ENFORCE_EQ( + dv.dims()[dv.dims().size() - 1], + out_grad.dims()[out_grad.dims().size() - 1], + common::errors::InvalidArgument( + "head dimension of dv != head dimension of out_grad (%d != %d)", + dv.dims()[dv.dims().size() - 1], + out_grad.dims()[out_grad.dims().size() - 1])); + return {dq, dk, dv}; + +#else + RaiseNotSupportedError(); +#endif +} + +std::vector> FlashMaskV3FwdInferShape( + const std::vector &query_shape, + const std::vector &key_shape, + const std::vector &value_shape, + const std::vector &startend_row_indices_shape, + const paddle::optional> &block_mask_shape, + float softmax_scale, bool is_causal) { + int64_t batch_size = query_shape[0]; + int64_t seqlen_q = query_shape[1]; + int64_t num_heads = query_shape[query_shape.size() - 2]; + int64_t head_size_v = value_shape[value_shape.size() - 1]; + + return {{batch_size, seqlen_q, num_heads, head_size_v}, + {batch_size, num_heads, seqlen_q}}; +} + +std::vector> FlashMaskV3GradInferShape( + const std::vector &query_shape, + const std::vector &key_shape, + const std::vector &value_shape, + const std::vector &out_shape, + const std::vector &softmax_lse_shape, + const std::vector &startend_row_indices_shape, + const paddle::optional> &block_mask_shape, + const std::vector &dout_shape, float softmax_scale, + bool is_causal) { + + return {query_shape, key_shape, value_shape}; +} + +std::vector FlashMaskV3FwdInferDtype( + paddle::DataType query_dtype, paddle::DataType key_dtype, + paddle::DataType value_dtype, paddle::DataType startend_row_indices_dtype, + const paddle::optional &block_mask_dtype, + float softmax_scale, bool is_causal) { + auto out_type = (query_dtype == paddle::DataType::FLOAT8_E4M3FN) + ? paddle::DataType::BFLOAT16 + : query_dtype; + return {out_type, paddle::DataType::FLOAT32}; +} + +std::vector FlashMaskV3GradInferDtype( + paddle::DataType query_dtype, paddle::DataType key_dtype, + paddle::DataType value_dtype, paddle::DataType out_dtype, + paddle::DataType softmax_lse_dtype, + paddle::DataType startend_row_indices_dtype, + const paddle::optional &block_mask_dtype, + paddle::DataType dout_dtype, float softmax_scale, bool is_causal) { + return {query_dtype, key_dtype, value_dtype}; +} + +PD_BUILD_OP(flashmask_attention_v3) + .Inputs({"query", "key", "value", "startend_row_indices", + paddle::Optional("block_mask")}) + .Outputs({"out", "softmax_lse"}) + .Attrs({"softmax_scale: float", "is_causal: bool"}) + .SetKernelFn(PD_KERNEL(FlashMaskV3Forward)) + .SetInferShapeFn(PD_INFER_SHAPE(FlashMaskV3FwdInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FlashMaskV3FwdInferDtype)); + +PD_BUILD_GRAD_OP(flashmask_attention_v3) + .Inputs({ + "query", "key", "value", "out", "softmax_lse", "startend_row_indices", + paddle::Optional("block_mask"), + paddle::Grad("out") // dout + }) + .Outputs({paddle::Grad("query"), paddle::Grad("key"), + paddle::Grad("value")}) + .Attrs({"softmax_scale: float", "is_causal: bool"}) + .SetKernelFn(PD_KERNEL(FlashMaskV3GradKernel)) + .SetInferShapeFn(PD_INFER_SHAPE(FlashMaskV3GradInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FlashMaskV3GradInferDtype)); diff --git a/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3.h b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3.h new file mode 100644 index 00000000000..ce0e15bba3a --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3.h @@ -0,0 +1,128 @@ +/****************************************************************************** + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "paddle/extension.h" + +#ifdef PADDLE_WITH_FLASHATTN_V3 + +#define CHECK_DEVICE(x) PD_CHECK(x.is_gpu(), #x " must be on CUDA Device") + +#define CHECK_SHAPE(x, ...) \ + PADDLE_ENFORCE_EQ(x.dims(), common::make_ddim({__VA_ARGS__}), \ + common::errors::InvalidArgument( \ + #x " must have shape (" #__VA_ARGS__ ")")) + +#define CHECK_CONTIGUOUS(x) \ + PADDLE_ENFORCE_EQ(x.is_contiguous(), true, \ + common::errors::InvalidArgument(#x " must be contiguous")) +#endif + +template +void FlashMaskV3BaseKernel( + const paddle::Tensor &q, const paddle::Tensor &k, const paddle::Tensor &v, + const paddle::optional + &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is + // cu_seqlens_k_new + const paddle::optional + &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is + // cu_seqlens_k_new + const paddle::optional + &q_v_, // (b, s_q, h, dv) or (total_q_new, h, + // dv) if there is cu_seqlens_q + const paddle::optional + &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + const paddle::optional &cu_seqlens_q_, // b+1 + const paddle::optional &cu_seqlens_k_, // b+1 + const paddle::optional &cu_seqlens_k_new_, // b+1 + const paddle::optional + &seqused_q_, // b. If given, only this many elements of each batch + // element's queries and outputs are used. + const paddle::optional + &seqused_k_, // b. If given, only this many elements of each batch + // element's keys are used. + const paddle::optional + &page_table_, // (b_k, max_num_pages_per_seq) + const paddle::optional + &kv_batch_idx_, // b. indices to index into the KV cache + const paddle::optional &leftpad_k_, // b + const paddle::optional + &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + const paddle::optional + &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + const paddle::optional &q_descale_, // (b, h_k), not (b, h) + const paddle::optional &k_descale_, // (b, h_k) + const paddle::optional &v_descale_, // (b, h_k) + const paddle::optional &scheduler_metadata_, // (b + 1) + const paddle::optional + &startend_row_indices_, // (b,h,s_1,[1,2,4]) + const paddle::optional + &block_mask_, // ((b,h,s// 128,s // 128) + const int max_seqlen_q_, // if max_seqlen_q_ is set to 0, it indicates that + // it is uninitialized and should not be referenced + // TODO(tridao): check if we need max_seqlen_k + const int max_seqlen_k_, // if max_seqlen_q_ is set to 0, it indicates that + // it is uninitialized and should not be referenced + const float softmax_scale, bool is_causal, int window_size_left, + int window_size_right, const float softcap, + const bool is_rotary_interleaved, // if true, rotary combines indices 0 & + // 1, else indices 0 & rotary_dim / 2 + int num_splits, const bool manual_set_pack_gqa, + const bool + pack_gqa_, // the pack_gqa_ will be used only if manual_set_pack_gqa is + // set to True; otherwise, the internal heuristic + // get_pack_gqa() from fa3 will decide whether to pack gqa + const int sm_margin, paddle::Tensor *out, paddle::Tensor *softmax_lse, + paddle::Tensor *out_accum, paddle::Tensor *softmax_lse_accum); + +template +void FlashMaskV3GradBaseKernel( + const paddle::Tensor + &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const paddle::Tensor + &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const paddle::Tensor + &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + const paddle::Tensor + &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + const paddle::Tensor + &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const paddle::Tensor + &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + const paddle::optional + &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const paddle::optional + &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + const paddle::optional + &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + const paddle::optional &cu_seqlens_q_, // b+1 + const paddle::optional &cu_seqlens_k_, // b+1 + const paddle::optional + &seqused_q_, // b. If given, only this many elements of each batch + // element's queries and outputs are used. + const paddle::optional + &seqused_k_, // b. If given, only this many elements of each batch + // element's keys are used. + const paddle::optional &startend_row_indices_, + const paddle::optional + &block_mask_, // ((b,h,s//128,s//128) + int max_seqlen_q_, int max_seqlen_k_, float const softmax_scale, + bool is_causal, int window_size_left, int window_size_right, + float const softcap, bool const deterministic, int const sm_margin, + paddle::Tensor *dq, paddle::Tensor *dk, paddle::Tensor *dv, + paddle::Tensor *softmax_d, paddle::Tensor *softmax_lse_log2, + paddle::Tensor *dq_accum, paddle::Tensor *dk_accum, + paddle::Tensor *dv_accum); diff --git a/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3_grad_kernel.cu b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3_grad_kernel.cu new file mode 100644 index 00000000000..e1dd25a80ec --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3_grad_kernel.cu @@ -0,0 +1,699 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + *Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include "flash_attn_v3_utils.h" +#include "flashmask_v3.h" + +#include +#include + +template +void FlashMaskV3GradBaseKernel( + const paddle::Tensor &dout, const paddle::Tensor &q, + const paddle::Tensor &k, const paddle::Tensor &v, const paddle::Tensor &out, + const paddle::Tensor &softmax_lse, + const paddle::optional &dq_, + const paddle::optional &dk_, + const paddle::optional &dv_, + const paddle::optional &cu_seqlens_q_, + const paddle::optional &cu_seqlens_k_, + const paddle::optional &seqused_q_, + const paddle::optional &seqused_k_, + const paddle::optional &startend_row_indices_, + const paddle::optional &block_mask_, int max_seqlen_q_, + int max_seqlen_k_, float const softmax_scale, bool is_causal, + int window_size_left, int window_size_right, float const softcap, + bool const deterministic, int const sm_margin, paddle::Tensor *dq, + paddle::Tensor *dk, paddle::Tensor *dv, paddle::Tensor *softmax_d, + paddle::Tensor *softmax_lse_log2, paddle::Tensor *dq_accum, + paddle::Tensor *dk_accum, paddle::Tensor *dv_accum) { +#ifdef PADDLE_WITH_FLASHATTN_V3 + // TODO(umiswing): support ampere + cudaStream_t stream = static_cast(q.stream()); + auto place = q.place(); + + int device_id = place.GetDeviceId(); + cudaDeviceProp dprops; + cudaGetDeviceProperties(&dprops, device_id); + + const bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + PADDLE_ENFORCE_EQ(is_sm90, true, + common::errors::Unavailable( + "FlashAttention-3 only supports Hopper GPUs.")); + + auto q_type = q.dtype(); + + PADDLE_ENFORCE_EQ( + (q_type == paddle::DataType::FLOAT16 || + q_type == paddle::DataType::BFLOAT16), + true, + common::errors::InvalidArgument( + "FlashAttention-3 bwd only support fp16 and bf16 data type")); + PADDLE_ENFORCE_EQ(k.dtype(), q_type, + common::errors::InvalidArgument( + "query and key must have the same dtype")); + PADDLE_ENFORCE_EQ(v.dtype(), q_type, + common::errors::InvalidArgument( + "query and value must have the same dtype")); + PADDLE_ENFORCE_EQ(out.dtype(), q_type, + common::errors::InvalidArgument( + "query and out must have the same dtype")); + PADDLE_ENFORCE_EQ(dout.dtype(), q_type, + common::errors::InvalidArgument( + "query and dout must have the same dtype")); + + CHECK_DEVICE(q); + CHECK_DEVICE(k); + CHECK_DEVICE(v); + CHECK_DEVICE(out); + CHECK_DEVICE(dout); + CHECK_DEVICE(softmax_lse); + + PADDLE_ENFORCE_EQ(q.strides()[q.strides().size() - 1], 1, + common::errors::InvalidArgument( + "Input tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(k.strides()[k.strides().size() - 1], 1, + common::errors::InvalidArgument( + "Input tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(v.strides()[v.strides().size() - 1], 1, + common::errors::InvalidArgument( + "Input tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(out.strides()[out.strides().size() - 1], 1, + common::errors::InvalidArgument( + "out tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(dout.strides()[dout.strides().size() - 1], 1, + common::errors::InvalidArgument( + "dout tensor must have contiguous last dimension")); + + paddle::Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.is_initialized(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.get(); + CHECK_DEVICE(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_q); + PADDLE_ENFORCE_EQ(cu_seqlens_q.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "cu_seqlens_q must have dtype paddle.int32")); + PADDLE_ENFORCE_GT( + max_seqlen_q_, 0, + common::errors::InvalidArgument( + "max_seqlen_q must be provided if cu_seqlens_q is provided")); + } + paddle::Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.is_initialized(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.get(); + CHECK_DEVICE(cu_seqlens_k); + CHECK_CONTIGUOUS(cu_seqlens_k); + PADDLE_ENFORCE_EQ(cu_seqlens_k.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "cu_seqlens_k must have dtype paddle.int32")); + PADDLE_ENFORCE_GT( + max_seqlen_k_, 0, + common::errors::InvalidArgument( + "max_seqlen_k must be provided if cu_seqlens_k is provided")); + } + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || + seqused_q_.is_initialized() || + seqused_k_.is_initialized(); +#ifdef FLASHATTENTION_DISABLE_VARLEN + PADDLE_ENFORCE_EQ(!is_varlen, true, + common::errors::Unavailable( + "This flash attention build does not support varlen.")); +#endif + + auto const sizes = q.dims(); + int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.dims()[0] - 1; + int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_; + int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int const num_heads = q.dims()[q.dims().size() - 2]; + int const head_size = q.dims()[q.dims().size() - 1]; + int const seqlen_k = !is_varlen_k ? k.dims()[1] : max_seqlen_k_; + int const total_k = !is_varlen_k ? batch_size * k.dims()[1] : k.dims()[0]; + int const num_heads_k = k.dims()[k.dims().size() - 2]; + PADDLE_ENFORCE_EQ( + head_size % 8, 0, + common::errors::InvalidArgument("head_size should be a multiple of 8")); + int const max_headdim = flashmaskv3_get_max_headdim(); + PADDLE_ENFORCE_LE( + head_size, max_headdim, + common::errors::InvalidArgument( + "FlashAttention forward only supports head dimension at most %d", + max_headdim)); + PADDLE_ENFORCE_EQ( + num_heads % num_heads_k, 0, + common::errors::InvalidArgument( + "Number of heads in key/value must divide number of heads in query")); + + // This needs to go before kBlockM & kBlockN since we rely on the correct + // window_size and is_causal to set kBlockM + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + if (window_size_right >= seqlen_q - 1) { + window_size_right = -1; + } + if (is_causal) { + window_size_right = 0; + } + // There's a case where is_causal=false, window_size=(-1, 0). Then + // set_params_bprop will set params.is_causal=true. If we don't have is_causal + // here matching params.is_causal, we might get the wrong kBlockM (and cause + // IMA). + is_causal = window_size_left < 0 && window_size_right == 0; + + int const arch = dprops.major * 10 + dprops.minor; + int const head_size_rounded = flashmaskv3_round_up_headdim(head_size); + // Very important that these match the kernel configs + bool const is_local = + (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + bool const is_flashmask = startend_row_indices_.is_initialized(); + paddle::Tensor startend_row_indices; + if (is_flashmask) + startend_row_indices = startend_row_indices_.get(); + bool const has_softcap = softcap > 0.0; + + paddle::Tensor flashmask_maxmin; + paddle::Tensor lt_start_slice, lt_end_slice, ut_start_slice, ut_end_slice; + const int32_t *lt_start_ptr; + const int32_t *lt_end_ptr; + const int32_t *ut_start_ptr; + const int32_t *ut_end_ptr; + + if (is_flashmask) { + PADDLE_ENFORCE_EQ( + startend_row_indices.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "flashmask_attention startend_row_indices must be INT32 type")); + PADDLE_ENFORCE_EQ( + startend_row_indices.dims().size(), 4, + common::errors::InvalidArgument( + "flashmask_attention receive startend_row_indices with dim " + "[batch_size, num_heads,seq_len, mask_bounds]")); + PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 || + startend_row_indices.dims()[3] == 2 || + startend_row_indices.dims()[3] == 4, + true, + common::errors::InvalidArgument( + "flashmask_attention startend_row_indices " + "mask_bounds must in [1,2,4]")); + + auto flashmask_maxmin_shape = startend_row_indices.dims(); + // TODO(umiswing): refine this block constraint (kBlockN % 32), since some + // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] = + // (flashmask_maxmin_shape[2] + 31) / 32 * 8; + flashmask_maxmin_shape[2] = + ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4; + flashmask_maxmin_shape[3] = 8; + + flashmask_maxmin = + paddle::empty({flashmask_maxmin_shape[0], flashmask_maxmin_shape[1], + flashmask_maxmin_shape[2], flashmask_maxmin_shape[3]}, + paddle::DataType::INT32, place); + + const int32_t *mask_base_ptr = startend_row_indices.data(); + auto mask_dims = startend_row_indices.dims(); + int B_mask = mask_dims[0]; + int H_mask = mask_dims[1]; + int S_mask = mask_dims[2]; + int C = mask_dims[3]; + int total_elements = B_mask * H_mask * S_mask; + + lt_start_ptr = nullptr; + lt_end_ptr = nullptr; + ut_start_ptr = nullptr; + ut_end_ptr = nullptr; + + auto extract_channel = [&](int channel_idx) -> paddle::Tensor { + auto slice = paddle::empty({B_mask, H_mask, S_mask}, + paddle::DataType::INT32, place); + cudaMemcpy2DAsync(slice.data(), sizeof(int32_t), + mask_base_ptr + channel_idx, C * sizeof(int32_t), + sizeof(int32_t), total_elements, + cudaMemcpyDeviceToDevice, stream); + return slice; + }; + + if (C == 1) { + lt_start_ptr = mask_base_ptr; + } else if (C == 2) { + lt_start_slice = extract_channel(0); + lt_start_ptr = lt_start_slice.data(); + if (!is_causal) { + ut_end_slice = extract_channel(1); + ut_end_ptr = ut_end_slice.data(); + } else { + lt_end_slice = extract_channel(1); + lt_end_ptr = lt_end_slice.data(); + } + } else if (C == 4) { + lt_start_slice = extract_channel(0); + lt_start_ptr = lt_start_slice.data(); + lt_end_slice = extract_channel(1); + lt_end_ptr = lt_end_slice.data(); + ut_start_slice = extract_channel(2); + ut_start_ptr = ut_start_slice.data(); + ut_end_slice = extract_channel(3); + ut_end_ptr = ut_end_slice.data(); + } + } + + bool const is_blockmask = block_mask_.is_initialized(); + paddle::Tensor block_mask; + if (is_blockmask) + block_mask = block_mask_.get(); + + if (is_blockmask) { + PADDLE_ENFORCE_EQ( + is_flashmask, true, + common::errors::InvalidArgument( + "blockmask should be used with flashmask at the same time ")); + + PADDLE_ENFORCE_EQ(block_mask.dims().size(), 4, + common::errors::InvalidArgument( + "blockmask receive blockmask_indices with dim " + "[batch_size, num_heads, blocklen_q, blocklen_k]")); + + PADDLE_ENFORCE_EQ(block_mask.dims()[2], (seqlen_q + 127) / 128, + common::errors::InvalidArgument( + "blockmask only supports blockdim_q = 128 now")); + + PADDLE_ENFORCE_EQ(block_mask.dims()[3], (seqlen_k + 127) / 128, + common::errors::InvalidArgument( + "blockmask only supports blockdim_k = 128 now")); + + PADDLE_ENFORCE_EQ( + block_mask.dims()[1], startend_row_indices.dims()[1], + common::errors::InvalidArgument( + "blockmask only supports same dim num_heads with flashmask now")); + + PADDLE_ENFORCE_LE(seqlen_k, 1024 * 128, + common::errors::InvalidArgument( + "blockmask only supports seqlen <= 128k in bwd now")); + + PADDLE_ENFORCE_LE(seqlen_q, 1024 * 128, + common::errors::InvalidArgument( + "blockmask only supports seqlen <= 128k in bwd now")); + } + + // const bool has_lt_start = lt_start_row_indices.initialized(); + // const bool has_lt_end = lt_end_row_indices.initialized(); + // const bool has_ut_start = ut_start_row_indices.initialized(); + // const bool has_ut_end = ut_end_row_indices.initialized(); + + const bool has_lt_start = (lt_start_ptr != nullptr); + const bool has_lt_end = (lt_end_ptr != nullptr); + const bool has_ut_start = (ut_start_ptr != nullptr); + const bool has_ut_end = (ut_end_ptr != nullptr); + + // umiswing: The tile dispatch for flashmask is now different from fa3. + // Replacing the original ternary operator with lambda makes the code + // easier to reason about and less error-prone. + const auto [kBlockM_sm90, kBlockN_sm90] = [&]() -> std::pair { + if (head_size_rounded <= 64) { + if (is_flashmask && !is_causal) { + return {64, 96}; + } else if (is_causal && has_softcap || is_flashmask) { + return {96, 128}; + } else { + return {128, 128}; + } + } else if (head_size_rounded <= 128) { + // umiswing: by now, we reuse template instantiation of head dim 128 for + // head dim in range (64, 128], and therefore no separate dispatch for + // head dim in range (64, 96] + if (is_causal || is_local || has_softcap) { + return {64, 128}; + } else { + if ((seqlen_q >= 1024 || seqlen_k >= 1024) && + !(has_lt_end && has_ut_start)) { + return {64, 128}; + } else { + return {64, 64}; + } + } + } else if (head_size_rounded <= 256) { + // umiswing: by now, we reuse template instantiation of head dim 256 for + // head dim in range (128, 256], and therefore no separate dispatch for + // head dim in range (128, 192] + if (has_lt_end && has_ut_start) { + return {64, 32}; + } else { + return {64, 64}; + } + } else { + PADDLE_THROW( + common::errors::Unimplemented("head dim is rounded to %d, which is " + "not supported in FlashMask V3 now.", + head_size_rounded)); + return {0, 0}; + } + }(); + + int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; + int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; + int const kBlockM = + arch >= 90 ? kBlockM_sm90 + : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); + int const kBlockN_sm80 = + head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64); + int const kBlockN_sm86 = + head_size_rounded <= 64 + ? 128 + : (head_size_rounded <= 96 + ? 128 + : (head_size_rounded <= 128 + ? 96 + : (head_size_rounded <= 192 ? 64 : 64))); + int const kBlockN = + arch >= 90 ? kBlockN_sm90 + : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); + int const total_q_padded_rounded = + round_multiple(total_q + batch_size * kBlockM, kBlockM); + int const total_k_padded_rounded = + round_multiple(total_k + batch_size * kBlockN, kBlockN); + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + + if (seqused_q_.is_initialized()) { + auto seqused_q = seqused_q_.get(); + PADDLE_ENFORCE_EQ( + seqused_q.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument("seqused_q must have dtype int32")); + CHECK_DEVICE(seqused_q); + CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.is_initialized()) { + auto seqused_k = seqused_k_.get(); + PADDLE_ENFORCE_EQ( + seqused_k.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument("seqused_k must have dtype int32")); + CHECK_DEVICE(seqused_k); + CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (dq_.is_initialized()) { + *dq = dq_.get(); + PADDLE_ENFORCE_EQ( + dq->dtype(), q_type, + common::errors::InvalidArgument("dq must have the same dtype as q")); + CHECK_DEVICE((*dq)); + PADDLE_ENFORCE_EQ(dq->strides()[dq->strides().size() - 1], 1, + common::errors::InvalidArgument( + "dq must have contiguous last dimension")); + if (!is_varlen_q) { + CHECK_SHAPE((*dq), batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE((*dq), total_q, num_heads, head_size); + } + } else { + *dq = paddle::empty_like(q); + } + if (dk_.is_initialized()) { + *dk = dk_.get(); + PADDLE_ENFORCE_EQ( + dk->dtype(), q_type, + common::errors::InvalidArgument("dk must have the same dtype as q")); + CHECK_DEVICE((*dk)); + PADDLE_ENFORCE_EQ(dk->strides()[dk->strides().size() - 1], 1, + common::errors::InvalidArgument( + "dk must have contiguous last dimension")); + if (!is_varlen_k) { + CHECK_SHAPE((*dk), batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE((*dk), total_k, num_heads_k, head_size); + } + } else { + *dk = paddle::empty_like(k); + } + if (dv_.is_initialized()) { + *dv = dv_.get(); + PADDLE_ENFORCE_EQ( + dv->dtype(), q_type, + common::errors::InvalidArgument("dv must have the same dtype as q")); + CHECK_DEVICE((*dv)); + PADDLE_ENFORCE_EQ(dv->strides()[dv->strides().size() - 1], 1, + common::errors::InvalidArgument( + "dv must have contiguous last dimension")); + if (!is_varlen_k) { + CHECK_SHAPE((*dv), batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE((*dv), total_k, num_heads_k, head_size); + } + } else { + *dv = paddle::empty_like(v); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + + // Need softmax_d to have total_q_padded_rounded since we want its address to + // be aligned by 16/8 bytes for TMA / LDG.64 + if (!is_varlen) { + if (softmax_d) { + // Need softmax_d to have seqlen_q_rounded since we want its address to be + // aligned by 16/8 bytes for TMA / LDG.64 + *softmax_d = paddle::empty({batch_size, num_heads, seqlen_q_rounded}, + paddle::DataType::FLOAT32, place); + } + if (softmax_lse_log2) { + *softmax_lse_log2 = + paddle::empty({batch_size, num_heads, seqlen_q_rounded}, + paddle::DataType::FLOAT32, place); + } + } else { + if (softmax_d) { + *softmax_d = paddle::empty({num_heads, total_q_padded_rounded}, + paddle::DataType::FLOAT32, place); + } + if (softmax_lse_log2) { + *softmax_lse_log2 = paddle::empty({num_heads, total_q_padded_rounded}, + paddle::DataType::FLOAT32, place); + } + } + + if (dq_accum) { + if (!is_varlen) { + *dq_accum = paddle::empty( + {batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, + paddle::DataType::FLOAT32, place); + + } else { + *dq_accum = + paddle::empty({num_heads, total_q_padded_rounded * head_size_rounded}, + paddle::DataType::FLOAT32, place); + } + } + + if (num_heads_k != num_heads) { // MQA / GQA + if (!is_varlen) { + if (dk_accum) { + *dk_accum = paddle::empty( + {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, + paddle::DataType::FLOAT32, place); + } + if (dv_accum) { + *dv_accum = paddle::empty( + {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, + paddle::DataType::FLOAT32, place); + } + } else { + if (dk_accum) { + *dk_accum = paddle::empty( + {num_heads_k, total_k_padded_rounded, head_size_rounded}, + paddle::DataType::FLOAT32, place); + } + if (dv_accum) { + *dv_accum = paddle::empty( + {num_heads_k, total_k_padded_rounded, head_size_rounded}, + paddle::DataType::FLOAT32, place); + } + } + + if (dk_accum) { + *dk_accum = paddle::full(dk_accum->shape(), float{0}, + paddle::DataType::FLOAT32, place); + } + if (dv_accum) { + *dv_accum = paddle::full(dv_accum->shape(), float{0}, + paddle::DataType::FLOAT32, place); + } + } + + FlashMask_bwd_params *params_handle = get_flashmask_bwd_params_handle(); + flashmaskv3_clear_bwd_params_handle(params_handle); + set_flashmaskv3_params_dgrad( + params_handle, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, + seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, + k, v, out, dout, dq, dk, dv, !is_varlen_q ? nullptr : cu_seqlens_q.data(), + !is_varlen_k ? nullptr : cu_seqlens_k.data(), + seqused_q_.is_initialized() ? const_cast(seqused_q_.get().data()) + : nullptr, + seqused_k_.is_initialized() ? const_cast(seqused_k_.get().data()) + : nullptr, + dq_accum ? dq_accum->data() : nullptr, + num_heads_k != num_heads && dk_accum ? dk_accum->data() : nullptr, + num_heads_k != num_heads && dv_accum ? dv_accum->data() : nullptr, + const_cast(softmax_lse.data()), + softmax_d ? (softmax_d->data()) : nullptr, + /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, + dprops, softcap, deterministic, sm_margin); + flashmaskv3_bwd_params_set_total_q(params_handle, total_q); + flashmaskv3_bwd_params_set_total_k(params_handle, total_k); + flashmaskv3_bwd_params_set_softmax_lse_log2_ptr( + params_handle, softmax_lse_log2 ? softmax_lse_log2->data() : nullptr); + flashmaskv3_bwd_params_set_dv(params_handle, + head_size); // We don't support hdim_v being + // different from hdim_qk for now + paddle::Tensor tile_count_semaphore; + if (arch >= 90) { + tile_count_semaphore = paddle::full({1}, 0, paddle::DataType::INT32, place); + + flashmaskv3_bwd_params_set_tile_count_semaphore( + params_handle, tile_count_semaphore.data()); + } else { + flashmaskv3_bwd_params_set_tile_count_semaphore(params_handle, nullptr); + } + + paddle::Tensor dq_semaphore = + paddle::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, + paddle::DataType::INT32, place); + flashmaskv3_bwd_params_set_dq_semaphore(params_handle, + dq_semaphore.data()); + + paddle::Tensor dk_semaphore; + paddle::Tensor dv_semaphore; + if (num_heads_k != num_heads && + flashmaskv3_bwd_params_get_deterministic(params_handle)) { + // xiangrui: we need to zero them out + dk_semaphore = paddle::full( + {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, 0, + paddle::DataType::INT32, place); + + dv_semaphore = paddle::full( + {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, 0, + paddle::DataType::INT32, place); + + flashmaskv3_bwd_params_set_dk_semaphore(params_handle, + dk_semaphore.data()); + flashmaskv3_bwd_params_set_dv_semaphore(params_handle, + dv_semaphore.data()); + } + + if (is_flashmask) { + flashmaskv3_bwd_params_set_lt_start_ptr( + params_handle, const_cast(lt_start_ptr)); + flashmaskv3_bwd_params_set_lt_end_ptr(params_handle, + const_cast(lt_end_ptr)); + flashmaskv3_bwd_params_set_ut_start_ptr( + params_handle, const_cast(ut_start_ptr)); + flashmaskv3_bwd_params_set_ut_end_ptr(params_handle, + const_cast(ut_end_ptr)); + + if (flashmask_maxmin.initialized()) + flashmaskv3_bwd_params_set_flashmask_maxmin_ptr( + params_handle, (flashmask_maxmin.data())); + else + flashmaskv3_bwd_params_set_flashmask_maxmin_ptr(params_handle, nullptr); + + flashmaskv3_bwd_params_set_h_flashmask(params_handle, + startend_row_indices.dims()[1]); + flashmaskv3_bwd_params_set_h_h_flashmask_ratio( + params_handle, num_heads / startend_row_indices.dims()[1]); + } else { + flashmaskv3_bwd_params_set_lt_start_ptr(params_handle, nullptr); + flashmaskv3_bwd_params_set_lt_end_ptr(params_handle, nullptr); + flashmaskv3_bwd_params_set_ut_start_ptr(params_handle, nullptr); + flashmaskv3_bwd_params_set_ut_end_ptr(params_handle, nullptr); + flashmaskv3_bwd_params_set_flashmask_maxmin_ptr(params_handle, nullptr); + flashmaskv3_bwd_params_set_h_flashmask(params_handle, 0); + flashmaskv3_bwd_params_set_h_h_flashmask_ratio(params_handle, 0); + } + + if (is_blockmask) { + // xhy: blockmask is now only support blockdim_q k = 128 + flashmaskv3_bwd_params_set_m_block_dim(params_handle, 128); + flashmaskv3_bwd_params_set_n_block_dim(params_handle, 128); + flashmaskv3_bwd_params_set_block_mask_ptr(params_handle, + (block_mask.data())); + } +#ifdef FLASHATTENTION_DISABLE_LOCAL + PADDLE_ENABLE_EQ( + !flashmaskv3_bwd_params_get_is_local(params_handle), true, + "This flash attention build does not support local attention."); +#endif +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + PADDLE_ENABLE_EQ( + flashmaskv3_bwd_params_get_softcap(params_handle), 0.0, + "This flash attention build does not support tanh softcapping."); +#endif + + if (total_q > 0 && total_k > 0 && num_heads_k > 0) { + flashmaskv3_run_mha_bwd(params_handle, stream); + } else if (total_k > 0 && num_heads_k > 0) { + *dk = paddle::full(dk->shape(), T{0}, q_type, place); + *dv = paddle::full(dv->shape(), T{0}, q_type, place); + if (softmax_d) { + *softmax_d = paddle::full(softmax_d->shape(), float{0}, + paddle::DataType::FLOAT32, place); + } + } else if (total_q > 0 && num_heads_k > 0) { + *dq = paddle::full(dq->shape(), T{0}, q_type, place); + if (softmax_d) { + *softmax_d = paddle::full(softmax_d->shape(), float{0}, + paddle::DataType::FLOAT32, place); + } + } +#else + RaiseNotSupportedError(); +#endif +} + +#define FLASHMASK_V3_GRAD_BASE_KERNEL_IMPL(DType) \ + template void FlashMaskV3GradBaseKernel( \ + const paddle::Tensor &dout, const paddle::Tensor &q, \ + const paddle::Tensor &k, const paddle::Tensor &v, \ + const paddle::Tensor &out, const paddle::Tensor &softmax_lse, \ + const paddle::optional &dq_, \ + const paddle::optional &dk_, \ + const paddle::optional &dv_, \ + const paddle::optional &cu_seqlens_q_, \ + const paddle::optional &cu_seqlens_k_, \ + const paddle::optional &seqused_q_, \ + const paddle::optional &seqused_k_, \ + const paddle::optional &startend_row_indices_, \ + const paddle::optional &block_mask_, int max_seqlen_q_, \ + int max_seqlen_k_, float const softmax_scale, bool is_causal, \ + int window_size_left, int window_size_right, float const softcap, \ + bool const deterministic, int const sm_margin, paddle::Tensor *dq, \ + paddle::Tensor *dk, paddle::Tensor *dv, paddle::Tensor *softmax_d, \ + paddle::Tensor *softmax_lse_log2, paddle::Tensor *dq_accum, \ + paddle::Tensor *dk_accum, paddle::Tensor *dv_accum) + +FLASHMASK_V3_GRAD_BASE_KERNEL_IMPL(paddle::float16); +FLASHMASK_V3_GRAD_BASE_KERNEL_IMPL(paddle::bfloat16); diff --git a/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3_kernel.cu b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3_kernel.cu new file mode 100644 index 00000000000..8d368143f38 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/csrc/flashmask_v3_kernel.cu @@ -0,0 +1,1055 @@ +/****************************************************************************** + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#include "flash_attn_v3_utils.h" +#include "flashmask_v3.h" +#include +#include + +template +void FlashMaskV3BaseKernel( + const paddle::Tensor &q, const paddle::Tensor &k, const paddle::Tensor &v, + const paddle::optional + &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is + // cu_seqlens_k_new + const paddle::optional + &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is + // cu_seqlens_k_new + const paddle::optional + &q_v_, // (b, s_q, h, dv) or (total_q_new, h, + // dv) if there is cu_seqlens_q + const paddle::optional + &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + const paddle::optional &cu_seqlens_q_, // b+1 + const paddle::optional &cu_seqlens_k_, // b+1 + const paddle::optional &cu_seqlens_k_new_, // b+1 + const paddle::optional + &seqused_q_, // b. If given, only this many elements of each batch + // element's queries and outputs are used. + const paddle::optional + &seqused_k_, // b. If given, only this many elements of each batch + // element's keys are used. + const paddle::optional + &page_table_, // (b_k, max_num_pages_per_seq) + const paddle::optional + &kv_batch_idx_, // b. indices to index into the KV cache + const paddle::optional &leftpad_k_, // b + const paddle::optional + &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + const paddle::optional + &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + const paddle::optional &q_descale_, // (b, h_k), not (b, h) + const paddle::optional &k_descale_, // (b, h_k) + const paddle::optional &v_descale_, // (b, h_k) + const paddle::optional &scheduler_metadata_, // (b + 1) + const paddle::optional + &startend_row_indices_, // (b,h,s_1,[1,2,4]) + const paddle::optional + &block_mask_, // ((b,h,s// 128,s // 128) + const int max_seqlen_q_, // if max_seqlen_q_ is set to 0, it indicates that + // it is uninitialized and should not be referenced + // TODO(tridao): check if we need max_seqlen_k + const int max_seqlen_k_, // if max_seqlen_q_ is set to 0, it indicates that + // it is uninitialized and should not be referenced + const float softmax_scale, bool is_causal, int window_size_left, + int window_size_right, const float softcap, + const bool is_rotary_interleaved, // if true, rotary combines indices 0 & + // 1, else indices 0 & rotary_dim / 2 + int num_splits, const bool manual_set_pack_gqa, + const bool + pack_gqa_, // the pack_gqa_ will be used only if manual_set_pack_gqa is + // set to True; otherwise, the internal heuristic + // get_pack_gqa() from fa3 will decide whether to pack gqa + const int sm_margin, paddle::Tensor *out, paddle::Tensor *softmax_lse, + paddle::Tensor *out_accum, paddle::Tensor *softmax_lse_accum) { +#ifdef PADDLE_WITH_FLASHATTN_V3 + + cudaStream_t stream = static_cast(q.stream()); + auto place = q.place(); + + // TODO(umiswing): support ampere + int device_id = place.GetDeviceId(); + cudaDeviceProp dprops; + cudaGetDeviceProperties(&dprops, device_id); + + const bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + PADDLE_ENFORCE_EQ(is_sm90, true, + common::errors::Unavailable( + "FlashAttention-3 only supports Hopper GPUs.")); + + auto q_type = q.dtype(); + PADDLE_ENFORCE_EQ( + (q_type == paddle::DataType::FLOAT16 || + q_type == paddle::DataType::BFLOAT16 || + q_type == paddle::DataType::FLOAT8_E4M3FN), + true, + common::errors::InvalidArgument( + "FlashAttention-3 only supports fp16, bf16, and fp8_e4m3 data type")); + + PADDLE_ENFORCE_EQ(k.dtype(), q_type, + common::errors::InvalidArgument( + "query and key must have the same dtype")); + PADDLE_ENFORCE_EQ(v.dtype(), q_type, + common::errors::InvalidArgument( + "query and value must have the same dtype")); + + CHECK_DEVICE(q); + CHECK_DEVICE(k); + CHECK_DEVICE(v); + + PADDLE_ENFORCE_EQ(q.strides()[q.strides().size() - 1], 1, + common::errors::InvalidArgument( + "Input tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(k.strides()[k.strides().size() - 1], 1, + common::errors::InvalidArgument( + "Input tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(v.strides()[v.strides().size() - 1], 1, + common::errors::InvalidArgument( + "Input tensor must have contiguous last dimension")); + + paddle::Tensor page_table; + // const bool paged_KV = page_table_.has_value(); + // umiswing: this is stupid but idk how to use optional + const bool paged_KV = page_table_.is_initialized(); + if (paged_KV) { + page_table = page_table_.get(); + CHECK_DEVICE(page_table); + PADDLE_ENFORCE_EQ(page_table.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "page_table must have dtype paddle.int32")); + PADDLE_ENFORCE_EQ(page_table.strides()[page_table.strides().size() - 1], 1, + common::errors::InvalidArgument( + "page_table must have contiguous last dimension")); + } + + // TODO(umiswing): support cusum + + paddle::Tensor cu_seqlens_q; + // bool const is_varlen_q = cu_seqlens_q_.has_value(); + // TODO(umiswing): this is stupid, must fix it (after understand + // optional) + const bool is_varlen_q = cu_seqlens_q_.is_initialized(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.get(); + CHECK_DEVICE(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_q); + PADDLE_ENFORCE_EQ(cu_seqlens_q.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "cu_seqlens_q must have dtype paddle.int32")); + PADDLE_ENFORCE_NE( + max_seqlen_q_, 0, + common::errors::InvalidArgument( + "max_seqlen_q must be provided if cu_seqlens_q is provided")); + } + + paddle::Tensor cu_seqlens_k; + const bool is_varlen_k = cu_seqlens_k_.is_initialized(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.get(); + CHECK_DEVICE(cu_seqlens_k); + CHECK_CONTIGUOUS(cu_seqlens_k); + PADDLE_ENFORCE_EQ(cu_seqlens_k.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "cu_seqlens_k must have dtype paddle.int32")); + PADDLE_ENFORCE_NE( + max_seqlen_k_, 0, + common::errors::InvalidArgument( + "max_seqlen_k must be provided if cu_seqlens_k is provided")); + PADDLE_ENFORCE_EQ( + !paged_KV, true, + common::errors::InvalidArgument( + "If cu_seqlens_k is passed in, then page table is not supported")); + PADDLE_ENFORCE_EQ( + !kv_batch_idx_, true, + common::errors::InvalidArgument( + "If cu_seqlens_k is passed in, then page table is not supported")); + } + + auto const sizes = q.dims(); + const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.dims()[0] - 1; + int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_; + int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int64_t num_heads = q.dims()[q.dims().size() - 2]; + int64_t const head_size = q.dims()[q.dims().size() - 1]; + int const head_size_v = v.dims()[v.dims().size() - 1]; + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.dims()[1]; + int const num_pages = !paged_KV ? 0 : k.dims()[0]; + int const page_size = !paged_KV ? 1 : k.dims()[1]; + int const seqlen_k = + !is_varlen_k + ? (!paged_KV ? k.dims()[1] : max_num_pages_per_seq * page_size) + : max_seqlen_k_; + int const total_k = !is_varlen_k ? batch_size * k.dims()[1] : k.dims()[0]; + int const num_heads_k = k.dims()[k.dims().size() - 2]; + int const batch_size_k = + !paged_KV ? (!is_varlen_k ? k.dims()[0] : cu_seqlens_k.dims()[0] - 1) + : page_table.dims()[0]; + if (!kv_batch_idx_.is_initialized()) { + PADDLE_ENFORCE_EQ(batch_size, batch_size_k, + common::errors::InvalidArgument( + "batch_size must be equal to batch_size_k")); + } + int const max_headdim = flashmaskv3_get_max_headdim(); + PADDLE_ENFORCE_LE( + head_size, max_headdim, + common::errors::InvalidArgument( + "FlashAttention forward only supports head dimension at most %d", + max_headdim)); + PADDLE_ENFORCE_EQ( + num_heads % num_heads_k, 0, + common::errors::InvalidArgument( + "Number of heads in key/value must divide number of heads in query")); + if (head_size_v != head_size) { + PADDLE_ENFORCE_EQ( + ((head_size > 128 && head_size <= 192 && head_size_v > 96 && + head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512)), + true, + common::errors::InvalidArgument( + "If V headdim is different from Q/K dim, we only support " + "Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512).")); + PADDLE_ENFORCE_EQ(dprops.major, 9, + common::errors::InvalidArgument( + "Only Hopper supports different V headdim")); + if (head_size_v > 256) { + PADDLE_ENFORCE_EQ((q_type == paddle::DataType::FLOAT16 || + q_type == paddle::DataType::BFLOAT16), + true, + common::errors::InvalidArgument( + "HeaddimV > 256 requires fp16 and bf16 data type")); + } + } + + bool const is_flashmask = startend_row_indices_.is_initialized(); + bool const is_blockmask = block_mask_.is_initialized(); + + // This needs to go before kBlockM & kBlockN since we rely on the correct + // window_size and is_causal to set kBlockM + // TODO(tridao): check this + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + if (window_size_right >= seqlen_q - 1) { + window_size_right = -1; + } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better + // for pagedKV and TMA + if (((head_size <= 64 || head_size > 128) || !paged_KV) && !is_flashmask) { + is_causal = false; + } + } + if (is_causal) { + window_size_right = 0; + } + // There's a case where is_causal=false, window_size=(-1, 0). Then + // set_params_fprop will set params.is_causal=true. If we don't have is_causal + // here matching params.is_causal, we might get the wrong kBlockM. + is_causal = window_size_left < 0 && window_size_right == 0; + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.is_initialized()) { + auto seqused_q = seqused_q_.get(); + PADDLE_ENFORCE_EQ( + seqused_q.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument("seqused_q must have dtype int32")); + CHECK_DEVICE(seqused_q); + CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.is_initialized()) { + auto seqused_k = seqused_k_.get(); + PADDLE_ENFORCE_EQ( + seqused_k.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument("seqused_k must have dtype int32")); + CHECK_DEVICE(seqused_k); + CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.is_initialized()) { + auto leftpad_k = leftpad_k_.get(); + PADDLE_ENFORCE_EQ( + leftpad_k.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument("leftpad_k must have dtype int32")); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = + is_varlen_q || is_varlen_k || seqused_q_.is_initialized() || + seqused_k_.is_initialized() || leftpad_k_.is_initialized(); +#ifdef FLASHATTENTION_DISABLE_VARLEN + PADDLE_ENFORCE_EQ(!is_varlen, true, + common::errors::Unavailable( + "This flash attention build does not support varlen.")); +#endif + + int const alignment = q_type == paddle::DataType::FLOAT8_E4M3FN ? 16 : 8; + PADDLE_ENFORCE_EQ(head_size % alignment, 0, + common::errors::InvalidArgument( + "head_size should be a multiple of %d", alignment)); + PADDLE_ENFORCE_EQ(head_size_v % alignment, 0, + common::errors::InvalidArgument( + "head_size_v should be a multiple of %d", alignment)); + + auto out_type = q_type == paddle::DataType::FLOAT8_E4M3FN + ? paddle::DataType::BFLOAT16 + : q_type; + if (out_.is_initialized()) { + *out = out_.get(); + PADDLE_ENFORCE_EQ( + out->dtype(), out_type, + common::errors::InvalidArgument( + "For FP16/BF16 input, output must have the same dtype as " + "inputs. For FP8 input, output must have dtype BF16")); + CHECK_DEVICE((*out)); + PADDLE_ENFORCE_EQ(out->strides()[out->strides().size() - 1], 1, + common::errors::InvalidArgument( + "Output tensor must have contiguous last dimension")); + if (!is_varlen_q) { + CHECK_SHAPE((*out), batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE((*out), total_q, num_heads, head_size_v); + } + } else { + // TODO 明确一下,q type是不是和 t 是一样的 + auto out_type = q_type == paddle::DataType::FLOAT8_E4M3FN + ? paddle::DataType::BFLOAT16 + : q_type; + + if (!is_varlen_q) { + *out = paddle::empty({batch_size, seqlen_q, num_heads, head_size_v}, + out_type, place); + + } else { + *out = paddle::empty({total_q, num_heads, head_size_v}, out_type, place); + } + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = flashmaskv3_round_up_headdim(head_size); + int const head_size_v_rounded = flashmaskv3_round_up_headdim(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + *softmax_lse = paddle::empty({batch_size, num_heads, seqlen_q}, + paddle::DataType::FLOAT32, place); + + FlashMask_fwd_params *params_handle = get_flashmask_fwd_params_handle(); + flashmaskv3_clear_fwd_params_handle(params_handle); + set_flashmaskv3_params_fprop( + params_handle, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, + seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, + k, v, out, !is_varlen_q ? nullptr : cu_seqlens_q.data(), + !is_varlen_k ? nullptr : cu_seqlens_k.data(), + seqused_q_.is_initialized() ? const_cast(seqused_q_.get().data()) + : nullptr, + seqused_k_.is_initialized() ? const_cast(seqused_k_.get().data()) + : nullptr, + softmax_lse->data(), + /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, + dprops, softcap, sm_margin); + flashmaskv3_fwd_params_set_total_q(params_handle, total_q); + flashmaskv3_fwd_params_set_total_k(params_handle, total_k); + flashmaskv3_fwd_params_set_b_k(params_handle, batch_size_k); + flashmaskv3_fwd_params_set_dv(params_handle, head_size_v); + flashmaskv3_fwd_params_set_dv_rounded(params_handle, head_size_v_rounded); + + if (leftpad_k_ + .is_initialized()) { // This needs to be set before get_pagedkv_tma + flashmaskv3_fwd_params_set_leftpad_k(params_handle, + leftpad_k_.get().data()); + } + if (paged_KV) { + flashmaskv3_fwd_params_set_page_table(params_handle, + page_table.data()); + flashmaskv3_fwd_params_set_page_table_batch_stride(params_handle, + page_table.strides()[0]); + } + flashmaskv3_fwd_params_set_page_size(params_handle, page_size); + flashmaskv3_fwd_params_set_num_pages(params_handle, num_pages); + + if (k_new_.is_initialized()) { // This needs to be set before get_pagedkv_tma + paddle::Tensor k_new, v_new; + PADDLE_ENFORCE_EQ( + v_new_.is_initialized(), true, + common::errors::InvalidArgument( + "If k_new is supplied, v_new must also be passed in")); + PADDLE_ENFORCE_EQ( + seqused_k_.is_initialized(), true, + common::errors::InvalidArgument( + "If k_new is supplied, seqlens_k must also be passed in")); + PADDLE_ENFORCE_LE( + seqlen_q, seqlen_k, + common::errors::InvalidArgument( + "If k_new is supplied, it must have seqlen <= the seqlen " + "of the KV cache")); + paddle::Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.is_initialized(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.get(); + CHECK_DEVICE(cu_seqlens_k_new); + CHECK_CONTIGUOUS(cu_seqlens_k_new); + PADDLE_ENFORCE_EQ(cu_seqlens_k_new.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "cu_seqlens_k_new must have dtype paddle.int32")); + } + k_new = k_new_.get(); + v_new = v_new_.get(); + PADDLE_ENFORCE_EQ(k_new.dtype(), q_type, + common::errors::InvalidArgument( + "k_new must have the same dtype as query")); + PADDLE_ENFORCE_EQ(v_new.dtype(), q_type, + common::errors::InvalidArgument( + "v_new must have the same dtype as query")); + CHECK_DEVICE(k_new); + CHECK_DEVICE(v_new); + PADDLE_ENFORCE_EQ(k_new.strides()[k_new.strides().size() - 1], 1, + common::errors::InvalidArgument( + "k_new tensor must have contiguous last dimension")); + PADDLE_ENFORCE_EQ(v_new.strides()[v_new.strides().size() - 1], 1, + common::errors::InvalidArgument( + "v_new tensor must have contiguous last dimension")); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when + // is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.dims()[1] : 0; + int total_k_new = + !is_varlen_k_new ? batch_size * k_new.dims()[1] : k_new.dims()[0]; + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + // umiswing: dump this to shared library + flashmaskv3_fwd_params_set_seqlen_knew(params_handle, seqlen_k_new); + flashmaskv3_fwd_params_set_total_knew(params_handle, total_k_new); + flashmaskv3_fwd_params_set_knew_ptr(params_handle, (k_new.data())); + flashmaskv3_fwd_params_set_vnew_ptr(params_handle, (v_new.data())); + // All stride are in elements, not bytes. + flashmaskv3_fwd_params_set_knew_row_stride( + params_handle, k_new.strides()[k_new.strides().size() - 3]); + flashmaskv3_fwd_params_set_vnew_row_stride( + params_handle, v_new.strides()[v_new.strides().size() - 3]); + flashmaskv3_fwd_params_set_knew_head_stride( + params_handle, k_new.strides()[k_new.strides().size() - 2]); + flashmaskv3_fwd_params_set_vnew_head_stride( + params_handle, v_new.strides()[v_new.strides().size() - 2]); + if (!is_varlen_k_new) { + flashmaskv3_fwd_params_set_knew_batch_stride(params_handle, + k_new.strides()[0]); + flashmaskv3_fwd_params_set_vnew_batch_stride(params_handle, + v_new.strides()[0]); + } + if (is_varlen_k_new) { + flashmaskv3_fwd_params_set_cu_seqlens_knew(params_handle, + cu_seqlens_k_new.data()); + } + } + + // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks + // kernel + bool const use_dynamic_split = + is_varlen && flashmaskv3_fwd_params_get_b(params_handle) <= 992; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + flashmaskv3_fwd_params_set_num_splits_dynamic_ptr( + params_handle, !use_dynamic_split ? nullptr : reinterpret_cast(1)); + + flashmaskv3_fwd_params_set_pagedkv_tma( + params_handle, flashmaskv3_get_pagedkv_tma(params_handle)); + if (num_splits <= 0) { + num_splits = flashmaskv3_get_num_splits(params_handle); + } + flashmaskv3_fwd_params_set_num_splits(params_handle, num_splits); + + // Always enable PackGQA for Split, and get_pack_gqa requires + // params.num_splits to decide + const bool pack_gqa = + manual_set_pack_gqa ? pack_gqa_ : flashmaskv3_get_pack_gqa(params_handle); + flashmaskv3_fwd_params_set_pack_gqa(params_handle, pack_gqa); + + // This needs to be set after get_num_splits + paddle::Tensor tile_count_semaphore; // Contains the semaphore and optionally + // num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + const bool params_is_causal = + flashmaskv3_fwd_params_get_is_causal(params_handle); + const bool params_is_local = + flashmaskv3_fwd_params_get_is_local(params_handle); + const int params_num_splits = + flashmaskv3_fwd_params_get_num_splits(params_handle); + const int params_b = flashmaskv3_fwd_params_get_b(params_handle); + const int params_arch = flashmaskv3_fwd_params_get_arch(params_handle); + bool const scheduler_needs_semaphore = + params_arch >= 90 ? true + : ((params_is_causal && !is_varlen) || + (is_varlen && params_num_splits > 1)); + int metadata_size = 0; + if (scheduler_needs_semaphore || use_dynamic_split) { + metadata_size = static_cast(scheduler_needs_semaphore) + + static_cast(use_dynamic_split) * params_b; + + flashmaskv3_fwd_params_set_skip_scheduler_metadata_computation( + params_handle, scheduler_metadata_.is_initialized()); + + if (scheduler_metadata_.is_initialized()) { + paddle::Tensor scheduler_metadata = scheduler_metadata_.get(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + PADDLE_ENFORCE_EQ(scheduler_metadata.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument( + "scheduler_metadata must have dtype int32")); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = + paddle::empty({metadata_size}, paddle::DataType::INT32, place); + } + if (scheduler_needs_semaphore && !use_dynamic_split) { + // If varlen we'll manually do the zero-ing + tile_count_semaphore = + paddle::full(tile_count_semaphore.shape(), int32_t{0}, + paddle::DataType::INT32, place); + } + flashmaskv3_fwd_params_set_tile_count_semaphore( + params_handle, scheduler_needs_semaphore + ? (tile_count_semaphore.data()) + : nullptr); + flashmaskv3_fwd_params_set_num_splits_dynamic_ptr( + params_handle, + use_dynamic_split ? (tile_count_semaphore.data()) + 1 : nullptr); + } + + if (q_v_.is_initialized()) { + PADDLE_ENFORCE_LT(head_size, 64, + common::errors::InvalidArgument( + "q_v is only supported for head_size <= 64")); + PADDLE_ENFORCE_EQ((q_type == paddle::DataType::FLOAT16 || + q_type == paddle::DataType::FLOAT16), + true, + common::errors::InvalidArgument( + "q_v is only supported for fp16 and bf16 data type")); + PADDLE_ENFORCE_EQ(params_arch, 90, + common::errors::InvalidArgument( + "q_v is only supported for Hopper GPUs")); + paddle::Tensor q_v = q_v_.get(); + PADDLE_ENFORCE_EQ(q_v.dtype(), q_type, + common::errors::InvalidArgument( + "q_v must have the same dtype as query")); + CHECK_DEVICE(q_v); + PADDLE_ENFORCE_EQ(q_v.strides()[q_v.strides().size() - 1], 1, + common::errors::InvalidArgument( + "q_v tensor must have contiguous last dimension")); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + flashmaskv3_fwd_params_set_qv_ptr(params_handle, (q_v.data())); + // All stride are in elements, not bytes. + flashmaskv3_fwd_params_set_qv_row_stride( + params_handle, q_v.strides()[q_v.strides().size() - 3]); + flashmaskv3_fwd_params_set_qv_head_stride( + params_handle, q_v.strides()[q_v.strides().size() - 2]); + if (!is_varlen_q) { + flashmaskv3_fwd_params_set_qv_batch_stride(params_handle, + q_v.strides()[0]); + } + } + + if (rotary_cos_.is_initialized()) { + PADDLE_ENFORCE_EQ( + k_new_.is_initialized(), true, + common::errors::InvalidArgument( + "If rotary cos/sin are provided, new key / value to be " + "appended to KV cache must also be provided")); + paddle::Tensor rotary_cos = rotary_cos_.get(); + CHECK_DEVICE(rotary_cos); + CHECK_CONTIGUOUS(rotary_cos); + int params_rotary_dim = rotary_cos.dims()[1] * 2; + flashmaskv3_fwd_params_set_rotary_dim(params_handle, params_rotary_dim); + PADDLE_ENFORCE_LE( + params_rotary_dim, head_size, + common::errors::InvalidArgument("rotary_dim must be <= headdim")); + PADDLE_ENFORCE_EQ( + params_rotary_dim % 16, 0, + common::errors::InvalidArgument( + "Only rotary dimensions divisible by 16 are currently supported")); + // TODO(large-tensor): downstream functors may still use int; guard until + // upgraded. + int64_t seqlen_ro = rotary_cos.dims()[0]; + + if (paged_KV) { + PADDLE_ENFORCE_GE( + seqlen_ro, seqlen_k, + common::errors::InvalidArgument( + "cos/sin seqlen must be at least the seqlen of KV cache")); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params_rotary_dim / 2); + PADDLE_ENFORCE_EQ(rotary_cos.dtype(), q_type, + common::errors::InvalidArgument( + "rotary_cos must have the same dtype as query")); + + PADDLE_ENFORCE_EQ( + rotary_sin_.is_initialized(), true, + common::errors::InvalidArgument( + "If rotary cos is provided, rotary sin must also be provided")); + auto rotary_sin = rotary_sin_.get(); + CHECK_DEVICE(rotary_sin); + CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params_rotary_dim / 2); + PADDLE_ENFORCE_EQ(rotary_sin.dtype(), q_type, + common::errors::InvalidArgument( + "rotary_cos must have the same dtype as query")); + + flashmaskv3_fwd_params_set_rotary_cos_ptr(params_handle, + (rotary_cos.data())); + flashmaskv3_fwd_params_set_rotary_sin_ptr(params_handle, + (rotary_sin.data())); + flashmaskv3_fwd_params_set_is_rotary_interleaved(params_handle, + is_rotary_interleaved); + } else { + flashmaskv3_fwd_params_set_rotary_dim(params_handle, 0); + } + + if (kv_batch_idx_.is_initialized()) { + paddle::Tensor kv_batch_idx = kv_batch_idx_.get(); + CHECK_DEVICE(kv_batch_idx); + CHECK_CONTIGUOUS(kv_batch_idx); + PADDLE_ENFORCE_EQ( + kv_batch_idx.dtype(), paddle::DataType::INT32, + common::errors::InvalidArgument("kv_batch_idx must have dtype int32")); + flashmaskv3_fwd_params_set_kv_batch_idx( + params_handle, reinterpret_cast(kv_batch_idx.data())); + } + + if (flashmaskv3_fwd_params_get_num_splits(params_handle) > 1) { + PADDLE_ENFORCE_LE( + flashmaskv3_fwd_params_get_num_splits(params_handle), 256, + common::errors::InvalidArgument("num_splits > 256 not supported")); + if (!is_varlen_q) { + + *out_accum = + paddle::empty({flashmaskv3_fwd_params_get_num_splits(params_handle), + batch_size, num_heads, seqlen_q, head_size_v}, + paddle::DataType::FLOAT32, place); + + *softmax_lse_accum = + paddle::empty({flashmaskv3_fwd_params_get_num_splits(params_handle), + batch_size, num_heads, seqlen_q}, + paddle::DataType::FLOAT32, place); + + flashmaskv3_fwd_params_set_oaccum_batch_stride(params_handle, + out_accum->strides()[1]); + flashmaskv3_fwd_params_set_lseaccum_batch_stride( + params_handle, softmax_lse_accum->strides()[1]); + } else { + *out_accum = + paddle::empty({flashmaskv3_fwd_params_get_num_splits(params_handle), + num_heads, total_q, head_size_v}, + paddle::DataType::FLOAT32, place); + + *softmax_lse_accum = + paddle::empty({flashmaskv3_fwd_params_get_num_splits(params_handle), + num_heads, total_q}, + paddle::DataType::FLOAT32, place); + } + flashmaskv3_fwd_params_set_is_fp32(params_handle, false); + flashmaskv3_fwd_params_set_oaccum_ptr(params_handle, (out_accum->data())); + flashmaskv3_fwd_params_set_softmax_lseaccum_ptr( + params_handle, (softmax_lse_accum->data())); + flashmaskv3_fwd_params_set_oaccum_split_stride(params_handle, + out_accum->strides()[0]); + flashmaskv3_fwd_params_set_oaccum_row_stride( + params_handle, out_accum->strides()[out_accum->strides().size() - 2]); + flashmaskv3_fwd_params_set_oaccum_head_stride( + params_handle, out_accum->strides()[out_accum->strides().size() - 3]); + flashmaskv3_fwd_params_set_lseaccum_split_stride( + params_handle, softmax_lse_accum->strides()[0]); + flashmaskv3_fwd_params_set_lseaccum_head_stride( + params_handle, + softmax_lse_accum->strides()[softmax_lse_accum->strides().size() - 2]); + } + + if (q_type == paddle::DataType::FLOAT8_E4M3FN) { + if (q_descale_.is_initialized()) { + paddle::Tensor q_descale = q_descale_.get(); + CHECK_DEVICE(q_descale); + CHECK_SHAPE(q_descale, batch_size, num_heads_k); + flashmaskv3_fwd_params_set_q_descale_ptr(params_handle, + (q_descale.data())); + flashmaskv3_fwd_params_set_q_descale_batch_stride(params_handle, + q_descale.strides()[0]); + flashmaskv3_fwd_params_set_q_descale_head_stride(params_handle, + q_descale.strides()[1]); + } else { + flashmaskv3_fwd_params_set_q_descale_ptr(params_handle, nullptr); + } + if (k_descale_.is_initialized()) { + paddle::Tensor k_descale = k_descale_.get(); + CHECK_DEVICE(k_descale); + CHECK_SHAPE(k_descale, batch_size, num_heads_k); + flashmaskv3_fwd_params_set_k_descale_ptr(params_handle, + (k_descale.data())); + flashmaskv3_fwd_params_set_k_descale_batch_stride(params_handle, + k_descale.strides()[0]); + flashmaskv3_fwd_params_set_k_descale_head_stride(params_handle, + k_descale.strides()[1]); + } else { + flashmaskv3_fwd_params_set_k_descale_ptr(params_handle, nullptr); + } + if (v_descale_.is_initialized()) { + paddle::Tensor v_descale = v_descale_.get(); + CHECK_DEVICE(v_descale); + CHECK_SHAPE(v_descale, batch_size, num_heads_k); + flashmaskv3_fwd_params_set_v_descale_ptr(params_handle, + (v_descale.data())); + flashmaskv3_fwd_params_set_v_descale_batch_stride(params_handle, + v_descale.strides()[0]); + flashmaskv3_fwd_params_set_v_descale_head_stride(params_handle, + v_descale.strides()[1]); + } else { + flashmaskv3_fwd_params_set_v_descale_ptr(params_handle, nullptr); + } + } + +#ifdef FLASHATTENTION_DISABLE_LOCAL + PADDLE_ENFORCE_EQ( + !flashmaskv3_fwd_params_get_is_local(params_handle), true, + common::errors::InvalidArgument( + "This flash attention build does not support local attention.")); +#endif +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + PADDLE_ENFORCE_EQ( + flashmaskv3_fwd_params_get_softcap(params_handle), 0.0, + common::errors::InvalidArgument( + "This flash attention build does not support tanh softcapping.")); +#endif +#ifdef FLASHATTENTION_DISABLE_SPLIT + PADDLE_ENFORCE_EQ(flashmaskv3_fwd_params_get_num_splits(params_handle), 1, + common::errors::InvalidArgument( + "This flash attention build does not support splits.")); +#endif +#ifdef FLASHATTENTION_DISABLE_PACKGQA + PADDLE_ENFORCE_EQ( + (!flashmaskv3_fwd_params_get_pack_gqa(params_handle) || + flashmaskv3_fwd_params_get_arch(params_handle) < 90 || + (flashmaskv3_fwd_params_get_page_table(params_handle) && + !flashmaskv3_fwd_params_get_pagedkv_tma(params_handle)) || + flashmaskv3_fwd_params_get_num_splits(params_handle) > 1), + true, + common::errors::InvalidArgument( + "This flash attention build does not support pack_gqa.")); +#endif +#ifdef FLASHATTENTION_DISABLE_PAGEDKV + PADDLE_ENFORCE_EQ( + (!(flashmaskv3_fwd_params_get_page_table(params_handle) && + !flashmaskv3_fwd_params_get_pagedkv_tma(params_handle))), + true, + common::errors::InvalidArgument( + "This flash attention build does not support paged KV.")); +#endif +#ifdef FLASHATTENTION_DISABLE_APPENDKV + PADDLE_ENFORCE_EQ( + !k_new_.is_initialized(), true, + common::errors::InvalidArgument( + "This flash attention build does not support appending KV.")); +#endif + + // flashmask + paddle::Tensor startend_row_indices; + if (is_flashmask) + startend_row_indices = startend_row_indices_.get(); + paddle::Tensor block_mask; + if (is_blockmask) + block_mask = block_mask_.get(); + + paddle::Tensor flashmask_maxmin; + paddle::Tensor lt_start_slice, lt_end_slice, ut_start_slice, ut_end_slice; + const int32_t *lt_start_ptr; + const int32_t *lt_end_ptr; + const int32_t *ut_start_ptr; + const int32_t *ut_end_ptr; + + if (is_flashmask) { + PADDLE_ENFORCE_EQ( + startend_row_indices.dims().size(), 4, + common::errors::InvalidArgument( + "flashmask_attention receive startend_row_indices with dim " + "[batch_size, num_heads,seq_len, mask_bounds]")); + PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 || + startend_row_indices.dims()[3] == 2 || + startend_row_indices.dims()[3] == 4, + true, + common::errors::InvalidArgument( + "flashmask_attention startend_row_indices " + "mask_bounds must in [1,2,4]")); + + auto flashmask_maxmin_shape = startend_row_indices.dims(); + // TODO(umiswing): refine this block constraint (kBlockN % 32), since some + // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] = + // (flashmask_maxmin_shape[2] + 31) / 32 * 8; + + if (is_sm90) { + // seqlen_k to nblock_seqlen, here we use kBlockN = 64 + // as a conservative estimation (reduce allocation size) + flashmask_maxmin_shape[2] = + ((flashmask_maxmin_shape[2] + 63) / 64 + 3) / 4 * 4; + // make sure this is the same with FlashMaskV3 fwd main loop + static constexpr int flashmask_buffer_length = 16 * 1024; + // estimate the upper bound of the possible chunk size + static constexpr int chunk_padded_length = + ((flashmask_buffer_length + 63) / 64 + 31) & 0xffffffe0; + static constexpr int chunk_valid_length = + ((flashmask_buffer_length + 63) / 64 + 3) & 0xfffffffc; + const int num_chunk = + (flashmask_maxmin_shape[2] + chunk_valid_length - 1) / + chunk_valid_length; + flashmask_maxmin_shape[2] = num_chunk * chunk_padded_length; + } else { + // seqlen_k to nblock_seqlen + flashmask_maxmin_shape[2] = + ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4; + } + flashmask_maxmin_shape[3] = 8; + + flashmask_maxmin = + paddle::empty({flashmask_maxmin_shape[0], flashmask_maxmin_shape[1], + flashmask_maxmin_shape[2], flashmask_maxmin_shape[3]}, + paddle::DataType::INT32, place); + + const int32_t *mask_base_ptr = startend_row_indices.data(); + auto mask_dims = startend_row_indices.dims(); + int B_mask = mask_dims[0]; + int H_mask = mask_dims[1]; + int S_mask = mask_dims[2]; + int C = mask_dims[3]; + int total_elements = B_mask * H_mask * S_mask; + + lt_start_ptr = nullptr; + lt_end_ptr = nullptr; + ut_start_ptr = nullptr; + ut_end_ptr = nullptr; + + auto extract_channel = [&](int channel_idx) -> paddle::Tensor { + auto slice = paddle::empty({B_mask, H_mask, S_mask}, + paddle::DataType::INT32, place); + cudaMemcpy2DAsync(slice.data(), sizeof(int32_t), + mask_base_ptr + channel_idx, C * sizeof(int32_t), + sizeof(int32_t), total_elements, + cudaMemcpyDeviceToDevice, stream); + return slice; + }; + + if (C == 1) { + lt_start_ptr = mask_base_ptr; + } else if (C == 2) { + lt_start_slice = extract_channel(0); + lt_start_ptr = lt_start_slice.data(); + if (!is_causal) { + ut_end_slice = extract_channel(1); + ut_end_ptr = ut_end_slice.data(); + } else { + lt_end_slice = extract_channel(1); + lt_end_ptr = lt_end_slice.data(); + } + } else if (C == 4) { + lt_start_slice = extract_channel(0); + lt_start_ptr = lt_start_slice.data(); + lt_end_slice = extract_channel(1); + lt_end_ptr = lt_end_slice.data(); + ut_start_slice = extract_channel(2); + ut_start_ptr = ut_start_slice.data(); + ut_end_slice = extract_channel(3); + ut_end_ptr = ut_end_slice.data(); + } + } + + if (is_blockmask) { + PADDLE_ENFORCE_EQ( + is_flashmask, true, + common::errors::InvalidArgument( + "blockmask should be used with flashmask at the same time ")); + + PADDLE_ENFORCE_EQ(block_mask.dims().size(), 4, + common::errors::InvalidArgument( + "blockmask receive blockmask_indices with dim " + "[batch_size, num_heads, blocklen_q, blocklen_k]")); + + PADDLE_ENFORCE_EQ(block_mask.dims()[2], (seqlen_q + 127) / 128, + common::errors::InvalidArgument( + "blockmask is now only support blockdim_q = 128 ")); + + PADDLE_ENFORCE_EQ(block_mask.dims()[3], (seqlen_k + 127) / 128, + common::errors::InvalidArgument( + "blockmask is now only support blockdim_k = 128 ")); + + PADDLE_ENFORCE_EQ( + block_mask.dims()[1], startend_row_indices.dims()[1], + common::errors::InvalidArgument("blockmask is now only support same " + "dim num_heads with flashmask ")); + } + + if (is_blockmask) { + // xhy: blockmask is now only support blockdim_q k = 128 + flashmaskv3_fwd_params_set_m_block_dim(params_handle, 128); + flashmaskv3_fwd_params_set_n_block_dim(params_handle, 128); + flashmaskv3_fwd_params_set_block_mask_ptr(params_handle, + (block_mask.data())); + } + + if (is_flashmask) { + flashmaskv3_fwd_params_set_lt_start_ptr( + params_handle, const_cast(lt_start_ptr)); + flashmaskv3_fwd_params_set_lt_end_ptr(params_handle, + const_cast(lt_end_ptr)); + flashmaskv3_fwd_params_set_ut_start_ptr( + params_handle, const_cast(ut_start_ptr)); + flashmaskv3_fwd_params_set_ut_end_ptr(params_handle, + const_cast(ut_end_ptr)); + + if (flashmask_maxmin.initialized()) + flashmaskv3_fwd_params_set_flashmask_maxmin_ptr( + params_handle, (flashmask_maxmin.data())); + else + flashmaskv3_fwd_params_set_flashmask_maxmin_ptr(params_handle, nullptr); + + flashmaskv3_fwd_params_set_h_flashmask(params_handle, + startend_row_indices.dims()[1]); + flashmaskv3_fwd_params_set_h_h_flashmask_ratio( + params_handle, num_heads / startend_row_indices.dims()[1]); + } else { + flashmaskv3_fwd_params_set_lt_start_ptr(params_handle, nullptr); + flashmaskv3_fwd_params_set_lt_end_ptr(params_handle, nullptr); + flashmaskv3_fwd_params_set_ut_start_ptr(params_handle, nullptr); + flashmaskv3_fwd_params_set_ut_end_ptr(params_handle, nullptr); + flashmaskv3_fwd_params_set_flashmask_maxmin_ptr(params_handle, nullptr); + flashmaskv3_fwd_params_set_h_flashmask(params_handle, 0); + flashmaskv3_fwd_params_set_h_h_flashmask_ratio(params_handle, 0); + } + + if (total_q > 0 && + (total_k + flashmaskv3_fwd_params_get_total_knew(params_handle)) > 0 && + num_heads_k > 0) { + // flashmaskv3_run_mha_fwd(params_handle, dev_ctx.stream()); + flashmaskv3_run_mha_fwd(params_handle, stream); + if (flashmaskv3_fwd_params_get_num_splits(params_handle) > 1) { + if (out_type == paddle::DataType::BFLOAT16) { + // Since we want output in BF16. Otherwise fwd_combine will output to + // FP16 + flashmaskv3_fwd_params_set_is_bf16(params_handle, true); + } + // Unless there's seqused_q, for the purpose of attn_combine, we can just + // treat it as batch=1 and seqlen = total_q, and don't need to dispatch to + // Varlen there. However, with dynamic split, each row needs to know which + // batch it belongs to to read the number of splits, so we just use the + // varlen version of combine kernel. if (is_varlen_q && + // !seqused_q_.has_value()) { if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // } + flashmaskv3_run_mha_fwd_combine(params_handle, stream, + true /*enable_pdl*/); + } + } else if (total_q > 0 && num_heads_k > 0) { + PADDLE_ENFORCE_EQ( + (out->dtype() == paddle::DataType::BFLOAT16 || + out->dtype() == paddle::DataType::FLOAT16 || + out->dtype() == paddle::DataType::FLOAT8_E4M3FN), + true, + common::errors::InvalidArgument("flash attention 3 supports bfloat16, " + "float16 and float8_e4m3fn only.")); + // If seqlen_k == 0, then we have an empty tensor. We need to set the output + // to 0. + int64_t out_numel = batch_size * seqlen_q * num_heads * head_size_v; + if (out->dtype() == paddle::DataType::BFLOAT16) { + cudaMemsetAsync(out->data(), 0, out_numel * 2, stream); // bf16 = 2 bytes + } else if (out->dtype() == paddle::DataType::FLOAT16) { + cudaMemsetAsync(out->data(), 0, out_numel * 2, stream); // fp16 = 2 bytes + } else if (out->dtype() == paddle::DataType::FLOAT8_E4M3FN) { + cudaMemsetAsync(out->data(), 0, out_numel * 1, stream); // fp8 = 1 byte + } + + *softmax_lse = paddle::full({batch_size, num_heads, seqlen_q}, + std::numeric_limits::infinity(), + paddle::DataType::FLOAT32, place); + } + +#else + RaiseNotSupportedError(); +#endif +} + +#define FLASHMASK_V3_BASE_KERNEL_IMPL(DType) \ + template void FlashMaskV3BaseKernel( \ + const paddle::Tensor &q, const paddle::Tensor &k, \ + const paddle::Tensor &v, const paddle::optional &k_new_, \ + const paddle::optional &v_new_, \ + const paddle::optional &q_v_, \ + const paddle::optional &out_, \ + const paddle::optional &cu_seqlens_q_, \ + const paddle::optional &cu_seqlens_k_, \ + const paddle::optional &cu_seqlens_k_new_, \ + const paddle::optional &seqused_q_, \ + const paddle::optional &seqused_k_, \ + const paddle::optional &page_table_, \ + const paddle::optional &kv_batch_idx_, \ + const paddle::optional &leftpad_k_, \ + const paddle::optional &rotary_cos_, \ + const paddle::optional &rotary_sin_, \ + const paddle::optional &q_descale_, \ + const paddle::optional &k_descale_, \ + const paddle::optional &v_descale_, \ + const paddle::optional &scheduler_metadata_, \ + const paddle::optional &startend_row_indices_, \ + const paddle::optional &block_mask_, \ + const int max_seqlen_q_, const int max_seqlen_k_, \ + const float softmax_scale, bool is_causal, int window_size_left, \ + int window_size_right, const float softcap, \ + const bool is_rotary_interleaved, int num_splits, \ + const bool manual_set_pack_gqa, const bool pack_gqa_, \ + const int sm_margin, paddle::Tensor *out, paddle::Tensor *softmax_lse, \ + paddle::Tensor *out_accum, paddle::Tensor *softmax_lse_accum); + +FLASHMASK_V3_BASE_KERNEL_IMPL(paddle::float16); +FLASHMASK_V3_BASE_KERNEL_IMPL(paddle::bfloat16); diff --git a/flashmask/flash_mask/flashmask_attention_v3/cuda_check.h b/flashmask/flash_mask/flashmask_attention_v3/cuda_check.h new file mode 100644 index 00000000000..fad9aaf27ae --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/cuda_check.h @@ -0,0 +1,33 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) diff --git a/flashmask/flash_mask/flashmask_attention_v3/cutlass b/flashmask/flash_mask/flashmask_attention_v3/cutlass new file mode 160000 index 00000000000..afa17722036 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/cutlass @@ -0,0 +1 @@ +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/flashmask/flash_mask/flashmask_attention_v3/epilogue_bwd.hpp b/flashmask/flash_mask/flashmask_attention_v3/epilogue_bwd.hpp new file mode 100644 index 00000000000..032e9cfbeaa --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/epilogue_bwd.hpp @@ -0,0 +1,554 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/barrier.h" +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "seqlen.h" +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveEpilogueBwd { + + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads); + static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int(TileShape_MNK{})) / AtomLayoutKdKV>>()); + using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutdKVtTMA = + decltype(cute::composition(SmemLayoutdKVTMA{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); + + // If we don't use TMA + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16); + static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); + using SmemLayoutAtomdKVSTG = + decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + + using SmemLayoutAtomdKV = std::conditional_t; + using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutdKVt = + decltype(cute::composition(SmemLayoutdKV{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); + + using SmemCopyAtomdKV = Copy_Atom< + std::conditional_t< + ArchTag::kMinComputeCapability >= 90, + std::conditional_t, + AutoVectorizingCopyWithAssumedAlignment<128> + >, + Element>; + + static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128; + static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment"); + + struct TensorStorage : cute::aligned_struct { + cute::array_aligned, SmemAlignmentdKV> smem_dk; + cute::array_aligned, SmemAlignmentdKV> smem_dv; + }; + + using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) + using StridedKV = cute::Stride; + + using TMA_dKV = std::conditional_t< + Use_TMA, + decltype(make_tma_copy( + GmemTiledCopydKVTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedKV{}, StridedKV{}), + SmemLayoutdKVTMA{}, + select<1, 2>(TileShape_MNK{}), + _1{})), // no mcast for dKV + std::nullptr_t + >; + + // Host side kernel arguments + struct Arguments { + Element* ptr_dK; + ShapedKV const shape_dK; + StridedKV const stride_dK; + Element* ptr_dV; + StridedKV const stride_dV; + int const num_heads_q; + int* dk_semaphore; + int* dv_semaphore; + int const* cu_seqlens; + int const* seqused; + }; + + // Device side kernel params + struct Params { + Element* ptr_dK; + ShapedKV const shape_dK; + StridedKV const stride_dK; + Element* ptr_dV; + StridedKV const stride_dV; + TMA_dKV tma_store_dK, tma_store_dV; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); + Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV); + TMA_dKV tma_store_dK = [&] { + if constexpr (Use_TMA) { + return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV + } else { + return nullptr; + } + }(); + TMA_dKV tma_store_dV = [&] { + if constexpr (Use_TMA) { + return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV + } else { + return nullptr; + } + }(); + // print("tma_dv:\n"); + // cute::print(tma_store_dV); + return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, + tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA) { + cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO const& tdKrdK, + FrgTensorO const& tdVrdV, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [n_block, bidh, bidb] = block_coord; + Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{})); + Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{})); + Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{})); + Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{})); + auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx); + + Tensor tdVrdV_out = make_tensor_like(tdVrdV); + flash::convert_type_out(tdVrdV, tdVrdV_out); + Tensor tdKrdK_out = make_tensor_like(tdKrdK); + flash::convert_type_out(tdKrdK, tdKrdK_out); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); } + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading K and V + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + if constexpr (Use_TMA ) { + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); + Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK); + Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); + auto block_tma_dV = params.tma_store_dV.get_slice(_0{}); + Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) + Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) + Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) + Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); + // if (blockIdx.x == 0 && threadIdx.x == 128) { + // printf("sdV:\n"); + // // cute::print(block_tma_dV); + // cute::print_tensor(sdV); + // } + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_dV, tdVsdV, tdVgdV); + cute::copy(params.tma_store_dK, tdKsdK, tdKgdK); + tma_store_arrive(); + } + } + tma_store_wait<0>(); + // if (blockIdx.x == 0 && threadIdx.x == 128) { + // printf("gdv:\n"); + // cute::print_tensor(tdVgdV); + // // cute::print_tensor(gdV); + // } + // if (blockIdx.x == 0 && threadIdx.x == 128) { + // Tensor mdV1 = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, 0); + // printf("mdv:\n"); + // cute::print_tensor(mdV1); + // } + // // Tell warp 0 that smem_k and smem_v are ready + // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + + } else { + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + + GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); + Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) + Tensor tdKVrdV = make_fragment_like(tdKVgdV); + Tensor tdKVrdK = make_fragment_like(tdKVgdK); + Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + // Need to check OOB when reading from smem if kBlockN isn't evenly tiled + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; + flash::copy( + gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN); + flash::copy( + gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN); + // // Tell warp 0 that smem_k and smem_v are ready + // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v + // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + // Construct identity layout for gdKV + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + ); + flash::copy( + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + ); + } + } + + CUTLASS_DEVICE void + store_tail() { + // if constexpr (Use_TMA) { tma_store_wait<0>(); } + } + + // Write 0 to dK and dV + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + auto [n_block, bidh, bidb] = block_coord; + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + + GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); + Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKVrdKV = make_fragment_like(tdKVgdK); + clear(tdKVrdKV); + // Construct identity layout for gdKV + Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + ); + flash::copy( + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + ); + } + +}; + +template +struct CollectiveEpilogueBwdGQA { + + using TileShape_MNK = TileShape_MNK_; + using Element = ElementAccum; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp"); + static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup; + // Thread layout, 256 or 384 threads per row + // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ + using R2SLayoutAtomdKVaccum = Layout, Int>>; + using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdKVaccum{}, + Layout>{})); // Val layout, 4 vals per store + // For Sm80 + using R2GLayoutAtomdKVaccum = Layout>>; + using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2GLayoutAtomdKVaccum{}, + Layout>{})); // Val layout, 1 vals per store + + using SmemLayoutdKVaccum = Layout, Int>>; + using SmemLayoutdKVaccumFlat = Layout>>; + + // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we + // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue. + static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256); + struct TensorStorageTMA : cute::aligned_struct { + cute::array_aligned, SmemAlignment> smem_dkv; + }; + struct TensorStorageSTG { + cute::array smem_dkv; + }; + using TensorStorage = std::conditional_t; + + using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) + using StridedKV = cute::Stride<_1, int64_t, int64_t>; + + // Host side kernel arguments + struct Arguments { + ElementAccum* ptr_dKaccum; + ShapedKV const shape_dKaccum; + StridedKV const stride_dKaccum; + ElementAccum* ptr_dVaccum; + StridedKV const stride_dVaccum; + int num_heads_q; + int* dk_semaphore; + int* dv_semaphore; + int const* cu_seqlens; + int const* seqused; + }; + + // Device side kernel params + struct Params { + ElementAccum* ptr_dKaccum; + ShapedKV const shape_dKaccum; + StridedKV const stride_dKaccum; + ElementAccum* ptr_dVaccum; + StridedKV const stride_dVaccum; + cutlass::FastDivmod qhead_per_khead_divmod; + int* dk_semaphore; + int* dv_semaphore; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + if constexpr (Deterministic) { + assert(args.dk_semaphore != nullptr); + assert(args.dv_semaphore != nullptr); + } + return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum, + cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), + args.dk_semaphore, args.dv_semaphore, + args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO const& tdKrdK, + FrgTensorO const& tdVrdV, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [n_block, bidh, bidb] = block_coord; + int bidh_idx_in_group; + int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); + Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{}); + Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{}); + static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum); + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) + + R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; + auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); + Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV); + + // Only used if !Use_TMA + R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum; + auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx); + + // Make sure all WGs have finished reading K and V, otherwise we get racy dQ + // because smem_q could be changed. + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if constexpr (Use_TMA) { + Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); + } + + // int const num_batch = params.num_batch; + int const num_batch = get<2>(params.shape_dKaccum); + int const num_head_kv = get<1>(params.shape_dKaccum); + int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; + using Barrier = cutlass::GenericBarrier; + + // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} + + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); + } + // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);} + if constexpr (Use_TMA) { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (thread_idx == 0) { + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + tma_store_wait<0>(); + } + } else { + Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV); + Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum); + static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic))); + #pragma unroll + for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); } + } + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); + } + + if constexpr (Use_TMA) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum); + } + lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv; + // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} + + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); + } + // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);} + if constexpr (Use_TMA) { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (thread_idx == 0) { + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + tma_store_wait<0>(); + } + } else { + Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK); + Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum); + static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic))); + #pragma unroll + for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); } + } + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); + } + // // Tell warp 0 that smem_k and smem_v are ready + // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + } + + CUTLASS_DEVICE void + store_tail() { + } + + // Write 0 to dK and dV + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + // Don't need to do anything since dKaccum and dVaccum are already zero-initialized + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/epilogue_fwd.hpp b/flashmask/flash_mask/flashmask_attention_v3/epilogue_fwd.hpp new file mode 100644 index 00000000000..bae35354190 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/epilogue_fwd.hpp @@ -0,0 +1,498 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include // For FastDivMod +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" + +#include "seqlen.h" +#include "named_barrier.hpp" +#include "pack_gqa.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveEpilogueFwd { + + using TileShape_MNK_PV = TileShape_MNK_PV_; + using ClusterShape = ClusterShape_; + using Element = Element_; + using ElementPartial = float; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool Use_smem = !(Split && !Varlen); + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; + + static_assert(ArchTag::kMinComputeCapability >= 80); + static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); + static_assert(sizeof(Element) <= 2); + + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + + static constexpr bool LargeHeadDimV = kHeadDimV > 256; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements + // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times + // we need to call divmod. + static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; + // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); + static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); + using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) + using StrideO = cute::Stride; + using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; + using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; + // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) + using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; + using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + + using CopyOpR2S = std::conditional_t< + ArchTag::kMinComputeCapability >= 90, + // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using SmemCopyAtomO = Copy_Atom; + + // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{}); + // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment"); + // struct TensorStorage : cute::aligned_struct { + // cute::array_aligned : 0, SmemAlignmentO> smem_o; + // }; + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned : 0> smem_o; + }; + + using TMA_O = std::conditional_t< + Use_TMA_O, + decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), + SmemLayoutOTMA{}, + select<0, 1>(TileShape_MNK_PV{}), + _1{})), // no mcast for O + std::nullptr_t + >; + + // Host side kernel arguments + struct Arguments { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + float* ptr_LSE; + StrideLSE const stride_LSE; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + int32_t const nheads_kv; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Device side kernel params + struct Params { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + ShapeOPacked const shape_O_packed; + StrideOPacked const stride_O_packed; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + StrideOPacked const stride_O_partial_packed; + float* ptr_LSE; + StrideLSE const stride_LSE; + ShapeLSEPacked const shape_LSE_packed; + StrideLSEPacked const stride_LSE_packed; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + StrideLSEPacked const stride_LSE_partial_packed; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_O tma_store_O; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); + TMA_O tma_store_O = [&]{ + if constexpr (Use_TMA_O) { + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast + } else { + return nullptr; + } + }(); + // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); + auto const shape_O_packed = cute::conditional_return( + args.shape_O, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + ); + auto const stride_O_packed = cute::conditional_return( + args.stride_O, + make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) + ); + auto const stride_O_partial_packed = cute::conditional_return( + args.stride_O_partial, + make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) + ); + // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) + auto const shape_LSE_packed = cute::conditional_return( + select<0, 2, 3, 4>(args.shape_O), + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + ); + auto const stride_LSE_packed = cute::conditional_return( + args.stride_LSE, + make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) + ); + auto const stride_LSE_partial_packed = cute::conditional_return( + args.stride_LSE_partial, + make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) + ); + return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, + args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, + args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, + args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, + cutlass::FastDivmod(qhead_per_khead), + tma_store_O, args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA_O) { + cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO& tOrO, + FrgTensorLSE const& lse, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); + // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); + + static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); + // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. + // Otherwise we can permute after conversion. + if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } + Tensor tOrO_out = make_tensor_like(tOrO); + flash::convert_type_out(tOrO, tOrO_out); + if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } + + // Make sure all WGs have finished reading V + // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that + // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with + // cp.async if we need). + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + // Step 1: Write O from rmem -> smem + if constexpr (Use_smem) { + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + if constexpr (Use_TMA_O) { + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } else { + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + } else { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; + bool is_varlen = Varlen && params.cu_seqlens; + int offset_o = seqlen_info.offset; + int seqlen_o = seqlen_info.seqlen; + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + + // Step 2: Write LSE from rmem -> gmem + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + // (MMA,MMA_M,MMA_K) + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; + + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); + // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } + if (!LargeHeadDimV || warp_group_idx == 0) { + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + + // Step 3: Write O from smem -> gmem + if constexpr (Use_TMA_O) { + Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_O = params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOrO = make_fragment_like(tOsO); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + if constexpr (ArchTag::kMinComputeCapability >= 90) { + cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + if constexpr (!PackGQA) { + // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } else { + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // We already arrived on barrier_O earlier if !Use_smem + if constexpr (Use_smem) { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + if constexpr (!PackGQA) { + static constexpr int kGmemElemsPerStoreDirect = 2; + cute::Copy_Atom, ElementPartial> gmem_copy_direct; + // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); + Tensor tOgO = thread_mma.partition_C(gOpartial); + Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); + Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); + Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); + #pragma unroll + for (int m = 0; m < size(taccOcO_row); ++m) { + if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { + #pragma unroll + for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { + if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { + cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); + } + } + } + } + } else { + PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + } + } + + CUTLASS_DEVICE void + store_tail() { + // Don't need to do tma_store_wait<0>() here since we already did in @store + } + + // Write 0 to output and -inf to LSE + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + int offset_o = seqlen_info.offset; + int seqlen_o = seqlen_info.seqlen; + int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); + Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); + + static_assert(kBlockM <= NumEpilogueThreads); + if (thread_idx < kBlockM) { + const int row = m_block * kBlockM + thread_idx; + if constexpr (!PackGQA) { + if (row < seqlen_o) { mLSE(row) = -INFINITY; } + } else { + if (row < seqlen_o * qhead_per_khead) { + int m_idx, h_idx; + m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); + // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" + mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; + } + } + } + + // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, + // since it will not use the value of O if LSE is -inf. + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + if constexpr (!PackGQA) { + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + cute::clear(tOrO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); + cute::clear(tOrO); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash.h b/flashmask/flash_mask/flashmask_attention_v3/flash.h new file mode 100644 index 00000000000..57eac059d83 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash.h @@ -0,0 +1,260 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include "cuda_runtime.h" +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t v_dim_stride; + + // The number of heads. + int h, h_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + using index_t = int64_t; + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // For FP8 scaling + float * __restrict__ q_descale_ptr; + float * __restrict__ k_descale_ptr; + float * __restrict__ v_descale_ptr; + index_t q_descale_batch_stride; + index_t q_descale_head_stride; + index_t k_descale_batch_stride; + index_t k_descale_head_stride; + index_t v_descale_batch_stride; + index_t v_descale_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int total_q, total_k, total_knew; + int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim + + // The scaling factors for the kernel. + float scale_softmax; + float softcap; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ cu_seqlens_knew; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each q/k sequence. + int *__restrict__ seqused_q; + int *__restrict__ seqused_k; + + // The stride between rows of Oaccum. + index_t oaccum_split_stride; + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + + // The stride between rows of LSEaccum. + index_t lseaccum_split_stride; + index_t lseaccum_batch_stride; + index_t lseaccum_head_stride; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + void *__restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ kv_batch_idx; + + // Paged KV cache + int * __restrict__ page_table = nullptr; + index_t page_table_batch_stride; + int page_size; + int num_pages; + bool pagedkv_tma; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_fp32; + bool is_e4m3; + bool is_causal; + bool is_local; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + bool pack_gqa; + + int * __restrict__ tile_count_semaphore; + // int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; + int * __restrict__ num_splits_dynamic_ptr; + bool skip_scheduler_metadata_computation; + + int arch; + int num_sm; + + // FlashMask + int h_flashmask; + int h_h_flashmask_ratio; + + int32_t * __restrict__ lt_start_ptr = nullptr; + int32_t * __restrict__ lt_end_ptr = nullptr; + + int32_t * __restrict__ ut_start_ptr = nullptr; + int32_t * __restrict__ ut_end_ptr = nullptr; + + int32_t * __restrict__ flashmask_maxmin_ptr = nullptr; + + int32_t * __restrict__ lt_start_nblockmax = nullptr; + int32_t * __restrict__ lt_start_nblockmin = nullptr; + + int32_t * __restrict__ lt_end_nblockmax = nullptr; + int32_t * __restrict__ lt_end_nblockmin = nullptr; + + int32_t * __restrict__ ut_start_nblockmax = nullptr; + int32_t * __restrict__ ut_start_nblockmin = nullptr; + + int32_t * __restrict__ ut_end_nblockmax = nullptr; + int32_t * __restrict__ ut_end_nblockmin = nullptr; + + int m_block_dim,n_block_dim; + int32_t * __restrict__ block_mask_ptr = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + using index_t = int64_t; + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + void *__restrict__ softmax_lse_log2_ptr; + + int *__restrict__ dq_semaphore; + int *__restrict__ dk_semaphore; + int *__restrict__ dv_semaphore; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); +template +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +void prepare_preemptive_scheduler(Flash_fwd_params ¶ms, cudaStream_t stream, int num_sm, bool is_dual_pptx = false); +void prepare_preemptive_scheduler(Flash_bwd_params ¶ms, cudaStream_t stream, int num_sm); diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_api.cu b/flashmask/flash_mask/flashmask_attention_v3/flash_api.cu new file mode 100644 index 00000000000..9937222b55e --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_api.cu @@ -0,0 +1,607 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include + +#include "flash.h" +#include "static_switch.h" +#include "tile_size.h" +#include "heuristics.h" +#include "cuda_check.h" + +#define PADDLE_CHECK(__cond, message) \ + do { \ + const bool __cond_var = (__cond); \ + if (!__cond_var) { \ + ::std::string __err_msg = ::std::string("`") + \ + #__cond + "` check failed at " + \ + __FILE__ + ":" + \ + ::std::to_string(__LINE__) + \ + message; \ + throw std::runtime_error(__err_msg); \ + } \ + } while (0) + +#define CHECK_DEVICE(x) PADDLE_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) PADDLE_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) PADDLE_CHECK(x.is_contiguous(), #x " must be contiguous") + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_fwd_(params, stream); + // }); + PADDLE_CHECK(params.num_splits >= 1, "num_splits should >= 1"); + ARCH_SWITCH(params.arch, Arch, [&] { + SPLIT_SWITCH(params.num_splits > 1, Split, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { + PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; + SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHMASK_V3_DISABLE_HDIM64 + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } + else { + return run_mha_fwd_(params, stream); + } + } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM192 + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHMASK_V3_DISABLE_FP16 + #ifndef FLASHMASK_V3_DISABLE_HDIM64 + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } + else { + return run_mha_fwd_(params, stream); + } + } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM192 + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + PADDLE_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHMASK_V3_DISABLE_FP8 + #ifndef FLASHMASK_V3_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM192 + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } else { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + #endif + } + }); + }); + }); + }); + }); +} + +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { + #ifndef FLASHMASK_V3_DISABLE_SPLIT + // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + // so that kBlockM is smaller and we have more parallelism. + if (params.is_fp32) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else if (params.is_bf16) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } + #else + PADDLE_CHECK(false, "This flash attention build does not support combine kernels."); + #endif +} + +inline bool is_short_seqlen(Flash_fwd_params const& params) { + return params.seqlen_k < 128 && params.seqlen_q < 128; +} + +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + bool const short_seqlen = is_short_seqlen(params); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f, short_seqlen); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + #ifdef FLASHMASK_V3_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + bool const short_seqlen = is_short_seqlen(params); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, short_seqlen); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHMASK_V3_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool const varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + bool const short_seqlen = is_short_seqlen(params); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, short_seqlen); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); + #endif +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + #ifndef FLASHMASK_V3_DISABLE_BACKWARD + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + // printf("params.d = %d",params.d); + ARCH_SWITCH(params.arch, Arch, [&] { + SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + if (!params.is_bf16) { + #ifndef FLASHMASK_V3_DISABLE_FP16 + #ifndef FLASHMASK_V3_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM192 + if (params.d <= 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_bwd_(params, stream); } + #endif + #else + PADDLE_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHMASK_V3_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM192 + if (params.d <= 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHMASK_V3_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_bwd_(params, stream); } + #endif + PADDLE_CHECK(false, "This flash attention build does not support "); + } + }); + }); + }); + }); + #endif +} + +#ifdef __cplusplus +extern "C" { +#endif +Flash_fwd_params* flashmaskv3_create_fwd_params_handle() { + Flash_fwd_params* params_handle = (Flash_fwd_params*)malloc(sizeof(Flash_fwd_params)); + if(params_handle) { + *params_handle = Flash_fwd_params{}; + } + return params_handle; +} + +Flash_bwd_params* flashmaskv3_create_bwd_params_handle() { + Flash_bwd_params* params_handle = (Flash_bwd_params*)malloc(sizeof(Flash_bwd_params)); + if(params_handle) { + *params_handle = Flash_bwd_params{}; + } + return params_handle; +} + +void flashmaskv3_clear_fwd_params_handle(Flash_fwd_params* params_handle) { + if(params_handle) { + *params_handle = Flash_fwd_params{}; + } +} + +void flashmaskv3_clear_bwd_params_handle(Flash_bwd_params* params_handle) { + if(params_handle) { + *params_handle = Flash_bwd_params{}; + } +} + +Flash_fwd_params* flashmaskv3_cast_to_fwd_params_handle(Flash_bwd_params* params_handle) { + return static_cast(params_handle); +} + +void flashmaskv3_destroy_fwd_params_handle(Flash_fwd_params* params_handle) { + PADDLE_CHECK(params_handle, "params_handle is nullptr"); + free(params_handle); +} + +void flashmaskv3_destroy_bwd_params_handle(Flash_bwd_params* params_handle) { + PADDLE_CHECK(params_handle, "params_handle is nullptr"); + free(params_handle); +} + +void flashmaskv3_run_mha_fwd_combine(Flash_fwd_params* params_handle, cudaStream_t stream, bool enable_pdl=false) { + run_mha_fwd_combine(*params_handle, stream, enable_pdl); +} + +void flashmaskv3_run_mha_fwd(Flash_fwd_params* params_handle, cudaStream_t stream) { + run_mha_fwd(*params_handle, stream); +} + +void flashmaskv3_run_mha_bwd(Flash_bwd_params* params_handle, cudaStream_t stream) { + // printf("point1\n"); + run_mha_bwd(*params_handle, stream); +} + +bool flashmaskv3_get_pagedkv_tma(Flash_fwd_params* params_handle) { + return get_pagedkv_tma(*params_handle); +} + +bool flashmaskv3_get_pack_gqa(Flash_fwd_params* params_handle) { + return get_pack_gqa(*params_handle); +} + +int flashmaskv3_get_num_splits(Flash_fwd_params* params_handle) { + return get_num_splits(*params_handle); +} + +#define DEFINE_GETTER_SETTER(type, member) \ +type flashmaskv3_fwd_params_get_##member(const Flash_fwd_params* params_handle) { return params_handle->member; } \ +void flashmaskv3_fwd_params_set_##member(Flash_fwd_params* params_handle, type value) { params_handle->member = value; } \ +type flashmaskv3_bwd_params_get_##member(const Flash_bwd_params* params_handle) { return params_handle->member; } \ +void flashmaskv3_bwd_params_set_##member(Flash_bwd_params* params_handle, type value) { params_handle->member = value; } + +// The QKV matrices. +DEFINE_GETTER_SETTER(void *, q_ptr) +DEFINE_GETTER_SETTER(void *, k_ptr) +DEFINE_GETTER_SETTER(void *, v_ptr) + +// The stride between rows of the Q, K and V matrices. +DEFINE_GETTER_SETTER(int64_t, q_batch_stride) +DEFINE_GETTER_SETTER(int64_t, k_batch_stride) +DEFINE_GETTER_SETTER(int64_t, v_batch_stride) +DEFINE_GETTER_SETTER(int64_t, q_row_stride) +DEFINE_GETTER_SETTER(int64_t, k_row_stride) +DEFINE_GETTER_SETTER(int64_t, v_row_stride) +DEFINE_GETTER_SETTER(int64_t, q_head_stride) +DEFINE_GETTER_SETTER(int64_t, k_head_stride) +DEFINE_GETTER_SETTER(int64_t, v_head_stride) +DEFINE_GETTER_SETTER(int64_t, v_dim_stride) + +// The number of heads. +DEFINE_GETTER_SETTER(int, h) +DEFINE_GETTER_SETTER(int, h_k) + +// The O matrix (output). +DEFINE_GETTER_SETTER(void *, o_ptr) +DEFINE_GETTER_SETTER(void *, oaccum_ptr) + +// The stride between rows of O. +DEFINE_GETTER_SETTER(int64_t, o_batch_stride) +DEFINE_GETTER_SETTER(int64_t, o_row_stride) +DEFINE_GETTER_SETTER(int64_t, o_head_stride) + +// The pointer to the softmax sum. +DEFINE_GETTER_SETTER(void*, softmax_lse_ptr) +DEFINE_GETTER_SETTER(void*, softmax_lseaccum_ptr) + +// For FP8 scaling +DEFINE_GETTER_SETTER(float *, q_descale_ptr) +DEFINE_GETTER_SETTER(float *, k_descale_ptr) +DEFINE_GETTER_SETTER(float *, v_descale_ptr) +DEFINE_GETTER_SETTER(int64_t, q_descale_batch_stride) +DEFINE_GETTER_SETTER(int64_t, q_descale_head_stride) +DEFINE_GETTER_SETTER(int64_t, k_descale_batch_stride) +DEFINE_GETTER_SETTER(int64_t, k_descale_head_stride) +DEFINE_GETTER_SETTER(int64_t, v_descale_batch_stride) +DEFINE_GETTER_SETTER(int64_t, v_descale_head_stride) + +// The dimensions. +DEFINE_GETTER_SETTER(int, b) +DEFINE_GETTER_SETTER(int, seqlen_q) +DEFINE_GETTER_SETTER(int, seqlen_k) +DEFINE_GETTER_SETTER(int, seqlen_knew) +DEFINE_GETTER_SETTER(int, d) +DEFINE_GETTER_SETTER(int, seqlen_q_rounded) +DEFINE_GETTER_SETTER(int, seqlen_k_rounded) +DEFINE_GETTER_SETTER(int, d_rounded) +DEFINE_GETTER_SETTER(int, rotary_dim) +DEFINE_GETTER_SETTER(int, total_q) +DEFINE_GETTER_SETTER(int, total_k) +DEFINE_GETTER_SETTER(int, total_knew) +DEFINE_GETTER_SETTER(int, b_k) +DEFINE_GETTER_SETTER(int, dv) +DEFINE_GETTER_SETTER(int, dv_rounded) + +// The scaling factors for the kernel. +DEFINE_GETTER_SETTER(float, scale_softmax) +DEFINE_GETTER_SETTER(float, softcap) + +// array of length b+1 holding starting offset of each sequence. +DEFINE_GETTER_SETTER(int *, cu_seqlens_q) +DEFINE_GETTER_SETTER(int *, cu_seqlens_k) +DEFINE_GETTER_SETTER(int *, cu_seqlens_knew) +DEFINE_GETTER_SETTER(int *, leftpad_k) + +// If provided, the actual length of each q/k sequence. +DEFINE_GETTER_SETTER(int *, seqused_q) +DEFINE_GETTER_SETTER(int *, seqused_k) + +// The stride between rows of Oaccum. +DEFINE_GETTER_SETTER(int64_t, oaccum_split_stride) +DEFINE_GETTER_SETTER(int64_t, oaccum_batch_stride) +DEFINE_GETTER_SETTER(int64_t, oaccum_row_stride) +DEFINE_GETTER_SETTER(int64_t, oaccum_head_stride) + +// The stride between rows of LSEaccum. +DEFINE_GETTER_SETTER(int64_t, lseaccum_split_stride) +DEFINE_GETTER_SETTER(int64_t, lseaccum_batch_stride) +DEFINE_GETTER_SETTER(int64_t, lseaccum_head_stride) + +// The K_new and V_new matrices. +DEFINE_GETTER_SETTER(void *, knew_ptr) +DEFINE_GETTER_SETTER(void *, vnew_ptr) + +// The stride between rows of the Q, K and V matrices. +DEFINE_GETTER_SETTER(int64_t, knew_batch_stride) +DEFINE_GETTER_SETTER(int64_t, vnew_batch_stride) +DEFINE_GETTER_SETTER(int64_t, knew_row_stride) +DEFINE_GETTER_SETTER(int64_t, vnew_row_stride) +DEFINE_GETTER_SETTER(int64_t, knew_head_stride) +DEFINE_GETTER_SETTER(int64_t, vnew_head_stride) + +DEFINE_GETTER_SETTER(void *, qv_ptr) +DEFINE_GETTER_SETTER(int64_t, qv_batch_stride) +DEFINE_GETTER_SETTER(int64_t, qv_row_stride) +DEFINE_GETTER_SETTER(int64_t, qv_head_stride) + +// The cos and sin matrices for rotary embedding. +DEFINE_GETTER_SETTER(void *, rotary_cos_ptr) +DEFINE_GETTER_SETTER(void *, rotary_sin_ptr) + +// The indices to index into the KV cache. +DEFINE_GETTER_SETTER(int *, kv_batch_idx) + +// Paged KV cache +DEFINE_GETTER_SETTER(int *, page_table) +DEFINE_GETTER_SETTER(int64_t, page_table_batch_stride) +DEFINE_GETTER_SETTER(int, page_size) +DEFINE_GETTER_SETTER(int, num_pages) +DEFINE_GETTER_SETTER(bool, pagedkv_tma) + +// The dropout probability (probability of keeping an activation). +DEFINE_GETTER_SETTER(float, p_dropout) +// uint32_t p_dropout_in_uint; +// uint16_t p_dropout_in_uint16_t; +DEFINE_GETTER_SETTER(uint8_t, p_dropout_in_uint8_t) + +// Scale factor of 1 / (1 - p_dropout). +DEFINE_GETTER_SETTER(float, rp_dropout) + +// Local window size +DEFINE_GETTER_SETTER(int, window_size_left) +DEFINE_GETTER_SETTER(int, window_size_right) + +// Pointer to the RNG seed (idx 0) and offset (idx 1). +DEFINE_GETTER_SETTER(uint64_t *, rng_state) + +DEFINE_GETTER_SETTER(bool, is_bf16) +DEFINE_GETTER_SETTER(bool, is_fp32) +DEFINE_GETTER_SETTER(bool, is_e4m3) +DEFINE_GETTER_SETTER(bool, is_causal) +DEFINE_GETTER_SETTER(bool, is_local) + +DEFINE_GETTER_SETTER(bool, is_rotary_interleaved) + +DEFINE_GETTER_SETTER(int, num_splits) // For split-KV version +DEFINE_GETTER_SETTER(bool, pack_gqa) + +DEFINE_GETTER_SETTER(int *, tile_count_semaphore) +// int * __restrict__ num_m_blocks_ptr; +// int * __restrict__ num_n_blocks_ptr; +DEFINE_GETTER_SETTER(int *, num_splits_dynamic_ptr) +DEFINE_GETTER_SETTER(bool, skip_scheduler_metadata_computation) + +DEFINE_GETTER_SETTER(int, arch) +DEFINE_GETTER_SETTER(int, num_sm) + +DEFINE_GETTER_SETTER(int, h_flashmask) +DEFINE_GETTER_SETTER(int, h_h_flashmask_ratio) + +DEFINE_GETTER_SETTER(int32_t *, lt_start_ptr) +DEFINE_GETTER_SETTER(int32_t *, lt_end_ptr) + +DEFINE_GETTER_SETTER(int32_t *, ut_start_ptr) +DEFINE_GETTER_SETTER(int32_t *, ut_end_ptr) + +DEFINE_GETTER_SETTER(int32_t *, flashmask_maxmin_ptr) + +DEFINE_GETTER_SETTER(int32_t *, lt_start_nblockmax) +DEFINE_GETTER_SETTER(int32_t *, lt_start_nblockmin) + +DEFINE_GETTER_SETTER(int32_t *, lt_end_nblockmax) +DEFINE_GETTER_SETTER(int32_t *, lt_end_nblockmin) + +DEFINE_GETTER_SETTER(int32_t *, ut_start_nblockmax) +DEFINE_GETTER_SETTER(int32_t *, ut_start_nblockmin) + +DEFINE_GETTER_SETTER(int32_t *, ut_end_nblockmax) +DEFINE_GETTER_SETTER(int32_t *, ut_end_nblockmin) + +DEFINE_GETTER_SETTER(int, m_block_dim) +DEFINE_GETTER_SETTER(int, n_block_dim) +DEFINE_GETTER_SETTER(int32_t *, block_mask_ptr) + +#define DEFINE_BWD_GETTER_SETTER(type, member) \ +type flashmaskv3_bwd_params_get_##member(const Flash_bwd_params* params_handle) { return params_handle->member; } \ +void flashmaskv3_bwd_params_set_##member(Flash_bwd_params* params_handle, type value) { params_handle->member = value; } + +// The dO and dQKV matrices. +DEFINE_BWD_GETTER_SETTER(void *, do_ptr) +DEFINE_BWD_GETTER_SETTER(void *, dq_ptr) +DEFINE_BWD_GETTER_SETTER(void *, dk_ptr) +DEFINE_BWD_GETTER_SETTER(void *, dv_ptr) + +// To accumulate dQ +DEFINE_BWD_GETTER_SETTER(void *, dq_accum_ptr) +DEFINE_BWD_GETTER_SETTER(void *, dk_accum_ptr) +DEFINE_BWD_GETTER_SETTER(void *, dv_accum_ptr) + +// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q +// dimension void *__restrict__ dk_accum_ptr; void *__restrict__ +// dv_accum_ptr; + +// The stride between rows of the dO, dQ, dK and dV matrices. +DEFINE_BWD_GETTER_SETTER(int64_t, do_batch_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, do_row_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, do_head_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dq_batch_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dk_batch_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dv_batch_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dq_row_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dk_row_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dv_row_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dq_head_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dk_head_stride) +DEFINE_BWD_GETTER_SETTER(int64_t, dv_head_stride) + +// The pointer to the softmax d sum. +DEFINE_BWD_GETTER_SETTER(void *, dsoftmax_sum) +DEFINE_BWD_GETTER_SETTER(void *, softmax_lse_log2_ptr) + +DEFINE_BWD_GETTER_SETTER(int *, dq_semaphore) +DEFINE_BWD_GETTER_SETTER(int *, dk_semaphore) +DEFINE_BWD_GETTER_SETTER(int *, dv_semaphore) + +DEFINE_BWD_GETTER_SETTER(bool, deterministic) +DEFINE_BWD_GETTER_SETTER(int64_t, dq_accum_split_stride) + +#ifdef __cplusplus +} +#endif diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_api.h b/flashmask/flash_mask/flashmask_attention_v3/flash_api.h new file mode 100644 index 00000000000..b76f5b24286 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_api.h @@ -0,0 +1,291 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cuda.h" +#include "cuda_runtime.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct Flash_fwd_params FlashMask_fwd_params; +typedef struct Flash_bwd_params FlashMask_bwd_params; + +Flash_fwd_params* flashmaskv3_create_fwd_params_handle(); +Flash_bwd_params* flashmaskv3_create_bwd_params_handle(); +void flashmaskv3_clear_fwd_params_handle(Flash_fwd_params* params_handle); +void flashmaskv3_clear_bwd_params_handle(Flash_bwd_params* params_handle); +Flash_fwd_params* flashmaskv3_cast_to_fwd_params_handle(Flash_bwd_params* params_handle); +void flashmaskv3_destroy_fwd_params_handle(Flash_fwd_params* params_handle); +void flashmaskv3_destroy_bwd_params_handle(Flash_bwd_params* params_handle); +void flashmaskv3_run_mha_fwd_combine(Flash_fwd_params* params_handle, cudaStream_t stream, bool enable_pdl=false); +void flashmaskv3_run_mha_fwd(Flash_fwd_params* params_handle, cudaStream_t stream); +bool flashmaskv3_get_pagedkv_tma(Flash_fwd_params* params_handle); +bool flashmaskv3_get_pack_gqa(Flash_fwd_params* params_handle); +int flashmaskv3_get_num_splits(Flash_fwd_params* params_handle); +void flashmaskv3_run_mha_bwd(Flash_bwd_params* params_handle, cudaStream_t stream); + +#define DECLARE_GETTER_SETTER(type, member) \ +type flashmaskv3_fwd_params_get_##member(const Flash_fwd_params* params_handle); \ +void flashmaskv3_fwd_params_set_##member(Flash_fwd_params* params_handle, const type value); \ +type flashmaskv3_bwd_params_get_##member(const Flash_bwd_params* params_handle); \ +void flashmaskv3_bwd_params_set_##member(Flash_bwd_params* params_handle, type value); + +// The QKV matrices. +DECLARE_GETTER_SETTER(void *, q_ptr) +DECLARE_GETTER_SETTER(void *, k_ptr) +DECLARE_GETTER_SETTER(void *, v_ptr) + +// The stride between rows of the Q, K and V matrices. +DECLARE_GETTER_SETTER(int64_t, q_batch_stride) +DECLARE_GETTER_SETTER(int64_t, k_batch_stride) +DECLARE_GETTER_SETTER(int64_t, v_batch_stride) +DECLARE_GETTER_SETTER(int64_t, q_row_stride) +DECLARE_GETTER_SETTER(int64_t, k_row_stride) +DECLARE_GETTER_SETTER(int64_t, v_row_stride) +DECLARE_GETTER_SETTER(int64_t, q_head_stride) +DECLARE_GETTER_SETTER(int64_t, k_head_stride) +DECLARE_GETTER_SETTER(int64_t, v_head_stride) +DECLARE_GETTER_SETTER(int64_t, v_dim_stride) + +// The number of heads. +DECLARE_GETTER_SETTER(int, h) +DECLARE_GETTER_SETTER(int, h_k) + +// The O matrix (output). +DECLARE_GETTER_SETTER(void *, o_ptr) +DECLARE_GETTER_SETTER(void *, oaccum_ptr) + +// The stride between rows of O. +DECLARE_GETTER_SETTER(int64_t, o_batch_stride) +DECLARE_GETTER_SETTER(int64_t, o_row_stride) +DECLARE_GETTER_SETTER(int64_t, o_head_stride) + +// The pointer to the softmax sum. +DECLARE_GETTER_SETTER(void *, softmax_lse_ptr) +DECLARE_GETTER_SETTER(void *, softmax_lseaccum_ptr) + +// For FP8 scaling +DECLARE_GETTER_SETTER(float *, q_descale_ptr) +DECLARE_GETTER_SETTER(float *, k_descale_ptr) +DECLARE_GETTER_SETTER(float *, v_descale_ptr) +DECLARE_GETTER_SETTER(int64_t, q_descale_batch_stride) +DECLARE_GETTER_SETTER(int64_t, q_descale_head_stride) +DECLARE_GETTER_SETTER(int64_t, k_descale_batch_stride) +DECLARE_GETTER_SETTER(int64_t, k_descale_head_stride) +DECLARE_GETTER_SETTER(int64_t, v_descale_batch_stride) +DECLARE_GETTER_SETTER(int64_t, v_descale_head_stride) + +// The dimensions. +DECLARE_GETTER_SETTER(int, b) +DECLARE_GETTER_SETTER(int, seqlen_q) +DECLARE_GETTER_SETTER(int, seqlen_k) +DECLARE_GETTER_SETTER(int, seqlen_knew) +DECLARE_GETTER_SETTER(int, d) +DECLARE_GETTER_SETTER(int, seqlen_q_rounded) +DECLARE_GETTER_SETTER(int, seqlen_k_rounded) +DECLARE_GETTER_SETTER(int, d_rounded) +DECLARE_GETTER_SETTER(int, rotary_dim) +DECLARE_GETTER_SETTER(int, total_q) +DECLARE_GETTER_SETTER(int, total_k) +DECLARE_GETTER_SETTER(int, total_knew) +DECLARE_GETTER_SETTER(int, b_k) +DECLARE_GETTER_SETTER(int, dv) +DECLARE_GETTER_SETTER(int, dv_rounded) + +// The scaling factors for the kernel. +DECLARE_GETTER_SETTER(float, scale_softmax) +DECLARE_GETTER_SETTER(float, softcap) + +// array of length b+1 holding starting offset of each sequence. +DECLARE_GETTER_SETTER(int *, cu_seqlens_q) +DECLARE_GETTER_SETTER(int *, cu_seqlens_k) +DECLARE_GETTER_SETTER(int *, cu_seqlens_knew) +DECLARE_GETTER_SETTER(int *, leftpad_k) + +// If provided, the actual length of each q/k sequence. +DECLARE_GETTER_SETTER(int *, seqused_q) +DECLARE_GETTER_SETTER(int *, seqused_k) + +// The stride between rows of Oaccum. +DECLARE_GETTER_SETTER(int64_t, oaccum_split_stride) +DECLARE_GETTER_SETTER(int64_t, oaccum_batch_stride) +DECLARE_GETTER_SETTER(int64_t, oaccum_row_stride) +DECLARE_GETTER_SETTER(int64_t, oaccum_head_stride) + +// The stride between rows of LSEaccum. +DECLARE_GETTER_SETTER(int64_t, lseaccum_split_stride) +DECLARE_GETTER_SETTER(int64_t, lseaccum_batch_stride) +DECLARE_GETTER_SETTER(int64_t, lseaccum_head_stride) + +// The K_new and V_new matrices. +DECLARE_GETTER_SETTER(void *, knew_ptr) +DECLARE_GETTER_SETTER(void *, vnew_ptr) + +// The stride between rows of the Q, K and V matrices. +DECLARE_GETTER_SETTER(int64_t, knew_batch_stride) +DECLARE_GETTER_SETTER(int64_t, vnew_batch_stride) +DECLARE_GETTER_SETTER(int64_t, knew_row_stride) +DECLARE_GETTER_SETTER(int64_t, vnew_row_stride) +DECLARE_GETTER_SETTER(int64_t, knew_head_stride) +DECLARE_GETTER_SETTER(int64_t, vnew_head_stride) + +DECLARE_GETTER_SETTER(void *, qv_ptr) +DECLARE_GETTER_SETTER(int64_t, qv_batch_stride) +DECLARE_GETTER_SETTER(int64_t, qv_row_stride) +DECLARE_GETTER_SETTER(int64_t, qv_head_stride) + +// The cos and sin matrices for rotary embedding. +DECLARE_GETTER_SETTER(void *, rotary_cos_ptr) +DECLARE_GETTER_SETTER(void *, rotary_sin_ptr) + +// The indices to index into the KV cache. +DECLARE_GETTER_SETTER(int *, kv_batch_idx) + +// Paged KV cache +DECLARE_GETTER_SETTER(int *, page_table) +DECLARE_GETTER_SETTER(int64_t, page_table_batch_stride) +DECLARE_GETTER_SETTER(int, page_size) +DECLARE_GETTER_SETTER(int, num_pages) +DECLARE_GETTER_SETTER(bool, pagedkv_tma) + +// The dropout probability (probability of keeping an activation). +DECLARE_GETTER_SETTER(float, p_dropout) +// uint32_t p_dropout_in_uint; +// uint16_t p_dropout_in_uint16_t; +DECLARE_GETTER_SETTER(uint8_t, p_dropout_in_uint8_t) + +// Scale factor of 1 / (1 - p_dropout). +DECLARE_GETTER_SETTER(float, rp_dropout) + +// Local window size +DECLARE_GETTER_SETTER(int, window_size_left) +DECLARE_GETTER_SETTER(int, window_size_right) + +// Pointer to the RNG seed (idx 0) and offset (idx 1). +DECLARE_GETTER_SETTER(uint64_t *, rng_state) + +DECLARE_GETTER_SETTER(bool, is_bf16) +DECLARE_GETTER_SETTER(bool, is_fp32) +DECLARE_GETTER_SETTER(bool, is_e4m3) +DECLARE_GETTER_SETTER(bool, is_causal) +DECLARE_GETTER_SETTER(bool, is_local) + +DECLARE_GETTER_SETTER(bool, is_rotary_interleaved) + +DECLARE_GETTER_SETTER(int, num_splits) // For split-KV version +DECLARE_GETTER_SETTER(bool, pack_gqa) + +DECLARE_GETTER_SETTER(int, num_splits) // For split-KV version +DECLARE_GETTER_SETTER(bool, pack_gqa) + +DECLARE_GETTER_SETTER(int *, tile_count_semaphore) +// int * __restrict__ num_m_blocks_ptr; +// int * __restrict__ num_n_blocks_ptr; +DECLARE_GETTER_SETTER(int *, num_splits_dynamic_ptr) +DECLARE_GETTER_SETTER(bool, skip_scheduler_metadata_computation) + +DECLARE_GETTER_SETTER(int, arch) +DECLARE_GETTER_SETTER(int, num_sm) + +DECLARE_GETTER_SETTER(int, h_flashmask) +DECLARE_GETTER_SETTER(int, h_h_flashmask_ratio) + +DECLARE_GETTER_SETTER(int32_t *, lt_start_ptr) +DECLARE_GETTER_SETTER(int32_t *, lt_end_ptr) + +DECLARE_GETTER_SETTER(int32_t *, ut_start_ptr) +DECLARE_GETTER_SETTER(int32_t *, ut_end_ptr) + +DECLARE_GETTER_SETTER(int32_t *, flashmask_maxmin_ptr) + +DECLARE_GETTER_SETTER(int, m_block_dim) +DECLARE_GETTER_SETTER(int, n_block_dim) +DECLARE_GETTER_SETTER(int32_t *, block_mask_ptr) + +#define DECLARE_BWD_GETTER_SETTER(type, member) \ +type flashmaskv3_bwd_params_get_##member(const Flash_bwd_params* params_handle); \ +void flashmaskv3_bwd_params_set_##member(Flash_bwd_params* params_handle, type value); + +// The dO and dQKV matrices. +DECLARE_BWD_GETTER_SETTER(void *, do_ptr) +DECLARE_BWD_GETTER_SETTER(void *, dq_ptr) +DECLARE_BWD_GETTER_SETTER(void *, dk_ptr) +DECLARE_BWD_GETTER_SETTER(void *, dv_ptr) + +// To accumulate dQ +DECLARE_BWD_GETTER_SETTER(void *, dq_accum_ptr) +DECLARE_BWD_GETTER_SETTER(void *, dk_accum_ptr) +DECLARE_BWD_GETTER_SETTER(void *, dv_accum_ptr) + +// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q +// dimension void *__restrict__ dk_accum_ptr; void *__restrict__ +// dv_accum_ptr; + +// The stride between rows of the dO, dQ, dK and dV matrices. +DECLARE_BWD_GETTER_SETTER(int64_t, do_batch_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, do_row_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, do_head_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dq_batch_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dk_batch_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dv_batch_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dq_row_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dk_row_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dv_row_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dq_head_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dk_head_stride) +DECLARE_BWD_GETTER_SETTER(int64_t, dv_head_stride) + +// The pointer to the softmax d sum. +DECLARE_BWD_GETTER_SETTER(void *, dsoftmax_sum) +DECLARE_BWD_GETTER_SETTER(void *, softmax_lse_log2_ptr) + +DECLARE_BWD_GETTER_SETTER(int *, dq_semaphore) +DECLARE_BWD_GETTER_SETTER(int *, dk_semaphore) +DECLARE_BWD_GETTER_SETTER(int *, dv_semaphore) + +DECLARE_BWD_GETTER_SETTER(bool, deterministic) +DECLARE_BWD_GETTER_SETTER(int64_t, dq_accum_split_stride) +DECLARE_BWD_GETTER_SETTER(int, h_flashmask) +DECLARE_BWD_GETTER_SETTER(int, h_h_flashmask_ratio) + +DECLARE_BWD_GETTER_SETTER(int32_t *, lt_start_ptr) +DECLARE_BWD_GETTER_SETTER(int32_t *, lt_end_ptr) + +DECLARE_BWD_GETTER_SETTER(int32_t *, ut_start_ptr) +DECLARE_BWD_GETTER_SETTER(int32_t *, ut_end_ptr) + +DECLARE_BWD_GETTER_SETTER(int32_t *, flashmask_maxmin_ptr) + +DECLARE_BWD_GETTER_SETTER(int32_t *, lt_start_nblockmax) +DECLARE_BWD_GETTER_SETTER(int32_t *, lt_start_nblockmin) + +DECLARE_BWD_GETTER_SETTER(int32_t *, lt_end_nblockmax) +DECLARE_BWD_GETTER_SETTER(int32_t *, lt_end_nblockmin) + +DECLARE_BWD_GETTER_SETTER(int32_t *, ut_start_nblockmax) +DECLARE_BWD_GETTER_SETTER(int32_t *, ut_start_nblockmin) + +DECLARE_BWD_GETTER_SETTER(int32_t *, ut_end_nblockmax) +DECLARE_BWD_GETTER_SETTER(int32_t *, ut_end_nblockmin) +#ifdef __cplusplus +} +#endif diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_kernel_sm80.h b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_kernel_sm80.h new file mode 100644 index 00000000000..e69734b7e23 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_kernel_sm80.h @@ -0,0 +1,187 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdSm80 { + +public: + + // Type Aliases + static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; + static constexpr bool Is_local = CollectiveMainloop_::Is_local; + static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); + static constexpr bool Varlen = CollectiveMainloop_::Varlen; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; + using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{})); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); + // Initialize matmul objects. + TiledMmadKV tiled_mma_dKV; + + scheduler.init_consumer(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + + // dK and dV output accumulator. + Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, + block_coord, shared_storage); + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + if (tile_valid) { + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x, block_coord); + } else { + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); + } + } + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_kernel_sm90.h b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_kernel_sm90.h new file mode 100644 index 00000000000..4b76d378b88 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_kernel_sm90.h @@ -0,0 +1,344 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdSm90 { + +public: + + // Type Aliases + static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; + static constexpr bool Is_local = CollectiveMainloop_::Is_local; + static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); + static constexpr bool Varlen = CollectiveMainloop_::Varlen; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; + using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32; + static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160; + // If you want to print from the producer warp, you'd need to increase the number of registers + // Otherwise you'll get CUDA error. + // static constexpr uint32_t LoadRegisterRequirement = 40; + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV; + alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q; + alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do; + // alignas(16) typename CollectiveMainloop::MainloopPipeline_flashmask::SharedStorage pipeline_flashmask; + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + } pipelines; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + // static constexpr int TensorStorageSize = sizeof(SharedStorage::tensors); + // static constexpr int PipelineStorageSize = sizeof(SharedStorage::pipelines); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO; + using PipelineParams_dO = typename MainloopPipeline_dO::Params; + using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; + // using MainloopPipeline_flashmask = typename CollectiveMainloop::MainloopPipeline_flashmask; + // using PipelineParams_flashmask = typename MainloopPipeline_flashmask::Params; + // using PipelineState_flashmask = typename MainloopPipeline_flashmask::PipelineState; + static constexpr bool Q_dO_same_stages = std::is_same_v; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // int32_t * flashmask_smem_ = nullptr; + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + // printf("enter"); + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/); + } + // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{}); + auto role_dO = warp_group_idx == 0 + ? MainloopPipeline_dO::ThreadCategory::Producer + : MainloopPipeline_dO::ThreadCategory::Consumer; + PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers}; + MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return(pipeline_params, pipeline_params_dO), ClusterShape{}); + + // PipelineParams_flashmask pipeline_params_flashmask; + // pipeline_params_flashmask.role = warp_group_idx == 0 + // ? MainloopPipeline_flashmask::ThreadCategory::Producer + // : MainloopPipeline_flashmask::ThreadCategory::Consumer; + // pipeline_params_flashmask.consumer_arv_count = NumMmaThreads + 32; //store_dp and mma + // MainloopPipeline_flashmask pipeline_flashmask(shared_storage.pipelines.pipeline_flashmask, pipeline_params_flashmask); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + __shared__ __align__(16) int32_t flashmask_smem_[8]; + __shared__ __align__(128) int32_t flashmask_index_smem_[kBlockN * 4]; + constexpr int max_seqlen_k = 1024 * 128; + static constexpr bool Is_blockmask = CollectiveMainloop_::Is_blockmask; + int32_t* blockmask_smem_ = nullptr; + + if constexpr (Is_blockmask) { + __shared__ __align__(128) int32_t blockmask_smem_array[max_seqlen_k / 128]; + blockmask_smem_ = blockmask_smem_array; + } + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + // if(threadIdx.x == 0){ + // int producer_num = LoadRegisterRequirement; + // printf("producer_num = %d\n", producer_num); + // } + + // TODO(heqianyue): some optimization that can be migrated + // 1. schedulers (using dynamic schedulers): including DualPPTX smem buffer support (some labor) + // 3. warp group 3, 4 do nothing at all? Can we get rid of them? Put the loader warps to the end + // and using only 64 threads + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x >> 5) & 0x03, 0); + PipelineState smem_pipe_write; + PipelineState_dO smem_pipe_write_do; + if (warp_idx_in_warpgroup == 0) { // pipeline can only be initialized once, otherwise we will hang with persistent tile scheduler + smem_pipe_write = cutlass::make_producer_start_state(); + smem_pipe_write_do = cutlass::make_producer_start_state(); + } + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info) + ) { + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + mainloop.load_n_block_info(flashmask_smem_, flashmask_index_smem_, blockmask_smem_, block_coord, params.mainloop); + scheduler.producer_notify(); + if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO + + // auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + // auto [n_block, bidh, bidb] = block_coord_; + // cute::tuple block_coord = {n_block, bidh, bidb}; + // mainloop.load_n_block_info(flashmask_smem_,flashmask_index_smem_,block_coord, params.mainloop); + mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, + smem_pipe_write_do, shared_storage, block_coord, flashmask_smem_, blockmask_smem_); + mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); + } else if (warp_idx_in_warpgroup == 1) { + // auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + // auto [n_block, bidh, bidb] = block_coord_; + // cute::tuple block_coord = {n_block, bidh, bidb}; + // mainloop.load_n_block_info(flashmask_smem_,flashmask_index_smem_,block_coord, params.mainloop); + mainloop.store_dq(params.mainloop, shared_storage, block_coord, flashmask_smem_, blockmask_smem_); + } + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + } + // notify the consumer that there is no more work to do + scheduler.producer_notify(); + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + // Initialize matmul objects. + // if(cute::elect_one_sync()){ + // int mma_num = MmaRegisterRequirement; + // printf("mma_num = %d\n", mma_num); + // } + TiledMmadKV tiled_mma_dKV; + + PipelineState smem_pipe_read; + PipelineState_dO smem_pipe_read_do; + + mainloop.mma_init(); + scheduler.init_consumer(); + + int binary_work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info) + ) { + // dK and dV output accumulator. + Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + + bool tile_valid = mainloop.mma( + params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, + tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, binary_work_idx, block_coord, shared_storage,flashmask_smem_, flashmask_index_smem_, blockmask_smem_); + if (tile_valid) { + binary_work_idx = 1 - binary_work_idx; + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x - NumCopyThreads, block_coord); + } else { + epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); + } + scheduler.consumer_notify(); + } + // if(threadIdx.x == 128) printf("consumer blockid: %d\n", blockIdx.x ); + epilogue.store_tail(); + } + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_launch_template.h b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_launch_template.h new file mode 100644 index 00000000000..3df881c876e --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_launch_template.h @@ -0,0 +1,491 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch +#include "cutlass/cluster_launch.hpp" // For ClusterLauncher + +#include "static_switch.h" +#include "flash.h" +#include "flash_bwd_preprocess_kernel.h" +#include "flash_bwd_postprocess_kernel.h" +#include "tile_scheduler.hpp" +#include "mainloop_bwd_sm90_tma_gmma_ws.hpp" +#include "mainloop_bwd_sm80.hpp" +#include "epilogue_bwd.hpp" +#include "flash_bwd_kernel_sm90.h" +#include "flash_bwd_kernel_sm80.h" +#include "utils.h" + +using namespace cute; + +template +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // printf("point3\n"); + // flash::print_addr_value<<<1, 1,0,stream>>>(params.lt_start_ptr, 0); + static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); + using ElementAccum = float; + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + + int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM); + int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN); + bool const is_varlen_q = params.cu_seqlens_q; + bool const is_varlen_k = params.cu_seqlens_k; + int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; + int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k; + int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded; + int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded; + int batch_q = !is_varlen_q ? params.b : 1; + int batch_k = !is_varlen_k ? params.b : 1; + // printf("params.dv_ptr:%p\n",params.dv_ptr); + // printf("seqlen_q_rounded:%d\n",seqlen_q_rounded); + // printf("d_rounded:%d\n",params.d_rounded); + + using TileShape_MK = cute::Shape, Int>; + using PreprocessKernel = flash::FlashAttnBwdPreprocess; + typename PreprocessKernel::Arguments preprocess_args { + static_cast(params.o_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O + static_cast(params.do_ptr), + {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO + static_cast(params.dsoftmax_sum), + {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE + static_cast(params.softmax_lse_log2_ptr), + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum + params.b, + params.dq_semaphore, + params.cu_seqlens_q, + params.seqused_q + }; + typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); + int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); + dim3 grid_m(num_m_block, params.h, params.b); + flash::flashmask_kernel_launch(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/); + CHECK_CUDA(cudaGetLastError()); + // flash::print_addr_value<<<1, 1,0,stream>>>(params.lt_start_ptr, 0); + // printf("point2\n"); + CHECK_CUDA_KERNEL_LAUNCH(); + CHECK_CUDA(cudaGetLastError()); + using TileShape_MNK = cute::Shape, Int, Int>; + using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster + // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80 + static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80; + static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1; + using CollectiveMainloop = std::conditional_t< + Arch >= 90, + flash::CollectiveMainloopBwdSm90, + flash::CollectiveMainloopBwdSm80 + >; + using CollectiveEpilogue = std::conditional_t< + !GQA, + flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, + flash::CollectiveEpilogueBwdGQA + >; + using Scheduler = std::conditional_t< + Arch >= 90, + flash::BwdPreemptivePersistentTileScheduler, + flash::SingleTileScheduler + >; + using AttnKernel = std::conditional_t< + Arch >= 90, + flash::enable_sm90_or_later>, + flash::enable_sm80_to_sm89> + >; + + if constexpr (Is_flashmask) { + flash::flashmask::prepare_block_maxmin(params, stream); + } + + if constexpr (Arch >= 90) { + prepare_preemptive_scheduler(params, stream, params.num_sm); + } + + typename CollectiveMainloop::Arguments mainloop_args = [&] () { + if constexpr(Arch >= 90) + return typename CollectiveMainloop::Arguments { + static_cast(params.q_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q + static_cast(params.k_ptr), + {seqlen_k, params.d, params.h_k, batch_k}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K + static_cast(params.v_ptr), + {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V + static_cast(params.do_ptr), + {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum + static_cast(params.softmax_lse_log2_ptr), + {seqlen_q_rounded, params.h, batch_q}, // shape_LSE + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 + static_cast(params.dsoftmax_sum), + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum + params.scale_softmax, + params.window_size_left, params.window_size_right, + params.softcap, + params.b, + params.dq_semaphore, + params.cu_seqlens_q, params.cu_seqlens_k, + params.seqused_q, params.seqused_k, + params.h_flashmask, params.h_h_flashmask_ratio, + params.lt_start_ptr, params.lt_end_ptr, + params.ut_start_ptr, params.ut_end_ptr, + params.flashmask_maxmin_ptr, + params.lt_start_nblockmax, params.lt_start_nblockmin, + params.lt_end_nblockmax, params.lt_end_nblockmin, + params.ut_start_nblockmax, params.ut_start_nblockmin, + params.ut_end_nblockmax, params.ut_end_nblockmin, + params.m_block_dim, params.n_block_dim, + params.block_mask_ptr + }; + else + return typename CollectiveMainloop::Arguments { + static_cast(params.q_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q + static_cast(params.k_ptr), + {seqlen_k, params.d, params.h_k, batch_k}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K + static_cast(params.v_ptr), + {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V + static_cast(params.do_ptr), + {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum + static_cast(params.softmax_lse_log2_ptr), + {seqlen_q_rounded, params.h, batch_q}, // shape_LSE + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 + static_cast(params.dsoftmax_sum), + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum + params.scale_softmax, + params.window_size_left, params.window_size_right, + params.softcap, + params.b, + params.dq_semaphore, + params.cu_seqlens_q, params.cu_seqlens_k, + params.seqused_q, params.seqused_k}; + }(); + // The case work with GQA is ugly but idk how to fix it. + typename CollectiveEpilogue::Arguments epilogue_args { + static_cast(!GQA ? params.dk_ptr : params.dk_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum + } + }(), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK + } else { + return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum + } + }(), + static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV + } else { + return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + } + }(), + params.h, + params.dk_semaphore, + params.dv_semaphore, + params.cu_seqlens_k, + params.seqused_k, + }; + + int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{})); + num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{})); + typename flash::TileSchedulerArguments scheduler_args { + num_blocks_n, params.h, params.b, 1 /*num_splits*/, + params.h / params.h_k, + params.seqlen_k, + params.seqlen_q, params.d, params.dv, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k + }; + + int device; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaGetLastError()); + typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ + mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args + }); + + dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); + dim3 block_dims = AttnKernel::get_block_shape(); + int smem_size = AttnKernel::SharedStorageSize; + // printf("tensor_size = %d\n",AttnKernel::TensorStorageSize); + // printf("ppl_size = %d\n",AttnKernel::PipelineStorageSize); + // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); + // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do)); + // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds)); + // int smem_size_dqacc = [&] { + // if constexpr (Arch >= 90) { + // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc)); + // } else { + // return 0; + // } + // }(); + // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); + // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse)); + // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum)); + // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum); + if constexpr (size(ClusterShape{}) > 1) { + void const* kernel = (void const*) flash::cutlass_flashmask_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLauncher::launch( + grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/); + } else { + void const* kernel = (void const*) flash::cutlass_flashmask_kernel; + if (smem_size >= 48 * 1024) { + int max_smem; + CHECK_CUDA(cudaGetLastError()); + CHECK_CUDA(cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, device)); + // printf("smem_size = %d, max_smem = %d\n", smem_size, max_smem); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // printf("pass"); + } + flash::flashmask_kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/); + } + CHECK_CUDA_KERNEL_LAUNCH(); + + using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ; + typename PostprocessKernel::Arguments postprocess_args { + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum + static_cast(params.dq_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_dQ + {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ + params.scale_softmax, + params.cu_seqlens_q, + params.seqused_q + }; + typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); + int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); + dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b); + int smem_size_postprocess = PostprocessKernel::SharedStorageSize; + if (smem_size_postprocess >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(flash::cutlass_flashmask_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); + } + flash::flashmask_kernel_launch(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + + if constexpr (GQA) { + using TileShape_NK = cute::Shape, Int>; + using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ; + typename PostprocessKerneldKV::Arguments postprocess_dK_args { + static_cast(params.dk_accum_ptr), + {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum + {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum + static_cast(params.dk_ptr), + {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK + {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK + 1.f, + params.cu_seqlens_k, + params.seqused_k + }; + typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); + typename PostprocessKerneldKV::Arguments postprocess_dV_args { + static_cast(params.dv_accum_ptr), + {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum + {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + static_cast(params.dv_ptr), + {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV + {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV + 1.f, + params.cu_seqlens_k, + params.seqused_k + }; + typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args); + int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{})); + dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b); + int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize; + if (smem_size_postprocess >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(flash::cutlass_flashmask_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); + } + flash::flashmask_kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + flash::flashmask_kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + +} + +template +void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) { + VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.h != params.h_k, GQA, [&] { + // run_flash_bwd(params, stream); + run_flash_bwd(params, stream); + }); + }); +} + + +template +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { + // printf("point2-1\n"); + static constexpr bool Is_local = false; + static constexpr bool Is_flashmask_ = true; + BOOL_SWITCH(params.block_mask_ptr != nullptr, Is_blockmask_, [&]{ + FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { + if constexpr (Arch >= 90) { + if constexpr (Is_flashmask_ && !Is_causal) { + run_mha_bwd_dispatch(params, stream); + } else if constexpr (Is_causal && Has_softcap || Is_flashmask_) { + // register spill with 128 x 128 + run_mha_bwd_dispatch(params, stream); + } else { + // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block. + run_mha_bwd_dispatch(params, stream); + } + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); + }); +} + +template +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { + static constexpr bool Is_local = false; + static constexpr bool Is_flashmask_ = true; + FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { + if constexpr (Arch >= 90) { + run_mha_bwd_dispatch(params, stream); + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} + +template +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + static constexpr bool Is_local = false; + static constexpr bool Is_flashmask_ = true; + BOOL_SWITCH(params.block_mask_ptr != nullptr, Is_blockmask_, [&]{ + FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { + if constexpr (Arch >= 90) { + if constexpr (Is_causal || Is_local || Has_softcap) { + run_mha_bwd_dispatch(params, stream); + } else { + if ((params.seqlen_q >= 1024 || params.seqlen_k >= 1024) && !(Has_lt_end && Has_ut_start)) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + } + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); + }); +} + +template +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { + static constexpr bool Is_local = false; + static constexpr bool Is_flashmask_ = true; + FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { + if constexpr (Arch >= 90) { + if (Has_lt_end && Has_ut_start) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} + +template +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { + static constexpr bool Is_local = false; + static constexpr bool Is_flashmask_ = true; + BOOL_SWITCH(params.block_mask_ptr != nullptr, Is_blockmask_, [&]{ + FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] { + if constexpr (Arch >= 90) { + if (Has_lt_end && Has_ut_start) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); + }); +} diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_postprocess_kernel.h b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_postprocess_kernel.h new file mode 100644 index 00000000000..c91e261507d --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_postprocess_kernel.h @@ -0,0 +1,256 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include "cutlass/arch/barrier.h" + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdPostprocessConvertdQ { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + + static_assert(ArchTag::kMinComputeCapability >= 75); + static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90; + + static constexpr uint32_t MaxThreadsPerBlock = kNThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kHeadDim = get<1>(TileShape_MK{}); + static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup"); + static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup; + using R2SLayoutAtomdQaccum = std::conditional_t< + IsSm90, + Layout, Int>>, + Layout>> + >; + using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, + Layout>>{})); // Val layout, 1 or 4 vals per read + using G2SLayoutAtomdQaccum = Layout>>; + // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions + using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, G2SLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per read + // We don't do bound checking for the gmem -> smem load so we just assert here. + static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0); + static constexpr int SmemdQaccumSize = size(TileShape_MK{}); + using SmemLayoutdQaccumFlat = Layout>>; + using SmemLayoutdQaccum = std::conditional_t< + IsSm90, + Layout, Int>>, + Layout>> + >; + + // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, + // then setting kBlockKSmem to 32 will cause "Static shape_div failure". + // We want to treat it as 64 x 48, so kBlockKSmem should be 16. + static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{}); + static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16); + static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); + using SmemLayoutAtomdQ = + decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{})); + using SmemLayoutdQt = + decltype(cute::composition(SmemLayoutdQ{}, + make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})), + make_stride(Int(TileShape_MK{})>{}, _1{})))); + + using SmemCopyAtomdQ = Copy_Atom< + std::conditional_t< + IsSm90, + std::conditional_t, + AutoVectorizingCopyWithAssumedAlignment<128> + >, + Element>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock)); + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_dqacc; + cute::array_aligned> smem_dq; + alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) + using StridedQ = cute::Stride; + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + // Device side arguments + struct Arguments { + ElementAccum const* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + Element* ptr_dQ; + ShapedQ const shape_dQ; + StridedQ const stride_dQ; + float const softmax_scale; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Kernel entry point API + struct Params { + ElementAccum const* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + Element* ptr_dQ; + ShapedQ const shape_dQ; + StridedQ const stride_dQ; + float const softmax_scale; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + return { + args.ptr_dQaccum, + args.shape_dQaccum, + args.stride_dQaccum, + args.ptr_dQ, + args.shape_dQ, + args.stride_dQ, + args.softmax_scale, + args.cu_seqlens, + args.seqused + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{}); + Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{}); + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{}); + Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const bidh = blockIdx.y; + int const bidb = blockIdx.z; + + flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused); + bool const is_varlen = params.cu_seqlens; + if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; } + + // Step 1: load dQaccum from gmem to smem + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) + if constexpr (IsSm90) { // Use BulkCopy + static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v / 8); + auto bulk_copy = Copy_Traits{}; + // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); } + if (thread_idx == 0) { + shared_storage.barrier_dQaccum.init(1 /*numThreads*/); + shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); + copy(bulk_copy.with(*reinterpret_cast(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat); + } + __syncthreads(); + shared_storage.barrier_dQaccum.wait(0); + } else { + G2STiledCopydQaccum g2s_tiled_copy_dQaccum; + auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum); + Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum); + cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s); + __syncthreads(); + } + + // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); } + + // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16 + R2STiledCopydQaccum s2r_tiled_copy_dQaccum; + auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum); + TiledMma tiled_mma_dQ; + Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select(TileShape_MK{})); + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); } + CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); + Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); + cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; } + // Convert tdQrdQ from fp32 to fp16 + Tensor rdQ = make_tensor_like(taccdQrdQaccum); + flash::convert_type_out(taccdQrdQaccum, rdQ); + + // Step 3: Copy dQ from register to smem + auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + // if (cute::thread0()) { print(smem_tiled_copy_dQ); } + // if (cute::thread0()) { print(smem_thr_copy_dQ); } + // if (cute::thread0()) { print(sdQ); } + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + + // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + GmemTiledCopy gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + + Tensor tdQrdQ = make_fragment_like(tdQsdQ); + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{})); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); } + // Need to check OOB when reading from smem if kBlockM isn't evenly tiled + static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; + flash::copy( + gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM); + + // Step 5: Copy dQ from register to gmem + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM) + ); + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_preprocess_kernel.h b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_preprocess_kernel.h new file mode 100644 index 00000000000..d3dade8b5d2 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_bwd_preprocess_kernel.h @@ -0,0 +1,266 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdPreprocess { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + + static_assert(std::is_same_v && ArchTag::kMinComputeCapability >= 75 || + std::is_same_v && ArchTag::kMinComputeCapability >= 80 || + std::is_same_v && ArchTag::kMinComputeCapability >= 89); + + static constexpr uint32_t MaxThreadsPerBlock = 256; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + static constexpr int SharedStorageSize = 0; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kHeadDim = get<1>(TileShape_MK{}); + // We want kBlockKGmem to be a power of 2 so that when we do the summing, + // it's just between threads in the same warp + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum"); + using GmemLayoutAtomAccum = Layout>>; + using GmemTiledCopyAccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomAccum{}, + Layout>>{})); // Val layout, 4 vals per store + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) + using StrideO = cute::Stride; + using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) + using StridedPsum = cute::Stride<_1, int64_t, int64_t>; + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + // Device side arguments + struct Arguments { + Element const* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + Element const* ptr_dO; + StrideO const stride_dO; + float* ptr_dPsum; + ShapedPsum const shape_dPsum; + StridedPsum const stride_dPsum; + float const* ptr_LSE; + StridedPsum const stride_LSE; + float *ptr_LSE_log2; + StridedPsum const stride_LSE_log2; + ElementAccum* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + int num_batch; // We need this to know the size of dq_semaphore in case of varlen + int* dq_semaphore; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Kernel entry point API + struct Params { + Element const* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + Element const* ptr_dO; + StrideO const stride_dO; + float* ptr_dPsum; + ShapedPsum const shape_dPsum; + StridedPsum const stride_dPsum; + float const* ptr_LSE; + StridedPsum const stride_LSE; + float* ptr_LSE_log2; + StridedPsum const stride_LSE_log2; + ElementAccum* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + int num_batch; + int* dq_semaphore; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + return { + args.ptr_O, + args.shape_O, + args.stride_O, + args.ptr_dO, + args.stride_dO, + args.ptr_dPsum, + args.shape_dPsum, + args.stride_dPsum, + args.ptr_LSE, + args.stride_LSE, + args.ptr_LSE_log2, + args.stride_LSE_log2, + args.ptr_dQaccum, + args.shape_dQaccum, + args.stride_dQaccum, + args.num_batch, + args.dq_semaphore, + args.cu_seqlens, + args.seqused + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, [[maybe_unused]] char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const bidh = blockIdx.y; + int const bidb = blockIdx.z; + + flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused); + bool const is_varlen = Varlen && params.cu_seqlens; + int const seqlen_o = seqlen_info.seqlen; + if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } + + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + + auto shape_LSE = select<0, 2, 3>(params.shape_O); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0); + Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape>{}, make_coord(m_block)); + static_assert(kBlockM <= MaxThreadsPerBlock); + float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY; + + GmemTiledCopy gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor tOgO = gmem_thr_copy_O.partition_S(gO); + Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO); + // Construct identity layout for gO + Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + + // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128) + Tensor tOrO = make_fragment_like(tOgO); + Tensor tOrdO = make_fragment_like(tOgdO); + flash::copy( + gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + flash::copy( + gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));} + + // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64)) + Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout()))); + Tensor tOrO_l = make_tensor(tOrO.data(), l); + Tensor o_fp32 = make_tensor_like(tOrO_l); + flash::convert_type_out(tOrO_l, o_fp32); + Tensor tOrdO_l = make_tensor(tOrdO.data(), l); + Tensor do_fp32 = make_tensor_like(tOrdO_l); + flash::convert_type_out(tOrdO_l, do_fp32); + // Sum across the last dimension + Tensor dP_sum = make_tensor(make_shape(size<0>(o_fp32))); + #pragma unroll + for (int mi = 0; mi < size<0>(o_fp32); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(o_fp32); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + flash::SumOp sum_op; + dP_sum(mi) = flash::Allreduce::run(dP_sum_cur, sum_op); + } + + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape>{}, make_coord(m_block)); + if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) { + #pragma unroll + for (int mi = 0; mi < size(dP_sum); ++mi) { + int const row = get<0>(tOcO(_0{}, mi, _0{})); + gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0; + } + } + + int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM); + Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0); + Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape>{}, make_coord(m_block)); + if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) { + gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); + } + + if constexpr (Clear_dQaccum) { + Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); + GmemTiledCopyAccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(Copy_Atom, ElementAccum>{}, zero, tdQgdQaccum); + } + + if (params.dq_semaphore != nullptr && thread_idx == 0) { + int const num_batch = params.num_batch; + int const num_head = get<2>(params.shape_O); + params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0; + } + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine.cu b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine.cu new file mode 100644 index 00000000000..9e76539e113 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine.cu @@ -0,0 +1,28 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#include "flash_fwd_combine_launch_template.h" + +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); + +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); + +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine_kernel.h b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine_kernel.h new file mode 100644 index 00000000000..c7f7a3b7ad7 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine_kernel.h @@ -0,0 +1,496 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include + +#include "cutlass/arch/grid_dependency_control.h" + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdCombine { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + static constexpr int kMaxSplits = 1 << kLogMaxSplits_; + static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float))); + static_assert(AlignmentLSE >= 1); + static constexpr int kStages = 4; + + static_assert(ArchTag::kMinComputeCapability >= 75); + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr uint32_t MaxThreadsPerBlock = kNThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kBlockK = get<1>(TileShape_MK{}); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); + static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemCopyAtom = std::conditional_t< + Has_cp_async, + cute::Copy_Atom, ElementPartial>, + cute::Copy_Atom, ElementPartial> + >; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + using GmemTiledCopyAccum = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 4 vals per load + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 4 vals per load + + using AlignmentTypeLSE = cute::uint_byte_t(sizeof(float)) * AlignmentLSE>; + static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float); + static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE"); + static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8"); + static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8))); + static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE"); + using GmemLayoutAtomLSE = Layout, Int>, + Stride, _1>>; + static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0); + using GmemCopyAtomLSE = std::conditional_t< + Has_cp_async, + cute::Copy_Atom, float>, + cute::Copy_Atom, float> + >; + using GmemTiledCopyLSE = decltype( + make_tiled_copy(GmemCopyAtomLSE{}, + GmemLayoutAtomLSE{}, + Layout>>{})); // Val layout, 4 vals per load + + // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking + static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE"); + // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + using SmemLSESwizzle = std::conditional_t< + kBlockMSmem == 8, + Swizzle<5, 0, 5>, + std::conditional_t, Swizzle<3, 2, 3>> + >; + using SmemLayoutAtomLSE = + decltype(composition(SmemLSESwizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); + + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; + + // We want each column (kMaxSplits) to be processed by threads in the same warp. + // To reduce the number of shuffles, we want as few threads on the same column as possible. + // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column + // have have 64 such quads. + static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem"); + static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem; + static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp"); + using S2RLayoutAtomLSE = Layout, Int>>; + using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom{}, S2RLayoutAtomLSE{}, Layout<_1>{})); + + using ShapeOPartial = cute::Shape; // (seqlen, d, num_splits, head, batch) + using StrideOPartial = cute::Stride; + using ShapeLSEPartial = cute::Shape; // (seqlen, num_splits, head, batch) + using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch) + using ShapeO = cute::Shape; // (seqlen, d, head, batch) + using StrideO = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_lse_partial; + cute::array_aligned smem_max_valid_split; + cute::array_aligned> smem_o_partial; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + ElementPartial const* const ptr_O_partial; + ShapeOPartial const shape_O_partial; + StrideOPartial const stride_O_partial; + float const* const ptr_LSE_partial; + ShapeLSEPartial const shape_LSE_partial; + StrideLSEPartial const stride_LSE_partial; + Element* const ptr_O; + StrideO const stride_O; + float* const ptr_LSE; + StrideLSE const stride_LSE; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; + }; + + // Kernel entry point API + struct Params { + ElementPartial const* const ptr_O_partial; + ShapeOPartial const shape_O_partial; + StrideOPartial const stride_O_partial; + float const* const ptr_LSE_partial; + ShapeLSEPartial const shape_LSE_partial; + StrideLSEPartial const stride_LSE_partial; + Element* const ptr_O; + StrideO const stride_O; + float* const ptr_LSE; + StrideLSE const stride_LSE; + cutlass::FastDivmod seqlen_divmod, head_divmod; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); + return { + args.ptr_O_partial, + args.shape_O_partial, + args.stride_O_partial, + args.ptr_LSE_partial, + args.shape_LSE_partial, + args.stride_LSE_partial, + args.ptr_O, + args.stride_O, + args.ptr_LSE, + args.stride_LSE, + cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), + args.cu_seqlens, + args.seqused, + args.num_splits_dynamic_ptr, + args.semaphore_to_reset + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{}); + Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape>{}); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const k_block = blockIdx.y; + int const batch = blockIdx.z; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + + if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + cutlass::arch::wait_on_dependent_grids(); + *params.semaphore_to_reset = 0; + } + if (num_splits <= 1) { return; } + flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; + int const offset = seqlen_info.offset; + int const seqlen = seqlen_info.seqlen; + int max_idx = seqlen * get<2>(params.shape_LSE_partial); + if constexpr (Varlen) { + if (m_block * kBlockM >= max_idx) { return; } + } + + cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); + + // Step 1: load LSE_partial from gmem -> smem + Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), + select<1, 0, 2, 3>(params.shape_LSE_partial), + select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) + Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); + GmemTiledCopyLSE gmem_tiled_copy_LSE; + auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); + Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE); + + // Construct identity layout for sLSE + Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m) + // Repeat the partitioning with identity layouts + Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); + + cutlass::arch::wait_on_dependent_grids(); + + #pragma unroll + for (int m = 0; m < size<2>(tLSEcLSE); ++m) { + int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh; + if constexpr (!Varlen) { + bidh = params.seqlen_divmod.divmod(m_idx, idx); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + } + Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); + #pragma unroll + for (int s = 0; s < size<1>(tLSEcLSE); ++s) { + int si = get<0>(tLSEcLSE(_0{}, s, _0{})); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} + if (si < num_splits) { + cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); + } else { + cute::fill(tLSEsLSE(_, s, m), -INFINITY); + } + } + } else { + // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem + // cute::fill(tLSEsLSE(_, _, m), -INFINITY); + } + } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + + // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2. + // We want these async loads to be in flight as we compute the LSE. + GmemTiledCopyAccum gmem_tiled_copy_O_partial; + auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx); + // Construct identity layout for gO + Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), + params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) + + // Precompute these values to avoid recomputing them in the loop + Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); + Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); + Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + int mi = get<0>(tOcO(_0{}, m, _0{})); + int idx = m_block * kBlockM + mi; + if constexpr (!Varlen) { + tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); + } else { + tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); + } + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); + if (idx >= max_idx) { + tObidh[m] = -1; + } + } + + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + if constexpr (!(Is_even_K)) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } + } + + Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); + + auto load_O_partial = [&] (int split, int stage) { + Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + if (tObidh(m) >= 0) { + Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); + Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOcO); ++k) { + int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (Is_even_K || tOpO(k)) { + cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k)); + } + } + } + } + }; + + for (int s = 0; s < kStages - 1; ++s) { + if (s < num_splits) { load_O_partial(s, s); } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + } + + // Step 3: load and transpose LSE_partial from smem -> rmem + if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } + __syncthreads(); + + S2RTiledCopyLSE s2r_tiled_copy_LSE; + auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx); + Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE); + Tensor ts2rrLSE = make_fragment_like(ts2rsLSE); + cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE); + + // Step 4: compute the final LSE along the split dimension + Tensor lse_sum = make_tensor(make_shape(size<2>(ts2rrLSE))); + Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE); + // We compute the max valid split for each row to short-circuit the computation later + Tensor max_valid_split = make_tensor(make_shape(size<2>(ts2rrLSE))); + static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + float lse_max = ts2rrLSE(_0{}, _0{}, m); + #pragma unroll + for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + int max_valid_idx = -1; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { + if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); } + } + MaxOp max_int_op; + max_valid_split[m] = Allreduce::run(max_valid_idx, max_int_op); + float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum_cur = 0.f; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { + float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur); + lse_sum_cur += scale; + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);} + // ts2rsLSE(_0{}, m, s) = scale; + ts2rrLSE(_0{}, s, m) = scale; + } + SumOp sum_op; + lse_sum_cur = Allreduce::run(lse_sum_cur, sum_op); + lse_sum(m) = logf(lse_sum_cur) + lse_max; + float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; } + } + // Store the scales exp(lse - lse_logsum) back to smem + cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); + + // Store max_valid_split to smem + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 5: store final LSE back to gmem + if (k_block == 0) { + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh; + if constexpr (!Varlen) { + bidh = params.seqlen_divmod.divmod(m_idx, idx); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh) = lse_sum(m); + } + } + } + } + + // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O + __syncthreads(); + int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))]; + #pragma unroll + for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); } + Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor(TileShape_MK{})).layout(); + Tensor tOrOpartial = make_fragment_like(tOrOpartial_layout); + Tensor tOrO = make_fragment_like(tOrOpartial); + clear(tOrO); + int stage_load = kStages - 1, stage_compute = 0; + #pragma unroll 4 // Already tuned for speed + for (int s = 0; s <= thr_max_valid_split; ++s) { + Tensor scale = make_tensor(make_shape(size<1>(tOrOpartial))); + #pragma unroll + for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); } + + if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0; + if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } + // We don't need __syncthreads() because each thread is just reading its own data from smem + cute::copy(Copy_Atom, ElementPartial>{}, + tOsOpartial(_, _, _, stage_compute), tOrOpartial); + stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0; + + #pragma unroll + for (int m = 0; m < size<1>(tOrOpartial); ++m) { + if (tObidh(m) >= 0 && scale(m) > 0.f) { + #pragma unroll + for (int k = 0; k < size<2>(tOrOpartial); ++k) { + if (Is_even_K || tOpO(k)) { + Tensor rOpartial = make_tensor_like(tOrOpartial(_, m, k)); + flash::convert_type_out(tOrOpartial(_, m, k), rOpartial); + #pragma unroll + for (int i = 0; i < size<0>(tOrOpartial); ++i) { + tOrO(i, m, k) += scale(m) * rOpartial[i]; + } + } + } + } + } + } + + // Step 7: Write the final O to gmem + Tensor rO = make_tensor_like(tOrO); + flash::convert_type_out(tOrO, rO); + auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), + shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); + Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); + GmemTiledCopy gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + if (tObidh(m) >= 0) { + #pragma unroll + for (int k = 0; k < size<2>(tOcO); ++k) { + int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (Is_even_K || tOpO(k)) { + cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); + } + } + } + } + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine_launch_template.h b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine_launch_template.h new file mode 100644 index 00000000000..07f78beacf5 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_combine_launch_template.h @@ -0,0 +1,94 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 +#include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_combine_kernel.h" + +using namespace cute; + +template +void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + using TileShape_MK = cute::Shape, Int>; + using CombineKernel = flash::FlashAttnFwdCombine; + + typename CombineKernel::Arguments args { + static_cast(params.oaccum_ptr), + {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial + static_cast(params.softmax_lseaccum_ptr), + {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial + {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial + static_cast(params.o_ptr), + {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O + static_cast(params.softmax_lse_ptr), + {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + }; + + typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); + int num_blocks_k = cute::ceil_div(params.dv, kBlockK); + int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); + dim3 grid_m(num_blocks_m, num_blocks_k, params.b); + auto kernel = cutlass::device_kernel; + int smem_size = CombineKernel::SharedStorageSize; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + // We want kBlockM to be as small as possible to maximize parallelism. + // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). + static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); + static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); + ARCH_SWITCH(params.arch, Arch, [&] { + BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream, enable_pdl); + return; + } + } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else { + run_flash_fwd_combine(params, stream, enable_pdl); + } + }); + }); +} diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_kernel_sm80.h b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_kernel_sm80.h new file mode 100644 index 00000000000..59e6252aabf --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_kernel_sm80.h @@ -0,0 +1,228 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include + +#include "seqlen.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdSm80 { + +public: + + // Type Aliases + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + static constexpr bool Is_causal = CollectiveMainloop::Is_causal; + static constexpr bool Is_local = CollectiveMainloop::Is_local; + static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); + static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; + static constexpr bool Varlen = CollectiveMainloop::Varlen; + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + static constexpr bool Split = CollectiveMainloop::Split; + static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; + static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; + static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool PackGQA = CollectiveMainloop::PackGQA; + static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; + + // Mainloop derived types + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1; + + // Kernel level shared memory storage + // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q + // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k). + static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) + - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))) + - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k))); + static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + struct { + cute::array padding_; + typename CollectiveMainloop::TensorStorage mainloop; + }; + // We want smem_o to line up with the start of smem_v + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); + // Initialize matmul objects. + TiledMma tiled_mma; + + scheduler.init_consumer(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{})); + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + // If there's tanh softcap, the scaling will be done before tanh. + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + int const bidb = get<2>(block_coord); + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); + + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.mainloop.shape_Q), + !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + }; + if constexpr (AppendKV) { + bool tile_new_valid = mainloop.store_kv_new( + params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); + if (tile_new_valid) { __syncthreads(); } + } + bool tile_valid = mainloop.mma( + params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, + shared_storage); + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + if (tile_valid) { + // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, + threadIdx.x, block_coord); + } else { + // Write 0 to gO and -inf to gLSE. + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); + } + } + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_kernel_sm90.h b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_kernel_sm90.h new file mode 100644 index 00000000000..ef88554325f --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_kernel_sm90.h @@ -0,0 +1,573 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cutlass/arch/grid_dependency_control.h" + +#include "seqlen.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdSm90 { + +public: + + // Type Aliases + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + static constexpr bool Is_causal = CollectiveMainloop::Is_causal; + static constexpr bool Is_local = CollectiveMainloop::Is_local; + static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); + static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; + static constexpr bool Varlen = CollectiveMainloop::Varlen; + static constexpr bool Split = CollectiveMainloop::Split; + static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; + static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; + static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool HasQv = CollectiveMainloop::HasQv; + static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; + static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; + static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; + static constexpr bool PackGQA = CollectiveMainloop::PackGQA; + static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; + static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; + static constexpr bool Is_flashmask = CollectiveMainloop::Is_flashmask; + static constexpr bool Use_Sch_Pipeline = TileScheduler_::pipelining; + static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); + using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; + + // Mainloop derived types + using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using BarrierQ = std::conditional_t; + + // Epilogue derived types + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumGenerateWarpGroups = 1; + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + static_assert(Use_TMA_KV); + + /// Register requirement for Load and Math WGs + // If we use cp.async to load K and V, we need more registers for the producer WG. + // static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32); + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160); + + static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32); + static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160); + + // If you want to print from the producer warp, you'd need to increase the number of registers + // Otherwise you'll get CUDA error. + // static constexpr uint32_t LoadRegisterRequirement = 40; + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; + + // Kernel level shared memory storage + // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v). + static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); + static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _1> { + union { + struct { + cute::array padding_; + typename CollectiveMainloop::TensorStorage mainloop; + }; + // We want smem_o to line up with the start of smem_v + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + struct PipelineStorage : cute::aligned_struct<16, _1> { + alignas(16) BarrierQ barrier_Q; + alignas(16) BarrierQ barrier_Qv; + alignas(16) cutlass::arch::ClusterBarrier barrier_O; + alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; + alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; + alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt; + alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new; + alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new; + alignas(16) typename CollectiveMainloop::MainloopPipelineNBlock::SharedStorage pipeline_n_block; + alignas(16) typename CollectiveMainloop::MainloopPipelineFlashMaskApply::SharedStorage pipeline_flashmask_apply; + // Use_Sch_Pipeline: 2, otherwise: 1 + alignas(16) typename TileScheduler::SharedStorage smem_scheduler[Use_Sch_Pipeline ? 2 : 1]; + } pipelines; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; + using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; + using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt; + using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew; + using MainloopPipelineNBlock = typename CollectiveMainloop::MainloopPipelineNBlock; + using MainloopPipelineFlashMaskApply = typename CollectiveMainloop::MainloopPipelineFlashMaskApply; + using PipelineState = typename CollectiveMainloop::PipelineState; + using PipelineParamsK = typename MainloopPipelineK::Params; + using PipelineParamsV = typename MainloopPipelineV::Params; + using PipelineParamsVt = typename MainloopPipelineVt::Params; + using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params; + using PipelineParamsNBlock = typename MainloopPipelineNBlock::Params; + using PipelineParamsFlashMaskApply = typename MainloopPipelineFlashMaskApply::Params; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + static constexpr int num_sch_stage = Use_Sch_Pipeline ? 2 : 1; + __shared__ int32_t flashmask_smem_[4 * kBlockN * CollectiveMainloop::kStages]; + __shared__ __align__(128) int32_t flashmask_maxmin_smem[num_sch_stage * 8 * CollectiveMainloop::Flashmask_n_block_buffer_length * CollectiveMainloop::kNBlockStages]; + __shared__ int32_t n_block_smem[num_sch_stage * CollectiveMainloop::Flashmask_n_block_buffer_length * CollectiveMainloop::kNBlockStages]; + __shared__ __align__(128) int32_t blockmask_smem_[CollectiveMainloop::Blockmask_n_block_buffer_valid_length * CollectiveMainloop::kNBlockStages]; + // When n_block_smem is full, we need to store the flag in the following extra flag storage, instead of allocating 4 more elements + __shared__ int32_t extra_flags[4]; // if num_sch_stage is 1, we actually only need two (kNBlockStages = 2) + + bool Is_blockmask = params.mainloop.block_mask_ptr != nullptr; + + if constexpr (Use_Sch_Pipeline) { + if (threadIdx.x < 2) { + shared_storage.pipelines.smem_scheduler[threadIdx.x] = -1; + } + } + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + + if (warp_idx == 0 && lane_predicate) { + shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + } + shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); + } + + PipelineParamsNBlock pipeline_params_n_block; + pipeline_params_n_block.role = warp_group_idx == 0 && warp_idx_in_warpgroup != 0 + ? MainloopPipelineNBlock::ThreadCategory::Producer + : MainloopPipelineNBlock::ThreadCategory::Consumer; + pipeline_params_n_block.consumer_arv_count = (!LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup) + NumProducerThreads; // TODO(umiswing): how to deal with LargeHeadDimV? + pipeline_params_n_block.producer_arv_count = cutlass::NumThreadsPerWarpGroup - NumProducerThreads; + + MainloopPipelineNBlock pipeline_n_block(shared_storage.pipelines.pipeline_n_block, pipeline_params_n_block); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + + if (warp_group_idx == 0 && warp_idx_in_warpgroup != 0) { // n_block generator + cutlass::arch::warpgroup_reg_dealloc(); + cutlass::PipelineState n_block_pipe_write = cutlass::make_producer_start_state(); + // Manually specify the scheduler role: producer. For StaticPersistentTileSch, passing template args won't change the behavior + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info) + ) { + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + SeqlenInfo_t seqlen_info{ + get<2>(block_coord) /*bidb*/, + get<0>(params.mainloop.shape_Q), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + }; + auto [n_block_min, n_block_max] = CollectiveMainloop::BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.mainloop.num_splits, + params.mainloop.window_size_left, params.mainloop.window_size_right, params.mainloop.qhead_per_khead_divmod); + + // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { + // skipping, don't forget to fetch us the next work! + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + continue; + } + } + + // for padding 32 and padding 4: the num_chunk (pad_32) >= num_chunk (pad_4) is always true + const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; // umiswing: padding for int4 load + const int num_chunk = (nblock_seqlen + CollectiveMainloop::Flashmask_n_block_buffer_valid_length - 1) / CollectiveMainloop::Flashmask_n_block_buffer_valid_length; + // reverse_chunk_idx, start from right to left: [5, 4, 3, 2, 1, 0], and fwd kernel scans from right to left + bool valid_chunk = true; + const int cppl_stage = scheduler.template stage(); // coarse pipeline stage (offset, 0 or 2) + +#define GEN_N_BLOCK_DISPATCH(DispatchTag) \ + valid_chunk = mainloop.generate_n_block(get<0>(block_coord), \ + reverse_chunk_idx, \ + num_chunk, \ + reverse_chunk_idx == num_chunk - 1 ? CollectiveMainloop::Flashmask_n_block_finish : CollectiveMainloop::Flashmask_n_block_chunk_end,\ + n_block_min, n_block_max, seqlen_info.seqlen_q, \ + flashmask_maxmin_smem + 8 * CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage), \ + n_block_smem + CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage), \ + extra_flags + n_block_pipe_write.index() + cppl_stage, \ + Is_blockmask, \ + blockmask_smem_ + CollectiveMainloop::Blockmask_n_block_buffer_valid_length * (n_block_pipe_write.index() + cppl_stage)) + + + for(int reverse_chunk_idx = 0; reverse_chunk_idx < num_chunk; reverse_chunk_idx++) { + if (valid_chunk) + pipeline_n_block.producer_acquire(n_block_pipe_write); + if (Is_blockmask) { + mainloop.load_blockmask(params.mainloop, seqlen_info, block_coord, reverse_chunk_idx, num_chunk, + blockmask_smem_ + CollectiveMainloop::Blockmask_n_block_buffer_valid_length * (n_block_pipe_write.index() + cppl_stage)); + } + mainloop.load_max_min(params.mainloop, seqlen_info, block_coord, reverse_chunk_idx, num_chunk, flashmask_maxmin_smem + + 8 * CollectiveMainloop::Flashmask_n_block_buffer_length * (n_block_pipe_write.index() + cppl_stage)); + if (params.mainloop.ut_start_ptr) { + GEN_N_BLOCK_DISPATCH(CollectiveMainloop::PtrExistDispatchTag::FULL_PTR); + } else if (params.mainloop.lt_end_ptr || params.mainloop.ut_end_ptr) { + GEN_N_BLOCK_DISPATCH(CollectiveMainloop::PtrExistDispatchTag::DUAL_PTR); + } else { + GEN_N_BLOCK_DISPATCH(CollectiveMainloop::PtrExistDispatchTag::SINGLE_PTR); + } + if (valid_chunk) { + pipeline_n_block.producer_commit(n_block_pipe_write); + ++n_block_pipe_write; + } + } +#undef GEN_N_BLOCK_DISPATCH + + // heqianyue note: the execution time of reverse_chunk for loop will be influenced by the workload of computation pipeline + // therefore, **works with more partially/fully masked block** will have longer execution time for this producer. So, the + // interval between two consecutive `get_next_work` of this producer will increase, thus lowering the frequency of preemptive + // scheduling. However, since there is double-buffer, the for-loop execution time of reverse_chunk is only a rough estimator for + // the workload of computation pipeline, but I think it is good enough. + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + } + } else { + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + PipelineParamsK pipeline_params_k; + pipeline_params_k.role = warp_group_idx == 0 + ? MainloopPipelineK::ThreadCategory::Producer + : MainloopPipelineK::ThreadCategory::Consumer; + if constexpr (Use_TMA_KV) { + pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_k.is_leader = warp_group_thread_idx == 0; + pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; + } else { + pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; + pipeline_params_k.producer_arv_count = NumProducerThreads; + } + + static_assert(is_same_v); + PipelineParamsVt pipeline_params_vt = pipeline_params_k; + if constexpr (Use_TMA_KV && !SameHeadDim) { + pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } + } else { + if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } + } + + MainloopPipelineK pipeline_k = [&] { + if constexpr (Use_TMA_KV) { + return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); + } else { + return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k); + } + }(); + // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); + MainloopPipelineV pipeline_v = [&] { + if constexpr (!Transpose_V) { + static_assert(is_same_v); + if constexpr (Use_TMA_KV) { + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); + } else { + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); + } + } else { + PipelineParamsV pipeline_params_v; + pipeline_params_v.role = warp_group_idx == 0 + ? MainloopPipelineV::ThreadCategory::Producer + : MainloopPipelineV::ThreadCategory::Consumer; + pipeline_params_v.producer_arv_count = NumProducerThreads; + pipeline_params_v.consumer_arv_count = NumMmaThreads; + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); + } + }(); + // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then + // the producer WG will read from pipeline_vt and write to pipeline_v. + // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. + // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers. + // However, the thread role isn't used in the pipeline implementation. + MainloopPipelineVt pipeline_vt = [&] { + if constexpr (Use_TMA_KV) { + pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); + } else { + pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); + } + }(); + + PipelineParamsKVNew pipeline_params_kv_new; + pipeline_params_kv_new.role = warp_group_idx == 0 + ? MainloopPipelineKVNew::ThreadCategory::Producer + : MainloopPipelineKVNew::ThreadCategory::Consumer; + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; + pipeline_params_kv_new.num_consumers = NumMmaThreads; + auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + if constexpr (!SameHeadDim) { + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } + auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + cutlass::PipelineState n_block_pipe_read; + + PipelineParamsFlashMaskApply pipeline_params_flashmask_apply; + pipeline_params_flashmask_apply.role = warp_group_idx == 0 + ? MainloopPipelineFlashMaskApply::ThreadCategory::Producer + : MainloopPipelineFlashMaskApply::ThreadCategory::Consumer; + pipeline_params_flashmask_apply.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; // TODO(umiswing): how to deal with LargeHeadDimV? + pipeline_params_flashmask_apply.producer_arv_count = NumProducerThreads; + + MainloopPipelineFlashMaskApply pipeline_flashmask_apply(shared_storage.pipelines.pipeline_flashmask_apply, pipeline_params_flashmask_apply); + + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + // The pipelines for AppendKV and main attention are different, since e.g. main attention + // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load + // KV_new. Since the pipeline states are different, we have to manually sync to make + // sure the two pipelines don't race when accessing smem_k and smem_v. + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); + + int work_idx = 0; + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + static_assert(SingleProducerWarp); + + scheduler.init_consumer(); + if constexpr (SingleProducerWarp) { + if (warp_idx_in_warpgroup != 0) { return; } + } + + cutlass::arch::wait_on_dependent_grids(); + + // Load Q, K, V + for (auto work_tile_info = scheduler.get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.get_next_work(params.scheduler, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + SeqlenInfo_t seqlen_info{ + get<2>(block_coord) /*bidb*/, + get<0>(params.mainloop.shape_Q), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + }; + mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, pipeline_n_block, pipeline_flashmask_apply, smem_pipe_write, + n_block_pipe_read, + shared_storage, seqlen_info, block_coord, work_idx, + flashmask_smem_, n_block_smem + CollectiveMainloop::Flashmask_n_block_buffer_length * scheduler.stage(), + extra_flags + scheduler.stage()); + // coarse pipeline stage (offset, 0 or 2) + } + mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + + // Initialize matmul objects. + TiledMmaPV tiled_mma_pv; + + PipelineState smem_pipe_read; + PipelineState smem_pipe_read_new; + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v + // (like in Cutlass's gemm) because the read and release pipeline states are always the same. + + scheduler.init_consumer(); + mainloop.mma_init(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + // get_next_work will be called before the epilogue + ) { + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + int const bidb = get<2>(block_coord); + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.mainloop.shape_Q), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + }; + // If there's tanh softcap, the scaling will be done before tanh. + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + flash::Softmax softmax(softmax_scale_log2); + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); + bool tile_valid; + if constexpr (!LargeHeadDimV) { + tile_valid = mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, pipeline_n_block, pipeline_flashmask_apply, smem_pipe_read, + n_block_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage, + flashmask_smem_, n_block_smem + CollectiveMainloop::Flashmask_n_block_buffer_length * scheduler.stage(), + extra_flags + scheduler.stage()); + } else { // mma_pv might not compile if !LargeHeadDimV + if (warp_group_idx == 1) { + tile_valid = mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, pipeline_n_block, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage, + flashmask_smem_, n_block_smem + CollectiveMainloop::Flashmask_n_block_buffer_length * scheduler.stage(), + extra_flags + scheduler.stage()); + } else { + tile_valid = mainloop.mma_pv( + params.mainloop, pipeline_v, pipeline_n_block, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage, + flashmask_smem_); + } + } + // Do this here before the epilogue so that the next tile is ready to go. + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); + if (tile_valid) { + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, + threadIdx.x - MmaThreadOffset, block_coord); + } else { + // Write 0 to gO and -inf to gLSE. + epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + } + } + epilogue.store_tail(); + } + } + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_launch_template.h b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_launch_template.h new file mode 100644 index 00000000000..6ebfc9bc83f --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_fwd_launch_template.h @@ -0,0 +1,273 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" // For device_kernel +#include +#include "cutlass/cluster_launch.hpp" +#include "cutlass/kernel_launch.h" + +#include "static_switch.h" +#include "flash.h" +#include "tile_size.h" +#include "tile_scheduler.hpp" +#include "flash_fwd_kernel_sm90.h" +#include "flash_fwd_kernel_sm80.h" +#include "mainloop_fwd_sm90_tma_gmma_ws.hpp" +#include "mainloop_fwd_sm80.hpp" +#include "epilogue_fwd.hpp" +#include "flash_mask.hpp" + +using namespace cute; + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); + static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); + static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen"); + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; + static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor; + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + + // Can't use structured binding since it's not compatible with constexpr + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, short_seqlen); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); + static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); + static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); + static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); + static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); + static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); + + // if true: use Dual PPTX, else, use PPT + static constexpr bool No_Scheduler_Pipeline = true; + // TODO(heqianyue): headdim = 64 comparison is actually worse for DualPPTX, for unknown reasons + static constexpr bool Predicate_for_Headdim = true; + + using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_PV = cute::Shape, Int, Int>; + using ClusterShape = cute::Shape, _1, _1>; + using CollectiveMainloop = std::conditional_t< + Arch >= 90, + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 + >; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + + // If Split then we probably don't have enough work for PersistentScheduler to be useful. + // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better + // since we'll avoid launching a bunch of thread blocks that immediately exit. + // On Sm80, noncausal persistent seems a bit slower. + static constexpr int _NumProducerThreads = cutlass::NumThreadsPerWarpGroup - cutlass::NumThreadsPerWarp; // expect: 96 + static constexpr int _NumConsumerThreads = CollectiveMainloop::NumMmaThreads + cutlass::NumThreadsPerWarpGroup - _NumProducerThreads; + // TODO(heqianyue): The following Predicate_for_Headdim might be removed in the future. Currently, Dual PPTX cannot be as fast as PPT + // in headdim = 64 case, I suspect I've fixed it, but there is no testing facility (9.30 EB5 occupied) + // The current logic: only headdim=128 will use Dual PPTX + using Scheduler = std::conditional_t< + Arch >= 90, + std::conditional_t< + (Predicate_for_Headdim && (kHeadDimV != 128 || kHeadDim != 128)) || No_Scheduler_Pipeline, + flash::PreemptivePersistentTileScheduler<_NumConsumerThreads, _NumProducerThreads, Split>, + flash::DualPreemptivePersistentTileExecutionScheduler<_NumConsumerThreads, _NumProducerThreads, Split> + >, + flash::StaticPersistentTileScheduler + >; + + using AttnKernel = std::conditional_t< + Arch >= 90, + flash::enable_sm90_or_later>, + flash::enable_sm80_to_sm89> + >; + + bool const is_varlen_q = params.cu_seqlens_q; + bool const is_varlen_k = params.cu_seqlens_k; + bool const is_varlen_k_new = params.cu_seqlens_knew; + int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; + int batch_q = !is_varlen_q ? params.b : 1; + int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1; + typename CollectiveMainloop::StrideV v_strides = + cute::conditional_return( + make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0), + make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0)); + + if constexpr (Is_flashmask) { + flash::flashmask::prepare_block_maxmin(params, stream, true); + } + + typename CollectiveMainloop::Arguments mainloop_args { + static_cast(params.q_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q + static_cast(params.k_ptr), + {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K + static_cast(params.v_ptr), + params.dv, // headdim_v + v_strides, // stride_V + static_cast(params.knew_ptr), + {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new + {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new + static_cast(params.vnew_ptr), + {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.qv_ptr), + {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv + static_cast(params.rotary_cos_ptr), + {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter + {params.rotary_dim / 2, _1{}}, // stride_rotary_cos + static_cast(params.rotary_sin_ptr), + {params.rotary_dim / 2, _1{}}, // stride_rotary_sin + params.is_rotary_interleaved, + params.page_table, + // if page_size is not set, avoid dividing by zero + {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table + {params.page_table_batch_stride, _1{}}, // stride_page_table + params.scale_softmax, + params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, + {params.q_descale_batch_stride, params.q_descale_head_stride}, + {params.k_descale_batch_stride, params.k_descale_head_stride}, + {params.v_descale_batch_stride, params.v_descale_head_stride}, + params.window_size_left, params.window_size_right, + params.softcap, + params.num_splits, + params.kv_batch_idx, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, + params.leftpad_k, + params.h_flashmask, params.h_h_flashmask_ratio, + params.lt_start_ptr, params.lt_end_ptr, + params.ut_start_ptr, params.ut_end_ptr, + params.flashmask_maxmin_ptr, + params.lt_start_nblockmax, params.lt_start_nblockmin, + params.lt_end_nblockmax, params.lt_end_nblockmin, + params.ut_start_nblockmax, params.ut_start_nblockmin, + params.ut_end_nblockmax, params.ut_end_nblockmin, + params.m_block_dim,params.n_block_dim, + params.block_mask_ptr + }; + + typename CollectiveEpilogue::Arguments epilogue_args { + static_cast(params.o_ptr), + {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O + static_cast(params.oaccum_ptr), + {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE + static_cast(params.softmax_lseaccum_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial + params.h_k, + params.cu_seqlens_q, params.seqused_q + }; + + if constexpr (Arch >= 90) { + prepare_preemptive_scheduler(params, stream, params.num_sm, Scheduler::pipelining); + } + + int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); + int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{})); + num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{})); + typename flash::TileSchedulerArguments scheduler_args { + num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, + params.h / params.h_k, + params.seqlen_q, + params.seqlen_k, params.d, params.dv, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, + // params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + }; + + if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + + int device; + CHECK_CUDA(cudaGetDevice(&device)); + typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ + mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args + }); + + dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); + dim3 block_dims = AttnKernel::get_block_shape(); + int smem_size = AttnKernel::SharedStorageSize; + // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); + // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); + // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); + // Get the ptr to kernel function. + if constexpr (size(ClusterShape{}) > 1) { + void const* kernel = (void const*) flash::cutlass_flashmask_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params); //TODO: support cluster + } else { + auto kernel = flash::cutlass_flashmask_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + flash::flashmask_kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, + Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + } + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; + using T_out = std::conditional_t; + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { + static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; + VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { + // Only needed here to decide if we should use cluster + BOOL_SWITCH(params.lt_start_ptr != nullptr, Is_flashmask, [&] { + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen && !Is_flashmask; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + BOOL_SWITCH(params.seqlen_k < 128 && params.seqlen_q < 128, ShortSeqlen, [&] { + // If the sequence length is (extremely) short, we should cut down the tile size + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, + sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, ShortSeqlen)) : 128; + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); + }); + }); + }); + }); + }); + }); + }); +} diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_mask.hpp b/flashmask/flash_mask/flashmask_attention_v3/flash_mask.hpp new file mode 100644 index 00000000000..f8d355f36da --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_mask.hpp @@ -0,0 +1,265 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once +#include +#include +#include +#include "utils.h" + +namespace flash { + + template + // CUTLASS_DEVICE + __device__ + void apply_flashmask_bwd(Tensor &tSrS, int const thread_idx, const int32_t* const __restrict__ flashmask_index_smem_, const int32_t m_block) { + + // static_assert(!PackGQA); + // static_assert(!SwapAB); + + const auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); + + const Tensor cS = cute::make_identity_tensor(Shape, Int>{}); + const Tensor tScS = thread_mma.partition_C(cS); + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + const Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); + + static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; + const int32_t* const s_lt_start = flashmask_index_smem_; + const int32_t* const s_lt_end = flashmask_index_smem_ + kBlockN; + const int32_t* const s_ut_start = flashmask_index_smem_ + 2 * kBlockN; + const int32_t* const s_ut_end = flashmask_index_smem_ + 3 * kBlockN; + + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = get(tScS_rowcol(m, _0{})) + m_block * kBlockM; + // __syncwarp(); + // printf("\n>>>>>> wsm debug row_idx:%d, thread_idx:%d\n", row_idx, thread_idx); + // __syncwarp(); + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = get(tScS_rowcol(m, n)); // col_idx within a block + + // Note(heqianyue): causal masking will be processed in generic fa-v3 `mask.apply`, so if causal, there is no need to apply mask again + if constexpr (Is_causal) { + // Note(heqianyue): if Has_lt_end == false, row_idx < s_lt_end[col_idx] is entirely unnecessary, but if we just + // throw it away, for sliding window and document mask, we might have about 3% performance loss + // due to if both predicates are present, some of the FSEL instructions are selectively performed + // instead of performed unconditionally. Through removing the latter predicate can save a lot of + // instructions (193 --> 99), we will actually store more / use more regs. This is basically a + // trade-off for speed and no performance recession + if (row_idx >= s_lt_start[col_idx] && row_idx < s_lt_end[col_idx]) + tSrS_rowcol(m, n) = -INFINITY; + } else { + if constexpr (Has_ut_start) { + // Note(heqianyue): currently, if we have ut_start, we will definitely have lt_end + // but if we have a new mask type other than global swin, the constraint might be violated + if (row_idx >= s_lt_start[col_idx] && row_idx < s_lt_end[col_idx]) + tSrS_rowcol(m, n) = -INFINITY; + if (row_idx >= s_ut_start[col_idx] && row_idx < s_ut_end[col_idx]) + tSrS_rowcol(m, n) = -INFINITY; + } else { + // Note(heqianyue): we don't have lt_start, lt_end, nullptr and ut_end composition, maybe in the future + if (row_idx >= s_lt_start[col_idx]) + tSrS_rowcol(m, n) = -INFINITY; + if (row_idx < s_ut_end[col_idx]) + tSrS_rowcol(m, n) = -INFINITY; + } + } + } + } + } + +// }; + +namespace flashmask { + + // make sure the following value is the same with the CooperativeMainLoopImpl + // for example, sm90 is 16 * 1024. + static constexpr int flashmask_buffer_length = 16 * 1024; + + // Note(heqianyue): this kernel is currently only used for fwd and sm90 (flashmask v3) + // for fully aligned minmax with no excessive global sector + template + __global__ + void scanMaxMinChunkedKernel( + const int *input, int b, int n, int *maxo, int *mino) { + int bid = threadIdx.y + blockIdx.y * blockDim.y; + if (bid >= b) { + return; + } + int i_offset = bid * n; + input = input + i_offset; + + const int nblock_seqlen = ((n + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; // umiswing: padding for int4 load + + // constexpr int nums = kBlockN / 32; // ensure N % 32 == 0 + constexpr int nums = (kBlockN + 31) / 32; + int warpId = blockIdx.x; // ensure blockDim.x == 32 + int tid = threadIdx.x; + int lane_id = threadIdx.x % 32; + int maxv, minv; + int idx = warpId * kBlockN + tid; + if (warpId * kBlockN + kBlockN > n) { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if (idx < n && lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + } + idx += 32; + } + } else { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if(lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + idx += 32; + } + } + } + __syncwarp(); + maxv = __reduce_max_sync(0xffffffff, maxv); + minv = __reduce_min_sync(0xffffffff, minv); + if (tid == 0) { + if constexpr (aligned_chunk) { + // the length of the buffer that actually takes part in computation + constexpr int chunk_valid_length = ((flashmask_buffer_length + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + // the padded length for the sake of 128B aligned reading sector (ceil to multiple of 32) + constexpr int chunk_padded_length = ((flashmask_buffer_length + kBlockN - 1) / kBlockN + 31) & 0xffffffe0; + + const int num_chunk = (nblock_seqlen + chunk_valid_length - 1) / chunk_valid_length; + const int total_length = num_chunk * chunk_padded_length; + // TODO(heqianyue): This can be made faster by fast div mod, but I suppose this will not be a bottleneck + const int chunk_id = warpId / chunk_valid_length; + const int within_chunk_id = warpId % chunk_valid_length; + + // stores chunk (there will be 'padding -- invalid data' on the tail) continuously + maxo[bid * total_length + chunk_padded_length * chunk_id + within_chunk_id] = maxv; + mino[bid * total_length + chunk_padded_length * chunk_id + within_chunk_id] = minv; + } else { + // stores data continuously + maxo[bid * nblock_seqlen + warpId] = maxv; + mino[bid * nblock_seqlen + warpId] = minv; + } + } + } + + template + void scanMaxMinGpu( + const int *input, int b, int n, int *maxo, int *mino, cudaStream_t stream, bool use_aligned_chunk = false) { + // static_assert(kBlockN % 32 == 0, "kBlockN must be a multiple of 32"); + dim3 block(32, 4); + dim3 grid((n + kBlockN - 1) / kBlockN, (b + 3) / 4); + if (use_aligned_chunk) + scanMaxMinChunkedKernel<<>>(input, b, n, maxo, mino); + else + scanMaxMinChunkedKernel<<>>(input, b, n, maxo, mino); + } + + template + void prepare_block_maxmin(Flash_fwd_params ¶ms, cudaStream_t stream, bool is_forward = false) { + if (params.lt_start_ptr == nullptr && + params.ut_end_ptr == nullptr) { + return; + } + int *nblock_smask = params.flashmask_maxmin_ptr; + + // only used in forward pass and SM90 (FlashMaskV3) + const int nblock_seqlen = ((params.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; // umiswing: padding for int4 load + int nblock_masklen = 0; + + const bool use_aligned_chunk = params.arch == 90 && is_forward; + + if (use_aligned_chunk) { + constexpr int chunk_valid_length = ((flashmask_buffer_length + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + constexpr int chunk_padded_length = ((flashmask_buffer_length + kBlockN - 1) / kBlockN + 31) & 0xffffffe0; + const int num_chunk = (nblock_seqlen + chunk_valid_length - 1) / chunk_valid_length; + nblock_masklen = params.b * params.h_flashmask * num_chunk * chunk_padded_length; + } else { + nblock_masklen = params.b * params.h_flashmask * nblock_seqlen; + } + + params.lt_start_nblockmax = nblock_smask; + params.lt_start_nblockmin = nblock_smask + nblock_masklen; + params.ut_end_nblockmax = nblock_smask + 2 * nblock_masklen; + params.ut_end_nblockmin = nblock_smask + 3 * nblock_masklen; + params.lt_end_nblockmax = nblock_smask + 4 * nblock_masklen; + params.lt_end_nblockmin = nblock_smask + 5 * nblock_masklen; + params.ut_start_nblockmax = nblock_smask + 6 * nblock_masklen; + params.ut_start_nblockmin = nblock_smask + 7 * nblock_masklen; + if (params.lt_start_ptr != nullptr) { + scanMaxMinGpu( + params.lt_start_ptr, + params.b * params.h_flashmask, + params.seqlen_k, + params.lt_start_nblockmax, + params.lt_start_nblockmin, + stream, + use_aligned_chunk); + } else { + params.lt_start_nblockmax = nullptr; + params.lt_start_nblockmin = nullptr; + } + if (params.ut_end_ptr != nullptr) { + scanMaxMinGpu( + params.ut_end_ptr, + params.b * params.h_flashmask, + params.seqlen_k, + params.ut_end_nblockmax, + params.ut_end_nblockmin, + stream, + use_aligned_chunk); + } else { + params.ut_end_nblockmax = nullptr; + params.ut_end_nblockmin = nullptr; + } + if (params.lt_end_ptr != nullptr) { + scanMaxMinGpu( + params.lt_end_ptr, + params.b * params.h_flashmask, + params.seqlen_k, + params.lt_end_nblockmax, + params.lt_end_nblockmin, + stream, + use_aligned_chunk); + } else { + params.lt_end_nblockmax = nullptr; + params.lt_end_nblockmin = nullptr; + } + if (params.ut_start_ptr != nullptr) { + scanMaxMinGpu( + params.ut_start_ptr, + params.b * params.h_flashmask, + params.seqlen_k, + params.ut_start_nblockmax, + params.ut_start_nblockmin, + stream, + use_aligned_chunk); + } else { + params.ut_start_nblockmax = nullptr; + params.ut_start_nblockmin = nullptr; + } + } +} // namespace flashmask +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/flash_prepare_scheduler.cu b/flashmask/flash_mask/flashmask_attention_v3/flash_prepare_scheduler.cu new file mode 100644 index 00000000000..d2c412a5e8f --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/flash_prepare_scheduler.cu @@ -0,0 +1,170 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#include "cutlass/fast_math.h" +#include "cutlass/barrier.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/grid_dependency_control.h" +#include "cuda_check.h" + +#include "flash.h" + +// fallback, in case that paddle end does not allocate valid GPU mem for semaphore +static __device__ int semaphore_storage_fwd[1]; +static __device__ int semaphore_storage_bwd[1]; + +namespace flash { + +__global__ void prepare_varlen_num_blocks_kernel( + int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, + int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, + cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, + int* const tile_count_semaphore, + // int* const num_m_blocks_ptr, + int* const num_splits_dynamic_ptr, + bool enable_pdl) { + + static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; + static constexpr int kSmemSize = 1; + // Assume that there's only one block in the grid + __shared__ int total_blocks_smem[kSmemSize]; + + // There's only 1 block in the grid, so might as well start launching the main attn kernel + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } + + if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } + __syncthreads(); + + if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } + + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + + auto get_num_m_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen; + if (seqused_q) { + seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; + } else if (cu_seqlens_q) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_q_static; + } + seqlen *= qhead_per_khead; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; + }; + + auto get_num_n_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; + int seqlen; + if (seqused_k) { + seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; + } else if (cu_seqlens_k) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_k_static; + } + int seqlen_new; + if (cu_seqlens_k_new) { + int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; + int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; + } else { + seqlen_new = seqlen_k_new_static; + } + // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } + seqlen = seqlen - leftpad_k + seqlen_new; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; + }; + + int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; + int bidb_start = kNumBatchPerWarp * warp_idx; + int num_m_blocks = get_num_m_blocks(bidb_start); + int num_n_blocks = get_num_n_blocks(bidb_start); + + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } +} + +__global__ void prepare_preemptive_scheduler_kernel( + int* const tile_count_semaphore, + int sm_count) { + // There's only 1 block in the grid, so might as well start launching the main attn kernel + cutlass::arch::launch_dependent_grids(); + if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = sm_count; } +} + +} // flash + +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, + int blockM, int blockN, bool enable_pdl) { + // Only support batch <= 992 (32 warps, each with 31 batches) + int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + // params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, enable_pdl); +} + +void prepare_preemptive_scheduler(Flash_fwd_params ¶ms, cudaStream_t stream, int num_sm, bool is_dual_pptx) { + if (params.tile_count_semaphore == nullptr) { + CHECK_CUDA(cudaGetSymbolAddress((void**)¶ms.tile_count_semaphore, semaphore_storage_fwd)); + } + if (is_dual_pptx) + num_sm *= 2; // double buffer PPTX will have 2 * num_sm static scheduling + flash::prepare_preemptive_scheduler_kernel<<<1 /*grid*/, 32 /*block*/, 0, stream>>>( + params.tile_count_semaphore, + num_sm); +} + +void prepare_preemptive_scheduler(Flash_bwd_params ¶ms, cudaStream_t stream, int num_sm) { + if (params.tile_count_semaphore == nullptr) { + CHECK_CUDA(cudaGetSymbolAddress((void**)¶ms.tile_count_semaphore, semaphore_storage_bwd)); + } + flash::prepare_preemptive_scheduler_kernel<<<1 /*grid*/, 32 /*block*/, 0, stream>>>( + params.tile_count_semaphore, + num_sm); +} diff --git a/flashmask/flash_mask/flashmask_attention_v3/generate_kernels.py b/flashmask/flash_mask/flashmask_attention_v3/generate_kernels.py new file mode 100644 index 00000000000..7c4af24cb65 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/generate_kernels.py @@ -0,0 +1,227 @@ +# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602 + +# This file is run to generate the kernel instantiations for the flash_attn kernels +# They are written to several files in order to speed up compilation + +import argparse +import itertools +from collections import namedtuple +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +KERNEL_BATCH = namedtuple("Kernel", ["template", "filename"]) + +DTYPE_MAP = { + "fp16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", + "e4m3": "cutlass::float_e4m3_t", +} + +DTYPE_MAP_FWD_SM8x = { + "fp16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +DTYPE_MAP_BWD = { + "fp16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +SM = [80, 90] # Sm kernels support up to +HEAD_DIMENSIONS = [64, 96, 128, 192, 256] +PAGEDKV = [False, True] +SPLIT = [False, True] +SOFTCAP = [False, True] +CAUSAL = [False, True] +PACKGQA = [False, True] +DETERM = [False, True] + +KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h" + +#ifndef FLASHMASK_V3_DISABLE_HDIM{HEAD_DIM} +template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +""" + +KERNEL_IMPL_TEMPLATE_FWD_SM8x = """#include "flash_fwd_launch_template.h" + +#ifndef FLASHMASK_V3_DISABLE_SM8x +#ifndef FLASHMASK_V3_DISABLE_HDIM{HEAD_DIM} +template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif +""" + +KERNEL_IMPL_TEMPLATE_BWD_SM90 = """#include "flash_bwd_launch_template.h" + +#ifndef FLASHMASK_V3_DISABLE_HDIM{HEAD_DIM} +template<> +void run_mha_bwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SOFTCAP}, {CAUSAL}, {DETERM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{ARCH}, {DTYPE}, {SOFTCAP}, {CAUSAL}, {DETERM}>(params, stream); +}} +#endif +""" + +KERNEL_IMPL_TEMPLATE_BWD_SM8x = """#include "flash_bwd_launch_template.h" + +#ifndef FLASHMASK_V3_DISABLE_SM8x +#ifndef FLASHMASK_V3_DISABLE_HDIM{HEAD_DIM} +template<> +void run_mha_bwd_<80, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<80, {DTYPE}, {SOFTCAP}>(params, stream); +}} +template<> +void run_mha_bwd_<86, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<86, {DTYPE}, {SOFTCAP}>(params, stream); +}} +#endif +#endif +""" + + + +@dataclass +class Kernel: + sm: int + dtype: str + head_dim: int + head_dim_v: int + split: bool + paged_kv: bool + softcap: bool + packgqa: bool + direction: str + causal: bool = False + determ: bool = False + + @property + def template(self) -> str: + if self.direction == "fwd": + if self.sm == 90: + # Always enable PackGQA for PagedKV or Split to reduce compilation + packgqa = self.packgqa or self.paged_kv or self.split + return KERNEL_IMPL_TEMPLATE_FWD_SM90.format( + ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, + SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), + SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower() + ) + else: + # Always enable PackGQA for Sm8x to reduce compilation + return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, + SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), + SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower() + ) + elif self.direction == "bwd": + if self.sm == 90: + return KERNEL_IMPL_TEMPLATE_BWD_SM90.format( + ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + SOFTCAP=str(self.softcap).lower(), + CAUSAL=str(self.causal).lower(), + DETERM=str(self.determ).lower() + ) + else: + return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + SOFTCAP=str(self.softcap).lower() + ) + + @property + def filename(self) -> str: + return f"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_causal' if self.causal and self.direction == 'bwd' and self.sm == 90 else ''}{'_determ' if self.determ and self.direction == 'bwd' and self.sm == 90 else ''}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" + + +def get_all_kernels() -> List[Kernel]: + for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): + # We always enable PackGQA for Sm8x or PagedKV or Split + # so we should just pass in packgqa=False to avoid the `_packgqa` in the filename. + if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))): + continue + if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 192: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + for dtype, head_dim, softcap, causal, determ, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, CAUSAL, DETERM, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd", causal=causal, determ=determ) + + +def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: + for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): + if sm < 90: + continue + # Same hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v] + if len(kernels) > 0: + filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) + # Different hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v] + if len(kernels) > 0: + filename = f"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) + + +def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: + for dtype, head_dim, split, paged_kv, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, PACKGQA, SM): + if sm >= 90: + continue + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.head_dim == head_dim and k.split == split and k.paged_kv == paged_kv and k.packgqa == packgqa and k.sm == sm] + if len(kernels) > 0: + filename = f"flash_fwd_hdim{head_dim}_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}_softcapall{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) + + # Bwd + for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): + if sm < 90: + continue + kernels = [k for k in kernels_all if k.direction == "bwd" and k.dtype == dtype and k.head_dim == head_dim and k.sm == sm] + if len(kernels) > 0: + filename = f"flash_bwd_hdim{head_dim}_{dtype}_softcapall_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) + + +def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: + prelude = """// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"\n +""" + (autogen_dir / kernel.filename).write_text(prelude + kernel.template) + + +def main(output_dir: Optional[str]) -> None: + output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent + output_dir.mkdir(parents=True, exist_ok=True) + kernels_all = list(get_all_kernels()) + for kernel in kernels_all: + write_kernel(kernel, output_dir) + for kernel in batch_hdim(kernels_all): + write_kernel(kernel, output_dir) + for kernel in batch_softcap(kernels_all): + write_kernel(kernel, output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate_kernels", + description="Generate the flash_attention kernels template instantiations", + ) + # Set an optional output directory + parser.add_argument( + "-o", + "--output_dir", + default="instantiations", + required=False, + help="Where to generate the kernels " + " will default to the current directory ", + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/flashmask/flash_mask/flashmask_attention_v3/heuristics.h b/flashmask/flash_mask/flashmask_attention_v3/heuristics.h new file mode 100644 index 00000000000..66dcc25439a --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/heuristics.h @@ -0,0 +1,73 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { + // If varlen, we don't actually know seqlen_q but only max_seqlen_q. + if (varlen_q) return true; + // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM + auto round_up = [](int a, int b) { return (a + b - 1) / b * b; }; + float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM)); + float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM)); + return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency; +}; + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + // However, in the case of super long seqlen where each head of KV doesn't even fit into + // L2 (we assume that L2 size is 50MB), we want to split. + if (total_mblocks >= 0.8f * num_SMs) { + int const size_l2 = 50 * 1024 * 1024; + // Only split if there are enough queries to go over the KV at least twice + // Don't split if causal + if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { + return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); + } else { + return 1; + } + } + // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. + if (num_n_blocks <= 4) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + float n_waves = float(total_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} diff --git a/flashmask/flash_mask/flashmask_attention_v3/interface.py b/flashmask/flash_mask/flashmask_attention_v3/interface.py new file mode 100644 index 00000000000..04592d045f0 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/interface.py @@ -0,0 +1,884 @@ +import paddle +from paddle import Tensor +from typing import Optional + +from paddle import _C_ops + +def flashmask_attention( + query: Tensor, + key: Tensor, + value: Tensor, + startend_row_indices: Tensor | None = None, + *, + dropout: float = 0.0, + causal: bool = False, + window_size: int | tuple | None = None, + return_softmax_lse: bool = False, + return_seed_offset: bool = False, + fixed_seed_offset: Tensor | None = None, + rng_name: str = "", + training: bool = True, + name: str | None = None, + softmax_scale: float | None = None, + block_mask: Tensor | None = None, +): + r""" + FlashMask: Official Implementation + + This module provides the official implementation of the FlashMask algorithm as described in the paper. For more details, please refer to the paper available at: https://arxiv.org/abs/2410.01359. + + The core equation utilized in FlashMask is as follows: + + .. math:: + + \text{result} = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d}} + M\right) \cdot V + + In this equation: + + - ``Q``, ``K``, and ``V`` are the input tensors to the attention module. + - All these tensors share the same dimensions. + - ``d`` denotes the size of the last dimension of these tensors. + - ``M`` represents the column-wise sparse mask introduced by FlashMask. + + Args: + query (Tensor): The query tensor in the attention module. + A 4-D tensor with shape [batch_size, q_seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + key (Tensor): The key tensor in the attention module. + A 4-D tensor with shape [batch_size, k_seq_len, k_num_heads, head_dim]. + The dtype can be float16 or bfloat16. + value (Tensor): The value tensor in the attention module. + A 4-D tensor with shape [batch_size, k_seq_len, k_num_heads, head_dim]. + The dtype can be float16 or bfloat16. + startend_row_indices(Tensor): + A column-wise sparse attention mask row indices tensor. + A 4-D tensor with shape [batch_size, k_num_heads, k_seq_len, {1, 2, 4}]. + The dtype must be int32. k_num_heads can be 1 or the same as key's num_heads. When num_heads is 1, it will be broadcast to match key's num_heads. + Depending on the value of the causal parameter, startend_row_indices can take different shapes and meanings. + + - When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 1], + indicating unidirectional attention. The value represents the starting row index of the left + lower triangular mask in the dense mask. The value startend_row_indices[..., 0] indicates that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked. + - When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 2], + indicating unidirectional attention. The values represent the starting and ending row indices of + the left lower triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1]-th row (exclusive) will be masked. + - When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 2], + indicating bidirectional attention. The values represent the starting row index of the left + lower triangular mask and the ending row index of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 1]-th row upwards (exclusive) will be masked. + - When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 4] , + indicating bidirectional attention. The values represent the start and end row indices of the + left lower triangular mask and the start and end row indices of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:4] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1] row (exclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 2]-th row downwards (inclusive) but above the startend_row_indices[..., 3] row (exclusive) will be masked. + + dropout (float): The dropout ratio. Default is 0.0. + causal (bool): Whether to enable causal mode. Default is False. + window_size (int|tuple, optional): Indicates the window size of sliding window local attention. + If causal mode is enabled, Query at position i will only attend to keys between [i - window_size, i] or [i - window_size[0], i]. + If causal mode is disabled, Query at position i will only attend to keys between [i - window_size, i + window_size] or [i - window_size[0], i + window_size[1]]. + return_softmax_lse (bool): Whether to return the log-sum-exp of the softmax. Default is False. + return_seed_offset (bool): Whether to return the random seed offset. Default is False. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + rng_name (str): The name to select Generator. + training (bool): Whether the module is in training mode. Default is True. + name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property. + For more information, refer to :ref:`api_guide_Name` . + block_mask (tensor, optional): + A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask. + The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where: + + blocklen_q = ceil(seqlen_q / 128), i.e., block_mask.shape[2] must be (seqlen_q + 127) // 128 + blocklen_k = ceil(seqlen_k / 128), i.e., block_mask.shape[3] must be (seqlen_k + 127) // 128 + block_mask.shape[1] (number of heads) must match the num_heads dimension of the flashmask + Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024 + The dtype should be int32, and each element should be either 0 or 1. + A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked. + + Usage Notes: + + Only supported when blockdim_q = blockdim_k = 128 now. + Only supported when headdim = 128 now. + This argument must be provided together with flashmask. + The mask will be applied at the block level: each [i, j] position in block_mask controls whether the corresponding [128 x 128] block in the attention matrix is masked. + Any mismatch in expected shape or head dimension will raise an error. + + + Returns + Tensor. The computed attention result with the same shape as the input `query`. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + + Hint: + This API supports GQA. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('flash_attn need A100 compile') + >>> import paddle + >>> paddle.seed(2023) + >>> q = paddle.rand((1, 10, 2, 32),dtype="bfloat16") # shape: [batch_size, seq_len, num_heads, head_dim] + >>> k = paddle.rand((1, 10, 2, 32),dtype="bfloat16") # shape: [batch_size, seq_len, num_heads, head_dim] + >>> v = paddle.rand((1, 10, 2, 32),dtype="bfloat16") # shape: [batch_size, seq_len, num_heads, head_dim] + >>> startend_row_indices = paddle.to_tensor([8]*10 + [5]*10, dtype="int32").reshape([1, 2, 10, 1]) + >>> output = paddle.nn.functional.flashmask_attention(q, k, v, startend_row_indices, causal=True) + >>> print(output) + Tensor(shape=[1, 10, 2, 32], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, + [[[[0.82421875, 0.27539062, 0.80859375, 0.98046875, 0.00251770, + 0.41992188, 0.17285156, 0.11767578, 0.42773438, 0.31250000, + 0.34570312, 0.70312500, 0.29296875, 0.44531250, 0.51562500, + 0.96093750, 0.85546875, 0.15625000, 0.34765625, 0.98437500, + 0.96484375, 0.45312500, 0.33593750, 0.56640625, 0.07714844, + 0.43750000, 0.83984375, 0.66796875, 0.93750000, 0.24804688, + 0.51171875, 0.55468750], + [0.54687500, 0.74609375, 0.43164062, 0.32421875, 0.10693359, + 0.37304688, 0.53906250, 0.17187500, 0.57421875, 0.75000000, + 0.13378906, 0.57031250, 0.19531250, 0.01403809, 0.29101562, + 0.14257812, 0.07568359, 0.88671875, 0.75390625, 0.17089844, + 0.87109375, 0.93359375, 0.89843750, 0.58203125, 0.75390625, + 0.27539062, 0.67968750, 0.24804688, 0.57812500, 0.67578125, + 0.92578125, 0.98046875]], + + [[0.59765625, 0.62890625, 0.62109375, 0.75781250, 0.03295898, + 0.64062500, 0.27929688, 0.20800781, 0.72265625, 0.52343750, + 0.53125000, 0.61718750, 0.57421875, 0.56640625, 0.65625000, + 0.48242188, 0.68359375, 0.42968750, 0.26562500, 0.86718750, + 0.83203125, 0.40820312, 0.38281250, 0.59765625, 0.43945312, + 0.22851562, 0.86328125, 0.51562500, 0.89453125, 0.62500000, + 0.50390625, 0.67968750], + [0.34765625, 0.61328125, 0.58593750, 0.60156250, 0.43164062, + 0.41601562, 0.71093750, 0.59765625, 0.53515625, 0.78125000, + 0.13867188, 0.30664062, 0.48828125, 0.04394531, 0.24316406, + 0.18847656, 0.10644531, 0.71093750, 0.69140625, 0.35937500, + 0.44531250, 0.81640625, 0.44140625, 0.64062500, 0.81640625, + 0.61328125, 0.72265625, 0.53125000, 0.49414062, 0.59765625, + 0.54296875, 0.61328125]], + + [[0.65234375, 0.47656250, 0.71875000, 0.64843750, 0.23828125, + 0.61328125, 0.29101562, 0.26562500, 0.54296875, 0.60937500, + 0.67187500, 0.67578125, 0.64062500, 0.41406250, 0.47656250, + 0.40820312, 0.66406250, 0.39453125, 0.39453125, 0.62109375, + 0.58593750, 0.31054688, 0.31835938, 0.45703125, 0.52343750, + 0.43164062, 0.64453125, 0.49804688, 0.82812500, 0.48242188, + 0.38476562, 0.59375000], + [0.44921875, 0.62109375, 0.50390625, 0.51562500, 0.51953125, + 0.57812500, 0.78515625, 0.73437500, 0.60546875, 0.55078125, + 0.30273438, 0.23339844, 0.60546875, 0.33007812, 0.23242188, + 0.30468750, 0.34570312, 0.70703125, 0.72656250, 0.58593750, + 0.40234375, 0.62109375, 0.62109375, 0.69531250, 0.66796875, + 0.51562500, 0.45898438, 0.67968750, 0.48828125, 0.50000000, + 0.54687500, 0.71875000]], + + [[0.67578125, 0.50000000, 0.58203125, 0.62109375, 0.43554688, + 0.69531250, 0.30273438, 0.24023438, 0.57812500, 0.63671875, + 0.51171875, 0.52734375, 0.60546875, 0.45507812, 0.42382812, + 0.46093750, 0.55859375, 0.34960938, 0.39453125, 0.57031250, + 0.55078125, 0.47265625, 0.24609375, 0.51953125, 0.46093750, + 0.49218750, 0.49609375, 0.60156250, 0.76953125, 0.57421875, + 0.40429688, 0.57031250], + [0.45703125, 0.71093750, 0.58984375, 0.43164062, 0.54296875, + 0.57031250, 0.72265625, 0.61328125, 0.64453125, 0.50781250, + 0.28125000, 0.19531250, 0.60546875, 0.40625000, 0.18554688, + 0.33203125, 0.40039062, 0.58593750, 0.79687500, 0.45507812, + 0.32812500, 0.58203125, 0.70703125, 0.64453125, 0.53906250, + 0.57421875, 0.48828125, 0.53515625, 0.49804688, 0.50000000, + 0.48437500, 0.55468750]], + + [[0.64453125, 0.43164062, 0.54687500, 0.53125000, 0.42187500, + 0.71484375, 0.30273438, 0.21484375, 0.50390625, 0.69531250, + 0.58203125, 0.51562500, 0.61328125, 0.41992188, 0.40039062, + 0.46679688, 0.58984375, 0.39062500, 0.41992188, 0.49023438, + 0.47851562, 0.47070312, 0.30078125, 0.50390625, 0.47656250, + 0.44921875, 0.43164062, 0.63671875, 0.78125000, 0.60156250, + 0.48242188, 0.58203125], + [0.52343750, 0.69921875, 0.58984375, 0.35156250, 0.49218750, + 0.58593750, 0.71093750, 0.59375000, 0.66406250, 0.49414062, + 0.24023438, 0.18554688, 0.66796875, 0.50000000, 0.23144531, + 0.29882812, 0.49414062, 0.57031250, 0.70312500, 0.42773438, + 0.35351562, 0.47460938, 0.73437500, 0.53125000, 0.47070312, + 0.49609375, 0.50000000, 0.55078125, 0.50000000, 0.45898438, + 0.45703125, 0.61328125]], + + [[0.63671875, 0.41210938, 0.52734375, 0.56640625, 0.44531250, + 0.64843750, 0.37890625, 0.31250000, 0.56640625, 0.62890625, + 0.53125000, 0.51562500, 0.54296875, 0.50781250, 0.35546875, + 0.41601562, 0.55468750, 0.36914062, 0.35937500, 0.45117188, + 0.46875000, 0.49609375, 0.28710938, 0.50000000, 0.49609375, + 0.50000000, 0.51562500, 0.57031250, 0.77734375, 0.62109375, + 0.43164062, 0.50781250], + [0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ]], + + [[0.62109375, 0.44531250, 0.46875000, 0.61328125, 0.39062500, + 0.60156250, 0.41015625, 0.28710938, 0.58984375, 0.67968750, + 0.55859375, 0.48632812, 0.51562500, 0.42382812, 0.37695312, + 0.46679688, 0.54687500, 0.44921875, 0.33789062, 0.36328125, + 0.49023438, 0.44140625, 0.25000000, 0.45312500, 0.43945312, + 0.45507812, 0.46679688, 0.57812500, 0.65625000, 0.64062500, + 0.42382812, 0.57031250], + [0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ]], + + [[0.62500000, 0.47070312, 0.51562500, 0.61328125, 0.36718750, + 0.66406250, 0.37890625, 0.28320312, 0.65625000, 0.66015625, + 0.48632812, 0.53906250, 0.46679688, 0.47851562, 0.43359375, + 0.45703125, 0.47070312, 0.39843750, 0.32617188, 0.37304688, + 0.49023438, 0.50390625, 0.27148438, 0.46679688, 0.37695312, + 0.49023438, 0.47265625, 0.58593750, 0.64453125, 0.60156250, + 0.38476562, 0.62109375], + [0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ]], + + [[0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ]], + + [[0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. , 0. , + 0. , 0. ]]]]) + >>> # doctest: -SKIP + + + To convert FlashMask's `startend_row_indices` to `dense_mask`, use the code below: + + .. code-block:: python + + >>> import paddle + >>> import numpy as np + >>> def flashmask_to_densemask(startend_row_indices, dtype, causal=True): + ... if startend_row_indices is None: + ... return None + ... bz, num_head, seq_len, bound_num = startend_row_indices.shape + ... m = paddle.zeros((bz, num_head, seq_len, seq_len), dtype=dtype) + ... has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + ... for bi in range(bz): + ... for hi in range(num_head): + ... for j in range(seq_len): + ... downstart = startend_row_indices[bi, hi, j, 0] + ... if has_end: + ... downend = startend_row_indices[bi, hi, j, 1] + ... m[bi, hi, downstart:downend, j] = -np.inf + ... else: + ... m[bi, hi, downstart:, j] = -np.inf + ... if causal: + ... m[bi, hi, :j, j] = -np.inf + ... else: + ... if has_end: + ... upstart = startend_row_indices[bi, hi, j, 2] + ... upend = startend_row_indices[bi, hi, j, 3] + ... m[bi, hi, upstart:upend, j] = -np.inf + ... else: + ... upend = startend_row_indices[bi, hi, j, 1] + ... m[bi, hi, :upend, j] = -np.inf + ... return m + + For `Causal Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> startend_row_indices = paddle.to_tensor([8]*10, dtype="int32").reshape([1, 1, 10, 1]) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 1], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[8], + [8], + [8], + [8], + [8], + [8], + [8], + [8], + [8], + [8]]]]) + >>> # doctest: -SKIP + + + For `Sliding Window Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> startend_row_indices = paddle.to_tensor([3, 4, 5, 6, 7, 8, 9, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 1], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[3 ], + [4 ], + [5 ], + [6 ], + [7 ], + [8 ], + [9 ], + [10], + [10], + [10]]]]) + >>> # doctest: -SKIP + + For `Causal Document Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> startend_row_indices = paddle.to_tensor([4, 4, 4, 4, 7, 7, 7, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 1], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[4 ], + [4 ], + [4 ], + [4 ], + [7 ], + [7 ], + [7 ], + [10], + [10], + [10]]]]) + >>> # doctest: -SKIP + + For `Document Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> LTS = paddle.to_tensor([4, 4, 4, 4, 7, 7, 7, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> UTE = paddle.to_tensor([0, 0, 0, 0, 4, 4, 4, 7, 7, 7], dtype="int32").reshape([1, 1, 10, 1]) + >>> startend_row_indices = paddle.concat([LTS, UTE], axis=-1) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[4 , 0 ], + [4 , 0 ], + [4 , 0 ], + [4 , 0 ], + [7 , 4 ], + [7 , 4 ], + [7 , 4 ], + [10, 7 ], + [10, 7 ], + [10, 7 ]]]]) + >>> # doctest: -SKIP + + For `Share Question Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> startend_row_indices = paddle.to_tensor([10, 10, 10, 10, 7, 7, 7, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 1], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[10], + [10], + [10], + [10], + [7 ], + [7 ], + [7 ], + [10], + [10], + [10]]]]) + >>> # doctest: -SKIP + + For `Global + Sliding Window Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + >>> # doctest: +SKIP('Only example') + + [[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 0, 1, 1, 1, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 0, 0, 0, 0, 0, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 0, 0, 1, 1]]]]) + + >>> import paddle + >>> LTS = paddle.to_tensor([10, 10, 4, 5, 6, 7, 8, 9, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> LTE = paddle.to_tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> UTS = paddle.to_tensor([0, 0, 0, 0, 2, 2, 2, 2, 2, 2], dtype="int32").reshape([1, 1, 10, 1]) + >>> UTE = paddle.to_tensor([0, 0, 0, 0, 3, 4, 5, 6, 7, 8], dtype="int32").reshape([1, 1, 10, 1]) + >>> startend_row_indices = paddle.concat([LTS, LTE, UTS, UTE], axis=-1) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 4], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[10, 10, 0 , 0 ], + [10, 10, 0 , 0 ], + [4 , 10, 0 , 0 ], + [5 , 10, 0 , 0 ], + [6 , 10, 2 , 3 ], + [7 , 10, 2 , 4 ], + [8 , 10, 2 , 5 ], + [9 , 10, 2 , 6 ], + [10, 10, 2 , 7 ], + [10, 10, 2 , 8 ]]]]) + >>> # doctest: -SKIP + + For `Causal Blockwise Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> LTS = paddle.to_tensor([4, 4, 4, 4, 10, 10, 10, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> LTE = paddle.to_tensor([7, 7, 7, 7, 10, 10, 10, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> startend_row_indices = paddle.concat([LTS, LTE], axis=-1) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[4 , 7 ], + [4 , 7 ], + [4 , 7 ], + [4 , 7 ], + [10, 10], + [10, 10], + [10, 10], + [10, 10], + [10, 10], + [10, 10]]]]) + >>> # doctest: -SKIP + + For `Prefix LM Document Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> LTS = paddle.to_tensor([3, 3, 3, 5, 5, 10, 10, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> UTE = paddle.to_tensor([0, 0, 2, 3, 3, 5, 5, 7, 8, 9], dtype="int32").reshape([1, 1, 10, 1]) + >>> startend_row_indices = paddle.concat([LTS, UTE], axis=-1) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[3 , 0 ], + [3 , 0 ], + [3 , 2 ], + [5 , 3 ], + [5 , 3 ], + [10, 5 ], + [10, 5 ], + [10, 7 ], + [10, 8 ], + [10, 9 ]]]]) + >>> # doctest: -SKIP + + For `Prefix LM Causal Mask`, where `causal=False`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> LTS = paddle.to_tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> UTE = paddle.to_tensor([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype="int32").reshape([1, 1, 10, 1]) + >>> startend_row_indices = paddle.concat([LTS, UTE], axis=-1) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[10, 0 ], + [10, 0 ], + [10, 0 ], + [10, 0 ], + [10, 0 ], + [10, 5 ], + [10, 6 ], + [10, 7 ], + [10, 8 ], + [10, 9 ]]]]) + + For `QK-sparse Mask`, where `causal=True`, the values of `startend_row_indices` are as follows: + + .. code-block:: python + + [[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]]) + + >>> # doctest: +SKIP('Only example') + >>> import paddle + >>> LTS = paddle.to_tensor([10, 10, 2, 3, 4, 5, 6, 7, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> LTE = paddle.to_tensor([10, 10, 5, 5, 5, 5, 8, 8, 10, 10], dtype="int32").reshape([1, 1, 10, 1]) + >>> startend_row_indices = paddle.concat([LTS, LTE], axis=-1) + >>> print(startend_row_indices) + Tensor(shape=[1, 1, 10, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[[[10, 10], + [10, 10], + [2 , 5 ], + [3 , 5 ], + [4 , 5 ], + [5 , 5 ], + [6 , 8 ], + [7 , 8 ], + [10, 10], + [10, 10]]]]) + + >>> # doctest: -SKIP + """ + if window_size is not None: + if isinstance(window_size, int): + window_size = (window_size, window_size) + sq = query.shape[1] + bsz = query.shape[0] + assert startend_row_indices is None, ( + "can't use window_size with startend_row_indices" + ) + if causal: + startend_row_indices = paddle.arange( + window_size[0] + 1, sq + window_size[0] + 1, dtype="int32" + ).reshape((1, 1, sq, 1)) + startend_row_indices = paddle.clip( + startend_row_indices, max=sq + ).repeat_interleave(bsz, 0) + + else: + startend_row_indices = paddle.empty((1, 1, sq, 2), dtype="int32") + startend_row_indices[0, 0, :, 0] = paddle.arange( + window_size[0] + 1, sq + window_size[0] + 1, dtype="int32" + ) + startend_row_indices[0, 0, :, 1] = paddle.arange( + -window_size[1], sq - window_size[1], dtype="int32" + ) + startend_row_indices = paddle.clip( + startend_row_indices, min=0, max=sq + ).repeat_interleave(bsz, 0) + + if block_mask is not None: + # xhy: can set a full startend_row_indices for block_mask_attn when using block_mask_attn? + assert startend_row_indices is not None, ( + "must provide startend_row_indices when using block_mask_attn" + ) + + if startend_row_indices is None: + ( + out, + result_softmax, + result_softmax_lse, + result_seed_offset, + ) = _C_ops.flash_attn( + query, + key, + value, + fixed_seed_offset, + None, + dropout, + causal, + False, + not training, + rng_name, + ) + + else: + assert startend_row_indices.dtype == paddle.int32, ( + f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}" + ) + assert len(startend_row_indices.shape) == 4, ( + f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}" + ) + + assert startend_row_indices.shape[0] == key.shape[0], ( + f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}" + ) + + assert startend_row_indices.shape[2] == key.shape[1], ( + f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}" + ) + assert startend_row_indices.shape[1] in [ + 1, + query.shape[2], + key.shape[2], + ], ( + "startend_row_indices head_num must be equal to 1(broadcast) or head_num_q or head_num_k." + ) + + if block_mask is not None: + assert block_mask.dtype == paddle.int32, ( + f"block_mask.dtype must be paddle.int32, but got {block_mask.dtype}" + ) + + assert block_mask.shape[0] == key.shape[0], ( + f"block_mask.shape[0] must be equal to batch_size, but got {block_mask.shape[0]} and {key.shape[0]}" + ) + + assert block_mask.shape[1] == startend_row_indices.shape[1], ( + f"block_mask.shape[1] must be equal to startend_row_indices.shape[1], but got {block_mask.shape[1]} and {key.shape[2]}" + ) + + assert block_mask.shape[2] == (query.shape[1] + 127) // 128, ( + "block_size must be 128 when using block_mask_attn" + ) + + assert block_mask.shape[3] == (key.shape[1] + 127) // 128, ( + "block_size must be 128 when using block_mask_attn" + ) + + assert key.shape[3] == 128, ( + "headdim must be 128 when using block_mask_attn" + ) + + if causal: + if startend_row_indices.shape[-1] == 1: + has_end = False + elif startend_row_indices.shape[-1] == 2: + has_end = True + else: + raise ValueError( + f"Invalid shape of startend_row_indices, when causal is True, the last dimension should be either 1 or 2 but got {startend_row_indices.shape[-1]}" + ) + else: + if startend_row_indices.shape[-1] == 2: + has_end = False + elif startend_row_indices.shape[-1] == 4: + has_end = True + else: + raise ValueError( + f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}" + ) + + if ( + "xpu" not in paddle.get_device() + and paddle.get_flags(["FLAGS_cudnn_deterministic"])[ + "FLAGS_cudnn_deterministic" + ] + ): + assert block_mask is None, ( + " blockmask attention no supports deterministic now ." + ) + + if "xpu" in paddle.get_device(): + fa_version = 2 + elif ( + paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])[ + "FLAGS_flash_attn_version" + ] + == 3 + and paddle.base.framework.get_flags(["FLAGS_cudnn_deterministic"])[ + "FLAGS_cudnn_deterministic" + ] + and query.shape[3] > 128 + ): + fa_version = 2 + else: + fa_version = paddle.base.framework.get_flags( + ["FLAGS_flash_attn_version"] + )["FLAGS_flash_attn_version"] + + if fa_version == 2: + assert softmax_scale is None, ( + "flashmask_attention does not support setting softmax_scale, use flashmask_attention_v3 instead" + ) + + assert block_mask is None, ( + " blockmask attention only supports sm >= 90 now." + ) + + ( + out, + result_softmax, + result_softmax_lse, + result_seed_offset, + ) = _C_ops.flashmask_attention( + query, + key, + value, + startend_row_indices, + fixed_seed_offset, + dropout, + causal, + False, + not training, + rng_name, + ) + + elif fa_version == 3: + assert dropout == 0.0, ( + "flashmask_attention_v3 does not support dropout" + ) + assert not return_seed_offset, ( + "flashmask_attention_v3 does not support return seed_offset" + ) + assert fixed_seed_offset is None, ( + "flashmask_attention_v3 does not support setting seed_offset" + ) + assert rng_name == "", ( + "flashmask_attention_v3 does not support setting rng_name" + ) + assert training, ( + "flashmask_attention_v3 does not support setting training to False" + ) + + assert name is None, ( + "flashmask_attention_v3 does not support setting name" + ) + + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + + # ( + # out, + # result_softmax_lse, + # ) = _C_ops.flashmask_attention_v3( + # query, + # key, + # value, + # startend_row_indices, + # block_mask, + # softmax_scale, + # causal, + # ) + + ( + out, + result_softmax_lse, + ) = _C_ops._run_custom_op( + "flashmask_attention_v3", + query, + key, + value, + startend_row_indices, + block_mask, + softmax_scale, + causal, + ) + + else: + raise ValueError(f"Invalid flash attention version: {fa_version}") + + outputs = [out] + if return_softmax_lse: + outputs += [result_softmax_lse] + if return_seed_offset: + outputs += [result_seed_offset] + if len(outputs) == 1: + return outputs[0] + else: + return outputs diff --git a/flashmask/flash_mask/flashmask_attention_v3/mainloop_bwd_sm80.hpp b/flashmask/flash_mask/flashmask_attention_v3/mainloop_bwd_sm80.hpp new file mode 100644 index 00000000000..d78e69c7d30 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/mainloop_bwd_sm80.hpp @@ -0,0 +1,915 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "seqlen.h" +#include "mask.h" +#include "mask.h" +#include "softmax.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopBwdSm80 { + + static constexpr int kStages = Stages; + static constexpr int kStages_dO = Stages_dO; + static_assert(kStages >= kStages_dO); + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup; + + static constexpr bool SdP_swapAB = SdP_swapAB_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool dQ_swapAB = dQ_swapAB_; + + static constexpr bool Q_dO_same_stages = kStages == kStages_dO; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr int NumMmaThreads = NumMmaWarps * cutlass::NumThreadsPerWarp; + static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler + + using MMA_Atom_Arch = std::conditional_t< + ArchTag::kMinComputeCapability >= 80, + std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >, + MMA_Atom + >; + + static_assert(NumMmaWarps % AtomLayoutMSdP == 0); + static_assert(NumMmaWarps % AtomLayoutNdKV == 0); + static_assert(NumMmaWarps % AtomLayoutMdQ == 0); + static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarps && SdP_swapAB && !dKV_swapAB; + static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarps && AtomLayoutMdQ == NumMmaWarps && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + + using AtomLayoutSdP = std::conditional_t< + !SdP_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + static constexpr bool MmaSdPEvenN = ((!SdP_swapAB ? kBlockN : kBlockM) / size<1>(AtomLayoutSdP{})) % 16 == 0; + using TiledMmaSdP = TiledMMA< + MMA_Atom_Arch, + AtomLayoutSdP, + Tile(AtomLayoutSdP{}))>, Int<(MmaSdPEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutSdP{}))>, _16>>; + + using AtomLayoutdKV = std::conditional_t< + !dKV_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + static constexpr bool MmadKVEvenN = ((!dKV_swapAB ? kHeadDim : kBlockN) / size<1>(AtomLayoutdKV{})) % 16 == 0; + using TiledMmadKV = TiledMMA< + MMA_Atom_Arch, + AtomLayoutdKV, + Tile(AtomLayoutdKV{}))>, Int<(MmadKVEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdKV{}))>, _16>>; + + using AtomLayoutdQ = std::conditional_t< + !dQ_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + static constexpr bool MmadQEvenN = ((!dQ_swapAB ? kHeadDim : kBlockM) / size<1>(AtomLayoutdQ{})) % 16 == 0; + using TiledMmadQ = TiledMMA< + MMA_Atom_Arch, + AtomLayoutdQ, + Tile(AtomLayoutdQ{}))>, Int<(MmadQEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdQ{}))>, _16>>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + + // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. + // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. + // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension + // changes the layout. + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutdO = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + // TODO: FA2 has a slightly different layout, does it matter? + Layout>, + Stride, _1>>{})); + using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); + + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN % 64 == 0 ? 64 : (kBlockN % 32 == 0 ? 32 : 16); + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + + // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, + // it's still a valid smem address. + using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; + using SmemLayoutLSEMma = std::conditional_t< + SdP_swapAB, + cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, + cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> + >; + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutQt = + decltype(cute::composition(SmemLayoutQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutdOt = + decltype(cute::composition(SmemLayoutdO{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutKt = + decltype(cute::composition(SmemLayoutK{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutPdSt = + decltype(cute::composition(SmemLayoutPdS{}, + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, _1{})))); + + // Thread layout, 256 or 384 threads per row + using R2SLayoutAtomdQaccum = Layout>>; + using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 1 vals per store + + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; + // For the case where the N dimension of MmaSdP is divisible by 8 but not by 16 + using SmemCopyAtomHalf = Copy_Atom; + // For the case where the N dimension of MmadQ is divisible by 8 but not by 16 + using SmemCopyAtomTransposedHalf = Copy_Atom; + // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. + // If PdS_major is MN, then we need to "transpose" the write. + // TODO: check this write + using R2SCopyAtomPdS = Copy_Atom, Element>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using GmemCopyStruct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL, + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using GmemCopyAtom = Copy_Atom; + + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per read + using GmemCopyAtomLSE = Copy_Atom; + using GmemLayoutAtomLSE = Layout>>; + using GmemTiledCopyLSE = decltype(make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{}, + Layout>{})); // Val layout, 4 vals per store + // So that we don't have to check if we overshot kBlockM when we load Q + // static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQKV = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + // These are tuned for speed. They don't affect correctness. + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + // For hdim 192, separating masking iterations results in register spills. + // static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; + static constexpr bool SeparateMaskingIterations = false; + // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then + // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each + // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep + // statistic for 2 rows. + // static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; + // static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; + static constexpr bool ShuffleLSE = SdP_swapAB && false; + static constexpr bool ShuffledPsum = SdP_swapAB && false; + + static constexpr bool Share_QV_Smem = V_in_regs; + using SmemP_t = std::conditional_t, cute::array_aligned>>; + + struct TensorStorageSharedQV : cute::aligned_struct<128> { + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_q; + }; + cute::array_aligned> smem_do; + cute::array_aligned, 128> smem_lse; + cute::array_aligned, 128> smem_dpsum; + SmemP_t smem_p; + cute::array_aligned> smem_ds; + }; + + struct TensorStorageSeparateQV : cute::aligned_struct<128> { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + cute::array_aligned, 128> smem_lse; + cute::array_aligned, 128> smem_dpsum; + SmemP_t smem_p; + cute::array_aligned> smem_ds; + }; + + using TensorStorage = std::conditional_t; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* const ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* const ptr_V; + StrideQKV const stride_V; + Element const* const ptr_dO; + StrideQKV const stride_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale; + int const window_size_left, window_size_right; + float const softcap_val; + int const num_batch; + int* const dq_semaphore; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* const ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* const ptr_V; + StrideQKV const stride_V; + Element const* const ptr_dO; + StrideQKV const stride_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum stride_dQaccum; + cutlass::FastDivmod qhead_per_khead_divmod; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale, softmax_scale_log2; + int const window_size_left, window_size_right; + float const softcap_val; + int const num_batch; + int *const dq_semaphore; + int const *const cu_seqlens_q = nullptr; + int const *const cu_seqlens_k = nullptr; + int const *const seqused_q = nullptr; + int const *const seqused_k = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + // In the backward, we need to multiply by + // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. + // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale + // (the original softmax_scale) at the end. + return {args.ptr_Q, args.shape_Q, args.stride_Q, + args.ptr_K, args.shape_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, + args.softmax_scale, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.window_size_left, args.window_size_right, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.num_batch, args.dq_semaphore, + args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + FrgTensordKV& tdKrdK, + FrgTensordKV& tdVrdV, + int thread_idx, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); + + int n_block = get<0>(block_coord); + int bidh = get<1>(block_coord); + int bidb = get<2>(block_coord); + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto m_block_min_max = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); + int const m_block_min = get<0>(m_block_min_max); + int const m_block_max = get<1>(m_block_min_max); + // It's possible to have m_block_max <= m_block_min. Exit early + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); + Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); + Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); + Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_Q, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen_q ? bidb : 0); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + + GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); + auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation + GmemTiledCopyLSE gmem_tiled_copy_lse; + auto gmem_thr_copy_lse = gmem_tiled_copy_lse.get_thread_slice(thread_idx); + R2STiledCopydQaccum r2s_tiled_copy_dQaccum; + auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_QKV.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_QKV.partition_D(sdO); + Tensor tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE); + Tensor tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE); + Tensor tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum); + Tensor tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum); + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } + + TiledMmaSdP tiled_mma_SdP; + TiledMmadKV tiled_mma_dKV; + TiledMmadQ tiled_mma_dQ; + + auto thr_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); + auto thr_mma_dKV = tiled_mma_dKV.get_thread_slice(thread_idx); + auto thr_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" + // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, + // because some partition_fragment_A/B don't compile. + // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function + Tensor tdPrV = mma_partition_fragment_AB(thr_mma_SdP, sV); + + // Copy Atom retiling + auto smem_copy_atom_SdP_B = cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}); + auto smem_tiled_copy_QdO = cute::conditional_return(make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP), make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP)); + auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(thread_idx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + auto smem_tiled_copy_KV = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP), make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP)); + auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(thread_idx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + auto r2s_tiled_copy_PdS = make_tiled_copy_C(R2SCopyAtomPdS{}, tiled_mma_SdP); + auto r2s_thr_copy_PdS = r2s_tiled_copy_PdS.get_thread_slice(thread_idx); + Tensor tPsP = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sP, sPt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sdS, sdSt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(r2s_thr_copy_PdS); print(sP); printf("\n"); print(sPt); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } + + auto smem_copy_atom_dKV_B = cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}); + auto smem_tiled_copy_PdSt = cute::conditional_return(make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV), make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV)); + auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(thread_idx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_tiled_copy_QdOt = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV)); + auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(thread_idx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_tiled_copy_dS = cute::conditional_return( + make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_dQ), + make_tiled_copy_B(cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}), tiled_mma_dQ)); + auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(thread_idx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_tiled_copy_Kt = cute::conditional_return( + make_tiled_copy_B(cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}), tiled_mma_dQ), + make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dQ)); + auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(thread_idx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + // thr_mma_SdP.partition_C(sLSEMma) has shape (MMA=4, MMA_M, MMA_N, PIPE), we only take the col indices + // or row indices, depending on whether SdP_swapAB. + Tensor tSsLSEMma = logical_divide(thr_mma_SdP.partition_C(sLSEMma), Shape<_2>{}); // (2, 2, MMA_M, MMA_N, PIPE) + Tensor tSsLSE = group_modes<0, 2>(cute::conditional_return( + tSsLSEMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) + tSsLSEMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) + Tensor tSsdPsumMma = logical_divide(thr_mma_SdP.partition_C(sdPsumMma), Shape<_2>{}); + Tensor tSsdPsum = group_modes<0, 2>(cute::conditional_return( + tSsdPsumMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) + tSsdPsumMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } + // If we want to split the stats among the 8 threads that share the same rows. + static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tSsLSE))::value, 8); + + // Predicates + Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); + Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } + Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); + Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.qhead_per_khead_divmod + ); + + { + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + // Predicates + Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); + Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + // static_assert(EvenN); // It simplifies the loading of K and V + // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit + // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. + // int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN + // ? seqlen_info.seqlen_k - n_block * kBlockN + // : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)); + // // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockN dimension + // flash::copy( + // gmem_tiled_copy_QKV, tVgV, tVsV, t0KVcKV, tKVpKV, seqlenk_row_limit); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { + bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); + } + } + } + if constexpr (V_in_regs) { flash::cp_async_fence(); } + // flash::copy( + // gmem_tiled_copy_QKV, tKgK, tKsK, t0KVcKV, tKVpKV, seqlenk_row_limit); + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + if (EvenN || m < size<1>(tKsK) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { + bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); + } + } + } + flash::cp_async_fence(); + } + + if constexpr (V_in_regs) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); + Tensor tdPsV_copy_view = smem_thr_copy_KV.partition_S(sV); + cute::copy(smem_tiled_copy_KV, tdPsV_copy_view, tdPrV_copy_view); + __syncthreads(); // Sync to avoid loading Q to smem_q, which overlaps with smem_v + } + + // Do we need bound check to make sure the row doesn't go above kBlockM + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + + auto load_Q_LSE = [&] (int const m_block, int const smem_pipe_write) { + // if (cute::thread0()) { printf("Inside load_Q_LSE, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } + Tensor tQsQ_cur = tQsQ(_, _, _, smem_pipe_write); + Tensor tQgQ_cur = tQgQ(_, _, _, m_block); + // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit + // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. + // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM + // ? seqlen_info.seqlen_q - m_block * kBlockM + // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); + // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockM dimension + // flash::copy( + // gmem_tiled_copy_QKV, tQgQ(_, _, _, m_block), tQsQ_cur, t0QcQ, tQpQ, seqlenq_row_limit); + int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenM || m < size<1>(tQsQ) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { + bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tQgQ_cur(_, m, k), tQsQ_cur(_, m, k)); + } + } + } + Tensor tLSEgLSE_cur = tLSEgLSE(_, _, m_block); + Tensor tLSEsLSE_cur = tLSEsLSE(_, _, smem_pipe_write); + // We made sure LSE length is padded so we read `kBlockM` elements so that all + // elements in sLSE are filled. Without this we might have uninitialized sLSE values. + #pragma unroll + for (int m = 0; m < size<1>(tLSEsLSE); ++m) { + if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { + cute::copy(gmem_tiled_copy_lse, tLSEgLSE_cur(_, m), tLSEsLSE_cur(_, m)); + } + } + }; + + auto load_dO_dPsum = [&] (int const m_block, int const smem_pipe_write) { + // if (cute::thread0()) { printf("Inside load_dO_dPsum, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } + Tensor tdOsdO_cur = tdOsdO(_, _, _, smem_pipe_write); + Tensor tdOgdO_cur = tdOgdO(_, _, _, m_block); + // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM + // ? seqlen_info.seqlen_q - m_block * kBlockM + // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); + // flash::copy( + // gmem_tiled_copy_QKV, tdOgdO(_, _, _, m_block), tdOsdO_cur, t0QcQ, tQpQ, seqlenq_row_limit); + int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tdOsdO); ++m) { + // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenM || m < size<1>(tdOsdO) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { + bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tdOsdO); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); + } + } + } + Tensor tLSEgdPsum_cur = tLSEgdPsum(_, _, m_block); + Tensor tLSEsdPsum_cur = tLSEsdPsum(_, _, smem_pipe_write); + #pragma unroll + for (int m = 0; m < size<1>(tLSEsdPsum); ++m) { + if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { + cute::copy(gmem_tiled_copy_lse, tLSEgdPsum_cur(_, m), tLSEsdPsum_cur(_, m)); + } + } + }; + + int m_block = m_block_min; + + // Note, using the for_each() function here to ensure `stage` is of type Int. + for_each(make_int_sequence{}, [&] (auto stage) { + static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; + static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; + if constexpr (!Is_last_stage || kStages == 1) { + if (Is_first_stage || m_block + stage < m_block_max) { + load_Q_LSE(m_block + stage, stage); + } + } + // We want the fence outside the if statement to have a fixed number of cp.async commits. + // so that we can wait with the correct number of outstanding commits. + cute::cp_async_fence(); + if constexpr (stage < kStages_dO) { + if (Is_first_stage || m_block + stage < m_block_max) { + load_dO_dPsum(m_block + stage, stage); + } + cute::cp_async_fence(); + } + }); + + int smem_pipe_read = 0, smem_pipe_read_do = 0, smem_pipe_write = kStages - 1, smem_pipe_write_do = 0; + + auto load_Q_next = [&] { + // if (cute::thread0()) { printf("m_block = %d, m_block_max = %d, smem_pipe_write = %d\n", m_block, m_block_max, smem_pipe_write); } + if (m_block + (kStages > 1 ? kStages - 1 : 1) < m_block_max) { + load_Q_LSE(m_block + (kStages > 1 ? kStages - 1 : 1), kStages > 1 ? smem_pipe_write : 0); + } + cute::cp_async_fence(); + }; + + auto load_dO_next = [&] { + // int smem_pipe_write_do_cur = Q_dO_same_stages ? smem_pipe_write : smem_pipe_write_do; + if (m_block + kStages_dO < m_block_max) { + // load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do_cur : 0); + load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do : 0); + } + cute::cp_async_fence(); + }; + + clear(tdKrdK); + clear(tdVrdV); + + auto bwd_step = [&](int m_block, auto mask_fn) { + Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + clear(tSrS); + flash::cp_async_wait<(kStages > 1) ? 1 : 0>(); + __syncthreads(); + Tensor tSrQ = mma_partition_fragment_AB(thr_mma_SdP, sQ(_, _, _0{})); + Tensor tSrK = mma_partition_fragment_AB(thr_mma_SdP, sK); + // if (cute::thread0()) { print(tiled_mma_SdP); print(tSrS); printf("\n"); print(tSrQ); printf("\n"); print(tSrK); printf("\n"); print(tSsQ); printf("\n"); print(tSsK); printf("\n"); } + flash::gemm_sm80( + tSrS, tSrQ, tSrK, tSsQ(_, _, _, kStages > 1 ? smem_pipe_read : 0), tSsK, + tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, nullptr /*hook*/); + Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tSsLSE(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffleLSE) { + cute::copy(tSsLSE(_, kStages > 1 ? smem_pipe_read : 0), tLSErLSE); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values + tLSErLSE(i) = tSsLSE((thread_idx % 32) / 4 + i * 8, kStages > 1 ? smem_pipe_read : 0); + } + } + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } + + // Reshape tSrS from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh + // if (cute::thread0()) { print_tensor(scores); } + auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); + mask_fn(tSrS, m_block); + #pragma unroll + for (int mi = 0; mi < size<0>(scores); ++mi) { + float const lse_scaled = [&] { + if constexpr (!ShuffleLSE) return tLSErLSE(mi); + else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(scores); ++ni) { + scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); + } + } + + Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + clear(tdPrdP); + int smem_pipe_read_do_cur = Q_dO_same_stages ? smem_pipe_read : smem_pipe_read_do; + flash::cp_async_wait<(kStages_dO > 1) ? 1 : 0>(); + __syncthreads(); + auto hook = cute::conditional_return<(kStages > 1)>(load_Q_next, nullptr); + Tensor tdPrdO = mma_partition_fragment_AB(thr_mma_SdP, sdO(_, _, _0{})); + Tensor tdPrV_cur = cute::conditional_return(tdPrV, mma_partition_fragment_AB(thr_mma_SdP, sV)); + flash::gemm_sm80( + tdPrdP, tdPrdO, tdPrV_cur, tdPsdO(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tdPsV, + tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, hook); + Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tSsdPsum(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffledPsum) { + cute::copy(tSsdPsum(_, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tLSErdPsum); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + tLSErdPsum(i) = tSsdPsum((thread_idx % 32) / 4 + i * 8, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); + } + } + + // Reshape tdPrdP from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + float const dP_sum_cur = [&] { + if constexpr (!ShuffledPsum) return tLSErdPsum(mi); + else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); + if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } + } + } + // if (cute::thread0()) { print_tensor(dS); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = make_tensor_like(tSrS); + flash::convert_type_out(tSrS, rP); + if constexpr (!Mma_dKV_is_RS) { + Tensor tPaP = r2s_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(r2s_tiled_copy_PdS, tPaP, tPsP); + } + Tensor rdS = make_tensor_like(tdPrdP); + flash::convert_type_out(tdPrdP, rdS); + if constexpr (!Mma_dKV_is_RS) { __syncthreads(); } // Make sure P is written + // For hdim 64, It's faster to write to smem_dS first before the dV gemm + Tensor tdSadS = r2s_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(r2s_tiled_copy_PdS, tdSadS, tdSsdS); + + Tensor tdVrdO = mma_partition_fragment_AB(thr_mma_dKV, sdOt(_, _, _0{})); + Tensor tdVsdO_cur = tdVsdOt(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); + if constexpr (Mma_dKV_is_RS) { + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + flash::gemm_rs_sm80(tdVrdV, tdVrP, tdVrdO, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); + } else { + Tensor tdVrP = mma_partition_fragment_AB(thr_mma_dKV, sPt); + flash::gemm_sm80( + tdVrdV, tdVrP, tdVrdO, tdVsPt, tdVsdO_cur, + tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, nullptr); + } + // if (cute::thread0()) { print_tensor(tdVrdV); } + __syncthreads(); // make sure sdS is written + auto do_mma_dQ = [&] (auto hook) { + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + clear(tdQrdQ); + Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); + Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); + flash::gemm_sm80( + tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, + // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); + // if (cute::thread0()) { print_tensor(tdQrdQ); } + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); + Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); + static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + }; + // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration + if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); } + Tensor tdKrQ = mma_partition_fragment_AB(thr_mma_dKV, sQt(_, _, _0{})); + Tensor tdKsQ_cur = tdKsQt(_, _, _, kStages > 1 ? smem_pipe_read : 0); + if constexpr (Mma_dKV_is_RS) { + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + flash::gemm_rs_sm80(tdKrdK, tdKrdS, tdKrQ, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); + } else { + Tensor tdKrdS = mma_partition_fragment_AB(thr_mma_dKV, sdSt); + flash::gemm_sm80( + tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur, + tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next)); + } + if constexpr (kStages == 1) { + __syncthreads(); + do_mma_dQ(load_Q_next); + } + // if (cute::thread0()) { print_tensor(tdKrdK); } + + smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; + smem_pipe_read_do = smem_pipe_read_do < kStages_dO - 1 ? smem_pipe_read_do + 1 : 0; + smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; + smem_pipe_write_do = smem_pipe_write_do < kStages_dO - 1 ? smem_pipe_write_do + 1 : 0; + + }; + + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { + bwd_step(m_block, mask_fn); + } + } + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations + ? m_block_max + : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); + + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max_before_local_mask; ++m_block) { + bwd_step(m_block, mask_fn); + } + + if constexpr (Is_local && SeparateMaskingIterations) { + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max; ++m_block) { + bwd_step(m_block, mask_fn); + } + } + + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } + #pragma unroll + for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } + + return true; + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp b/flashmask/flash_mask/flashmask_attention_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 00000000000..b9b44b7b307 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,1429 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "named_barrier.hpp" +#include "seqlen.h" +#include "block.h" +#include "mask.h" +#include "softmax.h" +#include "utils.h" +#include "copy_sm90_bulk_reduce.hpp" +#include "flash_mask.hpp" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopBwdSm90 { + + static constexpr int kStages = Stages; + static constexpr int kStages_dO = Stages_dO; + static constexpr int kStages_dS = Stages_dS; + static_assert(kStages >= kStages_dO); + static_assert(Stages_dS == 1 || Stages_dS == kStages); + static_assert(!Mma_dP_is_RS || SdP_swapAB_); // If Mma_dP_is_RS, we need SdP_SwapAB + using ClusterShape = ClusterShape_; + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + + static constexpr bool Has_lt_end = Has_lt_end_; + static constexpr bool Has_ut_start = Has_ut_start_; + static constexpr bool Is_blockmask = Is_blockmask_; + + static constexpr bool SdP_swapAB = SdP_swapAB_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool dQ_swapAB = dQ_swapAB_; + + static constexpr bool Q_dO_same_stages = kStages == kStages_dO; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + + static_assert(ArchTag::kMinComputeCapability >= 90); + static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp * 2; + + static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0); + static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0); + static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0); + static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB; + static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + + static constexpr GMMA::Major PdS_Major = GMMA::Major::K; + // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN; + static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K; + + using TileShapeAtomSdP = std::conditional_t< + !SdP_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutSdP = std::conditional_t< + !SdP_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmaSdP = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutSdP{})); + + using TiledMmadPRS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutSdP{})); + + using TileShapeAtomdKV = std::conditional_t< + !dKV_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdKV = std::conditional_t< + !dKV_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmadKV = decltype(cute::make_tiled_mma( + std::conditional_t< + Mma_dKV_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdKV{})); + + using TileShapeAtomdQ = std::conditional_t< + !dQ_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdQ = std::conditional_t< + !dQ_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmadQ = decltype(cute::make_tiled_mma( + std::conditional_t< + Mma_dQ_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdQ{})); + + // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. + // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. + // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension + // changes the layout. + using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); // for dKV_Mma + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutdO = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); + using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector, + Int>()); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}, Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80 + // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, + // it's still a valid smem address. + using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; + using SmemLayoutLSEMma = std::conditional_t< + SdP_swapAB, + cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, + cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> + >; + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutQt = + decltype(cute::composition(SmemLayoutQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutdOt = + decltype(cute::composition(SmemLayoutdO{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutKt = + decltype(cute::composition(SmemLayoutK{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutPdSt = + decltype(cute::composition(SmemLayoutPdS{}, + make_layout(make_shape(Int{}, Int{}, Int{}), + make_stride(Int{}, _1{}, Int{})))); + + // Thread layout, 256 or 384 threads per row + // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately. + using R2SLayoutAtomdQaccum = Layout, Int>>; + using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + using SmemLayoutdQaccum = Layout, Int>>; + + static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads; + // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. + // If PdS_major is MN, then we need to "transpose" the write. + using SmemCopyAtomPdS = Copy_Atom< + std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN), + std::conditional_t, + std::conditional_t + >, + Element + >; + + using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{}))); + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQKV = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + using TMA_QdO = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + take<0, 2>(SmemLayoutQ{}), + TileShape_MNK{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_K = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + SmemLayoutK{}, + TileShape_MNK{}, + ClusterShape{})); // no mcast for KV + + using TMA_V = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + SmemLayoutV{}, + TileShape_MNK{}, + ClusterShape{})); // no mcast for KV + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync; + using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; + // using MainloopPipeline_flashmask = typename cutlass::PipelineAsync; + // using PipelineState_flashmask = typename MainloopPipeline_flashmask::PipelineState; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast(size(SmemLayoutK{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV = static_cast(size(SmemLayoutV{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesLSE = static_cast(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v / 8); + + // These are tuned for speed. They don't affect correctness. + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + // For hdim 192, separating masking iterations results in register spills. + static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; + // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then + // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each + // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep + // statistic for 2 rows. + static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; + static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; + static constexpr bool dQacc_use_TMA = kHeadDim < 256; + // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can + // do atomic add on one half before doing the other half of the MMA, to reduce register pressure. + static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2; + // static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma"); + + static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); + static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); + // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA + static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128; + static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{}); + static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment"); + + // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment? + using SmemdQacc_t = std::conditional_t, cute::array_aligned>>; + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + //flashmask + // using SmemLayoutBlockMask = decltype(cute::make_layout(cute::Shape{kStages, 2},cute::Stride{2,1})); + struct TensorStorage : cute::aligned_struct { + cute::array_aligned, SmemAlignmentQKVdO> smem_k; + cute::array_aligned, SmemAlignmentV> smem_v; + SmemdQacc_t smem_dqacc; + cute::array_aligned, SmemAlignmentQKVdO> smem_q; + cute::array_aligned, SmemAlignmentQKVdO> smem_do; + cute::array_aligned, 128> smem_lse; + cute::array_aligned, 128> smem_dpsum; + SmemP_t smem_p; + cute::array_aligned, SmemAlignmentdS> smem_ds; + // cute::array_aligned smem_block_mask; + }; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* const ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* const ptr_V; + StrideQKV const stride_V; + Element const* const ptr_dO; + StrideQKV const stride_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale; + int const window_size_left, window_size_right; + float const softcap_val; + int const num_batch; + int* const dq_semaphore; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + + // FlashMask + int const h_flashmask; + int const h_h_flashmask_ratio; + + int32_t * __restrict__ const lt_start_ptr = nullptr; + int32_t * __restrict__ const lt_end_ptr = nullptr; + + int32_t * __restrict__ const ut_start_ptr = nullptr; + int32_t * __restrict__ const ut_end_ptr = nullptr; + + int32_t * __restrict__ const flashmask_maxmin_ptr = nullptr; + + int32_t * __restrict__ const lt_start_nblockmax = nullptr; + int32_t * __restrict__ const lt_start_nblockmin = nullptr; + + int32_t * __restrict__ const lt_end_nblockmax = nullptr; + int32_t * __restrict__ const lt_end_nblockmin = nullptr; + + int32_t * __restrict__ const ut_start_nblockmax = nullptr; + int32_t * __restrict__ const ut_start_nblockmin = nullptr; + + int32_t * __restrict__ const ut_end_nblockmax = nullptr; + int32_t * __restrict__ const ut_end_nblockmin = nullptr; + + int m_block_dim,n_block_dim; + int32_t * __restrict__ block_mask_ptr = nullptr; + }; + + // Device side kernel params + struct Params { + ShapeQKV const shape_Q; + ShapeQKV const shape_K; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum stride_dQaccum; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_QdO tma_load_Q, tma_load_dO; + TMA_K tma_load_K; + TMA_V tma_load_V; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale, softmax_scale_log2; + int const window_size_left, window_size_right; + float const softcap_val; + int const num_batch; + int* const dq_semaphore; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + + // FlashMask + int const h_flashmask; + int const h_h_flashmask_ratio; + + int32_t * __restrict__ const lt_start_ptr = nullptr; + int32_t * __restrict__ const lt_end_ptr = nullptr; + + int32_t * __restrict__ const ut_start_ptr = nullptr; + int32_t * __restrict__ const ut_end_ptr = nullptr; + + int32_t * __restrict__ const flashmask_maxmin_ptr = nullptr; + + int32_t * __restrict__ const lt_start_nblockmax = nullptr; + int32_t * __restrict__ const lt_start_nblockmin = nullptr; + + int32_t * __restrict__ const lt_end_nblockmax = nullptr; + int32_t * __restrict__ const lt_end_nblockmin = nullptr; + + int32_t * __restrict__ const ut_start_nblockmax = nullptr; + int32_t * __restrict__ const ut_start_nblockmin = nullptr; + + int32_t * __restrict__ const ut_end_nblockmax = nullptr; + int32_t * __restrict__ const ut_end_nblockmin = nullptr; + + int m_block_dim,n_block_dim; + int m_factor, n_factor; + int32_t * __restrict__ block_mask_ptr = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_QdO tma_load_Q = make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + mQ, + SmemLayoutQ{}(_, _, _0{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along N mode for this M load, if any + Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); + TMA_QdO tma_load_dO = make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + mdO, + SmemLayoutdO{}(_, _, _0{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along N mode for this M load, if any + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + TMA_K tma_load_K = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK, + SmemLayoutK{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for KV + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + TMA_V tma_load_V = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mV, + SmemLayoutV{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for KV + + assert(args.m_block_dim % kBlockM == 0); + assert(args.n_block_dim % kBlockN == 0); + int m_factor = args.m_block_dim / kBlockM; + int n_factor = args.n_block_dim / kBlockN; + if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + // In the backward, we need to multiply by + // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. + // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale + // (the original softmax_scale) at the end. + return {args.shape_Q, args.shape_K, + args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, + args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, + args.softmax_scale, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.window_size_left, args.window_size_right, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.num_batch, args.dq_semaphore, + args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k, + args.h_flashmask, args.h_h_flashmask_ratio, + args.lt_start_ptr, args.lt_end_ptr, + args.ut_start_ptr, args.ut_end_ptr, + args.flashmask_maxmin_ptr, + args.lt_start_nblockmax, args.lt_start_nblockmin, + args.lt_end_nblockmax, args.lt_end_nblockmin, + args.ut_start_nblockmax, args.ut_start_nblockmin, + args.ut_end_nblockmax, args.ut_end_nblockmin, + args.m_block_dim,args.n_block_dim, + m_factor,n_factor, + args.block_mask_ptr}; + } + + enum class FmBlockInfo { + lt_start_max = 0, + lt_end_max = 1, + ut_start_max = 2, + ut_end_max = 3, + lt_start_min = 4, + lt_end_min = 5, + ut_start_min = 6, + ut_end_min = 7 + }; + constexpr int fm_idx(FmBlockInfo v) { + return static_cast(v); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); + } + + + CUTLASS_DEVICE + void load_n_block_info( int32_t * fm_mem, int32_t * flashmask_index_smem_, int32_t* blockmask_smem_, cute::tuple block_coord, Params const& params){ + auto [n_block, bidh, bidb] = block_coord; + int const seqlen_k = get<0>(params.shape_K); + int const seqlen_q = get<0>(params.shape_Q); + int const thread_idx = threadIdx.x; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const bh_offset = bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio; + int const n_block_seqlen = ((seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; // / 4 * 4 + int const bh_offset_block = bh_offset * n_block_seqlen; + + + const int valid_block_nblock_seqlen = (seqlen_k + params.n_block_dim - 1) / params.n_block_dim; + const int valid_block_mblock_seqlen = (seqlen_q + params.m_block_dim - 1) / params.m_block_dim; + + int blockmask_offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * valid_block_nblock_seqlen * valid_block_mblock_seqlen; + blockmask_offset += n_block / params.n_factor; + int stride_offset = valid_block_nblock_seqlen; + constexpr int ProducerThreadNum = 128; + + if(thread_idx == 0){ + // lt_start is always valid, otherwise this is not a valid flashmask computation instance + fm_mem[0] = (params.lt_start_nblockmax[bh_offset_block + n_block] - 1) / kBlockM; + fm_mem[1] = params.lt_start_nblockmin[bh_offset_block + n_block] / kBlockM; + // if(bidb ==1 and bidh == 0) printf("params.lt_start_nblockmax: %d, params.lt_start_nblockmin: %d ,n_block: %d\n", fm_mem[0], fm_mem[1], n_block); + if constexpr (Has_lt_end) { + fm_mem[2] = (params.lt_end_nblockmax[bh_offset_block + n_block] - 1) / kBlockM; + fm_mem[3] = params.lt_end_nblockmin[bh_offset_block + n_block] / kBlockM; + } + if constexpr (Has_ut_start) { + fm_mem[4] = (params.ut_start_nblockmax[bh_offset_block + n_block] - 1) / kBlockM; + fm_mem[5] = params.ut_start_nblockmin[bh_offset_block + n_block] / kBlockM; + } + if constexpr (!Is_causal) { + fm_mem[6] = (params.ut_end_nblockmax[bh_offset_block + n_block] - 1) / kBlockM; + fm_mem[7] = params.ut_end_nblockmin[bh_offset_block + n_block] / kBlockM; + } + // if(bidb ==1 and bidh == 0) printf("params.ut_end_nblockmax: %d, params.ut_end_nblockmin: %d ,n_block: %d\n", fm_mem[6], fm_mem[7], n_block); + // printf("bidh: %d, bidb: %d, n_block: %d\n", bidh, bidb, n_block); + // printf("params.h_flashmask: %d, params.h_h_flashmask_ratio: %d,get<0>(params.shape_Q): %d", params.h_flashmask, params.h_h_flashmask_ratio, get<0>(params.shape_Q)); + // int row_offset1 = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * seqlen + n_block * kBlockN; + // printf("row_offset: %d",row_offset1); + } + int const row_offset = bh_offset * seqlen_k + n_block * kBlockN; + // if(thread_idx == 0 and n_block == 0) printf("row_offset: %d, bidb: %d,h_flashmask: %d, h_h_flashmask_ratio: %d\n",row_offset,bidb,params.h_flashmask,params.h_h_flashmask_ratio); + const bool in_range = n_block * kBlockN + thread_idx < seqlen_k; + // Note(xhy): kBlockN in fa3 is always less than 128 + + if (thread_idx < kBlockN) { + flashmask_index_smem_[thread_idx] = in_range ? params.lt_start_ptr[thread_idx + row_offset] : INT_MAX; + if constexpr (Has_lt_end) { + flashmask_index_smem_[thread_idx + kBlockN] = in_range ? (Has_lt_end ? params.lt_end_ptr[thread_idx + row_offset] : INT_MAX) : INT_MAX; + } else { + flashmask_index_smem_[thread_idx + kBlockN] = INT_MAX; + } + if constexpr (!Is_causal) { + // Note(heqianyue): make sure that `Is_causal` masks are actually causal (no unmasked elements on upper triangle) + if constexpr (Has_ut_start) { + flashmask_index_smem_[thread_idx + 2 * kBlockN] = in_range ? params.ut_start_ptr[thread_idx + row_offset] : INT_MAX; + } + // if causal, Has_ut_start won't be true, so if 'Is_causal' == true, ut_end loading and int branching can be skipped in its entirity + flashmask_index_smem_[thread_idx + 3 * kBlockN] = in_range ? params.ut_end_ptr[thread_idx + row_offset] : INT_MIN; + } + } + + if constexpr (Is_blockmask){ + for(int64_t idx = thread_idx; idx < valid_block_mblock_seqlen ; idx += ProducerThreadNum) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(blockmask_smem_ + idx)), + "l"(params.block_mask_ptr + blockmask_offset + idx * stride_offset), + "n"(4)); + } + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_group 0;\n" ::); + } + // if(thread_idx < kBlockN) if(bidb ==1 and bidh == 0) printf("threadidx: %d,bidb: %d,bidh: %d,n_block: %d, row_offset: %d, ut_end_flashmask_index_smem_%d: %d\n", thread_idx,bidb,bidh,n_block,thread_idx + row_offset-seqlen,thread_idx,flashmask_index_smem_[thread_idx + 3 * kBlockN]); + // if(bidb ==0 and (bidh == 0 or bidh == 2) and n_block * kBlockN + i < seqlen and params.ut_end_ptr != nullptr) printf("threadidx: %d,bidb: %d,bidh: %d,n_block: %d, row_offset: %d, ut_end_flashmask_index_smem_%d: %d, params.ut_end_ptr_val: %d, params.ut_end_ptr_ptr: %p\n", thread_idx,bidb,bidh,n_block,row_offset,i,flashmask_index_smem_[i + 3 * kBlockN],params.ut_end_ptr[i + row_offset],params.ut_end_ptr + i + row_offset); + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarp * 4, static_cast(BwdNamedBarriers::FlashmaskProducer) /*id*/); + } + + template + CUTLASS_DEVICE void + load(Params const& params, + MainloopPipeline pipeline_q, + MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_write, + PipelineState_dO& smem_pipe_write_do, + SharedStorage &shared_storage, + cute::tuple block_coord, + int32_t const * const flashmask_mem_, + int32_t const * const blockmask_smem_ + ) { + + auto [n_block, bidh, bidb] = block_coord; + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); + // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { + return; + } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); + Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); + // Tensor sBlockMask = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_block_mask.data()), SmemLayoutBlockMask{}); + + int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + + Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); + Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); + Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{})); + Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{})); + // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, + // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) + // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, + // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) + auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y); + auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); + Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); + Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); + auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, + group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA) + auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, + group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA) + auto bulk_copy = Copy_Traits{}; + + uint16_t mcast_mask_qdo = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{})); + } + } + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int m_block = m_block_min; + // int const thread_idx = threadIdx.x % NumProducerThreads; + + int lane_predicate = cute::elect_one_sync(); + // if(lane_predicate){ + // printf("kBlockM: %d, kBlockN: %d\n", kBlockM, kBlockN); + // } + // int32_t flashmask_mem_[8]; + + // if(lane_predicate) { + // load_n_block_info(n_block, flashmask_mem_, params); + // // printf("nummma+numproducer: %d+%d", NumMmaThreads, NumProducerThreads); + // } + // printf("enter producer0 threadidx:%d", threadIdx.x); + + // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + + if (lane_predicate) { + shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV); + copy(params.tma_load_K.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK); + copy(params.tma_load_V.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV); + + auto process_block = [&](int m_block) { + // If Q and dO have the same number of stages, we can use the same pipeline state variable + // to reduce registers + pipeline_q.producer_acquire(smem_pipe_write); + copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index())); + copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), + gLSE(_, m_block), sLSE(_, smem_pipe_write.index())); + PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); + pipeline_do.producer_acquire(smem_pipe_write_do_cur); + copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); + + copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), + gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); + if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } + ++smem_pipe_write; + }; + int loop_end = m_block_max; + if constexpr(!Is_causal){ + if constexpr (Has_ut_start) { + loop_end = flashmask_mem_[4]; + #pragma unroll (kHeadDim < 256 ? 2 : 1) + for (; m_block <= loop_end; ++m_block) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + process_block(m_block); + } + } + m_block = std::max(m_block, flashmask_mem_[7]); + } + loop_end = std::min(m_block_max - 1, flashmask_mem_[0]); + // printf("flashmask_mem_0,lt_start_nblockmax,n_block: %d, %d, %d\n", flashmask_mem_[0],params.lt_start_nblockmax[n_block],n_block); + // printf("loop_end: %d\n", loop_end); + #pragma unroll (kHeadDim < 256 ? 2 : 1) + for (; m_block <= loop_end; ++m_block) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + process_block(m_block); + } + if constexpr (Has_lt_end) { + m_block = std::max(m_block, flashmask_mem_[3]); + #pragma unroll (kHeadDim < 256 ? 2 : 1) + for (; m_block <= m_block_max - 1; ++m_block) { + // printf("producer1 m_block,n_block: %d, %d\n", m_block,n_block); + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + process_block(m_block); + } + } + } + if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_write) { + static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages"); + // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write + PipelineState smem_pipe_write_do = smem_pipe_write; + // Issue the epilogue waits + if (cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_q.producer_tail(smem_pipe_write); + pipeline_do.producer_tail(smem_pipe_write_do); + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) { + // Issue the epilogue waits + if (cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_q.producer_tail(smem_pipe_write); + pipeline_do.producer_tail(smem_pipe_write_do); + } + } + + template + CUTLASS_DEVICE void + store_dq(Params const& params, + SharedStorage &shared_storage, + cute::tuple block_coord, + int32_t const * const flashmask_mem_, + int32_t const * const blockmask_smem_ + // MainloopPipeline_flashmask pipeline_flashmask, + ) { + if constexpr (!dQacc_use_TMA) { return; } + + auto [n_block, bidh, bidb] = block_coord; + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, 0 /*sink_token_length*/); + // It's possible to have m_block_max <= m_block_min. Exit early + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { return; } + } + // printf("enter producer1 threadidx:%d", threadIdx.x); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(BwdNamedBarriers::Flashmask) /*id*/); + int m_block = m_block_min; + // if (threadIdx.x % 32 == 0) { printf("m_block:%d", m_block); } + + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); + static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum); + + bool const is_varlen = Varlen && params.cu_seqlens_q; + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) + + int const num_batch = params.num_batch; + int const num_head = get<2>(params.shape_Q); + int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh; + using Barrier = cutlass::GenericBarrier; + bool const lane_predicate = cute::elect_one_sync(); + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + + // int32_t flashmask_mem_[8]; + // load_n_block_info(n_block, flashmask_mem_, params); + // printf("m_block:%d", m_block); + // printf("m_block_max:%d\n", m_block_max); + if constexpr (Deterministic) { + for (int prefix_m_block=0; prefix_m_block < m_block; prefix_m_block++) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, prefix_m_block * num_batch * num_head, n_block); + /* Do Nothing, just wait */ + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, prefix_m_block * num_batch * num_head); + } + } + int loop_end = m_block_max; + if constexpr(!Is_causal){ + if constexpr (Has_ut_start) { + loop_end = flashmask_mem_[4]; + #pragma unroll 2 + for (; m_block <= loop_end; ++m_block) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } + #pragma unroll + for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem + if (lane_predicate) { + //cute::print_tensor(sdQ); + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + } + } + // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. + for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { + if (lane_predicate) { tma_store_wait(); } + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to + }); + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } + } + if constexpr (Deterministic) { + int cur_m_block = m_block; + m_block = std::max(m_block,flashmask_mem_[7]); + // up mask + for (; cur_m_block < m_block; cur_m_block++) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, cur_m_block * num_batch * num_head, n_block); + /* Do Nothing, just wait */ + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, cur_m_block * num_batch * num_head); + } + } + else { + m_block = std::max(m_block,flashmask_mem_[7]); + } + } + loop_end = std::min(m_block_max - 1, flashmask_mem_[0]); + #pragma unroll 2 + for (; m_block <= loop_end; ++m_block) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor] ) continue; + } + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } + #pragma unroll + for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem + if (lane_predicate) { + //cute::print_tensor(sdQ); + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + } + } + // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. + for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { + if (lane_predicate) { tma_store_wait(); } + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to + }); + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } + if constexpr (Has_lt_end) { + if constexpr (Deterministic) { + int cur_m_block = m_block; + m_block = std::max(m_block,flashmask_mem_[3]); + // down mask + for (; cur_m_block < m_block; cur_m_block++) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, cur_m_block * num_batch * num_head, n_block); + /* Do Nothing, just wait */ + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, cur_m_block * num_batch * num_head); + } + } + else m_block = std::max(m_block, flashmask_mem_[3]); + #pragma unroll 2 + for (; m_block < m_block_max; ++m_block) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } + #pragma unroll + for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem + if (lane_predicate) { + // cute::print_tensor(sdQ); + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + } + } + // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. + for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { + if (lane_predicate) { tma_store_wait(); } + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to + }); + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } + } + if constexpr (Deterministic) { + int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); + for (; m_block < m_block_global_max; m_block++) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + /* Do Nothing, just wait */ + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } + } + + CUTLASS_DEVICE void + mma_init() { + // // Tell producer (warp 0) that smem_k and smem_v are ready + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if constexpr (dQacc_use_TMA) { + if (warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to + } + } + } + + template + __device__ bool + // CUTLASS_DEVICE bool + mma(Params const& params, + MainloopPipeline pipeline_q, + MainloopPipeline_dO pipeline_do, + // MainloopPipeline_flashmask pipeline_flashmask, + PipelineState& smem_pipe_read, + PipelineState_dO& smem_pipe_read_do, + FrgTensordKV& tdKrdK, + FrgTensordKV& tdVrdV, + int const thread_idx, + int const binary_work_idx, + cute::tuple block_coord, + SharedStorage& shared_storage, + const int32_t* const __restrict__ flashmask_mem_, + const int32_t* const __restrict__ flashmask_index_smem_, + int32_t const * blockmask_smem_ + ) { + static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); + + int n_block = get<0>(block_coord); + int bidb = get<2>(block_coord); + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, 0 /*sink_token_length*/); + // It's possible to have m_block_max <= m_block_min. Exit early + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { return false; } + } + + // printf("enter consumer threadidx:%d", threadIdx.x); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(BwdNamedBarriers::Flashmask) /*id*/); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); + Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); + Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); + Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); + Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt); + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); + Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); + Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); + + static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and + stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and + size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + Layout warp_group_thread_layout_dq = make_layout(make_shape(Int{}), + make_stride(Int{})); + + TiledMmaSdP tiled_mma_SdP; + using TiledMmadP = std::conditional_t; + TiledMmadP tiled_mma_dP; + TiledMmadKV tiled_mma_dKV; + TiledMmadQ tiled_mma_dQ; + + decltype(tiled_mma_SdP.get_slice(warp_group_thread_layout(0))) wg_mma_SdP; + decltype(tiled_mma_dP.get_slice(warp_group_thread_layout(0))) wg_mma_dP; + decltype(tiled_mma_dKV.get_slice(warp_group_thread_layout(0))) wg_mma_dKV; + decltype(tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(0))) wg_mma_dQ; + + { + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); + wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx)); + wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); + wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx)); + } + + auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); + + auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); + + R2STiledCopydQaccum r2s_tiled_copy_dQaccum; + auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ); + // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); } + + // Allocate "fragments/descriptors" + // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, + // because some partition_fragment_A/B don't compile. + // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function + Tensor tSrQ = mma_partition_fragment_AB(wg_mma_SdP, sQ); + Tensor tSrK = mma_partition_fragment_AB(wg_mma_SdP, sK); + Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_SdP, sdO); + Tensor tdPrV = mma_partition_fragment_AB(wg_mma_dP, sV); + Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dKV, sdOt); + Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dKV, sQt); + Tensor tdQrdS = mma_partition_fragment_AB(wg_mma_dQ, sdS); + Tensor tdQrK = mma_partition_fragment_AB(wg_mma_dQ, sKt); + + Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } + + // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices + // or row indices, depending on whether SdP_swapAB. + Tensor tLSEsLSE = cute::conditional_return( + group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) + group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) + Tensor tLSEsdPsum = cute::conditional_return( + group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), + group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } + // If we want to split the stats among the 8 threads that share the same rows. + static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + // For the case where we do atomicAdd directly to gdQaccum instead of using TMA + bool const is_varlen = Varlen && params.cu_seqlens_q; + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, get<1>(block_coord)/*bidh*/, !is_varlen ? bidb : 0); + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } + + flash::Mask mask( + thread_idx, seqlen_info.seqlen_q, seqlen_info.seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.qhead_per_khead_divmod + ); + // int32_t flashmask_mem_[8]s; + // load_n_block_info(n_block, flashmask_mem_, params); + + int m_block = m_block_min; + // if(thread_idx == 0) printf("m_block:%d",m_block); + // get_next_m_block(n_block,m_block,partially_masked,m_block_max - 1,params); + + clear(tdKrdK); + clear(tdVrdV); + // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero; + + cutlass::ConsumerToken barrier_token = static_cast(shared_storage.pipelines.barrier_KV.try_wait(binary_work_idx)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(binary_work_idx); } + + if constexpr (Mma_dP_is_RS) { + using SmemCopyAtomV = Copy_Atom; + auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); + Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV); + Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV)); + cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view); + } + + auto bwd_step = [&](int m_block, auto mask_fn, bool partially_masked, const int32_t* const flashmask_index_smem_) { + Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + consumer_wait(pipeline_q, smem_pipe_read); + flash::gemm(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); + Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffleLSE) { + cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values + tLSErLSE(i) = tLSEsLSE((thread_idx & 31) / 4 + i * 8, smem_pipe_read.index()); + } + } + Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return(smem_pipe_read, smem_pipe_read_do); + consumer_wait(pipeline_do, smem_pipe_read_do_cur); + // printf("consumer2:stageid,do_stageid:%d,%d\n", smem_pipe_read.index(),smem_pipe_read_do_cur.index()); + flash::gemm(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP); + warpgroup_wait<1>(); + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } + + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh + auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); + mask_fn(tSrS, m_block); + if(partially_masked) flash::apply_flashmask_bwd(tSrS, thread_idx, flashmask_index_smem_, m_block); + #pragma unroll + for (int mi = 0; mi < size<0>(scores); ++mi) { + float const lse_scaled = [&] { + if constexpr (!ShuffleLSE) return tLSErLSE(mi); + else return __shfl_sync(0xffffffff, tLSErLSE(mi >> 3), (mi & 7) * 4 + (thread_idx & 3)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(scores); ++ni) { + // printf("score-point0,row_idx:%d, col_idx:%d,mi:%d,ni:%d,m_block:%d, thread_idx:%d, scores:%f\n", row_idx, col_idx,mi,ni,m_block, thread_idx, scores(mi, ni)); + scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); + // printf("score-point1,row_idx:%d, col_idx:%d,mi:%d,ni:%d,m_block:%d, thread_idx:%d, scores:%f\n", row_idx, col_idx,mi,ni,m_block, thread_idx, scores(mi, ni)); + } + } + + Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffledPsum) { + cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + tLSErdPsum(i) = tLSEsdPsum((thread_idx & 31) / 4 + i * 8, smem_pipe_read_do_cur.index()); + } + } + + warpgroup_wait<0>(); + // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + float const dP_sum_cur = [&] { + if constexpr (!ShuffledPsum) return tLSErdPsum(mi); + else return __shfl_sync(0xffffffff, tLSErdPsum(mi >> 3), (mi & 7) * 4 + (thread_idx & 3)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); + if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } + } + } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = make_tensor_like(tSrS); + flash::convert_type_out(tSrS, rP); + if constexpr (!Mma_dKV_is_RS) { + // Need to sync to make sure P has already been used in the previous iteration before writing new values + if constexpr (kStages_dS == 1) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + } + Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); + } + Tensor rdS = make_tensor_like(tdPrdP); + flash::convert_type_out(tdPrdP, rdS); + // If there's double buffering on dS, we don't need to sync here. + // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. + // But because both WGs have to sync at the end of the loop and double buffering, + // this race condition is not possible. + // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and + // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. + if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + } + // For hdim 64, It's faster to write to smem_dS first before the dV gemm + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); + + if constexpr (!Slice_dQKV_Mma) { + // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure + if constexpr (Mma_dKV_is_RS) { + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + } else { + Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); + Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + } + // SMEM fence to make sure sdS is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ + + if constexpr (Mma_dKV_is_RS) { + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + } else { + Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); + Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + } + if constexpr (dQacc_use_TMA) { + int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1; + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem + Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); + cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem + // if(blockIdx.x == 0 && threadIdx.x == 128){ + // printf("warp_group_idx: %d\n", warp_group_idx); + // printf("sdq\n"); + // cute::print_tensor(sdQ); + // printf("\n"); + // } + } else { + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); + Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); + static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + } + + } else { // Slice_dQKV_Mma + + static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS)); + Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); + Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); + flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); + Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + + Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); + Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO + + flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); + #pragma unroll + for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + + flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + } + + warpgroup_wait<0>(); + pipeline_q.consumer_release(smem_pipe_read); // release Q + ++smem_pipe_read; + if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; } + }; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + int loop_end = m_block_max; + if constexpr(!Is_causal){ + if constexpr (Has_ut_start) { + loop_end = std::min(flashmask_mem_[5]/*ut_start_nblockmin*/, m_block_max); + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < loop_end; m_block++) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + // if(threadIdx.x == 128) printf("consumer0 m_block,n_block: %d, %d\n", m_block,n_block); + bwd_step(m_block, mask_fn, false, flashmask_index_smem_); + } + loop_end = flashmask_mem_[4]/*ut_start_nblockmax*/; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block <= loop_end; ++m_block) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + // if(threadIdx.x == 128) printf("consumer0 m_block,n_block: %d, %d\n", m_block,n_block); + bwd_step(m_block, mask_fn, true, flashmask_index_smem_); + } + } + m_block = std::max(m_block, flashmask_mem_[7]/*ut_end_nblockmin*/); + loop_end = std::min(flashmask_mem_[6]/*ut_end_nblockmax*/, m_block_max - 1); + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block <= loop_end; m_block++) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + // if(threadIdx.x == 128) printf("consumer-u-2 m_block,n_block,m_block_max,flashmask_mem_[2]: %d, %d, %d,%d\n", m_block,n_block,m_block_max,flashmask_mem_[6]); + bwd_step(m_block, mask_fn, true, flashmask_index_smem_); + } + } + loop_end = std::min(flashmask_mem_[1]/*lt_start_nblockmin*/, m_block_max); + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < loop_end; m_block++) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + // if(threadIdx.x == 128) printf("consumer-l-0 m_block,n_block: %d, %d\n", m_block,n_block); + bwd_step(m_block, mask_fn, false, flashmask_index_smem_); + } + //partial_maskloop_end + loop_end = std::min(m_block_max - 1, flashmask_mem_[0]/*lt_start_nblockmax*/); + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block <= loop_end; m_block++) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + // if(threadIdx.x == 128) printf("consumer-l-1 m_block,n_block, flashmask_mem_[0]: %d, %d, %d\n", m_block,n_block,flashmask_mem_[0]); + bwd_step(m_block, mask_fn, true, flashmask_index_smem_); + } + if constexpr (Has_lt_end) { + m_block = std::max(m_block, flashmask_mem_[3]/*lt_end_nblockmin*/); + //partial_maskloop_end + loop_end = std::min(flashmask_mem_[2]/*lt_end_nblockmax*/, m_block_max - 1); + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block <= loop_end; m_block++) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + // if(threadIdx.x == 128) printf("consumer2 m_block,n_block,m_block_max,flashmask_mem_[2]: %d, %d, %d,%d\n", m_block,n_block,m_block_max,flashmask_mem_[2]); + bwd_step(m_block, mask_fn, true, flashmask_index_smem_); + } + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max; m_block++) { + if constexpr (Is_blockmask){ + if(!blockmask_smem_[m_block / params.m_factor]) continue; + } + bwd_step(m_block, mask_fn, false, flashmask_index_smem_); + } + } + + #pragma unroll + for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } + + if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; } + return true; + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/mainloop_fwd_sm80.hpp b/flashmask/flash_mask/flashmask_attention_v3/mainloop_fwd_sm80.hpp new file mode 100644 index 00000000000..30a4c7968f3 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/mainloop_fwd_sm80.hpp @@ -0,0 +1,863 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "seqlen.h" +#include "block.h" +#include "mask.h" +#include "pack_gqa.h" +#include "paged_kv.h" +#include "rotary.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwdSm80 { + + static constexpr int kStages = Stages; + static_assert(kStages > 0, "kStages must be greater than 0"); + using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PagedKV = PagedKV_; + static constexpr bool AppendKV = AppendKV_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool Transpose_V = Is_FP8; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + + using MMA_Atom_Arch = std::conditional_t< + ArchTag::kMinComputeCapability >= 80, + std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >, + MMA_Atom + >; + using TiledMma = TiledMMA< + MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + static constexpr int NumMmaThreads = size(TiledMma{}); + static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + using SmemLayoutAtomQKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomQKV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutV = decltype(tile_to_shape( + SmemLayoutAtomQKV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutVt = decltype( + composition(SmemLayoutV{}, + make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + Step<_2, _1, _3>{}))); + + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using GmemCopyAtom = Copy_Atom, + AutoVectorizingCopyWithAssumedAlignment<128> + >, Element>; + + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per read + // So that we don't have to check if we overshot kBlockM when we load Q + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + + // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of + // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), + // each thread will load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, "kGmemThreadsPerRowAppend must divide NumThreadsPerWarp"); + using GmemLayoutAtomAppend = Layout, Int>, + Stride, _1>>; + // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication + static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend"); + using GmemTiledCopyAppendKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomAppend{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQK = cute::Stride; + using StrideV = StrideQK; + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + using StrideDescale = cute::Stride; + + static constexpr bool Share_QV_Smem = Q_in_regs; + + struct TensorStorageSharedQV : cute::aligned_struct<128> { + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_q; + }; + cute::array_aligned> smem_k; + }; + + struct TensorStorageSeparateQV : cute::aligned_struct<128> { + cute::array_aligned> smem_v; + cute::array_aligned> smem_k; + cute::array_aligned> smem_q; + }; + + using TensorStorage = std::conditional_t; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + Element* const ptr_K; // Not Element const* since we might append to KV cache in-place + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + float const softmax_scale; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + int const window_size_left = -1, window_size_right = -1; + float const softcap_val; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + ShapeQPacked const shape_Q_packed; + StrideQPacked const stride_Q_packed; + Element* const ptr_K; + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod qhead_per_khead_divmod; + float const softmax_scale_log2; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + float const softcap_val; + int const window_size_left, window_size_right; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); + auto const shape_Q_packed = cute::conditional_return( + args.shape_Q, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) + ); + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) + ); + if (get<1>(args.shape_rotary) > 0) { + assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); + } + assert(args.num_splits >= 1); + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, + args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, + args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, + args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, + cutlass::FastDivmod(int(get<0>(args.shape_K))), + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, + args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.window_size_left, args.window_size_right, + !Split ? 1 : args.num_splits, + args.kv_batch_idx, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k}; + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const n_block_min = get<0>(n_block_min_max); + int const n_block_max = get<1>(n_block_min_max); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + + GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); + auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation + + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(thread_idx); + + // Allocate "fragments/descriptors" + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); + + // Copy Atom retiling + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); + auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx); + auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // Predicates + Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); + Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int n_block = n_block_max - 1; + + // Prologue: load Q, K, V + // If persistent, we don't need to wait for the previous work_idx to finish + // since we assume that all MMA threads sync in the epilogue before writing to smem_o. + // So any thread gets there, all threads must have finished the previous MMA and at least started + // writing to smem_o. + // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v + if constexpr (Share_QV_Smem) { __syncthreads(); } + if constexpr (!PackGQA) { + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); + Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } + // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit + // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})) + ); + } else { + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>; + PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block); + } + cute::cp_async_fence(); + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ + ); + + auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + if constexpr (!PagedKV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write); + // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit + // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. + int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_info.seqlen_k - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN))); + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy( + gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit); + } else { + paged_kv_manager.template load_page_table(n_block); + paged_kv_manager.template load_K(n_block, sK(_, _, smem_pipe_write)); + } + }; + + auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + if constexpr (!PagedKV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write); + // We don't call flash::copy since it doesn't support bound checking + // to not overshot kBlockN when writing to smem. + Tensor tVgV_cur = tVgV(_, _, _, n_block); + int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { + bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k)); + } + } + } + } else { + paged_kv_manager.template load_V(n_block, sV(_, _, smem_pipe_write)); + } + }; + + auto preprocess_Q = [&] { + if constexpr (!AppendKV) { + flash::cp_async_wait(); + } else { + if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q + int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + flash::cp_async_wait(); + __syncthreads(); + rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead); + } else { + auto [tRrCosCont, tRrSinCont] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + flash::cp_async_wait(); + __syncthreads(); + rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); + } + } else { + flash::cp_async_wait(); + } + } + + if constexpr (Q_in_regs) { + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ); + cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); + } + }; + + // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and + // read from smem_q to registers, then load V. + // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q. + + if constexpr (Share_QV_Smem) { + load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/); + cute::cp_async_fence(); + preprocess_Q(); + __syncthreads(); // Make sure all threads have read smem_q before loading V + } + + // For persistent, make sure all threads have finished reading smem_o + if constexpr (!Share_QV_Smem) { __syncthreads(); } + // Note, using the for_each() function here to ensure `stage` is of type Int. + for_each(make_int_sequence{}, [&] (auto stage) { + static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; + static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; + if constexpr (!Share_QV_Smem || !Is_first_stage) { + if (Is_first_stage || n_block - stage >= n_block_min) { + load_K(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + // We want the fence outside the if statement to have a fixed number of cp.async commits. + // so that we can wait with the correct number of outstanding commits. + cute::cp_async_fence(); + } + if constexpr (!Is_last_stage) { + if (Is_first_stage || n_block - stage >= n_block_min) { + load_V(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + } + }); + + if constexpr (!Share_QV_Smem) { preprocess_Q(); } + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.qhead_per_khead_divmod + ); + + float softcap_val = params.softcap_val; + if constexpr (Has_softcap && Is_FP8) { + float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; + float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; + softcap_val *= q_descale * k_descale; + } + // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + // -inf to e.g. -50.0, which can affect the attention softmax. + auto scoremod_premask_fn = [&](auto& tSrS) { + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } + }; + + int smem_pipe_read = 0, smem_pipe_write = kStages - 1; + + auto load_K_next = [&] { + if (n_block - kStages >= n_block_min) { + load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + }; + + auto sync = [&] { + flash::cp_async_wait(); + __syncthreads(); + }; + + clear(tOrO); + + auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { + static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; + static constexpr bool Check_inf = decltype(check_inf_type)::value; + Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); + clear(tSrS); + sync(); + auto load_V_next = [&] { + if (n_block - kStages + 1 >= n_block_min) { + load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + }; + Tensor tSrQ_cur = cute::conditional_return(tSrQ, thr_mma.partition_fragment_A(sQ)); + Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{})); + flash::gemm_sm80( + tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0), + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next + ); + smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; + scoremod_premask_fn(tSrS); + // Faster to load_K before gemm if we only have 1 stage + if constexpr (kStages == 1) { sync(); load_K_next(); } + mask_fn(tSrS, n_block); + Tensor scores_scale = softmax.template max_get_scale(tSrS); + softmax.template online_softmax(tSrS); + if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (kStages > 1) { sync(); } + Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{})); + flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + if constexpr (kStages > 1) { load_K_next(); } + smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; + }; + + auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + --n_block; + if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking + auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); + int const n_block_min_causal_local_mask = + std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + #pragma unroll 1 + for (; n_block >= n_block_min_causal_local_mask; --n_block) { + fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + } + } + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + int const n_block_min_before_local_mask = !Is_local + ? n_block_min + : std::max(n_block_min, + cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + #pragma unroll 1 + for (; n_block >= n_block_min_before_local_mask; --n_block) { + fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); + } + // Separate masking iterations on the left for local attention + if constexpr (Is_local) { + auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); + } + } + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + Tensor scores_scale = softmax.finalize(v_descale); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); } + return true; + } + + template + CUTLASS_DEVICE bool + store_kv_new(Params const& params, + int const thread_idx, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord + ) { + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const n_block_new_min = get<0>(n_block_new_min_max); + int const n_block_new_max = get<1>(n_block_new_min_max); + if (n_block_new_max <= n_block_new_min) { return false; } + + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; + Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + + Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; + Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; + int const seqlen_k_new = seqlen_info.seqlen_k_new; + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, + // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ + ); + + static_assert(std::is_same_v); + static_assert(!PagedKV || std::is_same_v); + GmemTiledCopyQKV gmem_tiled_copy_kv_g2s; + auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx); + auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{}); // Only for index calculation + GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g; + auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx); + auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{}); // Only for index calculation + Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew); + Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK); + Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK); + Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV); + Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK); + Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK); + Tensor tKpKg2s = make_tensor(make_shape(size<2>(tKsKg2s))); + Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK); + Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK); + Tensor tKpKs2g = make_tensor(make_shape(size<2>(tKsKs2g))); + #pragma unroll + for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); } + #pragma unroll + for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); } + + auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write); + int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k_new - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); + // We don't need to clear the sK smem tiles since we won't write them out + flash::copy( + gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); + }; + + auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write); + int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k_new - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); + // We don't need to clear the sV smem tiles since we won't write them out + flash::copy( + gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); + }; + + auto store_K = [&] (int const n_block, int const smem_pipe_read) { + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + if (get<1>(params.shape_rotary) <= 0) { + Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read); + if constexpr (!PagedKV) { + Tensor tKgK_cur = tKgK(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) + ); + } else { + paged_kv_manager.store_K(n_block, tKsK_cur); + } + } else { + Tensor gK_cur = gK(_, _, n_block); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block); + } else { + auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + } + } + }; + + auto store_V = [&] (int const n_block, int const smem_pipe_read) { + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read); + if constexpr (!PagedKV) { + Tensor tVgV_cur = tVgV(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit); + } else { + paged_kv_manager.store_V(n_block, tVsV_cur); + } + }; + + int n_block = n_block_new_max - 1; + // Note, using the for_each() function here to ensure `stage` is of type Int. + for_each(make_int_sequence{}, [&] (auto stage) { + static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; + static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; + if (Is_first_stage || n_block - stage >= n_block_new_min) { + load_K_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v + if constexpr (Is_first_stage) { __syncthreads(); } + if constexpr (!Is_last_stage) { + if (Is_first_stage || n_block - stage >= n_block_new_min) { + load_V_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + } + }); + + int smem_pipe_read = 0, smem_pipe_write = kStages - 1; + #pragma unroll 1 + for (; n_block >= n_block_new_min; --n_block) { + if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } + flash::cp_async_wait(); + __syncthreads(); + store_K(n_block, kStages > 1 ? smem_pipe_read : 0); + if (n_block - kStages + 1 >= n_block_new_min) { + load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; + flash::cp_async_wait(); + __syncthreads(); + store_V(n_block, kStages > 1 ? smem_pipe_read : 0); + smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; + if (n_block - kStages >= n_block_new_min) { + load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + } + + return true; + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/mainloop_fwd_sm90_tma_gmma_ws.hpp b/flashmask/flash_mask/flashmask_attention_v3/mainloop_fwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 00000000000..d67d2c4d295 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,2341 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "named_barrier.hpp" +#include "seqlen.h" +#include "block.h" +#include "mask.h" +#include "pack_gqa.h" +#include "paged_kv.h" +#include "rotary.h" +#include "utils.h" +#include "sm90_pipeline_no_cluster.hpp" +#include "flash_mask.hpp" +#include "cutlass/arch/memory_sm75.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwdSm90 { + + static constexpr int kStages = Stages; + static constexpr int kNBlockStages = 2; + using ClusterShape = ClusterShape_; + using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; + static constexpr bool AppendKV = AppendKV_; + static constexpr bool HasQv = HasQv_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool V_colmajor = V_colmajor_; + static constexpr bool Is_flashmask = Is_flashmask_; + static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; + static constexpr bool Use_TMA_Q = !PackGQA; + static constexpr bool Use_TMA_KV = !PagedKVNonTMA; + static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); + static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); + static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; + static constexpr bool LargeHeadDimV = kHeadDimV > 256; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; + static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + static constexpr int Flashmask_max_seqlen_k = 1024 * 16; // scheduler pipelining needs to cut smem in half + static constexpr int ProducerThreadNum = 96; + + // static constexpr bool Is_blockmask = Is_blockmask_; + //xhy : now only support blockmask blocksize == 128, and blockdim info need to be constexpr + // otherwise it will cause register spill + static constexpr int m_block_dim = 128; + static constexpr int n_block_dim = 128; + static constexpr int m_factor = std::max(m_block_dim / kBlockM,1); + static constexpr int n_factor = std::max(n_block_dim / kBlockN,1); + + // Flashmask_n_block_buffer_length is the multiple of 32 for 128B excessive-sector-free load/store + static constexpr int Flashmask_n_block_buffer_length = ((Flashmask_max_seqlen_k + kBlockN - 1) / kBlockN + 31) & 0xffffffe0; + static constexpr int Flashmask_n_block_buffer_valid_length = ((Flashmask_max_seqlen_k + kBlockN - 1) / kBlockN + 3) / 4 * 4; + static constexpr int Blockmask_n_block_buffer_valid_length = (Flashmask_n_block_buffer_valid_length * kBlockN + n_block_dim -1) / n_block_dim; + + // Using bool in smem will usually lead to 4-way bank conflict, in order to accelerate this func + // we encode `partially_masked` flags in `n_block_smem`, which both saved some smem, while eliminating 4-way bank conflict + // if partially mask: the original value is stored, otherwise we store `-n_block - 1` + // so -1 and -2 can not be used as flags any more (they will be meaningful) + static constexpr int Flashmask_n_block_chunk_end = INT_MIN + 1; // 0x80000001 + static constexpr int Flashmask_n_block_finish = INT_MIN; // 0x80000000 + + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); + static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); + static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); + + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. + // Leaving this option here for reference. + static constexpr bool MmaQK_is_RS = false; + // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); + static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); + + // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write + static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; + + using AtomLayoutQK = Layout, _1, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( + std::conditional_t< + !MmaQK_is_RS, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutQK{})); + using AtomLayoutPV = std::conditional_t< + !LargeHeadDimV, + AtomLayoutQK, + Layout, _1>> + >; + using TiledMmaPV = decltype(cute::make_tiled_mma( + std::conditional_t< + !MmaPV_is_RS, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutPV{})); + using TiledMmaQV = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutQK{})); + // For hdim64,512, WG1 can use RS but WG2 must use SS + using TiledMmaPV_RS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutPV{})); + + static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); + static constexpr int NumMmaThreads = size(TiledMmaPV{}); + static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); + static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); + static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + using SmemLayoutVt = decltype(tile_to_shape( + SmemLayoutAtomVt{}, + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + using SmemLayoutVtMma = decltype(tile_to_shape( + SmemLayoutAtomVtMma{}, + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); + using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutVMmaQV = decltype(tile_to_shape( + SmemLayoutAtomVMmaQV{}, + make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); + static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); + + // Only used if we're using cp.async to load V + using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int>()); + using SmemLayoutVCpAsync = decltype(tile_to_shape( + SmemLayoutAtomVCpAsync{}, + make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); + + using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + + // Only for LargeHeadDimV where WG0 sends WG1 the scales + using SmemLayoutScale = cute::Layout, Int>>; + + using SmemCopyAtomP = Copy_Atom; + + // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. + // For FP16/BF16 we don't do any transposing. + static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; + // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). + static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + using LDSM_value_shape = Shape<_2, _2, _1, _4>; + using LDSM_value_stride = Stride<_1, _2, _16, _4>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using S2RTiledCopyVt = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_value_shape = Shape<_1, _4, _2, _2>; + using STSM_value_stride = Stride<_0, _1, _4, _8>; + using STSM_divide_shape = Shape<_8, _16>; + // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts + // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). + // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. + // using STSM_value_shape = Shape<_2, _4, _1, _2>; + // using STSM_value_stride = Stride<_4, _1, _0, _8>; + // using STSM_divide_shape = Shape<_16, _16>; + using R2STiledCopyV = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); + + // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved + // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will + // load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication + static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow"); + using GmemTiledCopyAppendKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQK = cute::Stride; + using StrideV = std::conditional_t>; + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + using StrideDescale = cute::Stride; + + using TMA_Q = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{})); + + using TMA_K = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{})); // mcast along M mode for this N load, if any + + using TMA_V = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_Qv_ = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{})); + using TMA_Qv = std::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); + + using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; + using MainloopPipelineK = std::conditional_t>; + using MainloopPipelineV = std::conditional_t>; + using MainloopPipelineVt = std::conditional_t>; + // We always use TMA for K_new and V_new + using MainloopPipelineKVNew = PipelineTmaAsync; + using MainloopPipelineNBlock = typename cutlass::PipelineAsync; + using MainloopPipelineFlashMaskApply = typename cutlass::PipelineAsync; + using PipelineState = cutlass::PipelineState; + + // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned + // and have sQ being position_independent_swizzle_tensor. + // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); + static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); + static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); + static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); + static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); + + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; + using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; + // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes + // smem size to go from 227KB to 228KB and we get "invalid argument". + + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + }; + + struct TensorStorageWithPNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemP_t smem_p; + }; + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemP_t smem_p; + SmemScale_t smem_scale; + }; + + using TensorStorageNoTranspose = std::conditional_t< + MmaPV_is_RS, + TensorStorageWithoutPNoTranspose, + std::conditional_t + >; + + static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); + static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); + struct TensorStorageTransposeV : cute::aligned_struct { + cute::array_aligned, SmemAlignmentV> smem_v; + cute::array_aligned, SmemAlignmentVt> smem_vt; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemScale_t smem_scale; + }; + + using TensorStorage = std::conditional_t; + + // These are tuned for speed. They don't affect correctness. + static constexpr bool UseSchedulerBarrier = (IntraWGOverlap + ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) + : NumMmaWarpGroups == 2) + && !LargeHeadDimV; + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + Element* const ptr_K; // not Element const* since we might append to KV cache in-place + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + float const softmax_scale; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + int const window_size_left = -1, window_size_right = -1; + float const softcap_val; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + + // FlashMask + int const h_flashmask; + int const h_h_flashmask_ratio; + + int32_t * __restrict__ const lt_start_ptr = nullptr; + int32_t * __restrict__ const lt_end_ptr = nullptr; + + int32_t * __restrict__ const ut_start_ptr = nullptr; + int32_t * __restrict__ const ut_end_ptr = nullptr; + + int32_t * __restrict__ const flashmask_maxmin_ptr = nullptr; + + int32_t * __restrict__ const lt_start_nblockmax = nullptr; + int32_t * __restrict__ const lt_start_nblockmin = nullptr; + + int32_t * __restrict__ const lt_end_nblockmax = nullptr; + int32_t * __restrict__ const lt_end_nblockmin = nullptr; + + int32_t * __restrict__ const ut_start_nblockmax = nullptr; + int32_t * __restrict__ const ut_start_nblockmin = nullptr; + + int32_t * __restrict__ const ut_end_nblockmax = nullptr; + int32_t * __restrict__ const ut_end_nblockmin = nullptr; + + int m_block_dim,n_block_dim; + int32_t * __restrict__ block_mask_ptr = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + ShapeQPacked const shape_Q_packed; + StrideQPacked const stride_Q_packed; + Element* const ptr_K; + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideV const stride_Qv; + ShapeQPacked const shape_Qv_packed; + StrideQPacked const stride_Qv_packed; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod blockN_per_page_size_divmod; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_Q tma_load_Q; + TMA_K tma_load_K; + TMA_V tma_load_V; + TMA_K tma_load_K_new; + TMA_V tma_load_V_new; + TMA_Qv tma_load_Qv; + float const softmax_scale_log2; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + float const softcap_val; + int const window_size_left, window_size_right; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + + // FlashMask + int const h_flashmask; + int const h_h_flashmask_ratio; + + int * __restrict__ const lt_start_ptr = nullptr; + int * __restrict__ const lt_end_ptr = nullptr; + + int * __restrict__ const ut_start_ptr = nullptr; + int * __restrict__ const ut_end_ptr = nullptr; + + int * __restrict__ const flashmask_maxmin_ptr = nullptr; + + int * __restrict__ const lt_start_nblockmax = nullptr; + int * __restrict__ const lt_start_nblockmin = nullptr; + + int * __restrict__ const lt_end_nblockmax = nullptr; + int * __restrict__ const lt_end_nblockmin = nullptr; + + int * __restrict__ const ut_start_nblockmax = nullptr; + int * __restrict__ const ut_start_nblockmin = nullptr; + + int * __restrict__ const ut_end_nblockmax = nullptr; + int * __restrict__ const ut_end_nblockmin = nullptr; + + // int m_block_dim,n_block_dim; + int32_t * __restrict__ block_mask_ptr = nullptr; + // int m_factor = 0, n_factor = 0; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_Q tma_load_Q = make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + TMA_K tma_load_K = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK, + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), + select<1, 0, 2, 3>(args.stride_V)); + TMA_V tma_load_V = make_tma_copy( + GmemTiledCopyKV{}, + mV, + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); + TMA_K tma_load_K_new = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + cute::conditional_return(mKnew, mK), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), + make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), + select<1, 0, 2, 3>(args.stride_V_new)); + TMA_V tma_load_V_new = make_tma_copy( + GmemTiledCopyKV{}, + cute::conditional_return(mVnew, mV), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); + Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); + TMA_Qv tma_load_Qv = [&] { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQv, + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); // no mcast for Qv + } else { + return nullptr; + } + }(); + // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); + auto const shape_Q_packed = cute::conditional_return( + args.shape_Q, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) + ); + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) + ); + auto const shape_Qv_packed = cute::conditional_return( + shape_Qv, + make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) + ); + auto const stride_Qv_packed = cute::conditional_return( + args.stride_Qv, + make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) + ); + if (get<1>(args.shape_rotary) > 0) { + assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); + } + assert(args.num_splits >= 1); + int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); + if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { + assert(page_size % kBlockN == 0); + assert(!args.leftpad_k); + } + + //block sparse attn + assert(args.m_block_dim == 128); + assert(args.n_block_dim == 128); + // int m_factor = args.m_block_dim / kBlockM; + // int n_factor = args.n_block_dim / kBlockN; + + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, + args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, + args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, + args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, + args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, + cutlass::FastDivmod(page_size), // page_size_divmod + cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, + args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.window_size_left, args.window_size_right, + !Split ? 1 : args.num_splits, + args.kv_batch_idx, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, + args.h_flashmask, args.h_h_flashmask_ratio, + args.lt_start_ptr, args.lt_end_ptr, + args.ut_start_ptr, args.ut_end_ptr, + args.flashmask_maxmin_ptr, + args.lt_start_nblockmax, args.lt_start_nblockmin, + args.lt_end_nblockmax, args.lt_end_nblockmin, + args.ut_start_nblockmax, args.ut_start_nblockmin, + args.ut_end_nblockmax, args.ut_end_nblockmin, + // args.m_block_dim,args.n_block_dim, + // m_factor,n_factor, + args.block_mask_ptr}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA_Q) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + if constexpr (HasQv) { + cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); + } + } + if constexpr (Use_TMA_KV) { + cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); + } + if constexpr (AppendKV) { + cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor()); + } + } + + enum PtrExistDispatchTag { + SINGLE_PTR = 0x0, // lt_start is always valid + DUAL_PTR = 0x1, // two ptrs, with one more lt_end or ut_end + FULL_PTR = 0x2 // all four ptrs + }; + + CUTLASS_DEVICE + void + load_max_min(Params const& params, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int32_t reverse_chunk_idx, // reverse_chunk_idx, start from right to left: [5, 4, 3, 2, 1, 0] + int32_t total_num_chunks, + int32_t* const flashmask_maxmin_smem) { + int32_t bidh = get<1>(block_coord); + int32_t bidb = get<2>(block_coord); + // pad for fully 128B aligned load + int32_t m_block = get<0>(block_coord); + const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + + // change this to num_chunk * chunk_size (should be Flashmask_n_block_buffer_length) + const int chunks_size = total_num_chunks * Flashmask_n_block_buffer_length; + + const int offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * chunks_size + + (chunks_size - (reverse_chunk_idx + 1) * Flashmask_n_block_buffer_length); + + const int thread_idx = threadIdx.x - 32; + const int length = Flashmask_n_block_buffer_valid_length < nblock_seqlen ? Flashmask_n_block_buffer_valid_length : nblock_seqlen; + + // it's a pity that tag cannot have static dispatch, since load_max_min should remain the same + // across different main loop implementation. We can implement a func with default + const auto tag = [¶ms]() { + if (params.ut_start_ptr) + return PtrExistDispatchTag::FULL_PTR; + else if (params.lt_end_ptr || params.ut_end_ptr) + return PtrExistDispatchTag::DUAL_PTR; + return PtrExistDispatchTag::SINGLE_PTR; + }(); + + #define CP_ASYNC_SMEM_INT4(src_ptr, index) \ + asm volatile( \ + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" \ + ::"r"(cutlass::arch::cutlass_get_smem_pointer( \ + reinterpret_cast(flashmask_maxmin_smem) + Flashmask_n_block_buffer_length / 4 * index + idx)), \ + "l"(reinterpret_cast(params.src_ptr + offset) + idx), \ + "n"(16)) + + // Note(heqianyue): we make sure that length is the multiple of 4. If this constraint does not hold in the future + // one can refer to d0659db5c7 and re-implement the loading of remaining elements + if (tag == PtrExistDispatchTag::FULL_PTR) { + for(int64_t idx = thread_idx; idx * 4 < length; idx += ProducerThreadNum) { + // lt start is always valid in flashmask (otherwise it is a bug) + CP_ASYNC_SMEM_INT4(lt_start_nblockmax, 0); + CP_ASYNC_SMEM_INT4(lt_start_nblockmin, 1); + CP_ASYNC_SMEM_INT4(lt_end_nblockmax, 2); + CP_ASYNC_SMEM_INT4(lt_end_nblockmin, 3); + + CP_ASYNC_SMEM_INT4(ut_start_nblockmax, 4); + CP_ASYNC_SMEM_INT4(ut_start_nblockmin, 5); + CP_ASYNC_SMEM_INT4(ut_end_nblockmax, 6); + CP_ASYNC_SMEM_INT4(ut_end_nblockmin, 7); + } + } else if (tag == PtrExistDispatchTag::DUAL_PTR) { + for(int64_t idx = thread_idx; 4 * idx < length; idx += ProducerThreadNum) { + // lt start is always valid in flashmask (otherwise it is a bug) + CP_ASYNC_SMEM_INT4(lt_start_nblockmax, 0); + CP_ASYNC_SMEM_INT4(lt_start_nblockmin, 1); + // check: paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu (#L2073-L2119) + if constexpr (Is_causal) { + CP_ASYNC_SMEM_INT4(lt_end_nblockmax, 2); + CP_ASYNC_SMEM_INT4(lt_end_nblockmin, 3); + } else { + CP_ASYNC_SMEM_INT4(ut_end_nblockmax, 6); + CP_ASYNC_SMEM_INT4(ut_end_nblockmin, 7); + } + } + } else { + for(int64_t idx = thread_idx; 4 * idx < length; idx += ProducerThreadNum) { + // lt start is always valid in flashmask (otherwise it is a bug) + CP_ASYNC_SMEM_INT4(lt_start_nblockmax, 0); + CP_ASYNC_SMEM_INT4(lt_start_nblockmin, 1); + } + } + #undef CP_ASYNC_SMEM_INT4 + + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_group 0;\n" ::); + + cutlass::arch::NamedBarrier::sync(ProducerThreadNum, static_cast(FwdNamedBarriers::FlashMaskNBlock)); + } + + CUTLASS_DEVICE + void + load_blockmask(Params const& params, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int32_t reverse_chunk_idx, + int32_t total_num_chunks, + int32_t* block_mask_smem ) { + + int32_t bidh = get<1>(block_coord); + int32_t bidb = get<2>(block_coord); + int32_t m_block = get<0>(block_coord); + const int thread_idx = threadIdx.x - 32; + + const int chunks_size = total_num_chunks * Flashmask_n_block_buffer_length; + const int offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * chunks_size + + (chunks_size - (reverse_chunk_idx + 1) * Flashmask_n_block_buffer_length); + + const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + + const int valid_block_nblock_seqlen = (seqlen_info.seqlen_k + n_block_dim - 1) / n_block_dim ; //xhy :maybe nblock_seqlen - 4 + const int valid_block_mblock_seqlen = (seqlen_info.seqlen_q + m_block_dim - 1) / m_block_dim; + int blockmask_offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * valid_block_nblock_seqlen * valid_block_mblock_seqlen; // row_offset + blockmask_offset += m_block * valid_block_nblock_seqlen / m_factor; + blockmask_offset += std::max((valid_block_nblock_seqlen - (reverse_chunk_idx + 1) * Blockmask_n_block_buffer_valid_length), 0); + int blockmask_length = Blockmask_n_block_buffer_valid_length < valid_block_nblock_seqlen ? Blockmask_n_block_buffer_valid_length : valid_block_nblock_seqlen; + + //xhy: blockmask ptr maybe not 16-aligned, since load_blockmask is called before load_max_min, sync can be shared with load_max_min + for(int64_t idx = thread_idx; idx < blockmask_length && (offset + idx >= 0); idx += ProducerThreadNum) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(block_mask_smem + idx)), + "l"(params.block_mask_ptr + blockmask_offset + idx), + "n"(4)); + } +} + + template + CUTLASS_DEVICE bool + generate_n_block(int32_t const m_block, + int32_t const reverse_chunk_idx, // reverse_chunk_idx, start from right to left: [5, 4, 3, 2, 1, 0] + int32_t const total_num_chunks, + int32_t const end_flag, + int32_t const n_block_min, + int32_t const n_block_max, + int32_t const seqlen_q, + int32_t* const __restrict__ flashmask_maxmin_smem, + int32_t* const __restrict__ mask_encode_n_block_smem_, + int32_t* const __restrict__ extra_flags, + bool is_blockmask, + int32_t* const __restrict__ block_mask_smem = nullptr) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + int const thread_idx = threadIdx.x - 32; + + __shared__ int s_prefix_sum[4]; + int32_t lt_start_max = INT_MAX; + int32_t lt_start_min = INT_MAX; + + int32_t lt_end_max = INT_MAX; + int32_t lt_end_min = INT_MAX; + + int32_t ut_start_max = INT_MIN; + int32_t ut_start_min = INT_MIN; + + int32_t ut_end_max = INT_MIN; + int32_t ut_end_min = INT_MIN; + + const int32_t* const s_lt_start_max = flashmask_maxmin_smem; + const int32_t* const s_lt_start_min = flashmask_maxmin_smem + Flashmask_n_block_buffer_length; + + const int32_t *s_lt_end_max = nullptr, *s_lt_end_min = nullptr, *s_ut_start_max = nullptr, + *s_ut_start_min = nullptr, *s_ut_end_max = nullptr, *s_ut_end_min = nullptr; + + if constexpr (tag == PtrExistDispatchTag::FULL_PTR) { + s_lt_end_max = s_lt_start_min + Flashmask_n_block_buffer_length; + s_lt_end_min = s_lt_end_max + Flashmask_n_block_buffer_length; + + s_ut_start_max = s_lt_end_min + Flashmask_n_block_buffer_length; + s_ut_start_min = s_ut_start_max + Flashmask_n_block_buffer_length; + + s_ut_end_max = s_ut_start_min + Flashmask_n_block_buffer_length; + s_ut_end_min = s_ut_end_max + Flashmask_n_block_buffer_length; + } else if constexpr (tag == PtrExistDispatchTag::DUAL_PTR) { + if constexpr (Is_causal) { + s_lt_end_max = s_lt_start_min + Flashmask_n_block_buffer_length; + s_lt_end_min = s_lt_end_max + Flashmask_n_block_buffer_length; + } else { + s_ut_end_max = flashmask_maxmin_smem + 6 * Flashmask_n_block_buffer_length; + s_ut_end_min = s_ut_end_max + Flashmask_n_block_buffer_length; + } + } + + int32_t valid_n_block_num = 0; + + const int32_t base_offset = (total_num_chunks - 1 - reverse_chunk_idx) * Flashmask_n_block_buffer_valid_length; + + // explanation for the loop condition: + // -2, -1, 0, 1, 2 + // t4, t3, t2, t1, t0 + // although t4 and t3 are oob, they should not exit the loop, otherwise, the prefix-sum inside the loop will hang, just keep a default value is fine + const int m_block_s = m_block * kBlockM; + // Note(heqianyue): ute/lte will be seqlen_q (at most). Yet if m_block_e > seqlen_q, even if ute/lte are seqlen_q (masked to the end) + // we will still consider the block as partially masked, adding unnecessary computation for those fully-masked blocks + const int m_block_e = __viaddmin_s32(m_block_s, kBlockM, seqlen_q); // min(a + b, c) + for(int32_t idx = Flashmask_n_block_buffer_valid_length - 1 - thread_idx; // make sure thread_idx is in range [0, ProducerThreadNum) + idx >= (0 - (ProducerThreadNum - Flashmask_n_block_buffer_valid_length % ProducerThreadNum)); idx -= ProducerThreadNum + ) { + int32_t n_block = base_offset + idx; + int prefix_sum = 0; + bool fully_masked = true; + bool partially_masked; + if(n_block >= n_block_min && n_block < n_block_max && idx >= 0) { + lt_start_max = s_lt_start_max[idx]; + lt_start_min = s_lt_start_min[idx]; + + if constexpr (tag == PtrExistDispatchTag::FULL_PTR) { + lt_end_max = s_lt_end_max[idx]; + lt_end_min = s_lt_end_min[idx]; + ut_start_max = s_ut_start_max[idx]; + ut_start_min = s_ut_start_min[idx]; + ut_end_max = s_ut_end_max[idx]; + ut_end_min = s_ut_end_min[idx]; + + fully_masked = (m_block_s >= lt_start_max && m_block_e <= lt_end_min) || + (m_block_s >= ut_start_max && m_block_e <= ut_end_min); + partially_masked = (m_block_s < lt_end_max && m_block_e > lt_start_min) || + (m_block_s < ut_end_max && m_block_e > ut_start_min); + } else if constexpr (tag == PtrExistDispatchTag::DUAL_PTR) { + if constexpr (Is_causal) { + lt_end_max = s_lt_end_max[idx]; + lt_end_min = s_lt_end_min[idx]; + fully_masked = m_block_s >= lt_start_max && m_block_e <= lt_end_min; + partially_masked = m_block_s < lt_end_max && m_block_e > lt_start_min; + } else { + ut_end_max = s_ut_end_max[idx]; + ut_end_min = s_ut_end_min[idx]; + fully_masked = (m_block_s >= lt_start_max) || (m_block_e <= ut_end_min); + partially_masked = (m_block_e > lt_start_min) || (m_block_s < ut_end_max); + } + } else { + fully_masked = m_block_s >= lt_start_max; + partially_masked = m_block_e > lt_start_min; + } + if (is_blockmask){ + if(!block_mask_smem[idx / n_factor]){ + fully_masked = true; + } + } + + prefix_sum = int(!fully_masked); + } + + const int warp_id_ = thread_idx >> 5; + const int lane_id_ = thread_idx & 31; + // warp-wide prefix-sum + #pragma unroll + for(int i=1; i<32; i*=2) { + int tmp_prefix_sum = __shfl_up_sync(0xffffffff, prefix_sum, i); + prefix_sum = lane_id_ >= i ? prefix_sum + tmp_prefix_sum : prefix_sum; + } + + // inter-warp prefix-sum + if(lane_id_ == 31) { + s_prefix_sum[warp_id_] = prefix_sum; + } + + cutlass::arch::NamedBarrier::sync(ProducerThreadNum, static_cast(FwdNamedBarriers::FlashMaskNBlock)); + + // Currently, we use only 3 warps, so (warp_id_ <= 2) is always true, we can remove it from the predicate + if(warp_id_ >= 1) { + prefix_sum += s_prefix_sum[0]; + if (warp_id_ == 2) { + prefix_sum += s_prefix_sum[1]; + if (lane_id_ == 31) s_prefix_sum[2] = prefix_sum; + } + } + + // if not fully masked or not partially masked: unmasked, useless (no need to compute) + if(!fully_masked) + mask_encode_n_block_smem_[valid_n_block_num + prefix_sum - 1] = partially_masked ? n_block : (-n_block - 1); + cutlass::arch::NamedBarrier::sync(ProducerThreadNum, static_cast(FwdNamedBarriers::FlashMaskNBlock)); + valid_n_block_num += s_prefix_sum[2]; + cutlass::arch::NamedBarrier::sync(ProducerThreadNum, static_cast(FwdNamedBarriers::FlashMaskNBlock)); + } + // Do not allocate buffer length that is not the multiple of 32 (otherwise there will be global excessive sectors) + if (valid_n_block_num < Flashmask_n_block_buffer_valid_length) + mask_encode_n_block_smem_[valid_n_block_num] = end_flag; + else + *extra_flags = end_flag; + return valid_n_block_num != 0 || end_flag != Flashmask_n_block_chunk_end; + } + + template + CUTLASS_DEVICE void + load(Params const& params, + MainloopPipelineK pipeline_k, + MainloopPipelineV pipeline_v, + MainloopPipelineVt pipeline_vt, + MainloopPipelineNBlock pipeline_n_block, + MainloopPipelineFlashMaskApply pipeline_flashmask_apply, + PipelineState& smem_pipe_write, + cutlass::PipelineState& n_block_pipe_read, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int &work_idx, + int32_t* const flashmask_smem_, + const int32_t* const n_block_smem, + const int32_t* const extra_flags + ) { + // some of these are captured in lambda so can't use structured binding + int const m_block = get<0>(block_coord); + + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { + return; + } + } + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sK_pi = as_position_independent_swizzle_tensor(sK); + // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose. + // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes. + Tensor sVt = [&] { + if constexpr (!Transpose_V) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + } else { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{})); + } + }(); + // Only used if Transpose_V + Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{})); + // Only used if we're using cp.async to load V + Tensor sVcpasync = [&] { + if constexpr (!Transpose_V) { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); + } else { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); + } + }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + + int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } + Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) + + auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually + auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); + Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) + Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) + auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); + Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) + Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); + + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx + ); + + // Set up for transposing V, only used if Transpose_V + S2RTiledCopyVt s2r_tiled_copy_vt; + R2STiledCopyV r2s_tiled_copy_v; + auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(thread_idx); + auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(thread_idx); + // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8, kStages) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages) + // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages) + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages) + CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); + // Faster to have 2 LDSM.T, byte permute, STSM for better ILP + static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; + Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) + Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) + auto transpose_V = [&](int stage) { + if constexpr (Transpose_V) { + #pragma unroll + for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { + Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{})); + static_assert(size<0>(tTransrV) == 16); + Tensor tTransrV_64 = recast(tTransrV); + cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV); + #pragma unroll + for (int j = 0; j < size(tTransrV_64); ++j) { + uint32_t upper = tTransrV_64[j].x; + uint32_t lower = tTransrV_64[j].y; + tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); + tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + } + cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage)); + } + } + }; + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { + pipeline_k.producer_acquire(smem_pipe_write); + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); + copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), + tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); + } else { + constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); + pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { + auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); + pipeline_v_load.producer_acquire(smem_pipe_write); + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); + copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), + tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); + } else { + constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); + pipeline_v_load.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto copy_Vt_to_V = [&] (auto const& smem_pipe_write) { + // Instead of maintaining smem_pipe_read as a separate variable, we can just use smem_pipe_write, + // and exploit the invariance that smem_pipe_write.phase() == smem_pipe_read.phase() ^ 1. + // This saves 1 or 2 registers. + PipelineState smem_pipe_read{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()}; + pipeline_vt.consumer_wait(smem_pipe_read); + pipeline_v.producer_acquire(smem_pipe_write); + transpose_V(smem_pipe_write.index()); + // SMEM fence to make sure V is transposed before math + cutlass::arch::fence_view_async_shared(); + pipeline_v.producer_commit(smem_pipe_write); + // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized + // before calling. Without this we get race conditions. + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + pipeline_vt.consumer_release(smem_pipe_read); + }; + + int n_block = n_block_max; + int32_t n_block_idx = 0; + + const int32_t* mask_encode_n_block_smem_ = n_block_smem + Flashmask_n_block_buffer_length * n_block_pipe_read.index(); + const int32_t* extra_flags_smem = extra_flags + n_block_pipe_read.index(); + + auto load_flashmask = [&] (auto const& smem_pipe_write) { + if constexpr (Is_flashmask) { + pipeline_flashmask_apply.producer_acquire(smem_pipe_write); + int32_t* const flashmask_base_addr = flashmask_smem_ + smem_pipe_write.index() * 4 * kBlockN; + if(n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { + const int row_offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * seqlen_info.seqlen_k; + const int nb_mul_kBN = n_block * kBlockN; + const int loop_ub = std::min(kBlockN, seqlen_info.seqlen_k - nb_mul_kBN); + if (params.ut_start_ptr != nullptr) { + for(int idx = thread_idx; idx < loop_ub; idx += NumProducerThreads) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + idx)), + "l"(params.lt_start_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN + idx)), + "l"(params.lt_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN * 2 + idx)), + "l"(params.ut_start_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN * 3 + idx)), + "l"(params.ut_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + } + } else { + + for(int idx = thread_idx; idx < loop_ub; idx += NumProducerThreads) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + idx)), + "l"(params.lt_start_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + + if constexpr (Is_causal) { + if(params.lt_end_ptr != nullptr) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + kBlockN + idx)), + "l"(params.lt_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + } + } else { + if(params.ut_end_ptr != nullptr) { + asm volatile( + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + ::"r"(cutlass::arch::cutlass_get_smem_pointer(flashmask_base_addr + 3 * kBlockN + idx)), + "l"(params.ut_end_ptr + row_offset + nb_mul_kBN + idx), + "n"(4)); + } + } + } + } + asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.wait_group 0;\n" ::); + } + pipeline_flashmask_apply.producer_commit(smem_pipe_write); + } + }; + + auto n_block_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + auto n_block_getter = [&mask_encode_n_block_smem_, &extra_flags_smem](int32_t index) { + // if val >= 0 or val in [end, finish]: return val, else: return -val - 1 + if (index < Flashmask_n_block_buffer_valid_length) { + const int32_t encoded = mask_encode_n_block_smem_[index]; + const int32_t mask = -static_cast(encoded <= INT_MIN + 1); // INT_MIN is Flashmask_n_block_chunk_end + const int32_t converted = encoded ^ (encoded >> 31); + return (converted & ~mask) | (encoded & mask); + } else { + return *extra_flags_smem; + } + }; + + n_block_wait(pipeline_n_block, n_block_pipe_read); + n_block = n_block_getter(n_block_idx); + + if(n_block < n_block_min && n_block != Flashmask_n_block_chunk_end) { + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + return; + } + + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); + + if (should_load_KV) { + if constexpr (PagedKVNonTMA) { + paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.template load_page_table_TMA(n_block); + } + if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); + } + + if constexpr (Use_TMA_Q) { + // Wait for the MMA warpgroups to signal that smem_q is ready + if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + + if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { + shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } + } + } else { // Load Q with cp.async + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); + barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } + } + + // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. + // if (thread_idx == 0) { printf("Producer: main load, before barrier_O, work_idx = %d\n", work_idx);} + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");} + + load_flashmask(smem_pipe_write); + + if constexpr (!Transpose_V && !IntraWGOverlap) { + if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + } + int n_block_prev = n_block; + + n_block = n_block_getter(++n_block_idx); + + #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1) + for (; n_block >= n_block_min || n_block == Flashmask_n_block_chunk_end;) { + for(; n_block >= n_block_min; n_block = n_block_getter(++n_block_idx)) { + PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind + ++smem_pipe_write; + if (should_load_KV) { + if constexpr (PagedKVNonTMA) { + paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.load_page_table_TMA(n_block); + } + if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } + load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); + if constexpr (!Transpose_V) { + if constexpr (IntraWGOverlap) { + load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); + } else { + load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); + } + } + } + n_block_prev = n_block; + if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } + + load_flashmask(smem_pipe_write); + + } + + if (n_block == Flashmask_n_block_chunk_end) { + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + n_block_wait(pipeline_n_block, n_block_pipe_read); + mask_encode_n_block_smem_ = n_block_smem + Flashmask_n_block_buffer_length * n_block_pipe_read.index(); + extra_flags_smem = extra_flags + n_block_pipe_read.index(); + n_block_idx = 0; + n_block = n_block_getter(0); + } + } + + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + + if constexpr (!Transpose_V && IntraWGOverlap) { + if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + } + if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); } + ++smem_pipe_write; + // At the end, all threads have the correct smem_pipe_write. + ++work_idx; + } + + template + CUTLASS_DEVICE void + load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt, + PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) { + // If we don't wait for barrier_O here, when using Cluster, CTA0 might exit early and CTA1 will + // try to arrive on barrier_O of CTA0, causing "unspecified launch failure". + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // Issue the epilogue waits + // TODO: check if this should be called by 1 thread or more + if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_k.producer_tail(smem_pipe_write); + pipeline_v.producer_tail(smem_pipe_write); + if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); } + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_arrive() { + if constexpr (UseSchedulerBarrier) { + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1; + int const next_WG = NumMmaWarpGroups == 2 + ? 1 - cur_WG + : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0); + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/); + } + } + + CUTLASS_DEVICE void + mma_init() { + int warp_group_idx = flash::canonical_warp_group_idx_nosync(); + // Tell producers that smem_q is ready + if (!LargeHeadDimV || warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + if (LargeHeadDimV && warp_group_idx > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + if constexpr (UseSchedulerBarrier) { + // We have NamedBarrier for up to 3 WGs + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + // WG1 needs the very first signal to start + if (warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + } + + template + CUTLASS_DEVICE + void flashmask_apply(Tensor &tSrS, int m_block, int const thread_idx, int const index, int32_t* const flashmask_smem_, + int32_t* const lt_start_ptr, int32_t* const lt_end_ptr, + int32_t* const ut_start_ptr, int32_t* const ut_end_ptr) { + auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); + + Tensor cS = cute::make_identity_tensor(Shape, Int>{}); + Tensor tScS = thread_mma.partition_C(cS); + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); + + static constexpr int Row = 0, Col = 1; + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + const int32_t* const s_lt_start = flashmask_smem_ + 4 * kBlockN * index; + m_block *= kBlockM; + if constexpr (tag == PtrExistDispatchTag::SINGLE_PTR) { + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = get(tScS_rowcol(_0{}, n)); // col_idx within a block + int lts = s_lt_start[col_idx] - m_block; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = get(tScS_rowcol(m, n)); + if(row_idx >= lts) tSrS_rowcol(m, n) = -INFINITY; + } + } + return; + } else if constexpr (tag == PtrExistDispatchTag::FULL_PTR) { + const int32_t* const s_lt_end = flashmask_smem_ + 4 * kBlockN * index + kBlockN; + const int32_t* const s_ut_start = flashmask_smem_ + 4 * kBlockN * index + 2 * kBlockN; + const int32_t* const s_ut_end = flashmask_smem_ + 4 * kBlockN * index + 3 * kBlockN; + + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = get(tScS_rowcol(_0{}, n)); // col_idx within a block + int lts = s_lt_start[col_idx] - m_block; + int lte = s_lt_end[col_idx] - m_block; + int uts = s_ut_start[col_idx] - m_block; + int ute = s_ut_end[col_idx] - m_block; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = get(tScS_rowcol(m, n)); + if((row_idx >= lts && row_idx < lte) || (row_idx >= uts && row_idx < ute)) + tSrS_rowcol(m, n) = -INFINITY; + } + } + } else { + if constexpr (Is_causal) { + const int32_t* const s_lt_end = flashmask_smem_ + 4 * kBlockN * index + kBlockN; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = get(tScS_rowcol(_0{}, n)); // col_idx within a block + int lts = s_lt_start[col_idx] - m_block; + int lte = s_lt_end[col_idx] - m_block; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = get(tScS_rowcol(m, n)); + if(row_idx >= lts && row_idx < lte) + tSrS_rowcol(m, n) = -INFINITY; + } + } + } else { + const int32_t* const s_ut_end = flashmask_smem_ + 4 * kBlockN * index + 3 * kBlockN; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = get(tScS_rowcol(_0{}, n)); // col_idx within a block + int lts = s_lt_start[col_idx] - m_block; + int ute = s_ut_end[col_idx] - m_block; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = get(tScS_rowcol(m, n)); + if((row_idx >= lts) || (row_idx < ute)) + tSrS_rowcol(m, n) = -INFINITY; + } + } + } + } + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + MainloopPipelineK pipeline_k, + MainloopPipelineV pipeline_v, + MainloopPipelineNBlock pipeline_n_block, + MainloopPipelineFlashMaskApply pipeline_flashmask_apply, + PipelineState& smem_pipe_read, + cutlass::PipelineState& n_block_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + int &work_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage, + int32_t* const flashmask_smem_, + const int32_t* const n_block_smem, + const int32_t* const extra_flags + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = [&] { + if constexpr (MmaPV_is_RS) { + // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + } + }(); + Tensor sScale = [&] { + if constexpr (LargeHeadDimV) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + } else { // won't be used, just a placeholder + return make_tensor(make_smem_ptr(static_cast(nullptr)), SmemLayoutScale{}); + } + }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); + + if constexpr (!MmaQK_is_RS) { + static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and + stride<0>(typename TiledMmaQK::BLayout{}) == 0 and + size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + } + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + TiledMmaQV tiled_mma_qv; + auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); + auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" + Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); + Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); + Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + + // For storing scales to smem, only used when LargeHeadDimV + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto store_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + if (get<1>(taccOcO_row(_0{})) == 0) { + sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); + } + } + }; + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + + int n_block = n_block_max; + + consumer_wait(pipeline_n_block, n_block_pipe_read); + const int32_t* mask_encode_n_block_smem_ = n_block_smem + Flashmask_n_block_buffer_length * n_block_pipe_read.index(); + const int32_t* extra_flags_smem = extra_flags + n_block_pipe_read.index(); + + int n_block_idx = 0; + auto n_block_getter = [&mask_encode_n_block_smem_, &extra_flags_smem](int32_t index) { + // if val >= 0 or val in [end, finish]: return val, else: return -val - 1 + if (index < Flashmask_n_block_buffer_valid_length) { + const int32_t encoded = mask_encode_n_block_smem_[index]; + const int32_t mask = -static_cast(encoded <= INT_MIN + 1); // INT_MIN is Flashmask_n_block_chunk_end + const int32_t converted = encoded ^ (encoded >> 31); + return (converted & ~mask) | (encoded & mask); + } else { + return *extra_flags_smem; + } + }; + n_block = n_block_getter(0); + + if(n_block < n_block_min && n_block != Flashmask_n_block_chunk_end) { + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + return false; + } + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.qhead_per_khead_divmod + ); + + float softcap_val = params.softcap_val; + if constexpr (Has_softcap && Is_FP8) { + float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; + float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; + softcap_val *= q_descale * k_descale; + } + // Softcapping needs to happen before masking since if we apply after masking, softcapping + // can turn -inf to e.g. -50.0, which can affect the attention softmax. + auto scoremod_premask_fn = [&](auto& tSrS) { + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } + }; + + auto write_P_to_smem = [&](auto& tOrP) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); + }; + + auto arrive_on_P_write_barrier = [&] { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + }; + + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + if constexpr (!AppendKV) { + barrier_Q.wait(work_idx % 2); + } else { + if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q + int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); + int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + barrier_Q.wait(work_idx % 2); + rotary.apply_Q_interleaved(sQ_pi, tRrCos, tRrSin, m_block, qhead_per_khead); + } else { + auto [tRrCosCont, tRrSinCont] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + barrier_Q.wait(work_idx % 2); + rotary.apply_Q_contiguous(sQ_pi, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); + } + // SMEM fence to make sure the rotated Q is visible to GMMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + } else { + barrier_Q.wait(work_idx % 2); + } + } + + if constexpr (MmaQK_is_RS) { + using SmemCopyAtomQ = Copy_Atom; + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); + cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); + } + + if constexpr (IntraWGOverlap) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } + scoremod_premask_fn(tSrS); + mask.template apply(tSrS, m_block, n_block); + + if constexpr(Is_flashmask) { + consumer_wait(pipeline_flashmask_apply, smem_pipe_read); + if (n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { + if (params.ut_start_ptr) { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, params.lt_end_ptr, + params.ut_start_ptr, params.ut_end_ptr); + } else if (params.lt_end_ptr || params.ut_end_ptr) { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, params.lt_end_ptr, + nullptr, params.ut_end_ptr); + } else { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, nullptr, nullptr, nullptr); + } + } + pipeline_flashmask_apply.consumer_release(smem_pipe_read); + } + + Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + + softmax.template online_softmax(tSrS); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } + n_block = n_block_getter(++n_block_idx); + + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. + auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { + static constexpr bool Check_inf = decltype(check_inf_type)::value; + PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); + ++smem_pipe_read; + // PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); + // ++smem_pipe_read; + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } + scoremod_premask_fn(tSrS); + mask_fn(tSrS, n_block); + + if constexpr (Is_flashmask) { + consumer_wait(pipeline_flashmask_apply, smem_pipe_read); + if (n_block_idx < Flashmask_n_block_buffer_valid_length && mask_encode_n_block_smem_[n_block_idx] >= 0) { + if (params.ut_start_ptr) { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, params.lt_end_ptr, params.ut_start_ptr, params.ut_end_ptr); + } else if (params.lt_end_ptr || params.ut_end_ptr) { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, params.lt_end_ptr, nullptr, params.ut_end_ptr); + } else { + flashmask_apply( + tSrS, m_block, thread_idx, smem_pipe_read.index(), flashmask_smem_, + params.lt_start_ptr, nullptr, nullptr, nullptr); + } + } + pipeline_flashmask_apply.consumer_release(smem_pipe_read); + } + + cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } + softmax.template online_softmax(tSrS); + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } + convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); + if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } + }; + + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + int const n_block_min_before_local_mask = !Is_local + ? n_block_min + : std::max(n_block_min, + cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + + if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking + auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); + int const n_block_min_causal_local_mask = + std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + for(; n_block >= n_block_min_causal_local_mask || n_block == Flashmask_n_block_chunk_end;) { + #pragma unroll 1 + for (; n_block >= n_block_min_causal_local_mask; n_block = n_block_getter(++n_block_idx)) { + fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); + } + if (n_block == Flashmask_n_block_chunk_end) { + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + consumer_wait(pipeline_n_block, n_block_pipe_read); + mask_encode_n_block_smem_ = n_block_smem + Flashmask_n_block_buffer_length * n_block_pipe_read.index(); + extra_flags_smem = extra_flags + n_block_pipe_read.index(); + n_block_idx = 0; + n_block = n_block_getter(0); + } + } + } + + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + for(; n_block >= n_block_min_before_local_mask || n_block == Flashmask_n_block_chunk_end;) { + #pragma unroll 1 + for (; n_block >= n_block_min_before_local_mask; n_block = n_block_getter(++n_block_idx)) { + fwd_step(n_block, no_mask_fn, cute::bool_constant{} /*check_inf*/); + } + if (n_block == Flashmask_n_block_chunk_end) { + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + consumer_wait(pipeline_n_block, n_block_pipe_read); + mask_encode_n_block_smem_ = n_block_smem + Flashmask_n_block_buffer_length * n_block_pipe_read.index(); + extra_flags_smem = extra_flags + n_block_pipe_read.index(); + n_block_idx = 0; + n_block = n_block_getter(0); + } + } + + // Separate masking iterations on the left for local attention + if constexpr (Is_local) { + auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + for(; n_block >= n_block_min || n_block == Flashmask_n_block_chunk_end;) { + #pragma unroll 1 + for (; n_block >= n_block_min; n_block = n_block_getter(++n_block_idx)) { + fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); + } + if (n_block == Flashmask_n_block_chunk_end) { + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + consumer_wait(pipeline_n_block, n_block_pipe_read); + mask_encode_n_block_smem_ = n_block_smem + Flashmask_n_block_buffer_length * n_block_pipe_read.index(); + extra_flags_smem = extra_flags + n_block_pipe_read.index(); + n_block_idx = 0; + n_block = n_block_getter(0); + } + } + } + pipeline_n_block.consumer_release(n_block_pipe_read); + ++n_block_pipe_read; + // Tell producers that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + } else { // No intra-WG overlap + static_assert(!Is_flashmask, "flashmaskv3 does not support no intra-wg overlap"); + } + ++work_idx; + return true; + } + + template + CUTLASS_DEVICE bool + mma_pv(Params const& params, + MainloopPipelineV pipeline_v, + MainloopPipelineNBlock pipeline_flashmask, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage, + int32_t* const flashmask_smem_ + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaPV tiled_mma_pv; + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate "fragments/descriptors" + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + + // For load scales to smem, pretend thread_idx is thread_idx % 128 + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto load_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); + } + }; + + // clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + typename Softmax::TensorT scores_scale; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + int32_t n_block = n_block_max - 1; + bool partially_masked; + static constexpr int NumThreads = NumMmaThreads; + int32_t min_n_block_in_smem = INT_MAX; + + // If HasQv, then by the time P is ready, V must have been ready as well + if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + + #pragma unroll 1 + for (; n_block >= n_block_min; n_block--) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + softmax.rescale_o(tOrO, scores_scale); + ++smem_pipe_read; + if constexpr (!HasQv) { + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + } + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + return true; + } + + template + CUTLASS_DEVICE bool + load_kv_new(Params const& params, + MainloopPipelineKVNew pipeline_k_new, + MainloopPipelineKVNew pipeline_v_new, + PipelineState& smem_pipe_write, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int const work_idx + ) { + + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + + if (n_block_new_max <= n_block_new_min) { return false; } + + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sVt = [&] { + if constexpr (!Transpose_V) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); + } + }(); + + // int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; + Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + + Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + + auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); + Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) + Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) + auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); + Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) + Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_k_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + }; + + auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_v_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + }; + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); + + int n_block = n_block_new_max - 1; + // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV + // and the main attention are not the same. We want to make sure the consumers + // have finished reading all smem_k and smem_v for the previous iteration. + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + ++smem_pipe_write; + --n_block; + // if (thread_idx == 0) { printf("Producer: before for loop\n"); } + #pragma unroll 1 + for (; n_block >= n_block_new_min; --n_block) { + if (should_load_KV) { + load_K_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + load_V_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + } + ++smem_pipe_write; + } + // if (thread_idx == 0) { printf("Producer: after for loop\n"); } + // At the end, all threads have the correct smem_pipe_write. + return true; + } + + template + CUTLASS_DEVICE bool + store_kv_new(Params const& params, + MainloopPipelineKVNew pipeline_k_new, + MainloopPipelineKVNew pipeline_v_new, + PipelineState& smem_pipe_read, + int const thread_idx, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord + ) { + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + if (n_block_new_max <= n_block_new_min) { return false; } + + // as_position_independent_swizzle_tensor makes address calculation easier + Tensor sK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{})); + // We want to use SmemLayoutVCpAsync to have shape (kBlockN, kHeadDim) instead of (kHeadDim, kBlockN) + Tensor sV = [&] { + if constexpr (!Transpose_V) { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); + } else { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); + } + }(); + + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + + int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; + Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; + int const seqlen_k_new = seqlen_info.seqlen_k_new; + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx + // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + ); + + if constexpr (UseSchedulerBarrier) { + // WG1 already got the very first signal from mma_init(), but we'll be using the same NamedBarrier. + // So we'll need to "cancel it out" here and then re-signal it at the end. + if (flash::canonical_warp_group_idx_nosync() == 1) { + cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + + static_assert(std::is_same_v); + static_assert(!PagedKVNonTMA || std::is_same_v); + GmemTiledCopyAppendKV gmem_tiled_copy_kv; + auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); + Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tKgK = gmem_thr_copy_kv.partition_D(gK); + Tensor tVsV = gmem_thr_copy_kv.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVgV = gmem_thr_copy_kv.partition_D(gV); + Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_kv.partition_D(cK); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + #pragma unroll + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } + Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) + Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); + Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } + Tensor tVpV = cute::conditional_return(tKpK, tVpV_); + + auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + if (get<1>(params.shape_rotary) <= 0) { + pipeline_k_new.consumer_wait(smem_pipe_read); + Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); + if constexpr (!PagedKVNonTMA) { + Tensor tKgK_cur = tKgK(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv, tKsK_cur, tKgK_cur, tKcK, tKpK, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) + ); + } else { + paged_kv_manager.store_K(n_block, tKsK_cur); + } + } else { + Tensor gK_cur = gK(_, _, n_block); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); + pipeline_k_new.consumer_wait(smem_pipe_read); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); + } else { + auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); + pipeline_k_new.consumer_wait(smem_pipe_read); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + } + } + // Without this sync I'm getting race condition when seqlen_k is large + cutlass::arch::fence_view_async_shared(); + // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized + // before calling. + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + pipeline_k_new.consumer_release(smem_pipe_read); + // if (thread_idx == 0) { print_tensor(tKpK); printf("\n"); printf("seqlen_limit = %d\n", seqlen_k_new - n_block * kBlockN);} + }; + + auto store_V = [&] (int const n_block, auto const& smem_pipe_read) { + pipeline_v_new.consumer_wait(smem_pipe_read); + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); + if constexpr (!PagedKVNonTMA) { + Tensor tVgV_cur = tVgV(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); + } else { + paged_kv_manager.store_V(n_block, tVsV_cur); + } + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + pipeline_v_new.consumer_release(smem_pipe_read); + }; + + #pragma unroll 1 + for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } + store_K(n_block, smem_pipe_read); + // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + store_V(n_block, smem_pipe_read); + // if (thread_idx == 0) { printf("Done storing V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + ++smem_pipe_read; + } + // if (thread_idx == 0) { printf("After for loop\n"); } + + // Re-signaling the NamedBarrier that we "canceled out" + if constexpr (UseSchedulerBarrier) { + if (flash::canonical_warp_group_idx_nosync() == 1) { + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + + return true; + + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/mask.h b/flashmask/flash_mask/flashmask_attention_v3/mask.h new file mode 100644 index 00000000000..fffb4780bc9 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/mask.h @@ -0,0 +1,171 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct Mask { + + static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); + + int const thread_idx; + int const seqlen_q, seqlen_k; + int const window_size_left, window_size_right, sink_token_length; + cutlass::FastDivmod const qhead_per_khead_divmod; + + CUTLASS_DEVICE + Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, + const int window_size_left, const int window_size_right, const int sink_token_length, + cutlass::FastDivmod const &qhead_per_khead_divmod) + : thread_idx(thread_idx) + , seqlen_q(seqlen_q) + , seqlen_k(seqlen_k) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , sink_token_length(sink_token_length) + , qhead_per_khead_divmod(qhead_per_khead_divmod) + { + }; + + template + CUTLASS_DEVICE + void apply(Tensor &tSrS, const int m_block, const int n_block) const { + static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } + + auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); + auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); + + static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; + + Tensor cS = cute::make_identity_tensor(Shape, Int>{}); + Tensor tScS = thread_mma.partition_C(cS); + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); + Tensor t0ScS = thread0_mma.partition_C(cS); + Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); + // We want to use the col indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first col index of this thread (get(tScS_rowcol(_0{}, _0{}))) + int const thread_col_offset = get(tScS_rowcol(_0{}, _0{})); + int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; + if constexpr (!Causal_mask && !Local_mask) { + if constexpr (Seqlenk_mask) { // Just masking based on col + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + if (int(get(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } else { // mask based on both row and col + if constexpr (!SwapAB) { + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); + int mma_m_idx; + // Might get OOB but it's ok since we'll check it later + if constexpr (PackGQA) { + mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); + } + int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; + if constexpr (Causal_mask) { + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = !PackGQA + ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM + : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + int const col_limit_right = !Seqlenk_mask + ? row_idx + causal_row_offset + : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } else { + int const local_row_offset_right = causal_row_offset + window_size_right; + int const local_row_offset_left = causal_row_offset - 1 - window_size_left; + int const col_limit_sink = sink_token_length - n_block * kBlockN; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = !PackGQA + ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM + : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + int const col_limit_right = !Seqlenk_mask + ? row_idx + local_row_offset_right + : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); + int const col_limit_left = row_idx + local_row_offset_left; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = int(get(t0ScS_rowcol(m, n))); + if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } else { + int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); + int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; + if constexpr (Causal_mask) { + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col0 = int(get(t0ScS_rowcol(_0{}, n))); + // If col0 is beyond the column limit, we want to mask out the entire column, by setting + // row limit to be kBlockM. + int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + if (int(get(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } else { + int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col0 = int(get(t0ScS_rowcol(_0{}, n))); + // If col0 is beyond the column limit, we want to mask out the entire column, by setting + // row limit to be kBlockM. + int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; + int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = int(get(t0ScS_rowcol(m, _0{}))); + if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/named_barrier.hpp b/flashmask/flash_mask/flashmask_attention_v3/named_barrier.hpp new file mode 100644 index 00000000000..8cc8dd28d84 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/named_barrier.hpp @@ -0,0 +1,100 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cutlass/arch/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work +// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80. + +CUTLASS_DEVICE +static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) { + static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); + uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +} + +CUTLASS_DEVICE +static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { + uint32_t barrier_id = static_cast(reserved_named_barriers); + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +} + +CUTLASS_DEVICE +static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) { + static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); + uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + +CUTLASS_DEVICE +static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { + uint32_t barrier_id = static_cast(reserved_named_barriers); + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class FwdNamedBarriers { + QueryEmpty = 0, + WarpSchedulerWG1 = 1, + WarpSchedulerWG2 = 2, + WarpSchedulerWG3 = 3, + FlashMaskNBlock = 4, // AppendKV is not used in FlashMask V3 + QueryRotated = 5, + PFull = 6, + TileCountSmemEmpty = 7, + TileCountSmemFull = 8, + ProducerWG = 9, // ProducerWG is only used in Transpose_V, so it is currently not used + TileCountSmemEmptyDual = 9, // HACK: ProducerWG is useless in FlashMask currently, reuse + // HACK: 10 + 6 --> we will simplicity use barrier 0 (syncthreads bar) + // This is not safe to use! Since we do not know for sure whether bar.sync 16 is equivalent to bar.sync 0 + // Even if it is, bar.sync 0 might conflicts with __syncthreads, but if there is no __syncthreads + // in the kernel, we will be fine (by far, nearly 20k tests passed) + TileCountSmemFullDual = 10, + NBlockProducer = 4, // HACK: NBlockProducer is only used in PPTScheduler + PEmpty = 3, // HACK: PEmpty is only used when we don't have 3 WGs +}; + +enum class BwdNamedBarriers { + KVEmpty = 0, + PdS = 1, + // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it + FlashmaskSmemEmpty = 2, + FlashmaskSmemFull = 3, + dQEmptyWG1 = 4, + dQEmptyWG2 = 5, + dQEmptyWG3 = 6, + dQFullWG1 = 7, + dQFullWG2 = 8, + dQFullWG3 = 9, + FlashmaskProducer = 0, // HACK: KVEmpty sync is unused +}; + +} // flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/pack_gqa.h b/flashmask/flash_mask/flashmask_attention_v3/pack_gqa.h new file mode 100644 index 00000000000..f60383cbb2d --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/pack_gqa.h @@ -0,0 +1,269 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct PackGQAManager { + // We use CpAsync for Q, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad; + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // In the case of PackGQA, this reduces the number of times we need to call divmod. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQCpAsync = decltype( + make_tiled_copy(GmemCopyAtomCpAsync{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need + // to sync within each WG, but didn't seem to be any faster. + // using GmemLayoutAtomWG = Layout, Int, Int >, + // Stride, _128, _1>>; + // using GmemTiledCopyQCpAsyncWG = decltype( + // make_tiled_copy(GmemCopyAtomCpAsync{}, + // GmemLayoutAtomNew{}, + // Layout>>{})); // Val layout, 8 or 16 vals per load + + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + template + CUTLASS_DEVICE + static auto + compute_ptr(Tensor &tensor, TensorC const &tRows, + cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) { + // tensor of shape ((qhead_per_khead, seqlen_q)) + static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow); + using TensorType = typename Engine::value_type; + Tensor tPrPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < NumPtrPerThread; ++i) { + int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow)); + int const idx = m_block * kBlockM + row; + int m_idx, h_idx; + m_idx = qhead_per_khead_divmod.divmod(h_idx, idx); + tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx))); + } + return tPrPtr; + } + + + template + CUTLASS_DEVICE + static void + load_Q(TensormQ const &mQ, // ((qhead_per_khead, seqlen_q), headdim) + TensorsQ &sQ, // (kBlockM, kHeadDim) + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_q, int const m_block + ) + { + GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async; + // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async; + auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_); + // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); } + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q. + // We split the work among threads loading the same row of Q, then __shfl_sync the pointers. + Tensor mQ_0 = mQ(_, _0{}); + Tensor tQcQ_row = tQcQ(_0{}, _, _0{}); + Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{})); + Element const* q_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < seqlen_q * qhead_per_khead) { + // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));} + Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape>{}); + Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad; + // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false + // TODO: check this + cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k)); + } + } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows + } + }; + + template + CUTLASS_DEVICE + static void + store_LSE(TensormLSE &mLSE, // ((qhead_per_khead, seqlen_q)) + TensorsLSE const &tLSErLSE, // (kBlockM) split across threads according to tiled_mma + TiledMma tiled_mma, + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{}); + CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); // MMA_M + + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow); + static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); + + Tensor tPrLSEPtr = compute_ptr(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int mi = 0; mi < size(tLSErLSE); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + float* ptr_LSE_cur = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) { + *ptr_LSE_cur = tLSErLSE(mi); + } + } + }; + + template + CUTLASS_DEVICE + static void + store_O(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) + TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor cO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); } + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. + // We split the work among threads loading the same row of O, then __shfl_sync the pointers. + Tensor mO_0 = mO(_, _0{}); + Tensor tOcO_row = tOcO(_0{}, _, _0{}); + Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tOrO); ++m) { + int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{})); + Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < seqlen_o * qhead_per_khead) { + Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); + Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOrO); ++k) { + int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore; + if (tOpO(k)) { + cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + static void + store_O_direct(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) + TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to tiled_mma + TiledMma tiled_mma, + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + static constexpr int kGmemElemsPerStoreDirect = 2; + cute::Copy_Atom, Element> gmem_copy_direct; + // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); + + Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); + + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. + // We split the work among threads loading the same row of O, then __shfl_sync the pointers. + Tensor mO_0 = mO(_, _0{}); + Tensor tPrOPtr = compute_ptr(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1); + + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tOrO_copy); ++m) { + int row = m_block * kBlockM + get<0>(taccOcO_row(m)); + Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow)); + if (row < seqlen_o * qhead_per_khead) { + Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); + Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOrO_copy); ++k) { + int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)); + if (col < size<1>(mO)) { + cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect)); + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/paged_kv.h b/flashmask/flash_mask/flashmask_attention_v3/paged_kv.h new file mode 100644 index 00000000000..0ed83615ed0 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/paged_kv.h @@ -0,0 +1,368 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct PagedKVManager { + // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), + // load_page_table(2), load_K(2), load_V(1), etc. + // So we need to compute the V pointers for the previous iteration. + + // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for + // rotary where we want each thread to have at least 2 loads per row. + + static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // In the case of PackGQA, this reduces the number of times we need to call divmod. + static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; + using GmemLayoutAtomKVCpAsync = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyKVCpAsync = decltype( + make_tiled_copy(GmemCopyAtomCpAsync{}, + GmemLayoutAtomKVCpAsync{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + using GmemTiledCopyKVStore = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomKVCpAsync{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + using ShapeKV = cute::Shape; // (seqlen, d, head, batch) + using StrideKV = cute::Stride; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + + using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _)); + using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _)); + using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); + using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); + + // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, + // since those require int64_t arithmetic. We optimize by having threads split this work. + // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows + // that each thread needs to load for the case of hdim 128 and kBlockN = 176. + // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. + // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); + static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); + using TensorPageOffset = decltype(make_tensor>(Shape>{})); + using TensorKVPtr = decltype(make_tensor(Shape>{})); + + GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; + cutlass::FastDivmod const &page_size_divmod; + cutlass::FastDivmod const &blockN_per_page_size_divmod; + int const thread_idx; + int const seqlen_k; + int const leftpad_k; + int const* const ptr_page_table; + GmemThrCopyKVCpAsync const gmem_thr_copy_kv; + TensorPageTable mPageTable; + TensorKV mK_paged, mV_paged; + TensortKpK tKpK; + TensortVpV tVpV; + TensorPageOffset tPrPageOffset; + TensorKVPtr tPrVPtr; + int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA + + CUTLASS_DEVICE + PagedKVManager(int const* const ptr_page_table_, + ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, + Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, + Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, + cutlass::FastDivmod const &page_size_divmod, + cutlass::FastDivmod const &blockN_per_page_size_divmod, + int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, + int bidb_kv_idx + ) + : page_size_divmod(page_size_divmod) + , blockN_per_page_size_divmod(blockN_per_page_size_divmod) + , thread_idx(thread_idx) + , seqlen_k(seqlen_k) + , leftpad_k(leftpad_k) + , ptr_page_table(ptr_page_table_) + , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) + , bidb_kv_idx(bidb_kv_idx) + , bidb_kv_idx_prev(bidb_kv_idx) + + { + mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); + mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); + auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); + tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + #pragma unroll + for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + #pragma unroll + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } + tVpV = cute::conditional_return(tKpK, tVpV_); + }; + + template + CUTLASS_DEVICE + void load_page_table(const int n_block) { + // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries + // it needs, and we don't need any sync between warps. + // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by + // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc. + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow); + int const row_idx = n_block * kBlockN + row; + int page_idx, page_offset; + page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k); + // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row + // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0. + int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0; + tPrPageOffset[i] = {page, page_offset}; + // if (cute::thread0()) { printf("row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\n", row, page_idx, page_offset, page, leftpad_k, seqlen_k); } + } + if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } + }; + + template + CUTLASS_DEVICE + void load_page_table_TMA(const int n_block) { + // We require that page size is a multiple of kBlockN, and there's no leftpad_k + if (ptr_page_table) { + bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; + } else { + n_block_idx = n_block; + } + if constexpr (First_iter && !KV_Same_Iter) { + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + } + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_K_TMA() { + return {n_block_idx, bidb_kv_idx}; + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_V_TMA() { + if constexpr (KV_Same_Iter) { + return {n_block_idx, bidb_kv_idx}; + } else { + cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + return indices; + } + }; + + CUTLASS_DEVICE + TensorKVPtr compute_K_ptr() { + Tensor tPrKPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + auto [page, page_offset] = tPrPageOffset[i]; + tPrKPtr[i] = &mK_paged(page_offset, _0{}, page); + } + return tPrKPtr; + }; + + CUTLASS_DEVICE + void compute_V_ptr() { + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + auto [page, page_offset] = tPrPageOffset[i]; + tPrVPtr[i] = &mV_paged(page_offset, _0{}, page); + } + }; + + template + CUTLASS_DEVICE + void load_K(const int n_block, TensorK &&sK) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; + + Tensor tPrKPtr = compute_K_ptr(); + + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + // We want to use the row indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) + int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN))); + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + bool const should_load = EvenN + ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit) + : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element const* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k)); + } + } // Don't need to clear out the rest of the smem since we'll mask out the scores anyway + } + }; + + template + CUTLASS_DEVICE + void load_V(const int n_block, TensorV &&sV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; + + if constexpr (KV_Same_Iter) { compute_V_ptr(); } + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); + + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // Faster to rely on the cp.async to clear smem that are out of bound, + // rather than calling cute::clear directly. + // We have to be careful not to write to smem past `kBlockN` if !EvenN. + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; + Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + } + } + } + if constexpr (!KV_Same_Iter) { compute_V_ptr(); } + }; + + template + CUTLASS_DEVICE + void store_K(const int n_block, TensorK &&tKrK) { + Tensor tPrKPtr = compute_K_ptr(); + // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading) + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + GmemTiledCopyKVStore gmem_tiled_copy_kv_store; + // We want to use the row indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) + // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + // if (threadIdx.x == 128) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); } + #pragma unroll + for (int m = 0; m < size<1>(tKrK); ++m) { + bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tKrK); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tKpK(_0{}, k)) { + cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void store_V(const int n_block, TensorV &&tVrV) { + if constexpr (KV_Same_Iter) { compute_V_ptr(); } + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); + + GmemTiledCopyKVStore gmem_tiled_copy_kv_store; + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVrV); ++m) { + bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; + Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tVrV); ++k) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tVpV(_0{}, k)) { + cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); + } + } + } + } + if constexpr (!KV_Same_Iter) { compute_V_ptr(); } + }; + + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/print_val.cu b/flashmask/flash_mask/flashmask_attention_v3/print_val.cu new file mode 100644 index 00000000000..3564d50f304 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/print_val.cu @@ -0,0 +1,43 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#include +#include "utils.h" + +namespace flash{ + __global__ void print_addr_value(int* base, size_t offset_bytes) { + int* ptr = (int*)((char*)base + offset_bytes); + printf("Value at address %p: %d\n", ptr, *ptr); + } + + __global__ void print_addr_value_ordered(int* base, size_t start_offset_bytes, int count) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total_threads = gridDim.x * blockDim.x; + + // 按线程ID顺序打印,避免输出混乱 + for (int current_thread = 0; current_thread < total_threads; current_thread++) { + if (tid == current_thread && tid < count) { + size_t offset_bytes = start_offset_bytes + tid * sizeof(int); + int* ptr = (int*)((char*)base + offset_bytes); + printf("Thread %d - Value at address %p (offset %zu): %d\n", + tid, ptr, offset_bytes, *ptr); + } + __syncthreads(); // 同步保证顺序 + } +} +} \ No newline at end of file diff --git a/flashmask/flash_mask/flashmask_attention_v3/rotary.h b/flashmask/flash_mask/flashmask_attention_v3/rotary.h new file mode 100644 index 00000000000..5a5dc8886b9 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/rotary.h @@ -0,0 +1,503 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void +apply_rotary_interleaved(Tensor &rK, + Tensor const &rCos, + Tensor const &rSin) { + CUTE_STATIC_ASSERT_V(rank(rK) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); + CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); + static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2); + static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor K_fp32 = make_tensor_like(rK); + convert_type_out(rK, K_fp32); + Tensor cos_fp32 = make_tensor_like(rCos); + convert_type_out(rCos, cos_fp32); + Tensor sin_fp32 = make_tensor_like(rSin); + convert_type_out(rSin, sin_fp32); + #pragma unroll + for (int i = 0; i < size<0>(K_fp32) / 2; ++i) { + float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i]; + float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i]; + K_fp32[2 * i] = real; + K_fp32[2 * i + 1] = imag; + } + convert_type_out(K_fp32, rK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void +apply_rotary_contiguous(Tensor &rK_left, + Tensor &rK_right, + Tensor const &rCos, + Tensor const &rSin) { + CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); + CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right)); + CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos)); + CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); + static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor K_left_fp32 = make_tensor_like(rK_left); + convert_type_out(rK_left, K_left_fp32); + Tensor K_right_fp32 = make_tensor_like(rK_right); + convert_type_out(rK_right, K_right_fp32); + Tensor cos_fp32 = make_tensor_like(rCos); + convert_type_out(rCos, cos_fp32); + Tensor sin_fp32 = make_tensor_like(rSin); + convert_type_out(rSin, sin_fp32); + #pragma unroll + for (int i = 0; i < size<0>(K_left_fp32); ++i) { + float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i]; + float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i]; + K_left_fp32[i] = real; + K_right_fp32[i] = imag; + } + convert_type_out(K_left_fp32, rK_left); + convert_type_out(K_right_fp32, rK_right); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rotary { + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved + // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will + // load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + + using LayoutAtom = Layout, Int>, + Stride, _1>>; + using TiledCopyQK = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + using GmemTiledCopyRotary = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 4 or 8 vals per store + using GmemTiledCopyRotaryCont = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + + using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0))); + using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0))); + using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortRpR = decltype(make_tensor(make_shape(size<2>(TensortRcR{})))); + using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortRpRCont = decltype(make_tensor(make_shape(size<2>(TensortRcRCont{})))); + using TensormR = decltype(make_tensor( + make_gmem_ptr((Element const*)nullptr), + ShapeRotary{}, + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}))); + using TensortRgR = decltype( + GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor( + make_gmem_ptr((Element const*)nullptr), + make_shape(Int{}, Int{}, int(0)), + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); + using TensortRgRCont = decltype( + GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor( + make_gmem_ptr((Element const*)nullptr), + make_shape(Int{}, Int{}, int(0)), + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); + + GmemTiledCopyRotary gmem_tiled_copy_rotary; + GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont; + bool const is_rotary_interleaved; + int const rotary_dim; + int const thread_idx; + int const max_seqlen; + GmemThrCopyRotary const gmem_thr_copy_rotary; + GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont; + TensortRpR tRpR; + TensortRpRCont tRpRCont; + TensormR mCos, mSin; + TensortRgR tRgCos, tRgSin; + TensortRgRCont tRgCosCont, tRgSinCont; + + CUTLASS_DEVICE + Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_, + Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_, + bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx) + : is_rotary_interleaved(is_rotary_interleaved) + , rotary_dim(get<1>(shape_rotary) * 2) + , thread_idx(thread_idx) + , max_seqlen(max_seqlen) + , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx)) + , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx)) + + { + auto stride_rotary_cos = make_stride(cute::conditional_return(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_)); + auto stride_rotary_sin = make_stride(cute::conditional_return(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_)); + mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos); + mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin); + Tensor gCos = local_tile(mCos, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) + Tensor gSin = local_tile(mSin, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) + tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos); + tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin); + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR); + tRpR = make_tensor(make_shape(size<2>(tRcR))); + #pragma unroll + for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); } + Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR); + tRpRCont = make_tensor(make_shape(size<2>(tRcRCont))); + #pragma unroll + for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); } + }; + + template + CUTLASS_DEVICE + auto load_cos_sin(int const block) { + using GmemTiledCopyRo = std::conditional_t; + auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); + Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); + Tensor tRgCosCur = cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, block); + Tensor tRgSinCur = cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, block); + // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way + Tensor tRrCos = make_tensor_like(tRgCosCur); + Tensor tRrSin = make_tensor_like(tRgSinCur); + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); + // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens + #pragma unroll + for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) { + if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tRrCos); ++k) { + if (tRpRCur(k)) { + cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k)); + cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k)); + } + } + } + } + return cute::make_tuple(tRrCos, tRrSin);; + } + + template + CUTLASS_DEVICE + auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) { + static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad; + using GmemTiledCopyRo = std::conditional_t; + auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); + Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); + // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way + Tensor tRrCos = make_tensor_like(cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, _0{})); + Tensor tRrSin = make_tensor_like(cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, _0{})); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); + + // The main bottleneck here is actually instruction cache misses. + + // Similar to PagedKVNonTMA, it's expensive to compute the pointers. + // We split the work among threads loading the same row, then __shfl_sync the pointers. + static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); + Tensor tPrCosPtr = make_tensor(Shape>{}); + Tensor tPrSinPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < NumPtrPerThread; ++i) { + int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{})); + int const idx = block * kBlockMN + row; + int row_actual = qhead_per_khead_divmod.divide(idx); + tPrCosPtr[i] = &mCos(row_actual, _0{}); + tPrSinPtr[i] = &mSin(row_actual, _0{}); + } + + #pragma unroll + for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) { + int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{})); + Element const* cos_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Element const* sin_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < max_seqlen * qhead_per_khead) { + Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape>{}), + Shape>{}); + Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape>{}), + Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tRgCos); ++k) { + int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur); + if (tRpRCur(k)) { + cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k)); + cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k)); + } + } + } + } + return cute::make_tuple(tRrCos, tRrSin); + } + + template + CUTLASS_DEVICE + void + apply_Q_interleaved(TensorsQ &sQ, // (kBlockM, kHeadDim) + TensortRrR const &tRrCos, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary + TensortRrR const &tRrSin, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary + int const m_block, int const qhead_per_khead=1) + { + TiledCopyQK tiled_copy_q; + auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); + Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ); + Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); + static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2); + static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + if (tRpR(k)) { + Tensor rQ = make_fragment_like(tQsQ(_, m, k)); + cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ); + apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k)); + cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_Q_contiguous(TensorsQ &sQ, // (kBlockM, kHeadDim) + TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont + TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont + int const m_block, int const qhead_per_khead=1) + { + TiledCopyQK tiled_copy_q; + auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); + Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int>{}); + Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont)); + static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + + #pragma unroll + for (int m = 0; m < size<1>(tQcQ); ++m) { + int const row = get<0>(tQcQ(_0{}, m, _0{})); + if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tQcQ); ++k) { + int const col = get<1>(tQcQ(_0{}, _0{}, k)); + if (col < rotary_dim / 2) { + int const col_idx_left = col / kGmemElemsPerLoad; + int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad); + Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left); + Tensor rQ_right = make_fragment_like(rQ_left); + cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right); + apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); + cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) + TensorgK &gK, // (kBlockN, kHeadDim) + TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV + TensortRrR const &tRrCos, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary + TensortRrR const &tRrSin, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary + TensorKPtr const &tPrKPtr, + int const n_block) + { + TiledCopyQK tiled_copy_k; + auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); + Tensor tKsK = gmem_thr_copy_q.partition_S(sK); + Tensor tKgK = gmem_thr_copy_q.partition_S(gK); + Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); + static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); + static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + if constexpr (PagedKVNonTMA) { + static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); + } + + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + int const row = get<0>(tKcK(_0{}, m, _0{})); + auto mK_cur_copy = [&] { + if constexpr (PagedKVNonTMA) { + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + return cute::tiled_divide(mK_cur, Shape>{}); + } else { + return nullptr; + } + }(); + if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + if (tKpK(k)) { + Tensor rK = make_fragment_like(tKsK(_, m, k)); + cute::copy(tiled_copy_k, tKsK(_, m, k), rK); + if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } + if constexpr (!PagedKVNonTMA) { + cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); + } else { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki)); + } + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) + TensorgK &gK, // (kBlockN, kHeadDim) + TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV + TensortRrR const &tRrCosCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont + TensortRrR const &tRrSinCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont + TensorKPtr const &tPrKPtr, + int const n_block, int const max_k) + { + TiledCopyQK tiled_copy_k; + auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); + Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int>{}); + Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int>{}); + Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); + static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + if constexpr (PagedKVNonTMA) { + static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); + } + + const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad; + const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad; + #pragma unroll + for (int m = 0; m < size<1>(tKcK); ++m) { + int const row = get<0>(tKcK(_0{}, m, _0{})); + Tensor gK_cur_copy = [&] { + if constexpr (PagedKVNonTMA) { + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + return cute::tiled_divide(mK_cur, Shape>{}); + } else { + return gK_copy(_, row, _); + } + }(); + if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tKcK); ++k) { + if (tKpK(k)) { + int const col = get<1>(tKcK(_0{}, _0{}, k)); + bool rotate = col < rotary_dim / 2; + int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad; + int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2); + Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left); + Tensor rK_right = make_fragment_like(rK_left); + cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right); + if (rotate) { + apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); + } + cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left)); + if (col_idx_right * kGmemElemsPerLoad < max_k) { + cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right)); + } + } + } + } + } + }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/seqlen.h b/flashmask/flash_mask/flashmask_attention_v3/seqlen.h new file mode 100644 index 00000000000..b2c77a5664f --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/seqlen.h @@ -0,0 +1,107 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +namespace flash { + +// We consolidate all the info related to sequence length here. This is so that we can do all +// the gmem reads once at the beginning of each tile, rather than having to repeat these reads +// to compute various things like n_block_min, n_block_max, etc. + +template +struct SeqlenInfo { + + int const offset, offset_padded; + int const seqlen; + + CUTLASS_DEVICE + SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused) + : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb]) + , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock) + , seqlen(!Varlen + ? seqlen_static + : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static))) + { + } + +}; + +template +struct SeqlenInfoQK { + + int const offset_q, offset_k, offset_q_padded; + int const seqlen_q, seqlen_k; + + CUTLASS_DEVICE + SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, + int const* const seqused_q, int const* const seqused_k + ) + : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) + , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb]) + // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch + // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence. + // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM + // However, the start must align to multiples of kBlockM. + , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM) + , seqlen_q(!Varlen + ? seqlen_q_static + : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) + , seqlen_k(!Varlen + ? seqlen_k_static + : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static))) + { + } + +}; + +template +struct SeqlenInfoQKNewK { + + static_assert(!(AppendKV && !Varlen), "AppendKV is only supported with Varlen"); + + int const leftpad_k; + int const offset_q, offset_k, offset_k_new; + int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k; + + CUTLASS_DEVICE + SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k + ) + : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) + , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) + , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k) + , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb]) + , seqlen_q(!Varlen + ? seqlen_q_static + : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) + , seqlen_k_og(!Varlen + ? seqlen_k_static + : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k) + , seqlen_k_new(!AppendKV + ? 0 + : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) + , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) + { + } + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/sm90_pipeline_no_cluster.hpp b/flashmask/flash_mask/flashmask_attention_v3/sm90_pipeline_no_cluster.hpp new file mode 100644 index 00000000000..9211878a394 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/sm90_pipeline_no_cluster.hpp @@ -0,0 +1,113 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +namespace cutlass { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads +// signaling the barrier during consumer_release. This causes a perf regression in FA3 +// forward pass (especially hdim 128 causal). We instead reimplement the version of +// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. +// +// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +template > +class PipelineTmaAsyncNoCluster: public Base { +public: + using FullBarrier = typename Base::FullBarrier; + using EmptyBarrier = typename Base::EmptyBarrier; + static constexpr uint32_t Stages = Stages_; + using PipelineState = typename Base::PipelineState; + + using SharedStorage = typename Base::SharedStorage; + using ThreadCategory = typename Base::ThreadCategory; + using Params = typename Base::Params; + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + + static_assert(cute::is_same_v || cute::is_same_v); + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params); + } + + } + + // Constructor + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } + + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +private: + EmptyBarrier* const empty_barrier_ptr_ = nullptr; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & (!skip) /*is_signaling_thread*/); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass diff --git a/flashmask/flash_mask/flashmask_attention_v3/softmax.h b/flashmask/flash_mask/flashmask_attention_v3/softmax.h new file mode 100644 index 00000000000..af417c32d07 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/softmax.h @@ -0,0 +1,184 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ni++) { + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256]. + // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow. + static constexpr float max_offset = float(Max_offset); // We can only template on int, not float + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = Check_inf + ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset) + : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)). This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + float const softmax_scale_log2; + + CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {}; + + template + __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + cute::fill(scores_scale, 1.f); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + } + return scores_scale; + }; + + template + __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + }; + + __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; + #pragma unroll + for (int mi = 0; mi < size(row_sum); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; + scores_scale(mi) = inv_sum * final_scale; + // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. + if constexpr (Max_offset != 0) { + static constexpr float sum_scale = 1.f / float(1 << Max_offset); + sum *= sum_scale; + } + row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); + } + return scores_scale; + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } + } + }; + +}; + +} // namespace flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/static_switch.h b/flashmask/flash_mask/flashmask_attention_v3/static_switch.h new file mode 100644 index 00000000000..a9b58d3b96e --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/static_switch.h @@ -0,0 +1,204 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FLASH_MASK_SWITCH(LTE_COND, UTS_COND, LTS_CONST_NAME, UTS_CONST_NAME, ...) \ + [&] { \ + if (LTE_COND) { \ + constexpr static bool LTS_CONST_NAME = true; \ + if (UTS_COND) { \ + constexpr static bool UTS_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool UTS_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + } else { \ + constexpr static bool LTS_CONST_NAME = false; \ + if (UTS_COND) { \ + constexpr static bool UTS_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool UTS_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +#ifdef FLASHMASK_V3_DISABLE_LOCAL + #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ + [&] { \ + constexpr static bool LOCAL_CONST_NAME = false; \ + if (CAUSAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() +#else + #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ + [&] { \ + if (CAUSAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = true; \ + constexpr static bool LOCAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } else if (LOCAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + constexpr static bool LOCAL_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + constexpr static bool LOCAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() +#endif + +#ifdef FLASHMASK_V3_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_PAGEDKV + #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define PAGEDKV_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_SPLIT + #define SPLIT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SPLIT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_APPENDKV + #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define APPENDKV_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_PACKGQA + #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define PACKGQA_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_VARLEN + #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define VARLEN_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_CLUSTER + #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define CLUSTER_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHMASK_V3_DISABLE_SM8x + #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ + [&] { \ + constexpr static int ARCH_NAME = 90; \ + return __VA_ARGS__(); \ + }() +#else + #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ + [&] { \ + if (ARCH == 86 || ARCH == 89) { \ + constexpr static int ARCH_NAME = 86; \ + return __VA_ARGS__(); \ + } else if (ARCH < 90) { \ + constexpr static int ARCH_NAME = 80; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int ARCH_NAME = 90; \ + return __VA_ARGS__(); \ + } \ + }() +#endif + +#ifndef FLASHMASK_V3_ENABLE_VCOLMAJOR + #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define VCOLMAJOR_SWITCH BOOL_SWITCH +#endif + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int kHeadSize = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int kHeadSize = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int kHeadSize = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int kHeadSize = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 256) { \ + constexpr static int kHeadSize = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/flashmask/flash_mask/flashmask_attention_v3/tile_scheduler.hpp b/flashmask/flash_mask/flashmask_attention_v3/tile_scheduler.hpp new file mode 100644 index 00000000000..6a264183ce5 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/tile_scheduler.hpp @@ -0,0 +1,1032 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/arch/barrier.h" + +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +/////////////////////////////////////////////////////////////////////////////// + +#define DEFINE_DUMMY_NOTIFY_FUNCS \ + CUTLASS_DEVICE \ + void \ + producer_notify() const {} \ + CUTLASS_DEVICE \ + void \ + consumer_notify() const {} + +// Host side kernel arguments +struct TileSchedulerArguments { + // num_head is num_head_q if not PackGQA, else num_head_k + int const num_blocks, num_head, num_batch, num_splits; + int const qhead_per_khead; + int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr + int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling + int* const tile_count_semaphore = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + // int const* const num_m_blocks_ptr = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class SingleTileScheduler { +public: + static constexpr bool pipelining = false; + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int const num_blocks, num_head, num_batch, num_splits; + int const qhead_per_khead; + int const seqlen; + cutlass::FastDivmod nsplits_divmod; + int const* const cu_seqlens; + int const* const seqused; + int const* const num_splits_dynamic_ptr = nullptr; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); + assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits + return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, + args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(!Split ? 1 : args.num_splits), + !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, + args.num_splits_dynamic_ptr}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; + } + + DEFINE_DUMMY_NOTIFY_FUNCS + + struct WorkTileInfo { + int block_idx = 0; + int bidh = 0; + int bidb = 0; + int split_idx = 0; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return bidb >= 0; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; + } + + }; + + CUTLASS_DEVICE + SingleTileScheduler(SharedStorage* const smem_scheduler) { } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; + if constexpr (Split) { + int split_idx; + work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); + work_info.split_idx = split_idx; + } + bool is_valid_tile = true; + if constexpr (Varlen) { + int seqlen = params.seqused + ? params.seqused[work_info.bidb] + : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + is_valid_tile = work_info.block_idx * kBlock < seqlen; + } + if constexpr (Varlen && Split) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + // Use the top 16 bits to store num_splits + work_info.split_idx |= (num_splits_dynamic << 16); + is_valid_tile &= work_info.split_idx < num_splits_dynamic; + } + work_info.bidb = is_valid_tile ? work_info.bidb : -1; + return work_info; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {0, 0, -1, 0}; + } + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class StaticPersistentTileScheduler { + +public: + static constexpr bool pipelining = false; + using SharedStorage = int; + + // Device side kernel params + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + cutlass::FastDivmod nsplits_divmod; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), + cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + DEFINE_DUMMY_NOTIFY_FUNCS + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); + int split_idx = 0; + if constexpr (Split) { + bidh = params.nsplits_divmod.divmod(split_idx, bidh); + } + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } +}; + +template +class PreemptivePersistentTileScheduler { + // **PPT** scheduler: performs correct synchronization for producer (generate_n_block) and consumer (KV load and computation pipeline) + // This scheduler has the same coordinate computation logic as StaticPersistentTileSch, the difference is that + // we employ a preemptive scheduling strategy based on a rough estimation of the workload for the consumer + // In PPT, NumConsumerThreads is the total number of threads for (KV load and computation pipeline), and for FlashMask V3 + // it will be the #threads for (wg_id = 0, wp_id = 0) + (wg_id > 0, wp_id = *). The NumProducerThreads is simply 96 (hard-coded). + static_assert(NumProducerThreads == 96, "PreemptivePersistentTileScheduler has incorrect producer thread num."); + static constexpr int NumThreads = NumConsumerThreads + NumProducerThreads; +public: + using SharedStorage = int; + static constexpr bool pipelining = false; +protected: + SharedStorage* const tile_count_smem; + +public: + + // Device side kernel params + + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + cutlass::FastDivmod nsplits_divmod; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + assert(args.tile_count_semaphore != nullptr); + return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), + cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); + int split_idx = 0; + if constexpr (Split) { + bidh = params.nsplits_divmod.divmod(split_idx, bidh); + } + return {block, bidh, bidb, split_idx}; + } + + }; + + DEFINE_DUMMY_NOTIFY_FUNCS + + CUTLASS_DEVICE + PreemptivePersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + // when all the blocks (SMs) done initializing and no SM has done the first task, tile_count_semaphore will be + // at least `gridDim.x`, then, we just let prefetch_next_work and non-deterministic schedule (workload-related) take over + + // For FlashMask V3, only generate_n_block pipeline is the big brother producer to be preemptively scheduled! + // since the initial work is assigned deterministically via blockIdx.x, we need to ensure that the initial state of + // tile_count_semaphore is gridDim.x. Can't use atomicAdd here, since if we do, for example, SM1 is really fast, it performs + // prefetch_next_work even before SM2 calls get_initial_work, then SM1 will risk computing the same block as SM2. + + // for the initial work: assign deterministically + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const { + // this is a kick-off for the whole producer (producer waits for TileCountSmemEmpty), otherwise we will have a dead-lock, also + // this init_consumer can only be called in consumer warps, otherwise we will have more arriving threads than needed + // NumConsumerThreads: including (wg_id = 0, warp_id = 0: KV load) and (wg_id > 0, warp_id = *: computation) + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + // only producer will call this method + if (threadIdx.x == 96) { // hard-coded, since n_block producer threads are in [32, 128) + // the next job we are going to process: number of currently blocks done + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // only threadIdx.x == 96 has the correct `current_work.tile_idx` (see prefetch next_work) + // so there is no need to use shfl_sync to broadcast. Also shfl cannot broadcast across warps + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x == 96) { // hard-coded, since n_block producer threads are in [32, 128) + *tile_count_smem = current_work.tile_idx; + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + // Sync all the producers in case some of the producers return before the smem is updated + flash::named_barrier_sync(NumProducerThreads, static_cast(FwdNamedBarriers::NBlockProducer) /*id*/); + return {*tile_count_smem}; + } else { + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int tile_idx = *tile_count_smem; + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return {tile_idx}; + } + } + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } +}; + +template +class BwdPreemptivePersistentTileScheduler { + static constexpr int NumThreads = NumConsumerThreads + NumProducerThreads; +public: + using SharedStorage = int; + static constexpr bool pipelining = false; +protected: + SharedStorage* const tile_count_smem; + +public: + + // Device side kernel params + + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + assert(args.tile_count_semaphore != nullptr); + return {args.num_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + args.tile_count_semaphore}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); + return {block, bidh, bidb}; + } + + }; + + CUTLASS_DEVICE + BwdPreemptivePersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + if constexpr (!IsProducerWarp) { + flash::named_barrier_sync(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemFull) /*id*/); + } + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const { + // flash::named_barrier_arrive(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemEmpty) /*id*/); + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + producer_notify() const { // notify the consumer that we've written data into the buffer + flash::named_barrier_arrive(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemFull) /*id*/); + } + + CUTLASS_DEVICE + void + consumer_notify() const { + // sync to make sure (*tile_count_smem) modification is visible to consumers + flash::named_barrier_arrive(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemEmpty) /*id*/); + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + flash::named_barrier_sync(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemEmpty) /*id*/); + // TODO(heqianyue): atomicAdd here? + if (threadIdx.x == 0) { // hard-coded, since n_block producer threads are in [32, 128) + if constexpr (Deterministic) { + *tile_count_smem = current_work.tile_idx + gridDim.x; + } + else { + // the next job we are going to process: number of currently blocks done + *tile_count_smem = atomicAdd(params.tile_count_semaphore, 1); + } + } + flash::named_barrier_sync(NumProducerThreads, static_cast(BwdNamedBarriers::FlashmaskProducer) /*id*/); + } else { + flash::named_barrier_sync(NumThreads, static_cast(BwdNamedBarriers::FlashmaskSmemFull) /*id*/); + } + // how to make sure consumers can actually get this? + return {*tile_count_smem}; + } + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } +}; + + +template +class DualPreemptivePersistentTileExecutionScheduler { + // **PPT** scheduler: performs correct synchronization for producer (generate_n_block) and consumer (KV load and computation pipeline) + // This scheduler has the same coordinate computation logic as StaticPersistentTileSch, the difference is that + // we employ a preemptive scheduling strategy based on a rough estimation of the workload for the consumer + // In PPT, NumConsumerThreads is the total number of threads for (KV load and computation pipeline), and for FlashMask V3 + // it will be the #threads for (wg_id = 0, wp_id = 0) + (wg_id > 0, wp_id = *). The NumProducerThreads is simply 96 (hard-coded). + + // The following static_assert is NOT compulsory, it's just that we found that 64 producer threads performs worse + static_assert(NumProducerThreads == 96, "DualPPTX Scheduler has incorrect producer thread num."); + static constexpr int NumThreads = NumConsumerThreads + NumProducerThreads; +public: + using SharedStorage = int; + static constexpr bool pipelining = true; // DualPPTX has coarse-grained pipelining +protected: + SharedStorage* const tile_count_smem; + uint32_t sch_stage_; +public: + // Device side kernel params + + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + cutlass::FastDivmod nsplits_divmod; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + assert(args.tile_count_semaphore != nullptr); + return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), + cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); + int split_idx = 0; + if constexpr (Split) { + bidh = params.nsplits_divmod.divmod(split_idx, bidh); + } + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + DualPreemptivePersistentTileExecutionScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) { + // when all the blocks (SMs) done initializing and no SM has done the first task, tile_count_semaphore will be + // at least `gridDim.x`, then, we just let prefetch_next_work and non-deterministic schedule (workload-related) take over + + // For FlashMask V3, only generate_n_block pipeline is the big brother producer to be preemptively scheduled! + // since the initial work is assigned deterministically via blockIdx.x, we need to ensure that the initial state of + // tile_count_semaphore is gridDim.x. Can't use atomicAdd here, since if we do, for example, SM1 is really fast, it performs + // prefetch_next_work even before SM2 calls get_initial_work, then SM1 will risk computing the same block as SM2. + + // for the initial work: assign deterministically + if constexpr (IsProducerWarp) { + sch_stage_ = 0; // producer initial state is 0, since the first get_next, producer should sync full-1 (dual) + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } else { + sch_stage_ = 1; // consumer initial state is 1, since the first get_next, producer should sync empty-0 (non-dual) + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFullDual) /*id*/); + } + return {int(blockIdx.x)}; + } + + DEFINE_DUMMY_NOTIFY_FUNCS + + CUTLASS_DEVICE + void + init_consumer() const { /* Init is done in get_initial work, therefore no need to repeat. */ } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + // PPTX prefetch is moved to consumer for more exact delay scheduling + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) { + // change state immediately, since we are to get next work + // Note that for the return value: except from the initial work, PPT always dynamic schedules + // Dual PPTX will have static schedule for only twice: get initial work and the first time get_next_work + // This is intentional, since in the first get_next_work, smem is not fully ready. + if constexpr (IsProducerWarp) { + sch_stage_ = 0x1 ^ sch_stage_; + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) + (sch_stage_ << 1) /*id*/); + int tile_idx = tile_count_smem[sch_stage_]; + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) + (sch_stage_ << 1) /*id*/); + // Sync all the producers in case some of the producers return before the smem is updated + return {tile_idx >= 0 ? tile_idx : int(blockIdx.x + gridDim.x)}; + } else { + // for example: + // the 1st get_next_work of consumer: load from 1, and atomicAdd store to 0 + // load from 1 not initialized, use blockIdx.x + gridDim.x (static scheduling) + // the 2nd get_next_work of consumer: load from 0, and atomicAdd store to 1 + // load from 0 initialized: the 3rd consumer work ID is correctly set + int tile_idx = tile_count_smem[sch_stage_]; + sch_stage_ = 0x1 ^ sch_stage_; + if (threadIdx.x == NumConsumerThreads) { // thread 288 hard-coded, since n_block consumer threads are in [128, 384) + tile_count_smem[sch_stage_] = atomicAdd(params.tile_count_semaphore, 1); + } + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) + (sch_stage_ << 1) /*id*/); + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) + (sch_stage_ << 1) /*id*/); + return {tile_idx >= 0 ? tile_idx : int(blockIdx.x + gridDim.x)}; + } + } + + template + CUTLASS_DEVICE + uint32_t stage() const noexcept { + // Returns stage offset: sch_stage_ * 2. Producer always returns the current stage, + // while consumer returns 1 - current stage, so that consumer can always have valid input + if constexpr (IsProducerWarp) + return sch_stage_ << 1; + else + return (0x1 ^ sch_stage_) << 1; + } +}; + +template +class DynamicPersistentTileScheduler { + + // This scheduler targets the causal (or local) case where each tile takes different + // amount of time. We use longest-processing-time-first scheduling: + // the longest remaining tile is assigned to the first SM that's free. + // SM indicates they are free by incrementing a semaphore. + // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling + // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads. + // This is the L2 swizzling part. The size of each section is precomputed based on the + // size of K & V and the L2 cache size. + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + (Is_flashmask ? 128 : NumProducerThreads) : NumMmaThreads; + +public: + using SharedStorage = int; + static constexpr bool pipelining = false; +protected: + SharedStorage* const tile_count_smem; + +public: + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; + cutlass::FastDivmod const l2_minor_residual_divmod; + int const num_hb_quotient; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; + int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V + // Swizzle is the size of each "section". Round swizzle to a power of 2 + // If not PackGQA already, the size of each section can increase by qhead_per_khead + // Need to be careful about the case where only one head will fit + int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << cutlass::find_log2(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; + int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); + assert(args.tile_count_semaphore != nullptr); + return {num_split_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), + // don't divide by 0 + cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), + (args.num_head * args.num_batch) / swizzle, + args.tile_count_semaphore}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + int split_idx = 0; + if constexpr (Split) { + split_idx = params.m_block_divmod.divmod(block, block); + } + // Longest-processing-time-first + block = params.m_block_divmod.divisor - 1 - block; + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + DEFINE_DUMMY_NOTIFY_FUNCS + + CUTLASS_DEVICE + void + init_consumer() const { + if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 + int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % NumProducerThreads == 0) { + *tile_count_smem = current_work.tile_idx; + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + return {new_tile_idx}; + } else { + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int tile_idx = *tile_count_smem; + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return {tile_idx}; + } + } + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } +}; + +template +class VarlenDynamicPersistentTileScheduler { + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + +public: + using SharedStorage = int4; + static constexpr bool pipelining = false; +protected: + SharedStorage* const work_info_smem; + +public: + + // Device side kernel params + struct Params { + int num_head, num_batch; + int const qhead_per_khead; + int const seqlen; + cutlass::FastDivmod head_divmod; + cutlass::FastDivmod nsplits_divmod; + int* const tile_count_semaphore; + int const* const cu_seqlens; + int const* const seqused; + // int* const num_m_blocks_ptr; + int const* const num_splits_dynamic_ptr; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // If Split, for the purpose of scheduling, we pretend that instead there are + // (args.num_splits * args.num_head) number of heads. + assert(args.tile_count_semaphore != nullptr); + assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + return {args.num_head, args.num_batch, + args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(!Split ? 1 : args.num_splits), + args.tile_count_semaphore, args.cu_seqlens, args.seqused, + // args.num_m_blocks_ptr, + args.num_splits_dynamic_ptr}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx, block, bidh, bidb; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } + return bidb < params.num_batch; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + if constexpr (!Split) { + return {block, bidh, bidb, 0 /*split_idx*/}; + } else { + // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh); + uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; + int bidh_actual = reinterpret_cast(bidh_actual_u); + // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx + uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); + int split_idx = reinterpret_cast(split_idx_u); + // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // if (threadIdx.x == 128) { + // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); + // } + return {block, bidh_actual, bidb, split_idx}; + } + } + }; + + CUTLASS_DEVICE + VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; + + DEFINE_DUMMY_NOTIFY_FUNCS + + CUTLASS_DEVICE + WorkTileInfo + tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + auto get_num_m_blocks = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlock) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlock) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; + }; + + auto get_num_splits = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? (!Split ? 1 : (params.num_splits_dynamic_ptr + ? params.num_splits_dynamic_ptr[batch_idx] + : params.nsplits_divmod.divisor)) + : 0; + }; + + int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane + int num_splits = get_num_splits(current_work.bidb); + int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + // Cumulative number of blocks for the next 31 batches + int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); + // Total number of blocks for the next 31 batches + int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); + // Only the lower 16 bits are the actual bidh + int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + if constexpr (Split) { + int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + } + int bidb = current_work.bidb; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); + // } + // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } + while (group_end_tile <= next_tile_idx) { + bidb += cutlass::NumThreadsPerWarp - 1; + if (bidb >= params.num_batch) { + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + return {next_tile_idx, 0, 0, params.num_batch}; + } + num_m_blocks = get_num_m_blocks(bidb); + num_splits = get_num_splits(bidb); + num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); + m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); + group_end_tile += m_blocks_in_group * params.num_head; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + } + int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); + // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } + bidb += batch_idx_in_group; + num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); + if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } + int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int bidh = mh_block / num_m_blocks; + int block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); + // } + return {next_tile_idx, block, bidh, bidb}; + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + if constexpr (IsProducerWarp) { + WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + return work_info; + } else { + return get_next_work(params, {0, 0, 0, 0}); + } + } + + CUTLASS_DEVICE + void + init_consumer() const { + // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 + int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; + work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + return work_info; + } else { + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int4 work_info = *work_info_smem; + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; + } + } + + template + CUTLASS_DEVICE + constexpr uint32_t stage() const noexcept { return 0; } +}; + +} // flash diff --git a/flashmask/flash_mask/flashmask_attention_v3/tile_size.h b/flashmask/flash_mask/flashmask_attention_v3/tile_size.h new file mode 100644 index 00000000000..54627e677b7 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/tile_size.h @@ -0,0 +1,95 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include + +// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} +constexpr std::tuple tile_size_fwd_sm90( + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false, bool short_seqlen=false) { + if (element_size == 2) { + if (short_seqlen) { + return {64, 64, false, true}; + } + if (headdim <= 64) { + bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; + // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why + // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 + // Switch to tile size 192 x 192 for now + bool const use_blockN_128 = is_causal || is_local; + // return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; + return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim}; + // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen + // return {192, is_causal || is_local ? 192 : 176, true, false}; + } else if (headdim <= 96) { + return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; + } else if (headdim <= 128) { + // return {128, 96, true, true}; + // return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + return {128, 128, true, true}; + // {128, 192, false, false} and {192, 128, false, true} are quite good too + // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS + } else if (headdim <= 192) { + return {128, 96, true, true}; + // return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem + } else { + return {128, 64, true, true}; + // return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem + } + } else { + if (headdim <= 64) { + return {192, 160, true, true}; + } else if (headdim <= 96) { + return {192, 128, true, true}; + } else if (headdim <= 128) { + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + } else if (headdim <= 192) { + return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; + } else { + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap + } + } +} + +// Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} +constexpr std::tuple tile_size_fwd_sm8x( + bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, + bool paged_kv=false, bool varlen_and_split=false, + bool softcap=false, bool append_kv=false) { + if (element_size == 2) { + if (headdim <= 64) { + return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false}; + } else if (headdim <= 96) { + return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false}; + } else if (headdim <= 128) { + bool const use_8_warps = sm86_or_89 | varlen_and_split; + return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps}; + } else if (headdim <= 192) { + bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv; + return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64}; + } else { + return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv}; + } + } else { + // Placeholder for now + return {128, 64, 8, 2, false}; + } +} diff --git a/flashmask/flash_mask/flashmask_attention_v3/utils.h b/flashmask/flash_mask/flashmask_attention_v3/utils.h new file mode 100644 index 00000000000..5fe021ddb45 --- /dev/null +++ b/flashmask/flash_mask/flashmask_attention_v3/utils.h @@ -0,0 +1,760 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. + * + * Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +#include "cuda_check.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A wrapper for the kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55 +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm80_to_sm89 : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890) + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE auto convert_type_unsafe(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not + // inline this function, then the memory might not be valid. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE +auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { + if constexpr (A) { + return mma.partition_fragment_A(tensor0); + } else { + return mma.partition_fragment_B(tensor0); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { + if constexpr (M_slice >= 0) { + static constexpr int MMA_M = decltype(size<1>(tCrC))::value; + static_assert(M_slice < MMA_M); + // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N) + Tensor tCrC_slice = cute::logical_divide(tCrC, Shape>{})(_, make_coord(Int{}, _), _); + if constexpr (!SwapAB) { + Tensor tCrA_slice = cute::logical_divide(tCrA, Shape>{})(_, make_coord(Int{}, _), _); + gemm(tiled_mma, tCrA_slice, tCrB, tCrC_slice); + } else { + Tensor tCrB_slice = cute::logical_divide(tCrB, Shape>{})(_, make_coord(Int{}, _), _); + gemm(tiled_mma, tCrA, tCrB_slice, tCrC_slice); + } + } else { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { + if constexpr (!SwapAB) { + warpgroup_fence_operand(const_cast(tCrA)); + } else { + warpgroup_fence_operand(const_cast(tCrB)); + } + } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); + static constexpr int kMaxKIters = 16; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + // In the case of large kNumKIters, the compiler chooses to store the smem addresses + // in registers, causing spills. This loop forces the compiler to recompute the addresses. + if constexpr (kNumKIters > kMaxKIters) { + // This will always be zero, just a way to force the compiler to recompute the smem + // addresses. This results in USEL instructions. There's probably a better way to do this. + int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; + CUTLASS_PRAGMA_UNROLL + for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + if constexpr (!SwapAB) { + warpgroup_fence_operand(const_cast(tCrA)); + } else { + warpgroup_fence_operand(const_cast(tCrB)); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) { + if constexpr (SwapAB) { + gemm_sm80(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn); + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + if constexpr (!std::is_same_v) { + if (i == 0) { fn(); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + static constexpr int rA = decltype(rank(tA))::value; + static constexpr int rB = decltype(rank(tB))::value; + static constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void copy(TiledCopy const &tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + // Decay TiledCopy to CopyAtom + auto copy_atom = static_cast(tiled_copy); + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN; + if constexpr (Is_even_MN || !Clear_OOB_MN) { + if (Is_even_MN || predicate_mn) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if constexpr (Is_even_K || !Clear_OOB_K) { + if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } + } else { // Clear_OOB_K == true && Is_even_K == false + // If copy traits can be transformed with a predicate value, do it, otherwise branch here + if constexpr (has_with_bool) { + cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k)); + } else { + if (predicate_K(k)) { + cute::copy(copy_atom, S(_, m, k), D(_, m, k)); + } else { + cute::clear(D(_, m, k)); + } + } + } + } + } + } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true + if constexpr (!has_with_bool) { + if (predicate_mn) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(copy_atom, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else { + cute::clear(D(_, m, _)); + } + } else { // combine the mn predicate with the k predicate + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Byte permute and shuffle to match register layout of +// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II. +template +CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) { + // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits + static_assert(decltype(size<0, 0>(frag))::value == 4); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(decltype(stride<0, 1>(frag))::value == 4); + static_assert(sizeof(typename Fragment::value_type) == 1); + + int quad_idx = threadIdx.x % 4; + bool lane_03 = quad_idx == 0 || quad_idx == 3; + int selector_upper = lane_03 ? 0x5410 : 0x1054; + int selector_lower = lane_03 ? 0x7632 : 0x3276; + + static constexpr int upper_map[4] = {0, 3, 1, 2}; + // static constexpr int lower_map[4] = {1, 2, 0, 3}; + + Tensor frag_64b = recast(frag); // ((1, 1, 2), MMA_M, MMA_N) + #pragma unroll + for (int i = 0; i < size(frag_64b); ++i) { + uint32_t upper = frag_64b[i].x; + uint32_t lower = frag_64b[i].y; + uint32_t upper0 = lane_03 ? upper : lower; + uint32_t lower0 = lane_03 ? lower : upper; + upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); + // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); + lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4); + frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper); + frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_output_fp8(Fragment &out) { + // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(out))::value == 2); + static_assert(decltype(size<0, 1>(out))::value == 2); + static_assert(decltype(size<0, 2>(out))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(out))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag); ++mi) { + #pragma unroll + for (int j = 0; j < size<0, 1>(frag); ++j) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag) / 2; ++i) { + cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi)); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4); + + int quad_idx = threadIdx.x % 4; + bool lane_03 = quad_idx == 0 || quad_idx == 3; + + static constexpr int upper_map[4] = {0, 2, 3, 1}; + // static constexpr int lower_map[4] = {2, 0, 1, 3}; + + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } + using type2 = std::conditional_t; + Tensor frag_2 = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); } + #pragma unroll + for (int mi = 0; mi < size<1>(frag_2); ++mi) { + #pragma unroll + for (int j = 0; j < size<0, 1>(frag_2); ++j) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) { + type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi); + type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi); + type2 upper0 = lane_03 ? upper : lower; + type2 lower0 = lane_03 ? lower : upper; + upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); + // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); + lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4); + frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0; + frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0; + } + } + } + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void apply_softcap(Tensor &tensor, float const softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ + Tensor out = make_fragment_like(tensor); + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + out(i) = 1.f - (tensor(i) * tensor(i)); + } + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE T warp_prefix_sum(T val) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { + T partial_sum = __shfl_up_sync(0xffffffff, val, i); + if (lane >= i) { val += partial_sum; } + } + return val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE +int canonical_warp_group_idx_nosync() { + return threadIdx.x / cutlass::NumThreadsPerWarpGroup; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +#ifdef __CUDACC__ +// Enclosing this in __CUDACC__ suppresses MSVC warnings. +__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +#endif // __CUDACC__ +void cutlass_flashmask_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) +{ + // Dynamic shared memory base pointer + extern __shared__ __align__(1024) char smem[]; //xhy: fa3 tma needs to be aligned + //to 1024 bytes when using CU_TENSOR_MAP_SWIZZLE_128B + Operator op; + op(params, smem); + cutlass::arch::synclog_print(); + +} + +template +cutlass::Status flashmask_kernel_launch( + dim3 const grid_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + const Params &kernel_params, + bool launch_with_pdl) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::kernel_launch"); +#endif + + if (not launch_with_pdl) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::kernel_launch: No PDL"); +#endif + cutlass_flashmask_kernel<<>>(kernel_params); + } + else { +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) + if constexpr (GemmKernel::ArchTag::kMinComputeCapability < 90) { + CUTLASS_TRACE_HOST(" Programmatic dependent launch (PDL) is only supported for SM90."); + return cutlass::Status::kInvalid; + } + + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[1]; + + config.gridDim = grid_dims; + config.blockDim = block_dims; + config.dynamicSmemBytes = smem_size; + config.stream = cuda_stream; + + config.attrs = attrs; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.numAttrs = 1; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::kernel_launch: Calling cudaLaunchKernelEx"); +#endif + cudaError_t launch_result = cudaLaunchKernelEx(&config, &cutlass_flashmask_kernel, kernel_params); + if (cudaSuccess != launch_result) { + CUTLASS_TRACE_HOST("cutlass::kernel_launch: cudaLaunchKernelEx failed with error: " << cudaGetErrorString(launch_result)); + return cutlass::Status::kErrorInternal; + } +#else + CUTLASS_TRACE_HOST(" Programmatic dependent launch (PDL) is only supported starting CUDA 11.8."); + return cutlass::Status::kInvalid; +#endif + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::kernel_launch: cudaGetLastError reports success"); +#endif + return cutlass::Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return cutlass::Status::kErrorInternal; + } +} + +} // namespace flash diff --git a/flashmask/flash_mask/test/test.py b/flashmask/flash_mask/test/test.py new file mode 100644 index 00000000000..767f9901696 --- /dev/null +++ b/flashmask/flash_mask/test/test.py @@ -0,0 +1,41 @@ +import sys +import os + +# 添加当前目录到路径 +# sys.path.append('/workspace/paddle/flash-attention/flashmask/flash_mask') + +print("="*60) +print("Testing FlashMask Import") +print("="*60) + +from flash_mask.flashmask_attention_v3.interface import flashmask_attention + +try: + import paddle + + # 小批量测试数据 + query = paddle.randn([1, 16, 2, 32], dtype='float16') + key = paddle.randn([1, 16, 2, 32], dtype='float16') + value = paddle.randn([1, 16, 2, 32], dtype='float16') + mask = paddle.to_tensor([8]*16, dtype='int32').reshape([1, 1, 16, 1]) + + result = flashmask_attention( + query=query, + key=key, + value=value, + startend_row_indices=mask, + causal=True + ) + + print(f"✓ Function executed successfully!") + print(f" Result shape: {result.shape}") + print(f" Result dtype: {result.dtype}") + +except Exception as e: + print(f"✗ Function execution failed: {e}") + import traceback + traceback.print_exc() + +print("\n" + "="*60) +print("Test Complete!") +print("="*60) \ No newline at end of file diff --git a/flashmask/setup.py b/flashmask/setup.py index 28c4c7ac4d6..ec9ca434313 100644 --- a/flashmask/setup.py +++ b/flashmask/setup.py @@ -12,18 +12,155 @@ # See the License for the specific language governing permissions and # limitations under the License. -from setuptools import setup, find_packages +# 整合 cmake 构建流程到 setup.py,pip install . 一步到位 +# 不再需要手动 mkdir build && cmake .. && make +# +# 使用方法: +# pip install . # 标准安装(自动 cmake + 编译) +# pip install -e . --no-build-isolation # 开发安装 +# +# 环境变量: +# FLASH_MASK_SKIP_CMAKE=1 跳过 cmake(假设 libflashmaskv3.so 已存在) +# FLASH_MASK_FORCE_REBUILD=1 强制重新 cmake +# FLASH_MASK_CMAKE_ARGS 额外 cmake 参数(空格分隔) +# FLASH_MASK_LIB_DIR 手动指定 libflashmaskv2/v3.so 所在目录 + +import os +import sys +import subprocess +import shutil + +from setuptools import find_packages +from paddle.utils.cpp_extension import CUDAExtension, setup + +# ============================================================ +# 配置区 +# ============================================================ +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +FLASH_MASK_DIR = os.path.join(ROOT_DIR, 'flash_mask') +BUILD_DIR = os.path.join(FLASH_MASK_DIR, 'build') +# libflashmaskv3.so 安装到 flash_mask/lib/ 下,随 package_data 分发 +INSTALL_LIB_DIR = os.path.join(FLASH_MASK_DIR, 'lib') + +SKIP_CMAKE = os.environ.get('FLASH_MASK_SKIP_CMAKE', '0') == '1' +FORCE_REBUILD = os.environ.get('FLASH_MASK_FORCE_REBUILD', '0') == '1' +EXTRA_CMAKE_ARGS = os.environ.get('FLASH_MASK_CMAKE_ARGS', '').split() +MANUAL_LIB_DIR = os.environ.get('FLASH_MASK_LIB_DIR', '') + +# ============================================================ +# Step 1: 构建 libflashmaskv3.so (底层 CUDA kernel 库) +# ============================================================ +LIB_NAME = 'flashmaskv3' +LIB_FILE = f'lib{LIB_NAME}.so' + +def find_or_build_lib(): + """找到或构建 libflashmaskv3.so,返回 (lib_dir, lib_name)""" + global LIB_NAME + + if MANUAL_LIB_DIR: + lib_dir = os.path.abspath(MANUAL_LIB_DIR) + if os.path.exists(os.path.join(lib_dir, 'libflashmaskv3.so')): + return lib_dir, 'flashmaskv3' + else: + print(f"[WARNING] No flashmask lib found in {lib_dir}") + return lib_dir, 'flashmaskv3' + + if SKIP_CMAKE: + return BUILD_DIR, 'flashmaskv3' + + # 自动 cmake 构建 + lib_so_path = os.path.join(BUILD_DIR, LIB_FILE) + need_build = FORCE_REBUILD or not os.path.exists(lib_so_path) + + if need_build: + print("=" * 60) + print(f"Building {LIB_FILE} via cmake...") + print("=" * 60) + + os.makedirs(BUILD_DIR, exist_ok=True) + cmake_args = [ + 'cmake', '..', + '-DWITH_FLASHATTN_V3=ON', + '-DDISABLE_FLASHMASK_V3_BACKWARD=OFF', + ] + EXTRA_CMAKE_ARGS + + print(f" cmake args: {' '.join(cmake_args)}") + subprocess.check_call(cmake_args, cwd=BUILD_DIR) + + nproc = os.cpu_count() or 4 + print(f" make -j{nproc}") + subprocess.check_call(['make', f'-j{nproc}'], cwd=BUILD_DIR) + + if not os.path.exists(lib_so_path): + raise RuntimeError(f"cmake build completed but {LIB_FILE} not found at {lib_so_path}") + print(f" {LIB_FILE} built successfully") + else: + print(f" {LIB_FILE} already exists, skipping cmake (set FLASH_MASK_FORCE_REBUILD=1 to force)") + + return BUILD_DIR, 'flashmaskv3' + +LIB_DIR, LIB_NAME = find_or_build_lib() +LIB_DIR = os.path.abspath(LIB_DIR) +LIB_FILE = f'lib{LIB_NAME}.so' + +# 将 libflashmaskv3.so 拷贝到 flash_mask/lib/ 下 +os.makedirs(INSTALL_LIB_DIR, exist_ok=True) +src_lib = os.path.join(LIB_DIR, LIB_FILE) +dst_lib = os.path.join(INSTALL_LIB_DIR, LIB_FILE) +if os.path.exists(src_lib) and (not os.path.exists(dst_lib) or + os.path.getmtime(src_lib) > os.path.getmtime(dst_lib)): + shutil.copy2(src_lib, dst_lib) + print(f" Copied {LIB_FILE} -> flash_mask/lib/") +# ============================================================ +# Step 2: 构建自定义算子 +# ============================================================ setup( name='flash_mask', version='4.0', packages=find_packages(), + package_data={ + 'flash_mask': ['lib/*.so'], + }, author='PaddlePaddle', description='FlashMask: Efficient and Rich Mask Extension of FlashAttention', install_requires=[ - 'nvidia-cutlass==4.2.0.0', - 'nvidia-cutlass-dsl==4.3.0', 'typing_extensions', ], python_requires='>=3.10', + ext_modules=[ + CUDAExtension( + name='flash_mask_package', + sources=[ + 'flash_mask/flashmask_attention_v3/csrc/flashmask_v3.cpp', + 'flash_mask/flashmask_attention_v3/csrc/flashmask_v3_kernel.cu', + 'flash_mask/flashmask_attention_v3/csrc/flashmask_v3_grad_kernel.cu', + 'flash_mask/flashmask_attention_v3/csrc/flash_attn_v3_utils.cu', + ], + include_dirs=[ + 'flash_mask/flashmask_attention_v3/csrc', + 'flash_mask/flashmask_attention_v3', + ], + library_dirs=[LIB_DIR, INSTALL_LIB_DIR], + libraries=[LIB_NAME], + extra_compile_args={ + 'nvcc': [ + '-gencode', 'arch=compute_90,code=sm_90', + '-O3', + '-DPADDLE_WITH_FLASHATTN_V3=1', + '-std=c++17', + ], + 'cxx': [ + '-O3', + '-DPADDLE_WITH_FLASHATTN_V3=1', + '-std=c++17'], + }, + extra_link_args=[ + '-Wl,-rpath,$ORIGIN/flash_mask/lib', + '-Wl,-rpath,$ORIGIN', + f'-Wl,-rpath,{INSTALL_LIB_DIR}', + f'-Wl,-rpath,{LIB_DIR}', + ], + ) + ] )