diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 383aeab52c6..693f66d4d3a 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -1567,7 +1567,8 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms } if (true/*Is_flashmask*/) { - if (tidx < kBlockN) { + const int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) { sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx); if(!Is_causal) { sFlashMaskUTEnd(tidx) = gFlashMaskUTEnd(tidx); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 444d04bbab0..711cf1bc676 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1348,7 +1348,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p (((m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_ut_has_start || m_block * kBlockM < gFlashMaskLTEndMax[n_block])) || (m_block * kBlockM < gFlashMaskUTEndMax[n_block] && (!flashmask_ut_has_start || (m_block + 1) * kBlockM > gFlashMaskUTStartMin[n_block]))))) { - if (tidx < kBlockN) { + if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) { sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx); sFlashMaskUTEnd(tidx) = gFlashMaskUTEnd(tidx); if(flashmask_ut_has_start) { @@ -1415,7 +1415,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p if (true/*Is_flashmask*/ && (!enable_mask_bypass || (m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_lt_has_end || m_block * kBlockM < gFlashMaskLTEndMax[n_block]))) { - if (tidx < kBlockN) { + if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) { sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx); if(flashmask_lt_has_end) { sFlashMaskLTEnd(tidx) = gFlashMaskLTEnd(tidx); @@ -1594,7 +1594,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p (((m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_ut_has_start || m_block * kBlockM < gFlashMaskLTEndMax[n_block])) || (m_block * kBlockM < gFlashMaskUTEndMax[n_block] && (!flashmask_ut_has_start || (m_block + 1) * kBlockM > gFlashMaskUTStartMin[n_block]))))) { - if (tidx < kBlockN) { + if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) { sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx); sFlashMaskUTEnd(tidx) = gFlashMaskUTEnd(tidx); if(flashmask_ut_has_start) { @@ -1646,7 +1646,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p } else if (Is_causal && true/*Is_flashmask*/ && (!enable_mask_bypass || (m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_lt_has_end || m_block * kBlockM < gFlashMaskLTEndMax[n_block]))) { - if (tidx < kBlockN) { + if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) { sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx); if(flashmask_lt_has_end) { sFlashMaskLTEnd(tidx) = gFlashMaskLTEnd(tidx);