diff --git a/run_varlen.sh b/run_varlen.sh new file mode 100644 index 0000000..b795d76 --- /dev/null +++ b/run_varlen.sh @@ -0,0 +1,6 @@ + +# 保存改之前的结果 +/path/to/old/paddle/env/python test_varlen.py --mode save + +# 验证改之后的结果 +/path/to/new/paddle/env/python test_varlen.py --mode verify \ No newline at end of file diff --git a/test_flashmask_bwd_determ.py b/test_flashmask_bwd_determ.py new file mode 100644 index 0000000..bc1e223 --- /dev/null +++ b/test_flashmask_bwd_determ.py @@ -0,0 +1,239 @@ +import os +import math +import itertools +import pytest +from einops import rearrange, repeat +import paddle +from paddle.nn.functional.flash_attention import flashmask_attention +from generate_startend_row_indices import ( + startend_row_indices_to_attn_bias, + generate_none_mask, + generate_sliding_window_mask, + generate_causal_document_mask, + generate_document_mask, + generate_share_question_mask, + generate_global_sliding_window_mask, + generate_causal_blockwise_mask, + generate_prefix_lm_document_mask, + generate_prefix_lm_causal_mask, + generate_qk_sparse_mask, + generate_random_eviction_mask +) +from functools import partial +from test_util import attention_ref + +# 开启确定性 +paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + +# batch_size, seqlen_q, seqlen_k, nheads, nheads_kv +shape_cases = ( + [ + (2840, 32, 32, 16, 4), + (1, 300, 300, 16, 16), + # (2, 8192, 32768, 32, 4), # this will oom + # (2, 8192, 8192, 32, 4), # this will oom + (2, 8192, 8192, 14, 1), + (2, 16384, 16384, 4, 1), + (1, 128, 127, 1, 1), + (1, 127, 128, 1, 1), + (2, 16383, 16384, 4, 1), + (2, 16384, 16383, 4, 1), + # my case + ] + # tridao case + + list(itertools.product( + [9], # batch_size + [1, 64, 128, 256, 239, 799, 113, 113, 128, 113, 108, 256, 384, 640, 512, 1024, 1023, 1024,], # seqlen_q + [128, 192, 256, 203, 128, 217, 211, 256, 512, 256, 128, 256, 1024, 1024, 1023,], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) + + list(itertools.product( + [2], # batch_size + [4096, 4224], # seqlen_q + [4096, 4224], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) +) + +# Generate all combinations for second param +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + if nheads_kv == 1: + nheads_startend_row_indices_values = [1] + else: + nheads_startend_row_indices_values = [1, nheads_kv] + for nheads_startend_row_indices in nheads_startend_row_indices_values: + yield ( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices + ) + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("fa_version", [3]) +@pytest.mark.parametrize("d, dv", [(128, 128), (80, 80), (64, 64)]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()) +) +@pytest.mark.parametrize( + "gen_startend_row_indices", + [ + partial(generate_none_mask, causal=False), # full + partial(generate_none_mask, causal=True), # causal + partial(generate_sliding_window_mask), # sliding window + partial(generate_causal_document_mask), # causal document mask + partial(generate_document_mask), # document mask + partial(generate_share_question_mask), # share question mask + partial(generate_global_sliding_window_mask), # global sliding window + partial(generate_causal_blockwise_mask), # causal blockwise mask + partial(generate_prefix_lm_document_mask), # prefix lm document mask + partial(generate_prefix_lm_causal_mask), # prefix lm causal mask + partial(generate_qk_sparse_mask), # qk-sparse mask + partial(generate_random_eviction_mask), # random eviction mask + ], +) +def test_flashmask( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0 +): + paddle.seed(2024) + assert nheads % nheads_kv == 0 + q_ref = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype) + k_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype) + v_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype) + + q_ref.stop_gradient = False + k_ref.stop_gradient = False + v_ref.stop_gradient = False + + q_bf16, k_bf16, v_bf16 = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q_bf16.stop_gradient = False + k_bf16.stop_gradient = False + v_bf16.stop_gradient = False + + q, k, v = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + + startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices) + + if startend_row_indices is None and causal and d == 80: + pytest.skip(f"Skipping because running headdim 80 with flash_attn in causal mask") + + attn_bias = startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal) + + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + causal=causal, + attn_bias=attn_bias + ) + + out_bf16, attn_bf16 = attention_ref( + q_bf16, + k_bf16, + v_bf16, + causal=causal, + attn_bias=attn_bias, + upcast=False, + reorder_ops=True + ) + + # # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert softcap == 0.0 + rtol = 2 if softcap == 0.0 else 3 + + print(f"Paddle naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}") + print(f"Paddle naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}") + + if fa_version == 2: + paddle.set_flags({'FLAGS_flash_attn_version': 2}) + elif fa_version == 3: + paddle.set_flags({'FLAGS_flash_attn_version': 3}) + else: + raise ValueError( + f"Invalid flash attention version: {fa_version}" + ) + + out, lse = flashmask_attention( + q, + k, + v, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True + ) + print(f"flashmask Output max diff: {(out - out_ref).abs().max().item()}") + print(f"flashmask Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + + assert (out - out_ref).abs().max().item() <= rtol * (out_bf16 - out_ref).abs().max().item() + fwd_atol + + g = paddle.randn(shape=out.shape, dtype=out.dtype) + out.backward(g, retain_graph=True) + out_ref.backward(g) + out_bf16.backward(g) + + print(f"flashmask dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"flashmask dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"flashmask dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"flashmask dQ mean diff: {(q.grad - q_ref.grad).abs().mean().item()}") + print(f"flashmask dK mean diff: {(k.grad - k_ref.grad).abs().mean().item()}") + print(f"flashmask dV mean diff: {(v.grad - v_ref.grad).abs().mean().item()}") + + print(f"Paddle naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}") + print(f"Paddle naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}") + print(f"Paddle naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}") + + dq_atol = 2 * (q_ref.grad + 0.3 - 0.3 - q_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (q.grad - q_ref.grad).abs().max().item() <= rtol * (q_bf16.grad - q_ref.grad).abs().max().item() + dq_atol + dk_atol = 2 * (k_ref.grad + 0.3 - 0.3 - k_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (k.grad - k_ref.grad).abs().max().item() <= rtol * (k_bf16.grad - k_ref.grad).abs().max().item() + dk_atol + dv_atol = 2 * (v_ref.grad + 0.3 - 0.3 - v_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (v.grad - v_ref.grad).abs().max().item() <= rtol * (v_bf16.grad - v_ref.grad).abs().max().item() + dv_atol + + # --- 3. 反向传播确定性验证--- + # if is_deterministic: + print("\n--- Verifying Backward Pass Determinism ---") + + # 保存第一次运行的梯度作为基准 + dq_first = q.grad.detach().clone() + dk_first = k.grad.detach().clone() + dv_first = v.grad.detach().clone() + + # 在下一次反向传播前清空梯度 + q.clear_grad() + k.clear_grad() + v.clear_grad() + + # 对新的计算图进行反向传播 + out.backward(g) + + dq2 = q.grad + dk2 = k.grad + dv2 = v.grad + + # 打印与第一次运行结果的差异 + print(f'dQ max diff with first run: {(dq2 - dq_first).abs().max().item()}') + print(f'dK max diff with first run: {(dk2 - dk_first).abs().max().item()}') + print(f'dV max diff with first run: {(dv2 - dv_first).abs().max().item()}') + + # 断言:本次梯度必须与基准梯度逐位完全相等 + assert dq2._md5sum() == dq_first._md5sum(), f"dQ not deterministic" + assert dk2._md5sum() == dk_first._md5sum(), f"dK not deterministic" + assert dv2._md5sum() == dv_first._md5sum(), f"dV not deterministic" + + print(f"✅ Deterministic test passed!") diff --git a/test_varlen.py b/test_varlen.py new file mode 100644 index 0000000..c637b31 --- /dev/null +++ b/test_varlen.py @@ -0,0 +1,309 @@ +import numpy as np +import paddle +import itertools +import random +import os +import argparse +import glob + +# 强制使用 GPU +paddle.set_device('gpu') + +# 设置 Flash Attention 版本标记 +paddle.set_flags({'FLAGS_flash_attn_version': 2}) +# 开启确定性 +paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + +# 数据保存目录 +DATA_DIR = "./paddle_pure_comparison_data" + +def tonp(x): + """将 Paddle Tensor 转换为 Numpy array,特殊处理 bfloat16""" + if isinstance(x, paddle.Tensor): + if x.dtype == paddle.bfloat16: + # bfloat16 在 numpy 中没有原生类型,转存为 uint16 以保持二进制一致性 + return x.view('uint16').numpy() + elif x.dtype in [paddle.float32, paddle.float16, paddle.int32, paddle.int64]: + return x.numpy() + else: + assert False, f'Unsupported dtype for saving: {x.dtype}' + elif isinstance(x, np.ndarray): + return x + else: + assert False, f'wrong type: {type(x)}' + +def from_numpy(x_np, dtype_str, place): + """从 Numpy array 恢复 Paddle Tensor,特殊处理 bfloat16""" + tensor = paddle.to_tensor(x_np, place=place) + + if dtype_str == 'paddle.bfloat16': + # 从 uint16 视图恢复为 bfloat16 + return tensor.view(paddle.bfloat16) + elif dtype_str == 'paddle.float16': + return tensor.cast(paddle.float16) + elif dtype_str == 'paddle.float32': + return tensor.cast(paddle.float32) + elif dtype_str == 'paddle.int32': + return tensor.cast(paddle.int32) + else: + # 尝试直接转换 + return tensor + +def cmp(x_actual, x_ref_np, msg, array_equal=True, atol=0, rtol=0): + """对比函数:x_actual 是 Paddle Tensor,x_ref_np 是加载的 Numpy 数据""" + x = tonp(x_actual) # 如果是 bf16,这里会变成 uint16 numpy array + + if array_equal: + diff = np.abs(x - x_ref_np) + # 设定一个阈值,查看超过阈值的具体值 + bad_mask = diff > atol + rtol * np.abs(x_ref_np) + if np.any(bad_mask): + print(f"--- Debug Fail {msg} ---") + print(f"Max Diff: {np.max(diff)}") + indices = np.where(bad_mask) + # 打印前 5 个错误点 + for i in range(min(5, len(indices[0]))): + idx = tuple(ind[i] for ind in indices) + print(f"Index {idx}: Act={x[idx]}, Ref={x_ref_np[idx]}, Diff={diff[idx]}") + + np.testing.assert_array_equal(x, x_ref_np, err_msg=f'{msg} mismatch', strict=True) + else: + # 数值对比模式 + # 准备实际值 + if x_actual.dtype == paddle.bfloat16: + val_act = x_actual.cast(paddle.float32).numpy() + else: + val_act = x_actual.numpy() + + # 准备参考值 (处理 ref 是 uint16 (bf16) 的情况) + if x_ref_np.dtype == np.uint16 and x_actual.dtype == paddle.bfloat16: + # 将 ref 的 uint16 还原回 paddle bf16 再转 float32 供 numpy 对比 + tmp_tensor = paddle.to_tensor(x_ref_np).view(paddle.bfloat16) + val_ref = tmp_tensor.cast(paddle.float32).numpy() + else: + val_ref = x_ref_np + + np.testing.assert_allclose(actual=val_act, desired=val_ref, rtol=rtol, atol=atol, equal_nan=False, err_msg=msg) + +def random_cu_seqlens_paddle(total_tokens, batch_size): + """完全使用 Paddle 生成 cu_seqlens""" + if batch_size == 1: + cu_seqlens = paddle.to_tensor([0, total_tokens], dtype='int32') + return cu_seqlens.cuda() + + # 生成切分点 + # randperm 返回 0 到 n-1,我们需要 1 到 total_tokens-1 之间的切点 + random_points = paddle.randperm(total_tokens - 1)[:batch_size - 1] + 1 + random_points = paddle.sort(random_points) + + # 拼接 [0, ..., total] + zeros = paddle.to_tensor([0], dtype='int64') + total = paddle.to_tensor([total_tokens], dtype='int64') + + cu_seqlens = paddle.concat([zeros, random_points, total]) + return cu_seqlens.cast('int32').cuda() + +# ========================================== +# Mode 1: 生成数据并保存 (基准环境运行) +# ========================================== +def run_save(case_name, batch_size, total_q, total_k, nheads, nheads_k, headdim, headdim_v, softmax_scale, causal, dtype): + print(f'SAVE: {case_name}, {batch_size=}, {total_q=}, {total_k=}, {nheads=}, {nheads_k=}, {headdim=}, {headdim_v=} {softmax_scale=} {causal=} {dtype=}') + + # 1. 使用 Paddle 生成输入 + q = paddle.randn([total_q, nheads, headdim], dtype=dtype) + k = paddle.randn([total_k, nheads_k, headdim], dtype=dtype) + v = paddle.randn([total_k, nheads_k, headdim_v], dtype=dtype) + + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + + cu_seqlens_q = random_cu_seqlens_paddle(total_q, batch_size) + cu_seqlens_k = random_cu_seqlens_paddle(total_k, batch_size) + + max_seqlen_q = paddle.max(cu_seqlens_q[1:] - cu_seqlens_q[:-1]).item() + max_seqlen_k = paddle.max(cu_seqlens_k[1:] - cu_seqlens_k[:-1]).item() + + # 2. 运行 Forward + out, softmax_lse = paddle.nn.functional.flash_attention.flash_attn_unpadded( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + scale=softmax_scale, dropout=0.0, causal=causal + ) + + # 3. 运行 Backward + # 生成随机梯度 + out_grad = paddle.randn_like(out) + + # 只有当维度满足条件时才运行 backward + # Paddle FA3 支持 backward 的条件:headdim==headdim_v 或者 (headdim!=headdim_v 且 非GQA/MQA) + # 这里保持逻辑简单,尽可能跑 backward + if headdim == headdim_v or (nheads == nheads_k): + out.backward(out_grad) + + # 4. 保存所有数据到 npz + save_path = os.path.join(DATA_DIR, f"{case_name}.npz") + + data_dict = { + "config": np.array([max_seqlen_q, max_seqlen_k], dtype=np.int32), + "params": np.array([softmax_scale if softmax_scale is not None else -1.0], dtype=np.float32), + "meta": np.array([1 if causal else 0, 1 if softmax_scale is None else 0], dtype=np.int32), + "dtype_str": str(q.dtype), + + # 输入 (Tensor -> Numpy) + "q": tonp(q), + "k": tonp(k), + "v": tonp(v), + "cu_seqlens_q": tonp(cu_seqlens_q), + "cu_seqlens_k": tonp(cu_seqlens_k), + "out_grad": tonp(out_grad), + + # 期望输出 (Reference) + "ref_out": tonp(out), + } + + if q.grad is not None: data_dict["ref_dq"] = tonp(q.grad) + if k.grad is not None: data_dict["ref_dk"] = tonp(k.grad) + if v.grad is not None: data_dict["ref_dv"] = tonp(v.grad) + + np.savez(save_path, **data_dict) + +# ========================================== +# Mode 2: 加载数据并验证 (测试环境运行) +# ========================================== +def run_verify(file_path): + print(f'\nVERIFYING: {os.path.basename(file_path)}') + try: + data = np.load(file_path) + except Exception as e: + print(f"Error loading {file_path}: {e}") + return + + # 1. 恢复配置参数 + max_seqlen_q = int(data["config"][0]) + max_seqlen_k = int(data["config"][1]) + softmax_scale = float(data["params"][0]) + is_causal = bool(data["meta"][0]) + is_scale_none = bool(data["meta"][1]) + if is_scale_none: softmax_scale = None + dtype_str = str(data["dtype_str"]) + + # 2. 恢复 Tensor + place = paddle.CUDAPlace(0) + q = from_numpy(data["q"], dtype_str, place) + k = from_numpy(data["k"], dtype_str, place) + v = from_numpy(data["v"], dtype_str, place) + cu_seqlens_q = from_numpy(data["cu_seqlens_q"], 'paddle.int32', place) + cu_seqlens_k = from_numpy(data["cu_seqlens_k"], 'paddle.int32', place) + out_grad = from_numpy(data["out_grad"], dtype_str, place) + + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + + # 获取维度信息用于逻辑判断 + headdim = q.shape[2] + headdim_v = v.shape[2] + nheads = q.shape[1] + nheads_k = k.shape[1] + + # 3. 运行当前环境的 Forward + out, softmax_lse = paddle.nn.functional.flash_attention.flash_attn_unpadded( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + scale=softmax_scale, dropout=0.0, causal=is_causal + ) + + # 4. 对比 Forward + cmp(out, data["ref_out"], 'out', array_equal=True) + print('>> fwd out pass') + + # 5. 运行当前环境的 Backward + if headdim == headdim_v or (nheads == nheads_k): + out.backward(out_grad) + + if headdim != headdim_v and nheads == nheads_k: + paddle.device.synchronize() + assert q.grad is not None, "q.grad is None in MLA case" + assert k.grad is not None, "k.grad is None in MLA case" + assert v.grad is not None, "v.grad is None in MLA case" + print(">> mla paddle can run bwd") + + # 6. 对比 Backward + if "ref_dq" in data: + try: + cmp(q.grad, data["ref_dq"], 'q.grad', array_equal=True) + print('>> dq pass') + except AssertionError as e: + print(f"!! dq mismatch: {e}") + + if "ref_dk" in data: + try: + cmp(k.grad, data["ref_dk"], 'k.grad', array_equal=True) + print('>> dk pass') + except AssertionError as e: + print(f"!! dk mismatch: {e}") + + if "ref_dv" in data: + try: + cmp(v.grad, data["ref_dv"], 'v.grad', array_equal=True) + print('>> dv pass') + except AssertionError as e: + print(f"!! dv mismatch: {e}") + +def main_gen_loops(): + """生成测试用例的循环逻辑""" + counter = 0 + dtype_options = [paddle.bfloat16, paddle.float16] + causal_options = [True, False] + + # ================= Case 7 ================= + print('\nGenerating Case 7 (Random shapes, headdim=headdim_v)...') + for causal, dtype in itertools.product(causal_options, dtype_options): + # 跑几次随机 + for _ in range(2): + headdim = random.randrange(1, 33) * 8 + headdim_v = headdim + total_q = random.randrange(100, 2048) + total_k = random.randrange(100, 2048) + batch_size = random.randint(1, min(total_q, total_k, 32)) # 限制 batch_size 避免过大 + + nheads_k = random.randrange(1, 17) + group = random.randrange(1, 9) + nheads = group * nheads_k + + softmax_scale = random.uniform(0.000001, 1.0) + + run_save(f"case7_{counter}", batch_size, total_q, total_k, nheads, nheads_k, headdim, headdim_v, softmax_scale, causal, dtype) + counter += 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--mode", type=str, required=True, choices=["save", "verify"], + help="'save': 在旧环境中生成并保存基准数据; 'verify': 在新环境中加载数据并对比") + args = parser.parse_args() + + if args.mode == "save": + if os.path.exists(DATA_DIR): + print(f"Cleaning old data in {DATA_DIR}...") + import shutil + shutil.rmtree(DATA_DIR) + os.makedirs(DATA_DIR) + + print(f"Running generation mode using Paddle {paddle.__version__}") + main_gen_loops() + print(f"Done. Data saved to {DATA_DIR}") + + elif args.mode == "verify": + if not os.path.exists(DATA_DIR): + print(f"Error: Directory {DATA_DIR} does not exist. Run --mode save first.") + exit(1) + + print(f"Running verification mode using Paddle {paddle.__version__}") + files = sorted(glob.glob(os.path.join(DATA_DIR, "*.npz"))) + if not files: + print("No data files found.") + exit(1) + + for f in files: + run_verify(f) + print("Done verification.") \ No newline at end of file