diff --git a/test/_common_test.py b/test/_common_test.py new file mode 120000 index 0000000..e97308c --- /dev/null +++ b/test/_common_test.py @@ -0,0 +1 @@ +../tritonsrc/_common_test.py \ No newline at end of file diff --git a/test/aotriton_flash.py b/test/aotriton_flash.py index cea6b0b..cbefeb4 100644 --- a/test/aotriton_flash.py +++ b/test/aotriton_flash.py @@ -30,6 +30,8 @@ def mk_aotensor(q, if_empty_then_like=None): assert False, f'Unsupported tensor rank {rank}, shape {q.shape}' if q is None: return klass(0, [0] * rank, [0] * rank, cast_dtype(if_empty_then_like.dtype)) + if q is not None: + assert q.stride(-1) == 1, "AOTriton assumes the last stride of Tensors be 1" return klass(q.data_ptr(), tuple(q.size()), q.stride(), cast_dtype(q.dtype)) def attn_fwd(q, k, v, b, sm_scale, M, o, diff --git a/test/test_backward.py b/test/test_backward.py index 4ff4dca..afe695d 100644 --- a/test/test_backward.py +++ b/test/test_backward.py @@ -6,42 +6,7 @@ import torch from attn_torch_function import attention - -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: - # Efficient implementation equivalent to the following: - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) - if is_causal: - assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - """ - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - """ - attn_weight = query @ key.transpose(-2, -1) * scale_factor - SPARSE_HEAD_SINCE = 5 - SPARSE_SEQ_SINCE = 5 - # attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p > 0.0: - if dropout_mask is not None: - attn_weight.masked_fill_(dropout_mask.logical_not(), float("0.0")) - value = value / (1 - dropout_p) - else: - # assert False, "TESTING dropout_mask code path" - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - else: - # assert False, "TESTING dropout_mask code path" - pass - av = attn_weight @ value - return av, attn_weight +from _common_test import SdpaContext, SdpaParams def _make_block_eyes(q, base=1.0, inc=0.0): dhead = q.shape[-1] @@ -69,21 +34,14 @@ def RP(x): Note: In Flash V2 API the ... is denoted as "num_heads", serving as uniformly sized sequences but in PyTorch API it does not present at all ''' -def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias: torch.Tensor, dtype: torch.dtype = None, device=None): - """ Clones the query, key, and value tensors and moves them to the specified dtype. """ - if dtype is None: - dtype = query.dtype - query_ref = query.clone().detach().to(dtype=dtype, device=device).requires_grad_(query.requires_grad) - key_ref = key.clone().detach().to(dtype=dtype, device=device).requires_grad_(key.requires_grad) - value_ref = value.clone().detach().to(dtype=dtype, device=device).requires_grad_(value.requires_grad) - bias_ref = bias.clone().detach().to(dtype=dtype, device=device).requires_grad_(bias.requires_grad) if bias is not None else None - return query_ref, key_ref, value_ref, bias_ref def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type): if causal and seqlen_q != seqlen_k: pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") if causal and bias_type is not None: pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True") + # if BATCH > 1 and seqlen_q >= 1024 and seqlen_k >= 1024: + # torch.cuda.empty_cache() SKIP_DK_DV = False SKIP_DQ = False SKIP_DB = True if bias_type is None else False @@ -91,100 +49,23 @@ def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale torch.manual_seed(20) SPARSE_HEAD_SINCE = 1 SPARSE_SEQ_SINCE = 1 - qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) - kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) - if storage_flip: - qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) - kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) - vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) - bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) - q = torch.empty(qdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - k = torch.empty(kdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - v = torch.empty(vdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - if bias_type is None: - b = None - elif bias_type == 'matrix': - b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - else: - assert False, f'Unsupported bias_type {bias_type}' - if storage_flip: - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - v = torch.transpose(v, 1, 2) - if b is not None: - b = torch.transpose(b, 1, 2) - if not SKIP_DQ: - q.requires_grad_() - if not SKIP_DK_DV: - k.requires_grad_() - v.requires_grad_() - if not SKIP_DB: - assert b is not None - b.requires_grad_() + transpose = (1, 2) if storage_flip else None + ctx = SdpaContext(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype, + bias_type=bias_type, storage_flip=transpose, device='cuda') + ctx.create_ref_inputs() + ctx.set_require_grads(skip_dq=SKIP_DQ, skip_dk_dv=SKIP_DK_DV, skip_db=SKIP_DB) return_encoded_softmax = True - # q_ref_lp, k_ref_lp, v_ref_lp = query_key_value_clones(q, k, v, dtype=dtype) - higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 - # REF_DEVICE='cpu' - REF_DEVICE=None - q_ref, k_ref, v_ref, b_ref = query_key_value_clones(q, k, v, b, dtype=higher_precision_dtype, device=REF_DEVICE) - def TO(ref_tensor): - return ref_tensor.to(device=q.device, dtype=dtype) + q, k, v, b = ctx.dev_tensors # autotune = True # # triton implementation tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, USE_AUTOTUNE) dropout_mask = encoded_softmax >= 0 - ''' - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, - dropout_p=dropout_p, - is_causal=causal, - scale=sm_scale, - dropout_mask=dropout_mask) - ''' - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref, - dropout_p=dropout_p, - is_causal=causal, - attn_mask=b_ref, - scale=sm_scale, - dropout_mask=dropout_mask) - dout = torch.randn_like(q) - tri_out.backward(dout) - tri_dv, v.grad = None if SKIP_DK_DV else v.grad.clone(), None - tri_dk, k.grad = None if SKIP_DK_DV else k.grad.clone(), None - tri_dq, q.grad = None if SKIP_DQ else q.grad.clone(), None - if not SKIP_DB: - tri_db = b.grad.clone() - else: - tri_db = None + sdpa_params = SdpaParams(causal=causal, sm_scale=sm_scale, dropout_p=dropout_p, dropout_mask=dropout_mask) + ref_out, _ = ctx.compute_ref_forward(sdpa_params) - ''' - ref_out.backward(dout, None) - ref_dv, v.grad = None if SKIP_DK_DV else v.grad.clone(), None - ref_dk, k.grad = None if SKIP_DK_DV else k.grad.clone(), None - ref_dq, q.grad = None if SKIP_DQ else q.grad.clone(), None - ''' - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v_ref.grad = None if SKIP_DK_DV else v_ref.grad.clone(), None - ref_dk, k_ref.grad = None if SKIP_DK_DV else k_ref.grad.clone(), None - ref_dq, q_ref.grad = None if SKIP_DQ else q_ref.grad.clone(), None - if SKIP_DB: - ref_db = None - else: - ref_db, b_ref.grad = b_ref.grad.clone(), None - # compare - if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 128.0) - else: - ATOL = 1e-2 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 128.0) - # RTOL=1e-2 if dtype==torch.float16 else 5e-2 - RTOL=0.02 - print(f'Forward Using ATOL={ATOL} RTOL={RTOL}') - # FIXME: Need to raise tolerance - ''' - is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=RTOL) - ''' - is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL) + dout = torch.rand_like(tri_out) + ctx.compute_backward(tri_out, dout) + is_allclose, adiff, grads_allclose, grads_adiff = ctx.validate_with_reference(tri_out, ctx.dout_tensors) if not is_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) @@ -192,15 +73,23 @@ def TO(ref_tensor): print(f'{tri_out[err_idx]=}') print(f'{ref_out[err_idx]=}') assert is_allclose, 'Forward pass {is_allclose=}' - if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0) - elif dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0) - else: - ATOL = 1e-2 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0) - print(f"Backward Using {ATOL=} {RTOL=}") - dv_allclose = SKIP_DK_DV or torch.allclose(TO(ref_dv), tri_dv, atol=ATOL, rtol=RTOL) + dq_allclose, dk_allclose, dv_allclose, db_allclose = grads_allclose + tri_dq, tri_dk, tri_dv, tri_db = ctx.dout_tensors + ref_dq, ref_dk, ref_dv, ref_db = ctx.dref_tensors + if not SKIP_DQ: + assert tri_dq is not None + assert ref_dq is not None + if not SKIP_DK_DV: + assert tri_dk is not None + assert tri_dv is not None + assert ref_dk is not None + assert ref_dv is not None + if not SKIP_DB: + assert tri_db is not None + assert ref_db is not None + def TO(ref_tensor): + return ref_tensor.to(device=q.device, dtype=dtype) if not dv_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_dv) - tri_dv)).cpu().numpy(), ref_dv.shape) @@ -236,7 +125,6 @@ def TO(ref_tensor): # print(f'{tri_dq[0,0]=}') # print(f'{ref_dq[0,0]=}') - dk_allclose = SKIP_DK_DV or torch.allclose(TO(ref_dk), tri_dk, atol=ATOL, rtol=RTOL) if dv_allclose and not dk_allclose: print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') @@ -249,20 +137,19 @@ def TO(ref_tensor): print(f'{tri_dk[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_dk[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - dq_allclose = SKIP_DQ or torch.allclose(TO(ref_dq), tri_dq, atol=ATOL, rtol=RTOL) if dk_allclose and dv_allclose and not dq_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_dq) - tri_dq)).cpu().numpy(), ref_dq.shape) print(f'{err_idx=}') print(f'{tri_dq[err_idx]=} {ref_dq[err_idx]=} error = {torch.abs(tri_dq[err_idx] - ref_dq[err_idx])}') - db_allclose = SKIP_DB or torch.allclose(TO(ref_db), tri_db, atol=ATOL, rtol=RTOL) if dk_allclose and dv_allclose and dq_allclose and not db_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_db) - tri_db)).cpu().numpy(), ref_db.shape) print(f'{err_idx=}') print(f'{tri_db[err_idx]=} {ref_db[err_idx]=} error = {torch.abs(tri_db[err_idx] - ref_db[err_idx])}') assert dk_allclose and dv_allclose and dq_allclose and db_allclose, f'{dk_allclose=} {dv_allclose=} {dq_allclose=} {db_allclose=}' + print(f'{adiff=} {grads_adiff=}') # @pytest.mark.parametrize('BATCH', [1]) # @pytest.mark.parametrize('N_HEADS', [1]) @@ -283,7 +170,8 @@ def TO(ref_tensor): @pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('dropout_p', [0.0, 0.5]) # @pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) @pytest.mark.parametrize('storage_flip', [False, True]) # @pytest.mark.parametrize('return_encoded_softmax', [False]) @@ -309,8 +197,8 @@ def test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr # @pytest.mark.parametrize('seqlen_k', [128, 79]) @pytest.mark.parametrize('dropout_p', [0.0, 0.5]) # @pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) @pytest.mark.parametrize('storage_flip', [False, True]) # @pytest.mark.parametrize('return_encoded_softmax', [False]) diff --git a/tritonsrc/_common_test.py b/tritonsrc/_common_test.py new file mode 100644 index 0000000..ac264ef --- /dev/null +++ b/tritonsrc/_common_test.py @@ -0,0 +1,248 @@ +import os +from typing import List, Tuple, Optional +from collections import namedtuple +import torch + +def _reference_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + # Efficient implementation equivalent to the following: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + SPARSE_HEAD_SINCE = 5 + SPARSE_SEQ_SINCE = 5 + # attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p > 0.0: + if dropout_mask is not None: + attn_weight.masked_fill_(dropout_mask.logical_not(), float("0.0")) + value = value / (1 - dropout_p) + else: + # assert False, "TESTING dropout_mask code path" + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + else: + # assert False, "TESTING dropout_mask code path" + pass + av = attn_weight @ value + return av, attn_weight + +default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} + +def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + deviation = true_value - computed_value + deviation = torch.abs(deviation / true_value) + # Fill in the nans with the default rtol + torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) + return deviation.max().item() + +def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + # Low precision may yield NAN due to numerical instability + # See https://github.com/pytorch/pytorch/issues/116176 for a real-world example. + # Section 3 in https://arxiv.org/abs/2112.05682v3 explains how accelerated + # SDPA does not suffer from it. + deviation = torch.nan_to_num(true_value - computed_value) + atol = torch.abs(deviation).max().item() + return atol + +def get_tolerances( + true_value: torch.Tensor, + computed_value: torch.Tensor, + fudge_factor: Optional[float] = None, +) -> Tuple[float, float]: + """Returns the absolute and relative tolerances for comparing two tensors.""" + fudge_factor = fudge_factor if fudge_factor is not None else 1.0 + atol = get_atol(true_value, computed_value) + rtol = get_rtol(true_value, computed_value) + + atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) + rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) + # torch.isclose() has weird behavior around see: + # https://github.com/pytorch/pytorch/issues/102400 + if rtol > 1e30: + rtol = default_rtol[computed_value.dtype] + return atol, rtol + +SdpaParams = namedtuple('SdpaParams', ['causal', 'sm_scale', 'dropout_p', 'dropout_mask']) + +class SdpaContext(object): + TENSOR_NAMES = ('q', 'k', 'v', 'b') + + def __init__(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype, + bias_type=None, storage_flip=None, device='cuda'): + qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) + kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) + vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) + bdims = (seqlen_q, seqlen_k) + if storage_flip is not None: + order = [0,1,2,3] + x, y = storage_flip + order[x], order[y] = order[y], order[x] + i, j, k, l = order + qdims = (qdims[i], qdims[j], qdims[k], qdims[l]) + kdims = (kdims[i], kdims[j], kdims[k], kdims[l]) + vdims = (vdims[i], vdims[j], vdims[k], vdims[l]) + # bdims = (bdims[1], bdims[0]) + # q = torch.empty(qdims, dtype=dtype, device=device).normal_(mean=0., std=0.5) + # k = torch.empty(kdims, dtype=dtype, device=device).normal_(mean=0., std=0.5) + # v = torch.empty(vdims, dtype=dtype, device=device).normal_(mean=0., std=0.5) + q = torch.rand(*qdims, dtype=dtype, device=device) + k = torch.rand(*kdims, dtype=dtype, device=device) + v = torch.rand(*vdims, dtype=dtype, device=device) + if bias_type is None: + b = None + elif bias_type == 'matrix': + # b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) + b = torch.rand(*bdims, dtype=dtype, device=device) + b = b.expand(BATCH, N_HEADS, b.shape[0], b.shape[1]) + else: + assert False, f'Unsupported bias_type {bias_type}' + if storage_flip is not None: + x, y = storage_flip + q = torch.transpose(q, x, y) + k = torch.transpose(k, x, y) + v = torch.transpose(v, x, y) + ''' + # No need to support flipped storage + # attn_mask.stride(-1) is assumed to be 1 in PyTorch + if b is not None: + b = torch.transpose(b, 2, 3) + print(f'{b.stride()=}') + ''' + self.dev_tensors = (q, k, v, b) + self.FUDGE_FACTORS = (4, 2, 2, 2) + self.OUT_FUDGE_FACTOR = 3 + + @property + def dtype(self): + return self.dev_tensors[0].dtype + + @property + def ref_device(self): + return self.ref_tensors[0].device + + @staticmethod + def clone_tensor(t, dtype, device=None): + if t is None: + return None + return t.clone().detach().to(dtype=dtype, device=device).requires_grad_(t.requires_grad) + + @staticmethod + def clone_tensor_tuple(in_tensors, dtype, device=None): + return tuple([SdpaContext.clone_tensor(t, dtype=dtype, device=device) for t in in_tensors]) + + def create_ref_inputs(self): + ref_device_option = os.getenv('AOTRITON_REF_DEVICE_OPTION', default='default') + if ref_device_option == 'default': + q, k, v, b = self.dev_tensors + seqlen_k = k.shape[2] + ''' + Shader _ZN2at6native12_GLOBAL__N_119cunn_SoftMaxForwardILi2EdddNS1_22SoftMaxForwardEpilogueEEEvPT2_PKT0_i causes Segfault + for Case test_op_bwd[False-0.0-dtype2-0.0-False-587-64-8-4-4], but cannot be reproduced by running this individual UT. + Avoiding running it on GPU for now + ''' + if seqlen_k == 587: + ref_device = 'cpu' + else: + ref_device = 'cuda' + elif ref_device_option == 'cuda': + ref_device = 'cuda' + elif ref_device_option == 'cpu': + ref_device = 'cpu' + else: + assert False, f'Unknown ref_device_option value {ref_device_option}. Allowed choices "default" "cpu" "cuda"' + self.create_ref_inputs_with_device(ref_device) + + def create_ref_inputs_with_device(self, ref_device): + dtype = self.dtype + hp_dtype = torch.float64 if dtype == torch.float32 else torch.float32 + self.ref_tensors = self.clone_tensor_tuple(self.dev_tensors, dtype=hp_dtype, device=ref_device) + self.lp_ref_tensors = self.clone_tensor_tuple(self.dev_tensors, dtype=dtype, device=ref_device) + + @staticmethod + def _require_grads(tensors, skip_dq=False, skip_dk_dv=False, skip_db=False): + q, k, v, b = tensors + if not skip_dq: + q.requires_grad_() + if not skip_dk_dv: + k.requires_grad_() + v.requires_grad_() + if not skip_db: + assert b is not None + b.requires_grad_() + + def set_require_grads(self, skip_dq=False, skip_dk_dv=False, skip_db=False): + self._require_grads(self.dev_tensors, skip_dq=skip_dq, skip_dk_dv=skip_dk_dv, skip_db=skip_db) + self._require_grads(self.ref_tensors, skip_dq=skip_dq, skip_dk_dv=skip_dk_dv, skip_db=skip_db) + self._require_grads(self.lp_ref_tensors, skip_dq=skip_dq, skip_dk_dv=skip_dk_dv, skip_db=skip_db) + + @staticmethod + def _compute_ref_forward(ref_tensors, p : SdpaParams): + ref_q, ref_k, ref_v, ref_b = ref_tensors + dropout_mask = p.dropout_mask if p.dropout_mask is None else p.dropout_mask.to(device=ref_q.device) + ref_out, ref_mask = torch.ops.aten._scaled_dot_product_attention_math(ref_q, ref_k, ref_v, + dropout_p=p.dropout_p, + is_causal=p.causal, + attn_mask=ref_b, + scale=p.sm_scale, + dropout_mask=dropout_mask) + return (ref_out, ref_mask) + + def compute_ref_forward(self, p : SdpaParams): + self.refout_tensors = self._compute_ref_forward(self.ref_tensors, p) + self.lp_refout_tensors = self._compute_ref_forward(self.lp_ref_tensors, p) + return self.lp_refout_tensors + + @staticmethod + def _compute_backward(in_tensors, out, dout): + q, k, v, b = in_tensors + out.backward(dout.to(device=out.device, dtype=out.dtype)) + dq, q.grad = None if not q.requires_grad else q.grad.clone(), None + dk, k.grad = None if not k.requires_grad else k.grad.clone(), None + dv, v.grad = None if not v.requires_grad else v.grad.clone(), None + if b is None or not b.requires_grad: + db = None + else: + db, b.grad = b.grad.clone(), None + return (dq, dk, dv, db) + + # Note: this follows pytorch's testing approach and expects low precision dout + def compute_backward(self, out, dout): + self.dout_tensors = self._compute_backward(self.dev_tensors, out, dout) + self.dref_tensors = self._compute_backward(self.ref_tensors, self.refout_tensors[0], dout) + self.lp_dref_tensors = self._compute_backward(self.lp_ref_tensors, self.lp_refout_tensors[0], dout) + + @staticmethod + def _validate(out, ref, lp_ref, fudge_factor, tname): + if out is None and ref is None: + return True, float('nan') + atol, rtol = get_tolerances(ref, lp_ref, fudge_factor) + assert out is not None, f'd{tname} is none' + assert ref is not None, f'd{tname}_ref is none' + # print(f'{out=}') + # print(f'{ref=}') + x = out.to(device=ref.device) + y = ref.to(out.dtype) + max_adiff = float(torch.max(torch.abs(x - y))) + return torch.allclose(x, y, atol=atol, rtol=rtol), max_adiff + + def validate_with_reference(self, out, grads): + out_allclose, out_adiff = self._validate(out, self.refout_tensors[0], self.lp_refout_tensors[0], self.OUT_FUDGE_FACTOR, 'out') + grads_allclose = [] + grads_adiff = [] + for grad, ref, lp_ref, fudge_factor, tname in zip(grads, self.dref_tensors, self.lp_dref_tensors, self.FUDGE_FACTORS, self.TENSOR_NAMES): + allclose, adiff = self._validate(grad, ref, lp_ref, fudge_factor, tname) + grads_allclose.append(allclose) + grads_adiff.append(adiff) + return out_allclose, out_adiff, grads_allclose, grads_adiff diff --git a/tritonsrc/attn_torch_function.py b/tritonsrc/attn_torch_function.py index 527c484..f1854d9 100644 --- a/tritonsrc/attn_torch_function.py +++ b/tritonsrc/attn_torch_function.py @@ -113,7 +113,7 @@ def tuned_attn_fwd( @triton.jit def sized_tuned_bwd_kernel_dk_dv( Q, K, V, B, sm_scale, Out, DO, - DK, DV, DB, + DK, DV, L, D, stride_qz, stride_qh, stride_qm, stride_qk, @@ -123,7 +123,6 @@ def sized_tuned_bwd_kernel_dk_dv( stride_oz, stride_oh, stride_om, stride_ok, stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvk, stride_dvn, - stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q, max_seqlens_k, head_dim, dropout_p, @@ -139,7 +138,7 @@ def sized_tuned_bwd_kernel_dk_dv( ): bare_bwd_kernel_dk_dv( Q, K, V, B, sm_scale, Out, DO, - DK, DV, DB, + DK, DV, L, D, stride_qz, stride_qh, stride_qm, stride_qk, @@ -149,7 +148,6 @@ def sized_tuned_bwd_kernel_dk_dv( stride_oz, stride_oh, stride_om, stride_ok, stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvk, stride_dvn, - stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q, max_seqlens_k, head_dim, dropout_p, @@ -171,7 +169,7 @@ def sized_tuned_bwd_kernel_dk_dv( @triton.jit def sized_tuned_bwd_kernel_dq( Q, K, V, B, sm_scale, Out, DO, - DQ, + DQ, DB, L, D, stride_qz, stride_qh, stride_qm, stride_qk, @@ -180,6 +178,7 @@ def sized_tuned_bwd_kernel_dq( stride_bz, stride_bh, stride_bm, stride_bn, stride_oz, stride_oh, stride_om, stride_ok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, + stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q, max_seqlens_k, head_dim, dropout_p, @@ -193,7 +192,7 @@ def sized_tuned_bwd_kernel_dq( BIAS_TYPE: tl.constexpr, ): bare_bwd_kernel_dq(Q, K, V, B, sm_scale, Out, DO, - DQ, + DQ, DB, L, D, stride_qz, stride_qh, stride_qm, stride_qk, @@ -202,6 +201,7 @@ def sized_tuned_bwd_kernel_dq( stride_bz, stride_bh, stride_bm, stride_bn, stride_oz, stride_oh, stride_om, stride_ok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, + stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q, max_seqlens_k, head_dim, dropout_p, @@ -471,8 +471,8 @@ def backward(ctx, do, _, fwd_tuning_result): BLOCK_M = 128 BLOCK_N = 64 if q.dtype == torch.float32: - BLOCK_M //= 2 - BLOCK_N //= 2 + BLOCK_M = max(16, BLOCK_M // 2) + BLOCK_N = max(16, BLOCK_N // 2) # debug_mask = torch.zeros((q.shape[0], q.shape[1], max_seqlens_q, max_seqlens_k), device=q.device, dtype=ctx.encoded_softmax.dtype) grid_dk_dv = lambda META: ( triton.cdiv(max_seqlens_k, META['BLOCK_N']), @@ -485,13 +485,13 @@ def backward(ctx, do, _, fwd_tuning_result): stride_dbz, stride_dbh, stride_dbm, stride_dbn = 0,0,0,0 else: db.fill_(float('nan')) - print(f'{ctx.bias_type=} {BLOCK_M=} {BLOCK_N=} {stride_dbz=} {stride_dbh=} {stride_dbm=} {stride_dbn=}') + print(f'backward {ctx.bias_type=} {ctx.autotune=} {BLOCK_M=} {BLOCK_N=} {stride_dbz=} {stride_dbh=} {stride_dbm=} {stride_dbn=}') if k.requires_grad and v.requires_grad: if ctx.autotune: sized_tuned_bwd_kernel_dk_dv[grid_dk_dv]( q, k, v, b, ctx.sm_scale, o, do, - dk, dv, db, + dk, dv, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -500,7 +500,6 @@ def backward(ctx, do, _, fwd_tuning_result): do.stride(0), do.stride(1), do.stride(2), do.stride(3), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), - stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q=max_seqlens_q, max_seqlens_k=max_seqlens_k, head_dim=Lk, @@ -550,7 +549,7 @@ def backward(ctx, do, _, fwd_tuning_result): bare_bwd_kernel_dk_dv[grid_dk_dv]( q, k, v, b, ctx.sm_scale, o, do, - dk, dv, db, + dk, dv, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -559,7 +558,6 @@ def backward(ctx, do, _, fwd_tuning_result): do.stride(0), do.stride(1), do.stride(2), do.stride(3), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), - stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q=max_seqlens_q, max_seqlens_k=max_seqlens_k, head_dim=Lk, @@ -603,7 +601,7 @@ def backward(ctx, do, _, fwd_tuning_result): sized_tuned_bwd_kernel_dq[grid_dq]( q, k, v, b, ctx.sm_scale, o, do, - dq, + dq, db, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -611,6 +609,7 @@ def backward(ctx, do, _, fwd_tuning_result): b.stride(0), b.stride(1), b.stride(2), b.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), + stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q=max_seqlens_q, max_seqlens_k=max_seqlens_k, head_dim=Lk, @@ -659,7 +658,7 @@ def backward(ctx, do, _, fwd_tuning_result): bare_bwd_kernel_dq[grid_dq]( q, k, v, b, ctx.sm_scale, o, do, - dq, + dq, db, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -667,6 +666,7 @@ def backward(ctx, do, _, fwd_tuning_result): b.stride(0), b.stride(1), b.stride(2), b.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), + stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q=max_seqlens_q, max_seqlens_k=max_seqlens_k, head_dim=Lk, diff --git a/tritonsrc/bwd_split_kernel.py b/tritonsrc/bwd_split_kernel.py index ed6cda0..79c7b4c 100644 --- a/tritonsrc/bwd_split_kernel.py +++ b/tritonsrc/bwd_split_kernel.py @@ -30,7 +30,7 @@ def dot(BLOCK_M : tl.constexpr, QDIM : tl.constexpr, KDIM : tl.constexpr, q, k): @triton.jit def bwd_kernel_dk_dv( Q, K, V, B, sm_scale, Out, DO, - DK, DV, DB, + DK, DV, L, D, stride_qz, stride_qh, stride_qm, stride_qk, @@ -40,7 +40,6 @@ def bwd_kernel_dk_dv( stride_oz, stride_oh, stride_om, stride_ok, stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvk, stride_dvn, - stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q, max_seqlens_k, head_dim, dropout_p, @@ -111,7 +110,6 @@ def bwd_kernel_dk_dv( off_zh = off_z * num_h + off_h * 1 if BIAS_TYPE == 0: B_block_ptr = 0 - DB_block_ptr = 0 elif BIAS_TYPE == 1: B_block_ptr = tl.make_block_ptr( base=B + off_h * stride_bh + off_z * stride_bz, @@ -121,20 +119,6 @@ def bwd_kernel_dk_dv( block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) - if (stride_dbz == 0 or stride_dbh == 0) or stride_dbm == 0: - store_db = False - else: - store_db = True - # Still have to make one even if no_db = False - # due to a limit of Triton: runtime branches must have identical data types. - DB_block_ptr = tl.make_block_ptr( - base=DB + off_h * stride_dbh + off_z * stride_dbz, - shape=(seqlen_q, seqlen_k), - strides=(stride_dbm, stride_dbn), - offsets=(0, start_m), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0) - ) else: tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') # pointer to row-wise quantities in value-like data @@ -165,7 +149,6 @@ def bwd_kernel_dk_dv( batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (lo, 0)) - DB_block_ptr = tl.advance(DB_block_ptr, (lo, 0)) ''' K1 K2 (d)V dO Q1 qk11 qk12 (d)v1 dO1 @@ -267,9 +250,6 @@ def bwd_kernel_dk_dv( dp = tl.where(keep, dp / (1 - dropout_p), 0) # compute ds = p * (dp - delta[:, None]) ds = p * (dp - Di) # (BLOCK_M, BLOCK_N) - if BIAS_TYPE == 1: - if store_db: - tl.store(DB_block_ptr, ds.to(DB.type.element_ty), boundary_check=(0,1)) # compute dk if BLOCK_M == 1: dk += ds.to(Q.dtype.element_ty) * q @@ -281,7 +261,6 @@ def bwd_kernel_dk_dv( DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (BLOCK_M, 0)) - DB_block_ptr = tl.advance(DB_block_ptr, (BLOCK_M, 0)) # initialize pointers to output dk_offset = off_h * stride_dkh + off_z * stride_dkz DK_block_ptr = tl.make_block_ptr( @@ -307,7 +286,7 @@ def bwd_kernel_dk_dv( @triton.jit def bwd_kernel_dq( Q, K, V, B, sm_scale, Out, DO, - DQ, + DQ, DB, L, D, stride_qz, stride_qh, stride_qm, stride_qk, @@ -316,6 +295,7 @@ def bwd_kernel_dq( stride_bz, stride_bh, stride_bm, stride_bn, stride_oz, stride_oh, stride_om, stride_ok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, + stride_dbz, stride_dbh, stride_dbm, stride_dbn, max_seqlens_q, max_seqlens_k, head_dim, dropout_p, @@ -383,6 +363,7 @@ def bwd_kernel_dq( off_zh = off_z * num_h + off_h * 1 if BIAS_TYPE == 0: B_block_ptr = 0 + DB_block_ptr = 0 elif BIAS_TYPE == 1: B_block_ptr = tl.make_block_ptr( base=B + off_h * stride_bh + off_z * stride_bz, @@ -392,6 +373,20 @@ def bwd_kernel_dq( block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) + if (stride_dbz == 0 and stride_dbh == 0) and stride_dbm == 0: + store_db = False + else: + store_db = True + # Still have to make one even if no_db = False + # due to a limit of Triton: runtime branches must have identical data types. + DB_block_ptr = tl.make_block_ptr( + base=DB + off_h * stride_dbh + off_z * stride_dbz, + shape=(seqlen_q, seqlen_k), + strides=(stride_dbm, stride_dbn), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) else: tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') # pointer to row-wise quantities in value-like data @@ -488,11 +483,15 @@ def bwd_kernel_dq( else: # ds.shape = (BLOCK_M, BLOCK_N), kt.shape = (BLOCK_DMODEL, BLOCK_N) dq += tl.dot(ds.to(Q.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL) + if BIAS_TYPE == 1: + if store_db: + tl.store(DB_block_ptr, ds.to(DB.type.element_ty), boundary_check=(0,1)) # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N)) if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N)) + DB_block_ptr = tl.advance(DB_block_ptr, (0, BLOCK_N)) # initialize pointers to output dq_offset = off_h * stride_dqh + off_z * stride_dqz DQ_block_ptr = tl.make_block_ptr( diff --git a/tritonsrc/performance_forward.py b/tritonsrc/performance_forward.py index 2977523..2c9445b 100644 --- a/tritonsrc/performance_forward.py +++ b/tritonsrc/performance_forward.py @@ -59,10 +59,11 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + b = None sm_scale = 1.3 autotune = True return_encoded_softmax = False - fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel, return_encoded_softmax, autotune) + fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune) if mode == 'bwd': o = fn() do = torch.randn_like(o) diff --git a/tritonsrc/test_backward.py b/tritonsrc/test_backward.py index 0573985..96df3eb 100644 --- a/tritonsrc/test_backward.py +++ b/tritonsrc/test_backward.py @@ -4,44 +4,10 @@ import pytest import torch +import os from attn_torch_function import attention - -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: - # Efficient implementation equivalent to the following: - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) - if is_causal: - assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - """ - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - """ - attn_weight = query @ key.transpose(-2, -1) * scale_factor - SPARSE_HEAD_SINCE = 5 - SPARSE_SEQ_SINCE = 5 - # attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p > 0.0: - if dropout_mask is not None: - attn_weight.masked_fill_(dropout_mask.logical_not(), float("0.0")) - value = value / (1 - dropout_p) - else: - # assert False, "TESTING dropout_mask code path" - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - else: - # assert False, "TESTING dropout_mask code path" - pass - av = attn_weight @ value - return av, attn_weight +from _common_test import SdpaContext, SdpaParams def _make_block_eyes(q, base=1.0, inc=0.0): dhead = q.shape[-1] @@ -70,16 +36,6 @@ def RP(x): but in PyTorch API it does not present at all ''' -def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias: torch.Tensor, dtype: torch.dtype = None, device=None): - """ Clones the query, key, and value tensors and moves them to the specified dtype. """ - if dtype is None: - dtype = query.dtype - query_ref = query.clone().detach().to(dtype=dtype, device=device).requires_grad_(query.requires_grad) - key_ref = key.clone().detach().to(dtype=dtype, device=device).requires_grad_(key.requires_grad) - value_ref = value.clone().detach().to(dtype=dtype, device=device).requires_grad_(value.requires_grad) - bias_ref = bias.clone().detach().to(dtype=dtype, device=device).requires_grad_(bias.requires_grad) if bias is not None else None - return query_ref, key_ref, value_ref, bias_ref - def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type): if causal and seqlen_q != seqlen_k: pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") @@ -92,101 +48,23 @@ def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale torch.manual_seed(20) SPARSE_HEAD_SINCE = 1 SPARSE_SEQ_SINCE = 1 - qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) - kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) - if storage_flip: - qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) - kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) - vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) - bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) - q = torch.empty(qdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - k = torch.empty(kdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - v = torch.empty(vdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - if bias_type is None: - b = None - elif bias_type == 'matrix': - b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - else: - assert False, f'Unsupported bias_type {bias_type}' - if storage_flip: - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - v = torch.transpose(v, 1, 2) - if b is not None: - b = torch.transpose(b, 1, 2) - if not SKIP_DQ: - q.requires_grad_() - if not SKIP_DK_DV: - k.requires_grad_() - v.requires_grad_() - if not SKIP_DB: - assert b is not None - b.requires_grad_() - return_encoded_softmax = True + transpose = (1, 2) if storage_flip else None + ctx = SdpaContext(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype, + bias_type=bias_type, storage_flip=transpose, device='cuda') + ctx.create_ref_inputs() + ctx.set_require_grads(skip_dq=SKIP_DQ, skip_dk_dv=SKIP_DK_DV, skip_db=SKIP_DB) - # q_ref_lp, k_ref_lp, v_ref_lp = query_key_value_clones(q, k, v, dtype=dtype) - higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 - # REF_DEVICE='cpu' - REF_DEVICE=None - q_ref, k_ref, v_ref, b_ref = query_key_value_clones(q, k, v, b, dtype=higher_precision_dtype, device=REF_DEVICE) - def TO(ref_tensor): - return ref_tensor.to(device=q.device, dtype=dtype) + return_encoded_softmax = True + q, k, v, b = ctx.dev_tensors # autotune = True # # triton implementation tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, USE_AUTOTUNE) dropout_mask = encoded_softmax >= 0 - ''' - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, - dropout_p=dropout_p, - is_causal=causal, - scale=sm_scale, - dropout_mask=dropout_mask) - ''' - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref, - dropout_p=dropout_p, - is_causal=causal, - attn_mask=b_ref, - scale=sm_scale, - dropout_mask=dropout_mask) - dout = torch.randn_like(q) - tri_out.backward(dout) - tri_dv, v.grad = None if SKIP_DK_DV else v.grad.clone(), None - tri_dk, k.grad = None if SKIP_DK_DV else k.grad.clone(), None - tri_dq, q.grad = None if SKIP_DQ else q.grad.clone(), None - if not SKIP_DB: - tri_db = b.grad.clone() - else: - tri_db = None - - ''' - ref_out.backward(dout, None) - ref_dv, v.grad = None if SKIP_DK_DV else v.grad.clone(), None - ref_dk, k.grad = None if SKIP_DK_DV else k.grad.clone(), None - ref_dq, q.grad = None if SKIP_DQ else q.grad.clone(), None - ''' - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v_ref.grad = None if SKIP_DK_DV else v_ref.grad.clone(), None - ref_dk, k_ref.grad = None if SKIP_DK_DV else k_ref.grad.clone(), None - ref_dq, q_ref.grad = None if SKIP_DQ else q_ref.grad.clone(), None - if SKIP_DB: - ref_db = None - else: - ref_db, b_ref.grad = b_ref.grad.clone(), None - # compare - if dtype==torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k) / 128.0) - else: - ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k) / 128.0) - # RTOL=1e-2 if dtype==torch.float16 else 5e-2 - RTOL=0.02 - print(f'Forward Using ATOL={ATOL} RTOL={RTOL}') - # FIXME: Need to raise tolerance - ''' - is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=RTOL) - ''' - is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL) + sdpa_params = SdpaParams(causal=causal, sm_scale=sm_scale, dropout_p=dropout_p, dropout_mask=dropout_mask) + ref_out, _ = ctx.compute_ref_forward(sdpa_params) + dout = torch.randn_like(tri_out) + ctx.compute_backward(tri_out, dout) + is_allclose, adiff, grads_allclose, grads_adiff = ctx.validate_with_reference(tri_out, ctx.dout_tensors) if not is_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) @@ -194,15 +72,12 @@ def TO(ref_tensor): print(f'{tri_out[err_idx]=}') print(f'{ref_out[err_idx]=}') assert is_allclose, 'Forward pass {is_allclose=}' - if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0) - elif dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0) - else: - ATOL = 1e-2 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0) - print(f'Backward Using ATOL={ATOL} RTOL={RTOL}') - dv_allclose = SKIP_DK_DV or torch.allclose(TO(ref_dv), tri_dv, atol=ATOL, rtol=RTOL) + dq_allclose, dk_allclose, dv_allclose, db_allclose = grads_allclose + tri_dq, tri_dk, tri_dv, tri_db = ctx.dout_tensors + ref_dq, ref_dk, ref_dv, ref_db = ctx.dref_tensors + def TO(ref_tensor): + return ref_tensor.to(device=q.device, dtype=dtype) if not dv_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_dv) - tri_dv)).cpu().numpy(), ref_dv.shape) @@ -238,7 +113,6 @@ def TO(ref_tensor): # print(f'{tri_dq[0,0]=}') # print(f'{ref_dq[0,0]=}') - dk_allclose = SKIP_DK_DV or torch.allclose(TO(ref_dk), tri_dk, atol=ATOL, rtol=RTOL) if dv_allclose and not dk_allclose: print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') @@ -251,20 +125,19 @@ def TO(ref_tensor): print(f'{tri_dk[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_dk[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - dq_allclose = SKIP_DQ or torch.allclose(TO(ref_dq), tri_dq, atol=ATOL, rtol=RTOL) if dk_allclose and dv_allclose and not dq_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_dq) - tri_dq)).cpu().numpy(), ref_dq.shape) print(f'{err_idx=}') print(f'{tri_dq[err_idx]=} {ref_dq[err_idx]=} error = {torch.abs(tri_dq[err_idx] - ref_dq[err_idx])}') - db_allclose = SKIP_DB or torch.allclose(TO(ref_db), tri_db, atol=ATOL, rtol=RTOL) if dk_allclose and dv_allclose and dq_allclose and not db_allclose: import numpy as np err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_db) - tri_db)).cpu().numpy(), ref_db.shape) print(f'{err_idx=}') print(f'{tri_db[err_idx]=} {ref_db[err_idx]=} error = {torch.abs(tri_db[err_idx] - ref_db[err_idx])}') assert dk_allclose and dv_allclose and dq_allclose and db_allclose, f'{dk_allclose=} {dv_allclose=} {dq_allclose=} {db_allclose=}' + print(f'{adiff=} {grads_adiff=}') # @pytest.mark.parametrize('BATCH', [1]) # @pytest.mark.parametrize('N_HEADS', [1]) @@ -291,11 +164,11 @@ def TO(ref_tensor): # @pytest.mark.parametrize('seqlen_k', [32, 128]) @pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('dropout_p', [0.0, 0.5]) -# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.float16]) -# @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) -@pytest.mark.parametrize('sm_scale', [1.2]) +@pytest.mark.parametrize('sm_scale', [0.0, 1.2]) +# @pytest.mark.parametrize('sm_scale', [1.2]) # @pytest.mark.parametrize('storage_flip', [False]) @pytest.mark.parametrize('storage_flip', [False, True]) # @pytest.mark.parametrize('return_encoded_softmax', [False]) @@ -305,8 +178,8 @@ def test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr # @pytest.mark.parametrize('BATCH', [1, 4]) # @pytest.mark.parametrize('N_HEADS', [1, 4]) -@pytest.mark.parametrize('BATCH', [1, 2, 4]) -@pytest.mark.parametrize('N_HEADS', [1, 2, 4]) +@pytest.mark.parametrize('BATCH', [1, 4]) +@pytest.mark.parametrize('N_HEADS', [1, 4]) @pytest.mark.parametrize('D_HEAD', [16,32,64,128,256]) # @pytest.mark.parametrize('D_HEAD', [128]) # Complete set @@ -321,8 +194,8 @@ def test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr # @pytest.mark.parametrize('seqlen_k', [128, 79]) @pytest.mark.parametrize('dropout_p', [0.0, 0.5]) # @pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) @pytest.mark.parametrize('storage_flip', [False, True]) # @pytest.mark.parametrize('return_encoded_softmax', [False]) diff --git a/tritonsrc/tune_flash.py b/tritonsrc/tune_flash.py index ce4f1dc..2941f55 100644 --- a/tritonsrc/tune_flash.py +++ b/tritonsrc/tune_flash.py @@ -153,8 +153,8 @@ def parse(): p.add_argument('--causal', type=int, nargs='+', default=[True,False], choices=[0, 1], help='Causal mask. (Use 0/1 for False/True') p.add_argument('--dropout_p', type=float, nargs='+', default=[0.5, 0.0], help='Probablity to dropout (0 to disable).') p.add_argument('--dtype', type=str, nargs='+', - default=['float16', 'bfloat16'], - choices=['float16', 'bfloat16'], + default=['float16', 'bfloat16', 'float32'], + choices=['float16', 'bfloat16', 'float32'], help='Datatype to profile.') p.add_argument('--bias_type', type=int, nargs='+', default=[0, 1], choices=[0, 1], help='Bias types to profile, 0: None, 1: Matrix.') p.add_argument('--verbose', action='store_true', help='Verbose') diff --git a/v2python/generate_shim.py b/v2python/generate_shim.py index 8b001e2..2e869dd 100755 --- a/v2python/generate_shim.py +++ b/v2python/generate_shim.py @@ -144,11 +144,14 @@ def __init__(self, args): grand_target = LIBRARY_NAME self._build_dir = Path(args.build_dir) f = open(self._build_dir / 'Makefile.shim', 'w') + arf = open(self._build_dir / 'ar.txt', 'w') super().__init__(args=args, grand_target=grand_target, out=f) self._library_suffixes = ['.a'] if args.archive_only else ['.a', '.so'] + self._arf = arf def __del__(self): self._out.close() + self._arf.close() def gen_children(self, out): for k in triton_kernels: @@ -173,7 +176,8 @@ def write_conclude(self): fn = f'{LIBRARY_NAME}{s}' print(fn, ': ', all_object_files, file=self._out) if s == '.a': - print('\t', '${AR} -r ', fn, all_object_files, file=f) + print('\t', '${AR} -r ', fn, '@ar.txt', file=f) + print(all_object_files, file=self._arf) if s == '.so': print('\t', COMPILER, ' -g -shared -fPIC -o ', fn, all_object_files, file=f) print('\n\n', file=f) diff --git a/v2python/rules/flash/attn_fwd.py b/v2python/rules/flash/attn_fwd.py index c28c86e..1d56403 100644 --- a/v2python/rules/flash/attn_fwd.py +++ b/v2python/rules/flash/attn_fwd.py @@ -35,7 +35,7 @@ class attn_fwd(FlashKernel): 'Out' : select_pattern(ARGUMENTS, 'stride_o'), } TYPE_CHOICES = { - frozenset(['Q', 'K', 'V', 'B', 'Out', 'encoded_softmax']) : ['*fp16:16', '*bf16:16'], + frozenset(['Q', 'K', 'V', 'B', 'Out', 'encoded_softmax']) : ['*fp16:16', '*bf16:16', '*fp32:16'], frozenset(['sm_scale']) : ['fp32'], frozenset(['M']) : ['*fp32:16'], # frozenset(select_pattern(ARGUMENTS, 'stride_', trim=1)) : ['u64'], diff --git a/v2python/rules/flash/bwd_kernel_dk_dv.py b/v2python/rules/flash/bwd_kernel_dk_dv.py index b58e7ce..707cd78 100644 --- a/v2python/rules/flash/bwd_kernel_dk_dv.py +++ b/v2python/rules/flash/bwd_kernel_dk_dv.py @@ -7,7 +7,7 @@ class bwd_kernel_dk_dv(FlashKernel): ARGUMENTS = [ 'Q', 'K', 'V', 'B', 'sm_scale', 'Out', 'DO', - 'DK', 'DV', 'DB', + 'DK', 'DV', 'L', 'D', 'stride_qz', 'stride_qh', 'stride_qm', 'stride_qk', 'stride_kz', 'stride_kh', 'stride_kn', 'stride_kk', @@ -16,7 +16,6 @@ class bwd_kernel_dk_dv(FlashKernel): 'stride_oz', 'stride_oh', 'stride_om', 'stride_ok', 'stride_dkz', 'stride_dkh', 'stride_dkn', 'stride_dkk', 'stride_dvz', 'stride_dvh', 'stride_dvk', 'stride_dvn', - 'stride_dbz', 'stride_dbh', 'stride_dbm', 'stride_dbn', 'seqlen_q', 'seqlen_k', 'head_dim', 'dropout_p', @@ -39,7 +38,6 @@ class bwd_kernel_dk_dv(FlashKernel): 'DO' : select_pattern(ARGUMENTS, 'stride_o'), 'DK' : select_pattern(ARGUMENTS, 'stride_dk'), 'DV' : select_pattern(ARGUMENTS, 'stride_dv'), - 'DB' : select_pattern(ARGUMENTS, 'stride_db'), } TENSOR_RANKS = { '_default' : 4, @@ -47,7 +45,7 @@ class bwd_kernel_dk_dv(FlashKernel): 'D': 2, } TYPE_CHOICES = { - frozenset(['Q', 'K', 'V', 'B', 'Out', 'DO', 'DK', 'DV', 'DB']) : match_fwd('Q'), + frozenset(['Q', 'K', 'V', 'B', 'Out', 'DO', 'DK', 'DV']) : match_fwd('Q'), frozenset(['sm_scale']) : match_fwd( 'sm_scale'), frozenset(['L', 'D']) : ['*fp32:16'], frozenset(['seqlen_q', 'seqlen_k']) : ['u64'], diff --git a/v2python/rules/flash/bwd_kernel_dq.py b/v2python/rules/flash/bwd_kernel_dq.py index 1fe5c56..389a1c4 100644 --- a/v2python/rules/flash/bwd_kernel_dq.py +++ b/v2python/rules/flash/bwd_kernel_dq.py @@ -8,7 +8,7 @@ class bwd_kernel_dq(FlashKernel): ARGUMENTS = [ 'Q', 'K', 'V', 'B', 'sm_scale', 'Out', 'dO', - 'dQ', + 'dQ', 'dB', 'L', 'D', 'stride_qz', 'stride_qh', 'stride_qm', 'stride_qk', 'stride_kz', 'stride_kh', 'stride_kn', 'stride_kk', @@ -16,6 +16,7 @@ class bwd_kernel_dq(FlashKernel): 'stride_bz', 'stride_bh', 'stride_bk', 'stride_bn', 'stride_oz', 'stride_oh', 'stride_om', 'stride_ok', 'stride_dqz', 'stride_dqh', 'stride_dqm', 'stride_dqk', + 'stride_dbz', 'stride_dbh', 'stride_dbm', 'stride_dbn', 'seqlen_q', 'seqlen_k', 'head_dim', 'dropout_p', @@ -38,6 +39,7 @@ class bwd_kernel_dq(FlashKernel): 'B' : select_pattern(ARGUMENTS, 'stride_b'), 'dO' : select_pattern(ARGUMENTS, 'stride_o'), 'dQ' : select_pattern(ARGUMENTS, 'stride_dq'), + 'dB' : select_pattern(ARGUMENTS, 'stride_db'), } TENSOR_RANKS = { '_default' : 4, @@ -45,7 +47,7 @@ class bwd_kernel_dq(FlashKernel): 'D': 2, } TYPE_CHOICES = { - frozenset(['Q', 'K', 'V', 'B', 'Out', 'dO', 'dQ']) : match_fwd('Q'), + frozenset(['Q', 'K', 'V', 'B', 'Out', 'dO', 'dQ', 'dB']) : match_fwd('Q'), frozenset(['sm_scale']) : match_fwd( 'sm_scale'), frozenset(['L', 'D']) : ['*fp32:16'], frozenset(['seqlen_q', 'seqlen_k']) : ['u64'], diff --git a/v2python/rules/flash/bwd_preprocess.py b/v2python/rules/flash/bwd_preprocess.py index 2934e64..69508c2 100644 --- a/v2python/rules/flash/bwd_preprocess.py +++ b/v2python/rules/flash/bwd_preprocess.py @@ -25,7 +25,7 @@ class bwd_preprocess(FlashKernel): 'Delta' : 2, } TYPE_CHOICES = { - frozenset(['Out', 'DO']) : ['*fp16:16', '*bf16:16'], + frozenset(['Out', 'DO']) : ['*fp16:16', '*bf16:16', '*fp32:16'], frozenset(['Delta']) : ['*fp32:16'], frozenset(['seqlen_q']) : ['u64'], frozenset(['head_dim']) : ['i32'], diff --git a/v2python/rules/tuning_database.sqlite3 b/v2python/rules/tuning_database.sqlite3 index 9ca001e..8188332 100644 Binary files a/v2python/rules/tuning_database.sqlite3 and b/v2python/rules/tuning_database.sqlite3 differ diff --git a/v2python/table_tool.py b/v2python/table_tool.py index 5543ba1..7a84027 100644 --- a/v2python/table_tool.py +++ b/v2python/table_tool.py @@ -17,6 +17,8 @@ def parse(): help='Action to perform') p.add_argument('--table_name', type=str, help='Table to dump/load') p.add_argument('--table_file', type=str, help='CSV file of dump/load') + p.add_argument('--select_where', type=str, default='', help='Extra WHERE clause for SQL to only dump selected rows to CSV file') + p.add_argument('--ignore_id', action='store_true', help='Ignore row IDs when loading CSV to database, useful for table merge') args = p.parse_args() return args @@ -129,7 +131,7 @@ def close(self): def dumpcsv(self, table_name, table_file): with open(table_file, mode='w', newline='') as file: - self._cur.execute(f"SELECT * FROM {table_name};") + self._cur.execute(f"SELECT * FROM {table_name} WHERE {self._args.select_where};") writer = csv.writer(file) colunm_names = [tup[0] for tup in self._cur.description] writer.writerow(colunm_names) @@ -143,10 +145,15 @@ def loadcsv(self, table_file, table_name): with open(table_file, mode='r', newline='') as file: reader = csv.reader(file) csv_headers = next(reader) + if self._args.ignore_id: + assert csv_headers[0] == 'id', "--ignore_id: First column of CSV is not 'id'. This tool does not handle more compilicated situations." + csv_headers = csv_headers[1:] colunm_names = ', '.join(csv_headers) placeholders = ', '.join(['?'] * len(csv_headers)) stmt = f'INSERT INTO {table_name} ({colunm_names}) VALUES({placeholders});' for row in reader: + if self._args.ignore_id: + row = row[1:] self._cur.execute(stmt, row) self._conn.commit() diff --git a/v2python/tuning_database.py b/v2python/tuning_database.py index d205746..5bfa48c 100644 --- a/v2python/tuning_database.py +++ b/v2python/tuning_database.py @@ -62,7 +62,7 @@ def __init__(self, tune_info_dir : pathlib.Path, k : 'KernelDescription'): def select_gpu(self, gpu, index): arch = AOTRITON_GPU_ARCH_TUNING_STRING[gpu] if arch not in self.arch_dict: - print('For kernel {self._kdesc.KERNEL_FAMILY}.{self._kdesc.name}, Architecture {arch} was not found in tuning database, using dummy one instead') + print(f'For kernel {self._kdesc.KERNEL_FAMILY}.{self._kdesc.name}, Architecture {arch} was not found in tuning database, using dummy one instead') self.arch_dict[arch] = EmptyKernelTuningDatabaseForArch(self._kdesc, arch) return self.arch_dict[arch].set_gpu(gpu, index) diff --git a/v2src/flash/attn_bwd.cc b/v2src/flash/attn_bwd.cc index f3e2f1f..d9ac5a6 100644 --- a/v2src/flash/attn_bwd.cc +++ b/v2src/flash/attn_bwd.cc @@ -60,7 +60,6 @@ bwd_kernel_dk_dv(T4 q, T4 dout, T4 dk, T4 dv, - T4 db, T2 softmax_lse, T2 delta, float dropout_p, @@ -97,7 +96,6 @@ bwd_kernel_dk_dv(T4 q, .DO = &dout, .DK = &dk, .DV = &dv, - .DB = &db, .sm_scale = sm_scale, .L = &softmax_lse, .D = &delta, @@ -132,6 +130,7 @@ bwd_kernel_dq(T4 q, T4 out, T4 dout, T4 dq, + T4 db, T2 softmax_lse, T2 delta, float dropout_p, @@ -167,6 +166,7 @@ bwd_kernel_dq(T4 q, .Out = &out, .dO = &dout, .dQ = &dq, + .dB = &db, .sm_scale = sm_scale, .L = &softmax_lse, .D = &delta, @@ -224,7 +224,6 @@ attn_bwd(T4 q, dout, dk, dv, - db, softmax_lse, delta, dropout_p, @@ -243,6 +242,7 @@ attn_bwd(T4 q, out, dout, dq, + db, softmax_lse, delta, dropout_p,