Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,8 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
}

if (true/*Is_flashmask*/) {
if (tidx < kBlockN) {
const int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) {
sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx);
if(!Is_causal) {
sFlashMaskUTEnd(tidx) = gFlashMaskUTEnd(tidx);
Expand Down
8 changes: 4 additions & 4 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
(((m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_ut_has_start || m_block * kBlockM < gFlashMaskLTEndMax[n_block])) ||
(m_block * kBlockM < gFlashMaskUTEndMax[n_block] && (!flashmask_ut_has_start || (m_block + 1) * kBlockM > gFlashMaskUTStartMin[n_block]))))) {

if (tidx < kBlockN) {
if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) {
sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx);
sFlashMaskUTEnd(tidx) = gFlashMaskUTEnd(tidx);
if(flashmask_ut_has_start) {
Expand Down Expand Up @@ -1415,7 +1415,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
if (true/*Is_flashmask*/ && (!enable_mask_bypass ||
(m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_lt_has_end || m_block * kBlockM < gFlashMaskLTEndMax[n_block]))) {

if (tidx < kBlockN) {
if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) {
sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx);
if(flashmask_lt_has_end) {
sFlashMaskLTEnd(tidx) = gFlashMaskLTEnd(tidx);
Expand Down Expand Up @@ -1594,7 +1594,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
(((m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_ut_has_start || m_block * kBlockM < gFlashMaskLTEndMax[n_block])) ||
(m_block * kBlockM < gFlashMaskUTEndMax[n_block] && (!flashmask_ut_has_start || (m_block + 1) * kBlockM > gFlashMaskUTStartMin[n_block]))))) {

if (tidx < kBlockN) {
if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) {
sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx);
sFlashMaskUTEnd(tidx) = gFlashMaskUTEnd(tidx);
if(flashmask_ut_has_start) {
Expand Down Expand Up @@ -1646,7 +1646,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p
} else if (Is_causal && true/*Is_flashmask*/ && (!enable_mask_bypass ||
(m_block + 1) * kBlockM > gFlashMaskLTStartMin[n_block] && (!flashmask_lt_has_end || m_block * kBlockM < gFlashMaskLTEndMax[n_block]))) {

if (tidx < kBlockN) {
if (tidx < kBlockN && (n_block != n_block_max - 1 || binfo.actual_seqlen_k % kBlockN == 0 || tidx < binfo.actual_seqlen_k % kBlockN)) {
sFlashMaskLTStart(tidx) = gFlashMaskLTStart(tidx);
if(flashmask_lt_has_end) {
sFlashMaskLTEnd(tidx) = gFlashMaskLTEnd(tidx);
Expand Down