From 71bd17f03de2a8e6465ef0be5bc8a2eb3f92967e Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 30 Apr 2024 15:54:37 -0500 Subject: [PATCH] Add new triton kernel debug_fill_dropout_rng (#19) * tritonsrc/dropout_rng: Add new triton kernel debug_fill_dropout_rng This is for _fill_mem_eff_dropout_mask_ --- bindings/module.cc | 7 + include/aotriton/flash.h | 6 + test/aotriton_flash.py | 13 +- test/attn_torch_function.py | 8 +- test/test_forward.py | 367 ++++++++++-------- tritonsrc/attn_torch_function.py | 27 +- tritonsrc/dropout_rng.py | 37 ++ tritonsrc/flash.py | 1 + tritonsrc/fwd_kernel.py | 8 +- tritonsrc/test_forward.py | 354 +++++++++-------- v2python/rules/flash/__init__.py | 2 + .../rules/flash/debug_fill_dropout_rng.py | 40 ++ v2src/flash/attn_debug.cc | 47 +++ 13 files changed, 588 insertions(+), 329 deletions(-) create mode 100644 tritonsrc/dropout_rng.py create mode 100644 v2python/rules/flash/debug_fill_dropout_rng.py create mode 100644 v2src/flash/attn_debug.cc diff --git a/bindings/module.cc b/bindings/module.cc index 35d75d2..f15f212 100644 --- a/bindings/module.cc +++ b/bindings/module.cc @@ -53,6 +53,13 @@ namespace pyaotriton { py::arg("philox_offset"), py::arg("is_causal"), py::arg("stream") = nullptr); + m.def("debug_fill_dropout_rng", + &aotriton::v2::flash::debug_fill_dropout_rng, + "Flash Attention Debugging Function to get raw RNG numbers used in dropout", + py::arg("q"), + py::arg("philox_seed"), + py::arg("philox_offset"), + py::arg("stream") = nullptr); } } // namespace flash diff --git a/include/aotriton/flash.h b/include/aotriton/flash.h index 1da5569..5b39d19 100644 --- a/include/aotriton/flash.h +++ b/include/aotriton/flash.h @@ -51,6 +51,12 @@ attn_bwd(T4 q, // batch_size x num_heads x seqlen_q x head_size bool is_causal, aotriton::Stream stream); +hipError_t +debug_fill_dropout_rng(T4 r, + uint64_t philox_seed, + uint64_t philox_offset, + aotriton::Stream stream); + } // aotriton::v2::flash #endif diff --git a/test/aotriton_flash.py b/test/aotriton_flash.py index ab9f85c..cea6b0b 100644 --- a/test/aotriton_flash.py +++ b/test/aotriton_flash.py @@ -1,7 +1,11 @@ # Copyright © 2023-2024 Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT -from pyaotriton.v2.flash import attn_fwd as fa_forward, attn_bwd as fa_backward +from pyaotriton.v2.flash import ( + attn_fwd as fa_forward, + attn_bwd as fa_backward, + debug_fill_dropout_rng as fa_debug_fill_dropout_rng, +) from pyaotriton import T1, T2, T4, DType, Stream def cast_dtype(dtype): @@ -68,3 +72,10 @@ def attn_bwd(q, k, v, b, sm_scale, o, dout, dq, dk, dv, db, L, delta, is_causal, Stream()) print(f'{err=}') + +def debug_fill_dropout_rng(R, philox_seed, philox_offset): + err = fa_debug_fill_dropout_rng(mk_aotensor(R), + philox_seed, + philox_offset, + Stream()) + print(f'{err=}') diff --git a/test/attn_torch_function.py b/test/attn_torch_function.py index f55278c..638175a 100644 --- a/test/attn_torch_function.py +++ b/test/attn_torch_function.py @@ -3,9 +3,11 @@ # SPDX-License-Identifier: MIT import torch -from aotriton_flash import attn_fwd, attn_bwd +from aotriton_flash import attn_fwd, attn_bwd, debug_fill_dropout_rng VERBOSE=False +DEFAULT_PHILOX_SEED = 0x1BF52 +DEFAULT_PHILOX_OFFSET = 0x1D4B42 def is_power_of_two(n: int) -> bool: return (n & (n - 1) == 0) and n != 0 @@ -53,8 +55,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax if encoded_softmax is not None: print(f'{encoded_softmax.shape=} {encoded_softmax.dtype=}') - philox_seed = 114514 - philox_offset = 1919810 + philox_seed = DEFAULT_PHILOX_SEED + philox_offset = DEFAULT_PHILOX_OFFSET attn_fwd(q, k, v, b, sm_scale, M, o, dropout_p, philox_seed, philox_offset, encoded_softmax, causal); diff --git a/test/test_forward.py b/test/test_forward.py index 40c2fe9..451779a 100644 --- a/test/test_forward.py +++ b/test/test_forward.py @@ -5,7 +5,7 @@ import pytest import torch -from attn_torch_function import attention +from attn_torch_function import attention, debug_fill_dropout_rng, DEFAULT_PHILOX_SEED, DEFAULT_PHILOX_OFFSET def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: # Efficient implementation equivalent to the following: @@ -74,167 +74,194 @@ def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch. but in PyTorch API it does not present at all ''' -def _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type): - if causal and seqlen_q != seqlen_k: - pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") - torch.manual_seed(20) - print(f"test_op_fwd {BATCH=}, {N_HEADS=}, {seqlen_q=}, {seqlen_k=}, {D_HEAD=}, {causal=}") - SPARSE_HEAD_SINCE = 3 - SPARSE_SEQ_SINCE = 3 - Z = BATCH - H = N_HEADS - if True: # Real UT - qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) - kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) - if storage_flip: - qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) - kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) - vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) - bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) - q = ( - torch.empty(qdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = ( - torch.empty(kdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - v = ( - torch.empty(vdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - if bias_type is None: - b = None - elif bias_type == 'matrix': - b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - else: - assert False, f'Unsupported bias_type {bias_type}' - if storage_flip: - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - v = torch.transpose(v, 1, 2) - if b is not None: - b = torch.transpose(b, 1, 2) - assert q.shape == (BATCH, N_HEADS, seqlen_q, D_HEAD) - assert k.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) - assert v.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) - if False: # Debugging - q = ( - torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 - v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 - if False: - q = torch.ones((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") * 1.0 - k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 2.0 - v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 3.0 - if False: - import numpy as np - q = torch.arange(np.prod([Z, H, seqlen_q, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - k = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - v = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - q = (q - 128.0) * 0.01 - k = (k - 128.0) * 0.01 - v = (v - 128.0) * 0.01 - q[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - k[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - v[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - q[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 - k[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 - v[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 +class FwdTester(object): - ''' - dout = torch.randn_like(q) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ''' - return_encoded_softmax = dropout_p > 0.0 - higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 - REF_DEVICE=None - q_ref, k_ref, v_ref = query_key_value_clones(q, k, v, dtype=higher_precision_dtype, device=REF_DEVICE) - def TO(ref_tensor): - return ref_tensor.to(device=q.device, dtype=dtype) - tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax) + def __init__(self): + self.use_fill_rng_for_dropout = False - dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None - # assert torch.allclose(dropout_mask, dropout_mask_naive) - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref, - dropout_p=dropout_p, - is_causal=causal, - attn_mask=b, - scale=sm_scale, - dropout_mask=dropout_mask) - if False: - mref_out, mref_softmax = scaled_dot_product_attention(q, k, v, - dropout_p=dropout_p, - is_causal=causal, - scale=sm_scale, - dropout_mask=dropout_mask) - print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{mref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{q.shape=} {q.stride()=}') - print(f'{k.shape=} {k.stride()=}') - print(f'{v.shape=} {v.stride()=}') - print(f'{encoded_softmax=}') - if encoded_softmax is not None: - print(f'{encoded_softmax.shape=} {encoded_softmax.stride()=}') - print(f'{encoded_softmax[:,:, :SPARSE_SEQ_SINCE, :SPARSE_SEQ_SINCE]=}') - print(f'{dropout_mask.shape=} {dropout_mask.stride()=}') - print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - if dtype==torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - else: - ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - RTOL = 0.0 - print(f'Using ATOL={ATOL} RTOL={RTOL}') - is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL) - if not is_allclose: - import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) - print(f'{err_idx=}') - print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}') - # if not is_allclose: - if False: - import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) - print(f'{tri_out[0][0][0][:]=}') - print(f'{ref_out[0][0][0][:]=}') - print(f'{mref_out[0][0][0][:]=}') - if encoded_softmax is not None: - print(f'{encoded_softmax[0][0][0][:]=}') - print(f'{ref_softmax[0][0][0][:]=}') - print(f'{tri_out[-1][0][0][:]=}') - print(f'{ref_out[-1][0][0][:]=}') - print(f'{err_idx=}') - print(f'{tri_out[err_idx]=}') - print(f'{ref_out[err_idx]=}') - if dropout_p > 0: - # print(f'{unmasked_ref_out[0][0][0][:]=}') - print(f'{dropout_mask[0][0][0][:]=}') - print(f'{dropout_mask[err_idx]=}') - # tri_cpu = tri_out[0, 0].cpu().detach().numpy() - # print(f'{tri_cpu.shape=}') - # compare - assert is_allclose + def do_test_op_fwd(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type,): + if causal and seqlen_q != seqlen_k: + pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") + if causal and bias_type is not None: + pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True") + torch.manual_seed(20) + print(f"test_op_fwd {BATCH=}, {N_HEADS=}, {seqlen_q=}, {seqlen_k=}, {D_HEAD=}, {causal=}") + SPARSE_HEAD_SINCE = 3 + SPARSE_SEQ_SINCE = 3 + Z = BATCH + H = N_HEADS + if True: # Real UT + qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) + kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) + vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) + bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) + if storage_flip: + qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) + kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) + vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) + bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) + q = ( + torch.empty(qdims, dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty(kdims, dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + v = ( + torch.empty(vdims, dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + if bias_type is None: + b = None + elif bias_type == 'matrix': + b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) + else: + assert False, f'Unsupported bias_type {bias_type}' + if storage_flip: + q = torch.transpose(q, 1, 2) + k = torch.transpose(k, 1, 2) + v = torch.transpose(v, 1, 2) + if b is not None: + b = torch.transpose(b, 1, 2) + assert q.shape == (BATCH, N_HEADS, seqlen_q, D_HEAD) + assert k.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) + assert v.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) + if False: # Debugging + q = ( + torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 + v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 + if False: + q = torch.ones((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") * 1.0 + k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 2.0 + v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 3.0 + if False: + import numpy as np + q = torch.arange(np.prod([Z, H, seqlen_q, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) + k = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) + v = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) + q = (q - 128.0) * 0.01 + k = (k - 128.0) * 0.01 + v = (v - 128.0) * 0.01 + q[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 + k[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 + v[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 + q[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 + k[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 + v[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 + + ''' + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ''' + return_encoded_softmax = dropout_p > 0.0 and not self.use_fill_rng_for_dropout + # return_encoded_softmax = dropout_p > 0.0 # Reserved for debugging use_fill_rng_for_dropout + higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 + REF_DEVICE=None + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v, dtype=higher_precision_dtype, device=REF_DEVICE) + def TO(ref_tensor): + return ref_tensor.to(device=q.device, dtype=dtype) + autotune = False + return_autotune = False + tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, + autotune, return_autotune) + + if self.use_fill_rng_for_dropout: + rdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) + if storage_flip: + rdims = (rdims[0], rdims[2], rdims[1], rdims[3]) + r = torch.empty(rdims, device=q.device, dtype=torch.float32) + if storage_flip: + r = torch.transpose(r, 1, 2) + philox_seed = DEFAULT_PHILOX_SEED + philox_offset = DEFAULT_PHILOX_OFFSET + debug_fill_dropout_rng(r, philox_seed, philox_offset) + # Reserved for debugging use_fill_rng_for_dropout + # print(f'{r[0,0,:16, :16]}=') + # print(f'{r[0,0,:16, :16] > dropout_p}=') + # print(f'{encoded_softmax[0,0,:16, :16] > 0}=') + dropout_mask = r > dropout_p + else: + dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None + # assert torch.allclose(dropout_mask, dropout_mask_naive) + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref, + dropout_p=dropout_p, + is_causal=causal, + attn_mask=b, + scale=sm_scale, + dropout_mask=dropout_mask) + if False: + mref_out, mref_softmax = scaled_dot_product_attention(q, k, v, + dropout_p=dropout_p, + is_causal=causal, + scale=sm_scale, + dropout_mask=dropout_mask) + print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{mref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{q.shape=} {q.stride()=}') + print(f'{k.shape=} {k.stride()=}') + print(f'{v.shape=} {v.stride()=}') + print(f'{encoded_softmax=}') + if encoded_softmax is not None: + print(f'{encoded_softmax.shape=} {encoded_softmax.stride()=}') + print(f'{encoded_softmax[:,:, :SPARSE_SEQ_SINCE, :SPARSE_SEQ_SINCE]=}') + print(f'{dropout_mask.shape=} {dropout_mask.stride()=}') + print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + if dtype==torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) + else: + ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) + RTOL = 0.0 + print(f'Using ATOL={ATOL} RTOL={RTOL}') + is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL) + if not is_allclose: + import numpy as np + err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) + print(f'{err_idx=}') + print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}') + # if not is_allclose: + if False: + import numpy as np + err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) + print(f'{tri_out[0][0][0][:]=}') + print(f'{ref_out[0][0][0][:]=}') + print(f'{mref_out[0][0][0][:]=}') + if encoded_softmax is not None: + print(f'{encoded_softmax[0][0][0][:]=}') + print(f'{ref_softmax[0][0][0][:]=}') + print(f'{tri_out[-1][0][0][:]=}') + print(f'{ref_out[-1][0][0][:]=}') + print(f'{err_idx=}') + print(f'{tri_out[err_idx]=}') + print(f'{ref_out[err_idx]=}') + if dropout_p > 0: + # print(f'{unmasked_ref_out[0][0][0][:]=}') + print(f'{dropout_mask[0][0][0][:]=}') + print(f'{dropout_mask[err_idx]=}') + # tri_cpu = tri_out[0, 0].cpu().detach().numpy() + # print(f'{tri_cpu.shape=}') + # compare + assert is_allclose # @pytest.mark.parametrize('BATCH', [1, 4]) # @pytest.mark.parametrize('N_HEADS', [1, 4]) -@pytest.mark.parametrize('BATCH', [1, 2, 4]) -@pytest.mark.parametrize('N_HEADS', [1, 2, 4]) +@pytest.mark.parametrize('BATCH', [1, 4]) +@pytest.mark.parametrize('N_HEADS', [1, 4]) @pytest.mark.parametrize('D_HEAD', [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) # @pytest.mark.parametrize('seqlen_q', [16,32,64,128,256,512,1024]) # @pytest.mark.parametrize('seqlen_k', [16,32,64,128,256,512,1024]) @@ -251,10 +278,11 @@ def TO(ref_tensor): # @pytest.mark.parametrize('return_encoded_softmax', [False]) def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip): bias_type = None - _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + tester = FwdTester() + tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) -@pytest.mark.parametrize('BATCH', [1, 2, 4]) -@pytest.mark.parametrize('N_HEADS', [1, 2, 4]) +@pytest.mark.parametrize('BATCH', [1, 4]) +@pytest.mark.parametrize('N_HEADS', [1, 4]) @pytest.mark.parametrize('D_HEAD', [16,32,64,128,256]) @pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) @pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) @@ -265,7 +293,24 @@ def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr def test_op_fwd_with_matrix_bias(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, sm_scale, dropout_p, dtype, storage_flip): causal = False bias_type = 'matrix' + tester = FwdTester() ''' _scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True ''' - _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + +@pytest.mark.parametrize('BATCH', [1, 4]) +@pytest.mark.parametrize('N_HEADS', [1, 4]) +@pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) +@pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('storage_flip', [False, True]) +def test_fill_dropout_rng(BATCH, N_HEADS, seqlen_q, seqlen_k, causal, storage_flip): + D_HEAD = 128 + dropout_p = 0.5 + dtype = torch.float16 + sm_scale = 1.2 + bias_type = None + tester = FwdTester() + tester.use_fill_rng_for_dropout = True + tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) diff --git a/tritonsrc/attn_torch_function.py b/tritonsrc/attn_torch_function.py index e9aacf6..527c484 100644 --- a/tritonsrc/attn_torch_function.py +++ b/tritonsrc/attn_torch_function.py @@ -6,14 +6,17 @@ import torch import triton import triton.language as tl -from flash import attn_fwd as bare_attn_fwd from flash import ( + debug_fill_dropout_rng as bare_debug_fill_dropout_rng, + attn_fwd as bare_attn_fwd, bwd_preprocess as bare_bwd_preprocess, bwd_kernel_dk_dv as bare_bwd_kernel_dk_dv, bwd_kernel_dq as bare_bwd_kernel_dq ) VERBOSE=False +DEFAULT_PHILOX_SEED = 0x1BF52 +DEFAULT_PHILOX_OFFSET = 0x1D4B42 def is_power_of_two(n: int) -> bool: return (n & (n - 1) == 0) and n != 0 @@ -266,8 +269,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax if encoded_softmax is not None: print(f'{encoded_softmax.shape=} {encoded_softmax.dtype=}') - philox_seed = 114514 - philox_offset = 1919810 + philox_seed = DEFAULT_PHILOX_SEED + philox_offset = DEFAULT_PHILOX_OFFSET if b is None: b = torch.empty((0,0,0,0), device=q.device, dtype=q.dtype) BIAS_TYPE = 0 @@ -683,3 +686,21 @@ def backward(ctx, do, _, fwd_tuning_result): return dq, dk, dv, None if db.numel() == 0 else db, None, None, None, None, None, None, None attention = _attention.apply + +def debug_fill_dropout_rng(dropout_rng, philox_seed, philox_offset): + BLOCK_M = 64 + BLOCK_N = 32 + BATCH, N_HEADS, seqlen_q, seqlen_k = dropout_rng.size() + grid_rng = lambda META: ( + triton.cdiv(seqlen_q, META['BLOCK_M']), + N_HEADS, + BATCH, + ) + r = dropout_rng + bare_debug_fill_dropout_rng[grid_rng](r, + r.stride(0), r.stride(1), r.stride(2), r.stride(3), + seqlen_q, seqlen_k, + philox_seed, + philox_offset, + BLOCK_M, BLOCK_N, + num_stages=1) diff --git a/tritonsrc/dropout_rng.py b/tritonsrc/dropout_rng.py new file mode 100644 index 0000000..9dd6815 --- /dev/null +++ b/tritonsrc/dropout_rng.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# Copyright © 2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +import triton +import triton.language as tl +from fwd_kernel import dropout_rng + +@triton.jit +def debug_fill_dropout_rng(R, + stride_rz, stride_rh, stride_rm, stride_rn, + seqlen_q, seqlen_k, + philox_seed, + philox_offset_base, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + start_m = tl.program_id(0) + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + d_offset = off_h * stride_rh + off_z * stride_rz + num_h = tl.num_programs(1) + off_zh = off_z * num_h + off_h * 1 + batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k + R_block_ptr = tl.make_block_ptr( + base=R + d_offset, + shape=(seqlen_q, seqlen_k), + strides=(stride_rm, stride_rn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + for start_n in range(0, seqlen_k, BLOCK_N): + philox_offset = batch_philox_offset + start_m * BLOCK_M * seqlen_k + start_n + rng = dropout_rng(philox_seed, philox_offset, BLOCK_M, BLOCK_N, seqlen_k) + tl.store(R_block_ptr, rng.to(R_block_ptr.type.element_ty), boundary_check=(0,1)) + R_block_ptr = tl.advance(R_block_ptr, (0, BLOCK_N)) diff --git a/tritonsrc/flash.py b/tritonsrc/flash.py index f3a41bb..c24ca14 100644 --- a/tritonsrc/flash.py +++ b/tritonsrc/flash.py @@ -20,3 +20,4 @@ from fwd_kernel import attn_fwd from bwd_preprocess import bwd_preprocess from bwd_split_kernel import bwd_kernel_dk_dv, bwd_kernel_dq +from dropout_rng import debug_fill_dropout_rng diff --git a/tritonsrc/fwd_kernel.py b/tritonsrc/fwd_kernel.py index b12a83b..496ef9f 100644 --- a/tritonsrc/fwd_kernel.py +++ b/tritonsrc/fwd_kernel.py @@ -23,20 +23,20 @@ def max_fn(x, y): return tl.math.max(x, y) @triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): +def dropout_offsets(philox_seed, philox_offset, m, n, stride): ms = tl.arange(0, m) ns = tl.arange(0, n) return philox_offset + ms[:, None] * stride + ns[None, :] @triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) +def dropout_rng(philox_seed, philox_offset, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, m, n, stride).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_output = dropout_rng(philox_seed, philox_offset, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep diff --git a/tritonsrc/test_forward.py b/tritonsrc/test_forward.py index 7da96a0..5f80647 100644 --- a/tritonsrc/test_forward.py +++ b/tritonsrc/test_forward.py @@ -5,7 +5,7 @@ import pytest import torch -from attn_torch_function import attention +from attn_torch_function import attention, debug_fill_dropout_rng, DEFAULT_PHILOX_SEED, DEFAULT_PHILOX_OFFSET def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: # Efficient implementation equivalent to the following: @@ -65,165 +65,187 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask but in PyTorch API it does not present at all ''' -def _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type): - if causal and seqlen_q != seqlen_k: - pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") - if causal and bias_type is not None: - pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True") - torch.manual_seed(20) - print(f"test_op_fwd {BATCH=}, {N_HEADS=}, {seqlen_q=}, {seqlen_k=}, {D_HEAD=}, {causal=}") - SPARSE_HEAD_SINCE = 3 - SPARSE_SEQ_SINCE = 3 - Z = BATCH - H = N_HEADS - if True: # Real UT - qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) - kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) - if storage_flip: - qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) - kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) - vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) - bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) - q = ( - torch.empty(qdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = ( - torch.empty(kdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - v = ( - torch.empty(vdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - if bias_type is None: - b = None - elif bias_type == 'matrix': - b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - else: - assert False, f'Unsupported bias_type {bias_type}' - if storage_flip: - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - v = torch.transpose(v, 1, 2) - if b is not None: - b = torch.transpose(b, 1, 2) - assert q.shape == (BATCH, N_HEADS, seqlen_q, D_HEAD) - assert k.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) - assert v.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) - if False: # Debugging - q = ( - torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 - v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 - if False: - q = torch.ones((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") * 1.0 - k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 2.0 - v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 3.0 - if False: - import numpy as np - q = torch.arange(np.prod([Z, H, seqlen_q, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - k = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - v = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - q = (q - 128.0) * 0.01 - k = (k - 128.0) * 0.01 - v = (v - 128.0) * 0.01 - q[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - k[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - v[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - q[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 - k[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 - v[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 +class FwdTester(object): - ''' - dout = torch.randn_like(q) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ''' - return_encoded_softmax = dropout_p > 0.0 - autotune = False - tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, autotune) + def __init__(self): + self.use_fill_rng_for_dropout = False - dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None - # assert torch.allclose(dropout_mask, dropout_mask_naive) - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, - dropout_p=dropout_p, - is_causal=causal, - attn_mask=b, - scale=sm_scale, - dropout_mask=dropout_mask) - if False: - mref_out, mref_softmax = scaled_dot_product_attention(q, k, v, - dropout_p=dropout_p, - is_causal=causal, - scale=sm_scale, - dropout_mask=dropout_mask) - print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{mref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{q.shape=} {q.stride()=}') - print(f'{k.shape=} {k.stride()=}') - print(f'{v.shape=} {v.stride()=}') - print(f'{encoded_softmax=}') - if encoded_softmax is not None: - print(f'{encoded_softmax.shape=} {encoded_softmax.stride()=}') - print(f'{encoded_softmax[:,:, :SPARSE_SEQ_SINCE, :SPARSE_SEQ_SINCE]=}') - print(f'{dropout_mask.shape=} {dropout_mask.stride()=}') - print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - if dtype==torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - else: - ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - print(f'Using ATOL={ATOL}') - is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=0) - if not is_allclose: - import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) - print(f'{err_idx=}') - print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}') - # if not is_allclose: - if False: - import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) - print(f'{tri_out[0][0][0][:]=}') - print(f'{ref_out[0][0][0][:]=}') - print(f'{mref_out[0][0][0][:]=}') - if encoded_softmax is not None: - print(f'{encoded_softmax[0][0][0][:]=}') - print(f'{ref_softmax[0][0][0][:]=}') - print(f'{tri_out[-1][0][0][:]=}') - print(f'{ref_out[-1][0][0][:]=}') - print(f'{err_idx=}') - print(f'{tri_out[err_idx]=}') - print(f'{ref_out[err_idx]=}') - if dropout_p > 0: - # print(f'{unmasked_ref_out[0][0][0][:]=}') - print(f'{dropout_mask[0][0][0][:]=}') - print(f'{dropout_mask[err_idx]=}') - # tri_cpu = tri_out[0, 0].cpu().detach().numpy() - # print(f'{tri_cpu.shape=}') - # compare - assert is_allclose + def do_test_op_fwd(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type): + if causal and seqlen_q != seqlen_k: + pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") + if causal and bias_type is not None: + pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True") + torch.manual_seed(20) + print(f"test_op_fwd {BATCH=}, {N_HEADS=}, {seqlen_q=}, {seqlen_k=}, {D_HEAD=}, {causal=}") + SPARSE_HEAD_SINCE = 3 + SPARSE_SEQ_SINCE = 3 + Z = BATCH + H = N_HEADS + if True: # Real UT + qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) + kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) + vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) + bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) + if storage_flip: + qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) + kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) + vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) + bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) + q = ( + torch.empty(qdims, dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty(kdims, dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + v = ( + torch.empty(vdims, dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + if bias_type is None: + b = None + elif bias_type == 'matrix': + b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) + else: + assert False, f'Unsupported bias_type {bias_type}' + if storage_flip: + q = torch.transpose(q, 1, 2) + k = torch.transpose(k, 1, 2) + v = torch.transpose(v, 1, 2) + if b is not None: + b = torch.transpose(b, 1, 2) + assert q.shape == (BATCH, N_HEADS, seqlen_q, D_HEAD) + assert k.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) + assert v.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) + if False: # Debugging + q = ( + torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 + v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 + if False: + q = torch.ones((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") * 1.0 + k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 2.0 + v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 3.0 + if False: + import numpy as np + q = torch.arange(np.prod([Z, H, seqlen_q, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) + k = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) + v = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) + q = (q - 128.0) * 0.01 + k = (k - 128.0) * 0.01 + v = (v - 128.0) * 0.01 + q[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 + k[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 + v[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 + q[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 + k[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 + v[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 + + ''' + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ''' + return_encoded_softmax = dropout_p > 0.0 and not self.use_fill_rng_for_dropout + # return_encoded_softmax = dropout_p > 0.0 # Reserved for debugging use_fill_rng_for_dropout + autotune = False + tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, autotune) + + if self.use_fill_rng_for_dropout: + rdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) + if storage_flip: + rdims = (rdims[0], rdims[2], rdims[1], rdims[3]) + r = torch.empty(rdims, device=q.device, dtype=torch.float32) + if storage_flip: + r = torch.transpose(r, 1, 2) + philox_seed = DEFAULT_PHILOX_SEED + philox_offset = DEFAULT_PHILOX_OFFSET + debug_fill_dropout_rng(r, philox_seed, philox_offset) + # Reserved for debugging use_fill_rng_for_dropout + # print(f'{r[0,0,:16, :16]}=') + # print(f'{r[0,0,:16, :16] > dropout_p}=') + # print(f'{encoded_softmax[0,0,:16, :16] > 0}=') + dropout_mask = r > dropout_p + else: + dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None + # assert torch.allclose(dropout_mask, dropout_mask_naive) + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, + dropout_p=dropout_p, + is_causal=causal, + attn_mask=b, + scale=sm_scale, + dropout_mask=dropout_mask) + if False: + mref_out, mref_softmax = scaled_dot_product_attention(q, k, v, + dropout_p=dropout_p, + is_causal=causal, + scale=sm_scale, + dropout_mask=dropout_mask) + print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{mref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + print(f'{q.shape=} {q.stride()=}') + print(f'{k.shape=} {k.stride()=}') + print(f'{v.shape=} {v.stride()=}') + print(f'{encoded_softmax=}') + if encoded_softmax is not None: + print(f'{encoded_softmax.shape=} {encoded_softmax.stride()=}') + print(f'{encoded_softmax[:,:, :SPARSE_SEQ_SINCE, :SPARSE_SEQ_SINCE]=}') + print(f'{dropout_mask.shape=} {dropout_mask.stride()=}') + print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') + if dtype==torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) + else: + ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) + print(f'Using ATOL={ATOL}') + is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=0) + if not is_allclose: + import numpy as np + err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) + print(f'{err_idx=}') + print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}') + # if not is_allclose: + if False: + import numpy as np + err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) + print(f'{tri_out[0][0][0][:]=}') + print(f'{ref_out[0][0][0][:]=}') + print(f'{mref_out[0][0][0][:]=}') + if encoded_softmax is not None: + print(f'{encoded_softmax[0][0][0][:]=}') + print(f'{ref_softmax[0][0][0][:]=}') + print(f'{tri_out[-1][0][0][:]=}') + print(f'{ref_out[-1][0][0][:]=}') + print(f'{err_idx=}') + print(f'{tri_out[err_idx]=}') + print(f'{ref_out[err_idx]=}') + if dropout_p > 0: + # print(f'{unmasked_ref_out[0][0][0][:]=}') + print(f'{dropout_mask[0][0][0][:]=}') + print(f'{dropout_mask[err_idx]=}') + # tri_cpu = tri_out[0, 0].cpu().detach().numpy() + # print(f'{tri_cpu.shape=}') + # compare + assert is_allclose # @pytest.mark.parametrize('BATCH', [1, 4]) # @pytest.mark.parametrize('N_HEADS', [1, 4]) -@pytest.mark.parametrize('BATCH', [1, 2, 4]) -@pytest.mark.parametrize('N_HEADS', [1, 2, 4]) -@pytest.mark.parametrize('D_HEAD', [16,32,64,128,256]) +@pytest.mark.parametrize('BATCH', [1, 4]) +@pytest.mark.parametrize('N_HEADS', [1, 4]) +@pytest.mark.parametrize('D_HEAD', [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) # @pytest.mark.parametrize('D_HEAD', [128]) # Complete set # @pytest.mark.parametrize('seqlen_q', [4,8,16,17,32,64,128,143,256,512,1024,2048]) @@ -245,7 +267,8 @@ def _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale # @pytest.mark.parametrize('return_encoded_softmax', [False]) def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip): bias_type = None - _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + tester = FwdTester() + tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) # @pytest.mark.parametrize('BATCH', [1, 4]) # @pytest.mark.parametrize('N_HEADS', [1, 4]) @@ -273,7 +296,24 @@ def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr def test_op_fwd_with_matrix_bias(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, sm_scale, dropout_p, dtype, storage_flip): causal = False bias_type = 'matrix' + tester = FwdTester() ''' _scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True ''' - _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + +@pytest.mark.parametrize('BATCH', [1, 4]) +@pytest.mark.parametrize('N_HEADS', [1, 4]) +@pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) +@pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('storage_flip', [False, True]) +def test_fill_dropout_rng(BATCH, N_HEADS, seqlen_q, seqlen_k, causal, storage_flip): + D_HEAD = 128 + dropout_p = 0.5 + dtype = torch.float16 + sm_scale = 1.2 + bias_type = None + tester = FwdTester() + tester.use_fill_rng_for_dropout = True + tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) diff --git a/v2python/rules/flash/__init__.py b/v2python/rules/flash/__init__.py index e9ea4b2..3e6a621 100644 --- a/v2python/rules/flash/__init__.py +++ b/v2python/rules/flash/__init__.py @@ -5,6 +5,7 @@ from .bwd_preprocess import bwd_preprocess from .bwd_kernel_dk_dv import bwd_kernel_dk_dv from .bwd_kernel_dq import bwd_kernel_dq +from .debug_fill_dropout_rng import debug_fill_dropout_rng SOURCE_FILE = 'tritonsrc/flash.py' kernels = [ @@ -12,4 +13,5 @@ bwd_preprocess('bwd_preprocess', SOURCE_FILE), bwd_kernel_dk_dv('bwd_kernel_dk_dv', SOURCE_FILE), bwd_kernel_dq('bwd_kernel_dq', SOURCE_FILE), + debug_fill_dropout_rng('debug_fill_dropout_rng', SOURCE_FILE), ] diff --git a/v2python/rules/flash/debug_fill_dropout_rng.py b/v2python/rules/flash/debug_fill_dropout_rng.py new file mode 100644 index 0000000..d8833cb --- /dev/null +++ b/v2python/rules/flash/debug_fill_dropout_rng.py @@ -0,0 +1,40 @@ +# Copyright © 2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +from ._common import FlashKernel, select_pattern + +class debug_fill_dropout_rng(FlashKernel): + ARGUMENTS = [ + 'R', + 'stride_rz', 'stride_rh', 'stride_rm', 'stride_rn', + 'seqlen_q', 'seqlen_k', + 'philox_seed', + 'philox_offset_base', + 'BLOCK_M', # tl.constexpr starts here + 'BLOCK_N', + ] + TENSOR_STRIDE_INPUTS = { + 'R' : select_pattern(ARGUMENTS, 'stride_r'), + } + TENSOR_RANKS = { + '_default' : 4, + } + TYPE_CHOICES = { + frozenset(['R']) : ['*fp32:16'], + frozenset(['seqlen_q', 'seqlen_k']) : ['i32'], + frozenset(['philox_seed']) : ['u64'], + frozenset(['philox_offset_base']) : ['u32'], + } + FEAT_CHOICES = { + } + PERF_CHOICES = { + frozenset(['BLOCK_M']) : [64], + frozenset(['BLOCK_N']) : [32], + } + DEFAULT_NUM_WARPS=4 + DEFAULT_NUM_STAGES=1 + SHIM_KERNEL_NAME = 'debug_fill_dropout_rng' + + AUTOTUNE_KEYS = { } + PARTIALLY_TUNED_FUNCTIONALS = [('PADDED_HEAD', None)] + DOWNGRADER = [] diff --git a/v2src/flash/attn_debug.cc b/v2src/flash/attn_debug.cc new file mode 100644 index 0000000..56194e5 --- /dev/null +++ b/v2src/flash/attn_debug.cc @@ -0,0 +1,47 @@ +// Copyright © 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +namespace aotriton::v2::flash { + +hipError_t +debug_fill_dropout_rng(T4 r, + uint64_t philox_seed, + uint64_t philox_offset, + aotriton::Stream stream_wrap) { + hipError_t err; + auto stream = stream_wrap.native(); + auto arch = getArchFromStream(stream); + auto grid_calculator = [](const DebugFillDropoutRngParams& params) -> dim3 { + dim3 grid { + aotriton::cdiv(params.R->size(2), params.BLOCK_M), + uint32_t(params.R->size(1)), + uint32_t(params.R->size(0)), + }; + // std::cerr << "Grid conf " << grid.x << " " << grid.y << " " << grid.z << std::endl; + return grid; + }; + int seqlen_q = r.size(2); + int seqlen_k = r.size(3); + DebugFillDropoutRngParams params = { + .R = &r, + .seqlen_q = seqlen_q, + .seqlen_k = seqlen_k, + .philox_seed = philox_seed, + .philox_offset_base = static_cast(philox_offset), + }; + DebugFillDropoutRngContext context; + context.grid_calculator = grid_calculator; + err = context.lookup_optimal(params, arch); + if (err != hipSuccess) { + return err; + } + err = context.launch(params, stream); + return err; +} + +}