diff --git a/generate_startend_row_indices.py b/generate_startend_row_indices.py index 5209376..74de86b 100644 --- a/generate_startend_row_indices.py +++ b/generate_startend_row_indices.py @@ -1,6 +1,122 @@ import paddle import numpy as np +def _scale_int_list(values, seqlen_k: int, base_seqlen_k: int = 8192, min_value: int = 1): + scale = seqlen_k / base_seqlen_k + return [max(min_value, int(v * scale)) for v in values] + +def _default_doc_seqlens_int(seqlen_k: int): + # Keep in sync with the defaults used by generate_{causal_,}document_mask/share_question/causal_blockwise. + doc_seqlens = [2538, 1742, 3213] + if seqlen_k != 8192: + doc_seqlens = _scale_int_list(doc_seqlens, seqlen_k=seqlen_k) + return doc_seqlens + +def _default_doc_seqlens_prefix_pairs(seqlen_k: int): + # Keep in sync with the defaults used by generate_prefix_lm_document_mask. + doc_seqlens = [(1024, 2538), (1742, 1742), (512, 3213)] + if seqlen_k != 8192: + scale = seqlen_k / 8192 + doc_seqlens = [tuple(max(1, int(v * scale)) for v in pair) for pair in doc_seqlens] + return doc_seqlens + +def _perturb_doc_seqlens_int(doc_seqlens, head_idx: int, seqlen_k: int): + """ + Deterministically perturb integer segment lengths to avoid identical per-head masks, + even when `doc_seqlens` is None. + """ + if not doc_seqlens: + return doc_seqlens + if len(doc_seqlens) == 1: + return [min(seqlen_k, max(1, int(doc_seqlens[0]) + ((head_idx % 3) - 1)))] + + segs = [max(1, int(v)) for v in doc_seqlens] + delta = int((head_idx * 7) % 17) - 8 # [-8, 8], period 17 + if delta == 0: + delta = 1 + i = head_idx % len(segs) + j = (i + 1) % len(segs) + + segs[i] = max(1, segs[i] + delta) + segs[j] = max(1, segs[j] - delta) + + total = sum(segs) + if total > seqlen_k: + excess = total - seqlen_k + k = max(range(len(segs)), key=lambda t: segs[t]) + segs[k] = max(1, segs[k] - excess) + return segs + +def _perturb_doc_seqlens_prefix_pairs(doc_seqlens, head_idx: int, seqlen_k: int): + """ + Deterministically perturb prefix-LM (prefix_length, seq_length) segments per head. + Keeps 0 <= prefix_length <= seq_length and sum(seq_length) <= seqlen_k. + """ + if not doc_seqlens: + return doc_seqlens + pairs = [(max(0, int(p)), max(1, int(s))) for p, s in doc_seqlens] + if len(pairs) == 1: + p, s = pairs[0] + p = min(s, max(0, p + ((head_idx % 5) - 2))) + s = min(seqlen_k, max(1, s + ((head_idx % 3) - 1))) + p = min(p, s) + return [(p, s)] + + delta = int((head_idx * 7) % 17) - 8 + if delta == 0: + delta = 1 + i = head_idx % len(pairs) + j = (i + 1) % len(pairs) + + prefixes = [p for p, _ in pairs] + lengths = [s for _, s in pairs] + + lengths[i] = max(1, lengths[i] + delta) + lengths[j] = max(1, lengths[j] - delta) + + prefixes[i] = min(lengths[i], max(0, prefixes[i] + (delta // 2))) + prefixes[j] = min(lengths[j], max(0, prefixes[j] - (delta // 2))) + + total = sum(lengths) + if total > seqlen_k: + excess = total - seqlen_k + k = max(range(len(lengths)), key=lambda t: lengths[t]) + lengths[k] = max(1, lengths[k] - excess) + prefixes[k] = min(prefixes[k], lengths[k]) + + return list(zip(prefixes, lengths)) + +def _rotate_list(values, shift: int): + if not values: + return values + shift = shift % len(values) + return values[shift:] + values[:shift] + +def _stack_per_head_startend_row_indices( + batch_size, + seqlen_q, + seqlen_k, + h, + per_head_fn, +): + """ + Build [b, h, seqlen_k, bound_num] by calling a per-head generator that returns + [b, 1, seqlen_k, bound_num] for each head. + """ + if h == 1: + return per_head_fn(0) + startend_list = [] + causal = None + for head_idx in range(h): + s, c = per_head_fn(head_idx) + if causal is None: + causal = c + else: + assert c == causal + assert s.shape[0] == batch_size and s.shape[1] == 1 and s.shape[2] == seqlen_k + startend_list.append(s) + return paddle.concat(startend_list, axis=1), causal + def startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal=True): if startend_row_indices is None: return None @@ -54,6 +170,30 @@ def generate_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, window_size= causal=True return startend_row_indices, causal +def generate_sliding_window_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, window_size=None): + """ + Per-head distinct sliding window: each head uses a different window_size. + Output shape: [b, h, seqlen_k, 1] + """ + if window_size is None: + window_size = 1024 + if seqlen_k != 8192: + window_size = int(window_size * (seqlen_k / 8192)) + print(f"{seqlen_k=}, auto setting window_size to {window_size}") + + # Choose window sizes spanning roughly [window_size/2, window_size], clipped to [1, seqlen_q]. + max_window = max(1, int(window_size)) + head_ids = paddle.arange(h, dtype=paddle.int32) + denom = max(1, h - 1) + # linear ramp: largest window at head 0, smallest at head h-1 + win = max_window - (head_ids * (max_window // 2) // denom) + win = paddle.clip(win, min=1, max=seqlen_q).reshape((1, h, 1, 1)) + + col = paddle.arange(seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)) + startend_row_indices = paddle.clip(col + win, max=seqlen_q).repeat_interleave(batch_size, 0) + causal = True + return startend_row_indices, causal + def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): # TODO: this seems buggy, to be fixed if doc_seqlens == None: @@ -75,6 +215,22 @@ def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens causal = True return startend_row_indices, causal +def generate_causal_document_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + """ + Per-head distinct causal document mask via rotating document segments per head. + Output shape: [b, h, seqlen_k, 1] + """ + base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens) + + def _per_head(head_idx: int): + per_head = _rotate_list(list(base_doc_seqlens), head_idx) + per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k) + return generate_causal_document_mask( + batch_size, seqlen_q, seqlen_k, 1, per_head + ) + + return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head) + def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): # TODO: this seems buggy, to be fixed if doc_seqlens == None: @@ -114,6 +270,22 @@ def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): causal = False return startend_row_indices, causal +def generate_document_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + """ + Per-head distinct bidirectional document mask via rotating document segments per head. + Output shape: [b, h, seqlen_k, 2] + """ + base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens) + + def _per_head(head_idx: int): + per_head = _rotate_list(list(base_doc_seqlens), head_idx) + per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k) + return generate_document_mask( + batch_size, seqlen_q, seqlen_k, 1, per_head + ) + + return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head) + def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): if doc_seqlens == None: doc_seqlens = [2538, 1742, 3213] @@ -146,6 +318,22 @@ def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens= causal = True return startend_row_indices, causal +def generate_share_question_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + """ + Per-head distinct share-question mask via rotating document segments per head. + Output shape: [b, h, seqlen_k, 1] + """ + base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens) + + def _per_head(head_idx: int): + per_head = _rotate_list(list(base_doc_seqlens), head_idx) + per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k) + return generate_share_question_mask( + batch_size, seqlen_q, seqlen_k, 1, per_head + ) + + return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head) + def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, global_token=16, window_size=None): if window_size == None: window_size = (512, 512) @@ -184,6 +372,51 @@ def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, globa causal = False return startend_row_indices, causal +def generate_global_sliding_window_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, global_token=16, window_size=None): + """ + Per-head distinct global sliding window mask by varying (global_token, left/right window) per head. + Output shape: [b, h, seqlen_k, 4] + """ + if window_size is None: + window_size = (512, 512) + if seqlen_k != 8192: + window_size = tuple(int(ws * (seqlen_k / 8192)) for ws in window_size) + print(f"{seqlen_k=}, auto setting window_size to {window_size}") + left_base, right_base = window_size + + head_ids = paddle.arange(h, dtype=paddle.int32) + denom = max(1, h - 1) + + # Vary global token count modestly across heads. + gt_base = int(global_token) + gt = gt_base + (head_ids - (h // 2)) # centered offsets + gt = paddle.clip(gt, min=0, max=seqlen_k).reshape((1, h, 1, 1)) + + # Vary left/right windows across heads. + left = int(left_base) - (head_ids * max(1, int(left_base) // 4) // denom) + right = int(right_base) - (head_ids * max(1, int(right_base) // 4) // denom) + left = paddle.clip(left, min=0, max=seqlen_q).reshape((1, h, 1, 1)) + right = paddle.clip(right, min=0, max=seqlen_q).reshape((1, h, 1, 1)) + + col = paddle.arange(seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)) + + down_left_start = paddle.clip(col + left + 1, max=seqlen_q) + down_left_start = paddle.where(col < gt, paddle.full_like(down_left_start, seqlen_q), down_left_start) + down_left_end = paddle.full_like(down_left_start, seqlen_q) + + up_right_start = paddle.full_like(down_left_start, 0) + up_right_start = paddle.where(col >= (gt + right + 1), gt, up_right_start) + + up_right_end = col - right + up_right_end = paddle.where(col < (gt + right + 1), paddle.zeros_like(up_right_end), up_right_end) + + startend_row_indices = paddle.concat( + [down_left_start, down_left_end, up_right_start, up_right_end], axis=-1 + ).repeat_interleave(batch_size, 0) + startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + causal = False + return startend_row_indices, causal + def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): # TODO: this seems buggy, to be fixed if doc_seqlens == None: @@ -216,6 +449,22 @@ def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlen causal = True return startend_row_indices, causal +def generate_causal_blockwise_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + """ + Per-head distinct causal blockwise mask via rotating document segments per head. + Output shape: [b, h, seqlen_k, 2] + """ + base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens) + + def _per_head(head_idx: int): + per_head = _rotate_list(list(base_doc_seqlens), head_idx) + per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k) + return generate_causal_blockwise_mask( + batch_size, seqlen_q, seqlen_k, 1, per_head + ) + + return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head) + def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): """ tuple(prefix_length, seq_length) @@ -260,6 +509,22 @@ def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seql causal = False return startend_row_indices, causal +def generate_prefix_lm_document_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + """ + Per-head distinct prefix-LM document mask via rotating (prefix_length, seq_length) segments per head. + Output shape: [b, h, seqlen_k, 2] + """ + base_doc_seqlens = _default_doc_seqlens_prefix_pairs(seqlen_k) if doc_seqlens is None else list(doc_seqlens) + + def _per_head(head_idx: int): + per_head = _rotate_list(list(base_doc_seqlens), head_idx) + per_head = _perturb_doc_seqlens_prefix_pairs(per_head, head_idx=head_idx, seqlen_k=seqlen_k) + return generate_prefix_lm_document_mask( + batch_size, seqlen_q, seqlen_k, 1, per_head + ) + + return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head) + def generate_prefix_lm_causal_mask(batch_size, seqlen_q, seqlen_k, h, prefix_length=None): """ tuple(prefix_length, seq_length) @@ -278,6 +543,32 @@ def generate_prefix_lm_causal_mask(batch_size, seqlen_q, seqlen_k, h, prefix_len causal = False return startend_row_indices, causal +def generate_prefix_lm_causal_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, prefix_length=None): + """ + Per-head distinct prefix-LM causal mask by varying prefix_length per head. + Output shape: [b, h, seqlen_k, 2] + """ + if prefix_length is None: + prefix_length = 1024 + if seqlen_k != 8192: + prefix_length = int(prefix_length * (seqlen_k / 8192)) + print(f"{seqlen_k=}, auto setting doc_seqlens to {prefix_length}") + base = int(prefix_length) + head_ids = paddle.arange(h, dtype=paddle.int32) + denom = max(1, h - 1) + # Spread prefix lengths roughly in [base/2, base], clipped. + pref = base - (head_ids * (base // 2) // denom) + pref = paddle.clip(pref, min=0, max=seqlen_k).reshape((1, h, 1, 1)) + + col = paddle.arange(seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)) + down_left = paddle.full((1, h, seqlen_k, 1), seqlen_k, dtype=paddle.int32) + up_right = paddle.where(col < pref, paddle.zeros_like(col), col) + + startend_row_indices = paddle.concat([down_left, up_right], axis=-1).repeat_interleave(batch_size, 0) + startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + causal = False + return startend_row_indices, causal + def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None): """ tuple(offset, maskout_len) @@ -313,6 +604,42 @@ def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None causal = True return startend_row_indices, causal +def generate_qk_sparse_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None): + """ + Per-head distinct qk-sparse mask by shifting maskout offsets per head. + Output shape: [b, h, seqlen_k, 2] + """ + if maskout_pair is None: + maskout_pair = [(1024, 538), (2358, 1700)] + if seqlen_k != 8192: + scale = seqlen_k / 8192 + maskout_pair = [tuple(int(v * scale) for v in pair) for pair in maskout_pair] + print(f"{seqlen_k=}, auto setting maskout_pair to {maskout_pair}") + + base_pairs = list(maskout_pair) + shift_stride = max(1, seqlen_k // max(8, h)) + + def _per_head(head_idx: int): + shift = head_idx * shift_stride + pairs = [] + for offset, length in base_pairs: + o = min(max(0, offset + shift), seqlen_k) + l = min(max(0, length), max(0, seqlen_k - o)) + pairs.append((o, l)) + # Ensure offsets are strictly increasing for the generator. + pairs.sort(key=lambda x: x[0]) + dedup = [] + last_offset = -1 + for o, l in pairs: + o = max(o, last_offset + 1) + o = min(o, seqlen_k) + l = min(l, max(0, seqlen_k - o)) + dedup.append((o, l)) + last_offset = o + return generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, 1, dedup) + + return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head) + def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=None): # np.random.seed(0) if start_row == None: diff --git a/test_blockmask_qh.py b/test_blockmask_qh.py new file mode 100644 index 0000000..58407d7 --- /dev/null +++ b/test_blockmask_qh.py @@ -0,0 +1,285 @@ +import os +import math +import itertools +import pytest +from einops import rearrange, repeat +import paddle +import time +from paddle.nn.functional.flash_attention import flashmask_attention +from generate_startend_row_indices import ( + startend_row_indices_to_attn_bias, + generate_none_mask, + generate_sliding_window_mask, + generate_sliding_window_mask_distinct_heads, + generate_causal_document_mask, + generate_causal_document_mask_distinct_heads, + generate_document_mask, + generate_document_mask_distinct_heads, + generate_share_question_mask, + generate_share_question_mask_distinct_heads, + generate_global_sliding_window_mask, + generate_global_sliding_window_mask_distinct_heads, + generate_causal_blockwise_mask, + generate_causal_blockwise_mask_distinct_heads, + generate_prefix_lm_document_mask, + generate_prefix_lm_document_mask_distinct_heads, + generate_prefix_lm_causal_mask, + generate_prefix_lm_causal_mask_distinct_heads, + generate_qk_sparse_mask, + generate_qk_sparse_mask_distinct_heads, + generate_random_eviction_mask +) +from functools import partial +from test_util import attention_ref, blockmask_to_densemask, random_blockmask, flashmask_to_densemask + +# batch_size, seqlen_q, seqlen_k, nheads, nheads_kv +shape_cases = ( + [ + (28, 128, 128, 16, 4), + (4, 256, 256, 4, 1), + # (2, 8192, 32768, 32, 4), # this will oom + # (2, 8192, 8192, 32, 4), # this will oom + (1, 8192, 8192, 1, 1), + # (2, 16384, 16384, 1, 1), + (1, 128, 128, 1, 1), + (1, 127, 128, 1, 1), + (1, 16384, 16384, 1, 1), + # (2, 16384, 16383, 4, 1), + # my case + ] + # tridao case + + list(itertools.product( + [1], # batch_size + [1, 64, 128, 256, 239, 799, 113, 113, 128, 113, 108, 256, 384, 640, 512, 1024, 1023, 1024,], # seqlen_q + [128, 192, 256, 203, 128, 217, 211, 256, 512, 256, 128, 256, 1024, 1024, 1023,], # seqlen_k + [1,2], # nheads + [1], # nheads_kv + )) + + list(itertools.product( + [2], # batch_size + [4096, 4224], # seqlen_q + [4096, 4224], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) +) + +# Generate all combinations for second param +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + # flashmask historically supported `num_head` = 1 or `nheads_kv`. + # We also want to test `num_head` = `nheads` (i.e. `nheads_q`). + nheads_startend_row_indices_values = list(dict.fromkeys([1, nheads_kv, nheads])) + for nheads_startend_row_indices in nheads_startend_row_indices_values: + yield ( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices + ) + +def _expand_startend_row_indices_heads(startend_row_indices, target_num_heads: int): + if startend_row_indices is None: + return None + if target_num_heads <= 0: + raise ValueError(f"Invalid target_num_heads={target_num_heads}") + _, num_head, _, _ = startend_row_indices.shape + if num_head == target_num_heads: + return startend_row_indices + if target_num_heads % num_head != 0: + raise ValueError( + f"Cannot expand startend_row_indices head dim from {num_head} to {target_num_heads}" + ) + repeat_factor = target_num_heads // num_head + return startend_row_indices.repeat_interleave(repeat_factor, axis=1) + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("fa_version", [3]) +@pytest.mark.parametrize("d, dv", [(128, 128)]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()) +) +@pytest.mark.parametrize( + "gen_startend_row_indices", + ( + [ + # partial(generate_none_mask, causal=False), # full + # partial(generate_none_mask, causal=True), # causal + partial(generate_sliding_window_mask), # sliding window + partial(generate_causal_document_mask), # causal document mask + partial(generate_document_mask), # document mask + partial(generate_share_question_mask), # share question mask + partial(generate_global_sliding_window_mask), # global sliding window + partial(generate_causal_blockwise_mask), # causal blockwise mask + partial(generate_prefix_lm_document_mask), # prefix lm document mask + partial(generate_prefix_lm_causal_mask), # prefix lm causal mask + partial(generate_qk_sparse_mask), # qk-sparse mask + partial(generate_random_eviction_mask), # random eviction mask + ] + + ( + [ + # Enable per-query-head distinct payload tests explicitly to avoid case explosion by default. + partial(generate_sliding_window_mask_distinct_heads), + partial(generate_causal_document_mask_distinct_heads), + partial(generate_document_mask_distinct_heads), + partial(generate_share_question_mask_distinct_heads), + partial(generate_global_sliding_window_mask_distinct_heads), + partial(generate_causal_blockwise_mask_distinct_heads), + partial(generate_prefix_lm_document_mask_distinct_heads), + partial(generate_prefix_lm_causal_mask_distinct_heads), + partial(generate_qk_sparse_mask_distinct_heads), + ] + if os.getenv("FLASHMASK_TEST_DISTINCT_HEADS", "0") == "1" + else [] + ) + ), +) +def test_flashmask( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0 +): + paddle.seed(2024) + assert nheads % nheads_kv == 0 + q_ref = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype) + # print(q_ref) + k_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype) + v_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype) + + q_ref.stop_gradient = False + k_ref.stop_gradient = False + v_ref.stop_gradient = False + + q_bf16, k_bf16, v_bf16 = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q_bf16.stop_gradient = False + k_bf16.stop_gradient = False + v_bf16.stop_gradient = False + + q, k, v = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + # print(q_ref) + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + + startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices) + startend_row_indices = _expand_startend_row_indices_heads( + startend_row_indices, nheads_startend_row_indices + ) + assert startend_row_indices is None or startend_row_indices.shape[1] == nheads_startend_row_indices + + if startend_row_indices is None and causal and d == 80: + pytest.skip(f"Skipping because running headdim 80 with flash_attn in causal mask") + # print(q_ref) + print(k_ref.shape) + blockmask = random_blockmask( + shape=[ + startend_row_indices.shape[0], + startend_row_indices.shape[1], + (seqlen_q + 127)// 128, + (seqlen_k + 127)// 128 + ], + dtype=paddle.int32, + is_causal=causal, + ref_q = q_ref + ) + assert blockmask.shape[1] == nheads_startend_row_indices + # print(q_ref) + # paddle.save(q, 'query.pd') + # paddle.save(k, 'key.pd') + # paddle.save(v, 'value.pd') + # paddle.save(blockmask, 'blockmask.pd') + # paddle.save(startend_row_indices, 'startend_row_indices.pd') + + mask_flash = flashmask_to_densemask(startend_row_indices, seqlen_q, nheads_startend_row_indices, causal) + mask_block = blockmask_to_densemask(blockmask,seqlen_q,seqlen_k,paddle.int32,causal) + + mask_inf = mask_flash & mask_block + # print(mask_inf) + attn_bias = paddle.where( + mask_inf, + paddle.zeros_like(mask_inf, dtype=paddle.float32), + paddle.full_like(mask_inf, float("-inf"), dtype=paddle.float32), + ) + paddle.save(attn_bias, 'attn_bias.pd') + # time.sleep(0.1) + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + causal=causal, + attn_bias=attn_bias + ) + + out_bf16, attn_bf16 = attention_ref( + q_bf16, + k_bf16, + v_bf16, + causal=causal, + attn_bias=attn_bias, + upcast=False, + reorder_ops=True + ) + + # # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert softcap == 0.0 + rtol = 2 if softcap == 0.0 else 3 + + print(f"Paddle naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}") + print(f"Paddle naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}") + + if fa_version == 2: + paddle.set_flags({'FLAGS_flash_attn_version': 2}) + elif fa_version == 3: + paddle.set_flags({'FLAGS_flash_attn_version': 3}) + else: + raise ValueError( + f"Invalid flash attention version: {fa_version}" + ) + + out, lse = flashmask_attention( + q, + k, + v, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + block_mask=blockmask + ) + print(f"flashmask output max at {(out - out_ref).abs().argmax()}") + print(f"flashmask Output max diff: {(out - out_ref).abs().max().item()}") + print(f"flashmask Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + + assert (out - out_ref).abs().max().item() <= rtol * (out_bf16 - out_ref).abs().max().item() + fwd_atol + # return + + g = paddle.randn(shape=out.shape, dtype=out.dtype) + paddle.save(g, 'g.pd') + out.backward(g) + out_ref.backward(g) + out_bf16.backward(g) + paddle.device.synchronize() + + print(f"flashmask dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"flashmask dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"flashmask dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"flashmask dQ mean diff: {(q.grad - q_ref.grad).abs().mean().item()}") + print(f"flashmask dK mean diff: {(k.grad - k_ref.grad).abs().mean().item()}") + print(f"flashmask dV mean diff: {(v.grad - v_ref.grad).abs().mean().item()}") + + print(f"Paddle naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}") + print(f"Paddle naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}") + print(f"Paddle naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}") + + dq_atol = 2 * (q_ref.grad + 0.3 - 0.3 - q_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (q.grad - q_ref.grad).abs().max().item() <= rtol * (q_bf16.grad - q_ref.grad).abs().max().item() + dq_atol + dk_atol = 2 * (k_ref.grad + 0.3 - 0.3 - k_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (k.grad - k_ref.grad).abs().max().item() <= rtol * (k_bf16.grad - k_ref.grad).abs().max().item() + dk_atol + dv_atol = 2 * (v_ref.grad + 0.3 - 0.3 - v_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (v.grad - v_ref.grad).abs().max().item() <= rtol * (v_bf16.grad - v_ref.grad).abs().max().item() + dv_atol diff --git a/test_flashmask.py b/test_flashmask.py index ab98115..6d4bda9 100644 --- a/test_flashmask.py +++ b/test_flashmask.py @@ -9,14 +9,23 @@ startend_row_indices_to_attn_bias, generate_none_mask, generate_sliding_window_mask, + generate_sliding_window_mask_distinct_heads, generate_causal_document_mask, + generate_causal_document_mask_distinct_heads, generate_document_mask, + generate_document_mask_distinct_heads, generate_share_question_mask, + generate_share_question_mask_distinct_heads, generate_global_sliding_window_mask, + generate_global_sliding_window_mask_distinct_heads, generate_causal_blockwise_mask, + generate_causal_blockwise_mask_distinct_heads, generate_prefix_lm_document_mask, + generate_prefix_lm_document_mask_distinct_heads, generate_prefix_lm_causal_mask, + generate_prefix_lm_causal_mask_distinct_heads, generate_qk_sparse_mask, + generate_qk_sparse_mask_distinct_heads, generate_random_eviction_mask ) from functools import partial @@ -58,6 +67,11 @@ def generate_shapes(): for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: nheads_startend_row_indices_values = [1, nheads_kv] + if ( + os.getenv("FLASHMASK_TEST_DISTINCT_HEADS", "0") == "1" + ) and nheads != nheads_kv: + nheads_startend_row_indices_values.append(nheads) + nheads_startend_row_indices_values = list(dict.fromkeys(nheads_startend_row_indices_values)) for nheads_startend_row_indices in nheads_startend_row_indices_values: yield ( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices @@ -79,20 +93,39 @@ def generate_shapes(): ) @pytest.mark.parametrize( "gen_startend_row_indices", - [ - partial(generate_none_mask, causal=False), # full - partial(generate_none_mask, causal=True), # causal - partial(generate_sliding_window_mask), # sliding window - partial(generate_causal_document_mask), # causal document mask - partial(generate_document_mask), # document mask - partial(generate_share_question_mask), # share question mask - partial(generate_global_sliding_window_mask), # global sliding window - partial(generate_causal_blockwise_mask), # causal blockwise mask - partial(generate_prefix_lm_document_mask), # prefix lm document mask - partial(generate_prefix_lm_causal_mask), # prefix lm causal mask - partial(generate_qk_sparse_mask), # qk-sparse mask - partial(generate_random_eviction_mask), # random eviction mask - ], + ( + [ + partial(generate_none_mask, causal=False), # full + partial(generate_none_mask, causal=True), # causal + partial(generate_sliding_window_mask), # sliding window + partial(generate_causal_document_mask), # causal document mask + partial(generate_document_mask), # document mask + partial(generate_share_question_mask), # share question mask + partial(generate_global_sliding_window_mask), # global sliding window + partial(generate_causal_blockwise_mask), # causal blockwise mask + partial(generate_prefix_lm_document_mask), # prefix lm document mask + partial(generate_prefix_lm_causal_mask), # prefix lm causal mask + partial(generate_qk_sparse_mask), # qk-sparse mask + partial(generate_random_eviction_mask), # random eviction mask + ] + + ( + [ + # Per-head distinct startend_row_indices variants (enabled explicitly). + partial(generate_sliding_window_mask_distinct_heads), + partial(generate_causal_document_mask_distinct_heads), + partial(generate_document_mask_distinct_heads), + partial(generate_share_question_mask_distinct_heads), + partial(generate_global_sliding_window_mask_distinct_heads), + partial(generate_causal_blockwise_mask_distinct_heads), + partial(generate_prefix_lm_document_mask_distinct_heads), + partial(generate_prefix_lm_causal_mask_distinct_heads), + partial(generate_qk_sparse_mask_distinct_heads), + # `generate_random_eviction_mask` is already per-head distinct. + ] + if os.getenv("FLASHMASK_TEST_DISTINCT_HEADS", "0") == "1" + else [] + ) + ), ) def test_flashmask( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0 diff --git a/test_rrattn_estimate.py b/test_rrattn_estimate.py new file mode 100644 index 0000000..3d6097a --- /dev/null +++ b/test_rrattn_estimate.py @@ -0,0 +1,541 @@ +import os +import math +import itertools +import pytest +import numpy as np +import time + +import paddle +import paddle.nn.functional as F + +from einops import rearrange, reduce, repeat +from paddle.nn.functional.flash_attention import flashmask_attention +from generate_startend_row_indices import ( + startend_row_indices_to_attn_bias, + generate_none_mask, + generate_sliding_window_mask, + generate_causal_document_mask, + generate_document_mask, + generate_share_question_mask, + generate_global_sliding_window_mask, + generate_causal_blockwise_mask, + generate_prefix_lm_document_mask, + generate_prefix_lm_causal_mask, + generate_qk_sparse_mask, + generate_random_eviction_mask +) +from functools import partial +from test_util import flashmask_to_densemask + +from paddlefleet._extensions.flashmask import ( + rr_attn_estimate_triton_op, + rr_attention, +) + +def run_ref_estimate_from_dense_mask( + query_states: paddle.Tensor, + key_states: paddle.Tensor, + mask_dense: paddle.Tensor, + *, + block_size: int = 128, + stride: int = 8, + causal: bool = False, + chunk_size: int = 512, + threshold: float = 1.0, +): + """ + 估算稀疏Attention Block分数与边界Mask + """ + assert mask_dense is not None + assert chunk_size % stride == 0, "chunk_size must be divisible by stride" + assert block_size % stride == 0, "block_size must be divisible by stride" + + # [B, Q, H, D] -> [B, H, Q, D] + query_states = query_states.transpose([0, 2, 1, 3]) + key_states = key_states.transpose([0, 2, 1, 3]) + + batch_size, num_q_head, q_len, head_dim = query_states.shape + _, num_kv_head, k_len, _ = key_states.shape + + # 1. 处理 GQA (Group Query Attention) + if num_q_head != num_kv_head: + assert num_q_head % num_kv_head == 0 + num_groups = num_q_head // num_kv_head + # 在 Head 维度 (dim=1) 复制 + # key_states: [B, H_kv, K, D] -> [B, H_q, K, D] + key_states = repeat(key_states, 'b h k d -> b (h g) k d', g=num_groups) + + nheads_dense_mask = mask_dense.shape[1] + if num_q_head != nheads_dense_mask: + assert num_q_head % nheads_dense_mask == 0 + num_groups = num_q_head // nheads_dense_mask + mask_dense = repeat(mask_dense, 'b h k d -> b (h g) k d', g=num_groups) + + # 2. Padding 计算与对齐 + # 统统 Pad 到 chunk_size 的倍数,这样处理最方便 + def get_pad_len(length, align): + return (align - length % align) % align + + q_pad_len = get_pad_len(q_len, chunk_size) + k_pad_len = get_pad_len(k_len, chunk_size) + + padded_q_len = q_len + q_pad_len + padded_k_len = k_len + k_pad_len + + mask_dense = mask_dense.astype(paddle.float32) + if q_pad_len > 0: + query_states = F.pad(query_states, (0, 0, 0, q_pad_len), value=0) + # Pad Mask Dense Height (Q dim) + if mask_dense.shape[2] == q_len: + mask_dense = F.pad(mask_dense, (0, 0, 0, q_pad_len), value=0) + + if k_pad_len > 0: + key_states = F.pad(key_states, (0, 0, 0, k_pad_len), value=0) + # Pad Mask Dense Width (K dim) + if mask_dense.shape[3] == k_len: + mask_dense = F.pad(mask_dense, (0, k_pad_len, 0, 0), value=0) + + # 3. 预计算 K 的 Padding Mask (用于 partial mask 判断) + # [1, 1, 1, KC, S] + num_k_strides = padded_k_len // stride + k_global_indices = paddle.arange(padded_k_len, device=key_states.place) + + # 这里的 reshape 必须 match 后续 logits 广播的维度 + k_is_non_padding = (k_global_indices < k_len) + k_is_non_padding = rearrange(k_is_non_padding, '(ks s) -> 1 1 1 ks s', s=stride) + + # 计算每个 Stride 内真正的有效长度 (分母) + stride_valid_length = k_is_non_padding.astype('float32').sum(axis=-1) # -> [1, 1, 1, ks] + + # 4. Sampling Query (Round-Robin) + # head_offsets: [1, H, 1, 1] + head_offsets = rearrange(paddle.arange(num_q_head, device=query_states.place) % stride, 'h -> 1 h 1 1') + + num_q_strides = padded_q_len // stride + # stride_starts: [1, 1, QC, 1] + stride_starts = rearrange(paddle.arange(num_q_strides, device=query_states.place) * stride, 'qs -> 1 1 qs 1') + + gather_indices = head_offsets + stride_starts # [1, H, QC, 1] + + # Expand to D for gather + gather_indices_expanded = repeat(gather_indices, '1 h qs 1 -> b h qs d', b=batch_size, d=head_dim) + sampled_query = paddle.take_along_axis(query_states, gather_indices_expanded, axis=2) + + # 5. Sampling Mask Dense + # Mask 通常是 [B, 1, Q, K] 或 [B, H, Q, K] + # 我们需要在 Q 维度采样,在 K 维度保持完整 + mask_h_dim = mask_dense.shape[1] + + # 构造 mask 的 gather index + mask_gather_idx = gather_indices # [1, H, QC, 1] + if mask_h_dim == 1: + # 如果 mask 是 shared head,取第一个 head 的 pattern 即可 + mask_gather_idx = mask_gather_idx[:, 0:1, :, :] # [1, 1, QC, 1] + + # Expand to K len + mask_gather_idx = repeat(mask_gather_idx, '1 h qs 1 -> b h qs k', b=batch_size, k=padded_k_len) + sampled_mask_dense = paddle.take_along_axis(mask_dense, mask_gather_idx, axis=2) + + # 6. Chunk 计算 + attn_sums_list = [] + boundary_masks_list = [] + + scale = 1.0 / math.sqrt(head_dim) / stride + q_chunk_size = chunk_size // stride + num_chunks = num_q_strides // q_chunk_size # 因为 pad 到了整倍数,可以直接整除 + + for i in range(num_chunks): + st = i * q_chunk_size + ed = (i + 1) * q_chunk_size + + q_chunk = sampled_query[:, :, st:ed, :] # [B, H, qc, D] + mask_chunk = sampled_mask_dense[:, :, st:ed, :] # [B, H/1, qc, K] + + # Dot Product + # [B, H, qc, D] @ [B, H, K, D]^T -> [B, H, qc, K] + logits = paddle.matmul(q_chunk, key_states, transpose_y=True) + logits = logits * scale + + # Reshape to Stride view + # [B, H, qc, (ks s)] -> [B, H, qc, ks, s] + logits = rearrange(logits, 'b h qc (ks s) -> b h qc ks s', s=stride) + mask_chunk = rearrange(mask_chunk, 'b h qc (ks s) -> b h qc ks s', s=stride) + + # Causal Logic + logical_mask = mask_chunk + + if causal: + # global_row: [1, H, qc, 1, 1] + q_idx_val = rearrange(paddle.arange(q_chunk_size, device=logits.place), 'qc -> 1 1 qc 1 1') + global_q_stride_idx = st + q_idx_val + h_idx_val = rearrange(paddle.arange(num_q_head, device=logits.place), 'h -> 1 h 1 1 1') + real_row = global_q_stride_idx * stride + (h_idx_val % stride) + + # global_col: [1, 1, 1, ks, s] + k_idx_val = rearrange(paddle.arange(num_k_strides, device=logits.place), 'ks -> 1 1 1 ks 1') + s_idx_val = rearrange(paddle.arange(stride, device=logits.place), 's -> 1 1 1 1 s') + real_col = k_idx_val * stride + s_idx_val + + shift = k_len - q_len + + # Causal Mask + is_causal = (real_row + shift >= real_col).astype(logits.dtype) + logical_mask = logical_mask * is_causal + + # --- 核心 Mask 融合逻辑 --- + + # final_effective_mask: [B, H, qc, ks, s] + # 1. 逻辑允许 (Dense & Causal) + # 2. 数据有效 (非 Padding) + final_effective_mask = logical_mask * k_is_non_padding.astype(logits.dtype) + + # 应用 Mask (Zero out masked logits for sum reduction) + logits = logits * final_effective_mask + + # 统计 + passed_counts = reduce(final_effective_mask, 'b h qc ks s -> b h qc ks', 'sum') + total_valid_counts = stride_valid_length # Broadcasts automatically + + # 判断 Mask 类型 + is_fully_masked = (passed_counts == 0) + is_partially_masked = (passed_counts > 0) & (passed_counts < total_valid_counts) + + # Reduce Stride -> Logits Sum (Mean estimate) + logits_stride = reduce(logits, 'b h qc ks s -> b h qc ks', 'sum') + + # Fully masked 设为 -inf + if is_fully_masked.any(): + neg_inf = paddle.to_tensor(float('-inf'), dtype=logits_stride.dtype, place=logits_stride.place) + logits_stride = paddle.where(is_fully_masked, neg_inf, logits_stride) + + # Softmax + scores_stride = F.softmax(logits_stride, axis=-1) + # 简单的 NaN 处理 + scores_stride = paddle.nan_to_num(scores_stride, 0.0).astype(query_states.dtype) + + # Block Aggregation + ratio = block_size // stride + + # Sum Reduce for Scores + attn_sum_chunk = reduce( + scores_stride, + 'b h (qb r1) (kb r2) -> b h qb kb', + 'sum', + r1=ratio, r2=ratio + ) + attn_sums_list.append(attn_sum_chunk) + + # Max Reduce for Boundary Mask + boundary_stride = is_partially_masked.astype('float32') + boundary_mask_chunk = reduce( + boundary_stride, + 'b h (qb r1) (kb r2) -> b h qb kb', + 'max', + r1=ratio, r2=ratio + ) + boundary_masks_list.append(boundary_mask_chunk.astype('bool')) + + # 7. 合并与切片 + final_attn_sums = paddle.concat(attn_sums_list, axis=2) + final_boundary_mask = paddle.concat(boundary_masks_list, axis=2) + + # 原始需要的 Block 数量 + valid_q_blocks = (q_len + block_size - 1) // block_size + valid_k_blocks = (k_len + block_size - 1) // block_size + + return ( + final_attn_sums[:, :, :valid_q_blocks, :valid_k_blocks], + final_boundary_mask[:, :, :valid_q_blocks, :valid_k_blocks] + ) + +def find_blocks_chunked( + input_tensor, threshold, +): + """ + Finds and selects relevant blocks of attention for transformer-based models based on a + threshold or a predefined number of blocks. + + Parameters: + - input_tensor (paddle.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num). + - threshold (float or None): A threshold value used to determine the minimum attention weight sum. + + Returns: + - paddle.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num), + indicating which blocks should be attended to. + """ + assert threshold is not None + + x = input_tensor.astype("float32") + B, H, C, N = x.shape + total_sum = x.sum(axis=-1, keepdim=True) + cutoff = total_sum * float(threshold) + + sorted_values, sorted_idx = paddle.compat.sort(x, dim=-1, descending=True) # both [B,H,C,N] + + prefix = paddle.cumsum(sorted_values, axis=-1) # [B,H,C,N] + keep = (prefix - sorted_values) < cutoff # [B,H,C,N], bool + + mask0 = paddle.zeros_like(x, dtype="int32") + mask_int = paddle.put_along_axis(mask0, sorted_idx, keep.astype("int32"), axis=-1) + mask = mask_int.astype("bool") + + mask = paddle.logical_and(mask, total_sum > 0) + return mask + +def verify_topp( + out_kernel, + topp_mask_kernel, + top_p_value, +): + out_tensor = out_kernel.astype('float32') + + mask_py = find_blocks_chunked(out_tensor, threshold=top_p_value) + + mask_ker_bool = topp_mask_kernel.astype('bool') + values_kernel = paddle.masked_select(out_tensor, mask_ker_bool) + + mask_py_bool = mask_py.astype('bool') + values_py = paddle.masked_select(out_tensor, mask_py_bool) + + count_ker = values_kernel.shape[0] + count_py = values_py.shape[0] + print(f"Selected Block Count - Kernel: {count_ker}, Python: {count_py}") + + if count_ker != count_py: + print(f"Warning: Selection count mismatch! Diff: {abs(count_ker - count_py)}") + assert False + + sum_ker = paddle.sum(values_kernel).item() + sum_py = paddle.sum(values_py).item() + sum_diff = abs(sum_ker - sum_py) + + print(f"Selected Mass Sum - Kernel: {sum_ker:.6f}, Python: {sum_py:.6f}") + print(f"Mass Diff: {sum_diff:.8f}") + + if count_ker == count_py: + if sum_ker == 0: + return + val_k_sorted = paddle.sort(values_kernel, descending=True) + val_p_sorted = paddle.sort(values_py, descending=True) + + max_val_diff = paddle.max(paddle.abs(val_k_sorted - val_p_sorted)).item() + print(f"Max Diff in Sorted Values: {max_val_diff:.8f}") + + assert max_val_diff < 1e-3, f"Values mismatch! Max diff: {max_val_diff}" + print("Value sets match perfectly (Sorted check passed).") + + else: + print("Counts differ, skipping sorted element-wise check.") + assert sum_diff < 1e-2, f"Mass diff too high: {sum_diff}" + + +shape_cases = ( + [ + (28, 128, 128, 16, 4), + (4, 256, 256, 4, 1), + # (2, 8192, 32768, 32, 4), # this will oom + # (2, 8192, 8192, 32, 4), # this will oom + (1, 8192, 8192, 1, 1), + (2, 16384, 16384, 1, 1), + (1, 128, 128, 1, 1), + (1, 127, 128, 1, 1), + (1, 16384, 16384, 1, 1), + (2, 16384, 16383, 4, 1), + # my case + ] + # tridao case + + list(itertools.product( + [1], # batch_size + [1, 64, 128, 256, 239, 799, 113, 113, 128, 113, 108, 256, 384, 640, 512, 1024, 1023, 1024,], # seqlen_q + [128, 192, 256, 203, 128, 217, 211, 256, 512, 256, 128, 256, 1024, 1024, 1023,], # seqlen_k + [1,2], # nheads + [1], # nheads_kv + )) + + list(itertools.product( + [2], # batch_size + [4096, 4224], # seqlen_q + [4096, 4224], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) +) +# Generate all combinations for second param +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + nheads_startend_row_indices_values = [1, nheads_kv] + for nheads_startend_row_indices in nheads_startend_row_indices_values: + yield ( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices + ) + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("stride", [8]) +@pytest.mark.parametrize("dim", [128]) +@pytest.mark.parametrize("threshold", [0.3, 0.8]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()) +) +@pytest.mark.parametrize( + "gen_startend_row_indices", + [ + # partial(generate_none_mask, causal=False), # full + # partial(generate_none_mask, causal=True), # causal + partial(generate_sliding_window_mask), # sliding window + partial(generate_causal_document_mask), # causal document mask + partial(generate_document_mask), # document mask + partial(generate_share_question_mask), # share question mask + partial(generate_global_sliding_window_mask), # global sliding window + partial(generate_causal_blockwise_mask), # causal blockwise mask + partial(generate_prefix_lm_document_mask), # prefix lm document mask + partial(generate_prefix_lm_causal_mask), # prefix lm causal mask + partial(generate_qk_sparse_mask), # qk-sparse mask + partial(generate_random_eviction_mask), # random eviction mask + ], +) +def test_rrattn_estimate( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices, + dtype, gen_startend_row_indices, stride, dim, threshold, +): + paddle.seed(2024) + np.random.seed(2024) + assert nheads % nheads_kv == 0 + + q_ref_t = paddle.randn(shape=[batch_size, seqlen_q, nheads, dim], dtype='float32') + k_ref_t = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dim], dtype='float32') + + q_naive_t = q_ref_t.astype(dtype) + k_naive_t = k_ref_t.astype(dtype) + + q_kernel_t = q_naive_t.detach().clone() + k_kernel_t = k_naive_t.detach().clone() + + startend_row_indices, causal = gen_startend_row_indices( + batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices + ) + + # rr_attention(q_kernel_t, k_kernel_t, k_kernel_t, startend_row_indices, causal=causal, stride=stride, threshold=threshold) + + mask_dense = flashmask_to_densemask( + startend_row_indices, seqlen_q, nheads_startend_row_indices, causal + ) + + print(f"Testing Config: B={batch_size}, Q={seqlen_q}, K={seqlen_k}, HQ={nheads}, H={nheads_kv}, Stride={stride}, Causal={causal}") + + out_ref, bound_ref = run_ref_estimate_from_dense_mask( + q_ref_t, + k_ref_t, + mask_dense, + block_size=128, + stride=stride, + causal=causal, + chunk_size=2048, + ) + + out_naive, _ = run_ref_estimate_from_dense_mask( + q_naive_t, + k_naive_t, + mask_dense, + block_size=128, + stride=stride, + causal=causal, + chunk_size=2048, + ) + + with paddle.compat.use_torch_proxy_guard(): + out_kernel, bound_kernel, topp_kernel = rr_attn_estimate_triton_op.rr_attn_estimate_triton_func( + q=q_kernel_t, + k=k_kernel_t, + startend_row_indices=startend_row_indices, + stride=stride, + threshold=threshold, + causal=causal + ) + + out_ref = out_ref.astype('float32') + bound_ref = bound_ref.astype('int32') + + out_naive = out_naive.astype('float32') + + out_kernel = out_kernel.astype('float32') + bound_kernel = bound_kernel.astype('int32') + + # ----------------------------------------------------------- + # Test 1: Boundary Check Mask (Exact Match) + # ----------------------------------------------------------- + print("\n--- Testing Boundary Mask ---") + assert bound_ref.shape == bound_kernel.shape, \ + f"Shape Mismatch! Ref: {bound_ref.shape}, Kernel: {bound_kernel.shape}" + + mask_diff_tensor = paddle.sum(paddle.abs(bound_ref - bound_kernel)) + mask_diff = mask_diff_tensor.item() + total_elements = bound_ref.size + + print(f"Boundary Mask Mismatches: {mask_diff} / {total_elements} ({(mask_diff/total_elements)*100:.4f}%)") + + if mask_diff > 0: + mismatch_indices = paddle.nonzero(bound_ref != bound_kernel) + + # 取前5个错误点 + top_indices = mismatch_indices[:5] + print("First 5 mismatches (Indices):", top_indices.tolist()) + + ref_vals = paddle.gather_nd(bound_ref, top_indices) + kernel_vals = paddle.gather_nd(bound_kernel, top_indices) + + print("Ref Values:", ref_vals.tolist()) + print("Kernel Values:", kernel_vals.tolist()) + + assert mask_diff == 0, "[FAIL] Boundary masks do not match exactly!" + print("✅ Boundary Mask Matched Exactly.") + + # ----------------------------------------------------------- + # Test 2: Attention Score Estimation (Dynamic Tolerance) + # ----------------------------------------------------------- + print("\n--- Testing Attention Scores ---") + + fwd_atol = 2 * paddle.max(paddle.abs(out_ref + 0.3 - 0.3 - out_ref)).item() + rtol = 2 + + # Baseline Error + naive_diff = paddle.abs(out_naive - out_ref) + naive_err = paddle.max(naive_diff).item() + print(f"Naive float32 Output max diff (vs FP32): {naive_err:.6f}") + + # Kernel Error + kernel_diff = paddle.abs(out_kernel - out_ref) + kernel_err = paddle.max(kernel_diff).item() + kernel_mean_err = paddle.mean(kernel_diff).item() + + print(f"Kernel Output max diff (vs FP32): {kernel_err:.6f}") + print(f"Kernel Output mean diff (vs FP32): {kernel_mean_err:.6f}") + + allowed_error = rtol * naive_err + fwd_atol + 1e-4 + + if kernel_err > allowed_error: + print(f"[FAIL] Score error exceeds tolerance!") + print(f"Max Diff: {kernel_err}") + print(f"Allowed: {allowed_error}") + + flat_idx = paddle.argmax(kernel_diff).item() + err_indices = [] + for dim in reversed(out_ref.shape): + err_indices.append(flat_idx % dim) + flat_idx //= dim + err_indices = tuple(reversed(err_indices)) + + ref_val = out_ref[err_indices].item() + kernel_val = out_kernel[err_indices].item() + print(f"Max Error at {err_indices}: Ref={ref_val}, Kernel={kernel_val}") + + assert kernel_err <= allowed_error, \ + f"Output max diff {kernel_err} > Allowed {allowed_error}" + + # ----------------------------------------------------------- + # Test 3: Top-p block selection + # ----------------------------------------------------------- + + verify_topp(out_kernel, topp_kernel, threshold) + + print("Attention Score Matched within tolerance.") + print("All Tests Passed!") \ No newline at end of file