Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 80 additions & 37 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1080,9 +1080,86 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}

const index_t row_offset_sparse_mask = (bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * params.seqlen_k + (n_block_max - 1) * kBlockN;
const index_t row_offset_sparsemask_nblock =
(bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * cute::ceil_div(params.seqlen_k, kBlockN);
Tensor gFlashMaskLTStart = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_downstart_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gFlashMaskLTEnd = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_downend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gFlashMaskUTStart = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upstart_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gFlashMaskUTEnd = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
const int* gFlashMaskLTStartMax = reinterpret_cast<int32_t*>(params.flashmask_downstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskLTStartMin = reinterpret_cast<int32_t*>(params.flashmask_downstart_nblockmin) + row_offset_sparsemask_nblock;
const int* gFlashMaskLTEndMax = reinterpret_cast<int32_t*>(params.flashmask_downend_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskLTEndMin = reinterpret_cast<int32_t*>(params.flashmask_downend_nblockmin) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTStartMax = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTStartMin = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmin) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTEndMax = reinterpret_cast<int32_t*>(params.flashmask_upend_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTEndMin = reinterpret_cast<int32_t*>(params.flashmask_upend_nblockmin) + row_offset_sparsemask_nblock;

const bool enable_mask_bypass = params.enable_mask_bypass;
const bool flashmask_lt_has_end = params.flashmask_downend_ptr != nullptr;
const bool flashmask_ut_has_start = params.flashmask_upstart_ptr != nullptr;

#define SPARSE_MASKED_DOWN(N_BLOCK) \
(((m_block * kBlockM) >= gFlashMaskLTStartMax[(N_BLOCK)]) && (!flashmask_lt_has_end || (m_block + 1) * kBlockM <= gFlashMaskLTEndMin[(N_BLOCK)]))

#define SPARSE_MASKED_UP(N_BLOCK) \
(!Is_causal && (m_block + 1) * kBlockM <= gFlashMaskUTEndMin[(N_BLOCK)] && (!flashmask_ut_has_start || m_block * kBlockM >= gFlashMaskUTStartMax[(N_BLOCK)]))

#define SPARSE_MASKED(N_BLOCK) \
(SPARSE_MASKED_DOWN(N_BLOCK) || SPARSE_MASKED_UP(N_BLOCK))

for (--n_block_max; n_block_max >= 0; --n_block_max) {
if (true/*Is_flashmask*/ && n_block_max >= 0 && enable_mask_bypass && SPARSE_MASKED(n_block_max)) {
gFlashMaskLTStart.data() = gFlashMaskLTStart.data() + (-kBlockN);
gFlashMaskLTEnd.data() = gFlashMaskLTEnd.data() + (-kBlockN);
if (!Is_causal) {
gFlashMaskUTEnd.data() = gFlashMaskUTEnd.data() + (-kBlockN);
gFlashMaskUTStart.data() = gFlashMaskUTStart.data() + (-kBlockN);
}
continue;
} else {
n_block_max++;
break;
}
}

if (n_block_max <= 0) {
// need clear O block if we skip the whole row, otherwise elements in corresponding block will be uninitialized
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));

typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);

Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
cute::clear(tOrO);

// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(kBlockM, kHeadDim)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);

return;
}

// We iterate over the blocks in reverse order. This is because the last block is the only one
Expand All @@ -1104,9 +1181,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
+ (m_block * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k
+ (n_block_max - 1) * kBlockN;

const index_t row_offset_sparse_mask = (bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * params.seqlen_k + (n_block_max - 1) * kBlockN;
const index_t row_offset_sparsemask_nblock =
(bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * cute::ceil_div(params.seqlen_k, kBlockN);

Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Expand All @@ -1121,28 +1195,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));

Tensor gFlashMaskLTStart = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_downstart_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gFlashMaskLTEnd = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_downend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gFlashMaskUTStart = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upstart_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
Tensor gFlashMaskUTEnd = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.flashmask_upend_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});
const int* gFlashMaskLTStartMax = reinterpret_cast<int32_t*>(params.flashmask_downstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskLTStartMin = reinterpret_cast<int32_t*>(params.flashmask_downstart_nblockmin) + row_offset_sparsemask_nblock;
const int* gFlashMaskLTEndMax = reinterpret_cast<int32_t*>(params.flashmask_downend_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskLTEndMin = reinterpret_cast<int32_t*>(params.flashmask_downend_nblockmin) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTStartMax = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTStartMin = reinterpret_cast<int32_t*>(params.flashmask_upstart_nblockmin) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTEndMax = reinterpret_cast<int32_t*>(params.flashmask_upend_nblockmax) + row_offset_sparsemask_nblock;
const int* gFlashMaskUTEndMin = reinterpret_cast<int32_t*>(params.flashmask_upend_nblockmin) + row_offset_sparsemask_nblock;


const bool enable_mask_bypass = params.enable_mask_bypass;
const bool flashmask_lt_has_end = params.flashmask_downend_ptr != nullptr;
const bool flashmask_ut_has_start = params.flashmask_upstart_ptr != nullptr;

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Expand Down Expand Up @@ -1299,15 +1351,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.

#define SPARSE_MASKED_DOWN(N_BLOCK) \
(((m_block * kBlockM) >= gFlashMaskLTStartMax[(N_BLOCK)]) && (!flashmask_lt_has_end || (m_block + 1) * kBlockM <= gFlashMaskLTEndMin[(N_BLOCK)]))

#define SPARSE_MASKED_UP(N_BLOCK) \
(!Is_causal && (m_block + 1) * kBlockM <= gFlashMaskUTEndMin[(N_BLOCK)] && (!flashmask_ut_has_start || m_block * kBlockM >= gFlashMaskUTStartMax[(N_BLOCK)]))

#define SPARSE_MASKED(N_BLOCK) \
(SPARSE_MASKED_DOWN(N_BLOCK) || SPARSE_MASKED_UP(N_BLOCK))

constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Expand Down