From 3450d89830017e7b5f20b3bb28553736137923e6 Mon Sep 17 00:00:00 2001 From: wangboyun Date: Mon, 15 Apr 2024 12:03:04 +0800 Subject: [PATCH 1/2] var mask fix --- csrc/flash_attn/src/flash_fwd_kernel.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 0bdfeab7981..bac51c97afd 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -175,7 +175,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - const uint64_t row_offset_mask = (uint64_t)((bidb * params.mask_head_mod_size + (bidh % params.mask_head_mod_size)) * params.mask_seq_q_mod_size + (m_block * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k @@ -392,9 +391,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); if (Is_attn_mask) { - flash::apply_attn_mask(scores, tPgMask, tPcMask, - m_block == m_block_max - 1 ? m_residue : params.seqlen_q, - n_block == n_block_max - 1 ? n_residue : params.seqlen_k, + flash::apply_attn_mask(scores, tPgMask, tPcMask, + params.seqlen_q, + params.seqlen_k, params.unscale_softmax); tPgMask.data() = tPgMask.data() + (-kBlockN); } From c450e752a72fcf0677c2f0f0ac00bf78dc53daea Mon Sep 17 00:00:00 2001 From: wangboyun Date: Mon, 15 Apr 2024 18:01:46 +0800 Subject: [PATCH 2/2] refine --- csrc/flash_attn/src/flash_fwd_kernel.h | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index bac51c97afd..9726af9d5f7 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -146,11 +146,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; - // umiswing: residue is for predication of additional mask gmem access. - // Additional mask for varlen qkv is supported, but a varlen mask is not supported. - const int m_residue = params.seqlen_q % kBlockM ? params.seqlen_q % kBlockM : kBlockM; - const int n_residue = params.seqlen_k % kBlockN ? params.seqlen_k % kBlockN : kBlockN; - const int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); @@ -518,9 +513,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); if (Is_attn_mask) { - flash::apply_attn_mask(scores, tPgMask, tPcMask, - m_block == m_block_max - 1 ? m_residue : params.seqlen_q, - n_block == n_block_max - 1 ? n_residue : params.seqlen_k, + flash::apply_attn_mask(scores, tPgMask, tPcMask, + params.seqlen_q, + params.seqlen_k, params.unscale_softmax); tPgMask.data() = tPgMask.data() + (-kBlockN); }