Skip to content

Commit

Permalink
Add FP32 and Bias to fulfill the functionalities required by `torch.n…
Browse files Browse the repository at this point in the history
…n.attention.SDPBackend.EFFICIENT_ATTENTION` (#22)

This PR includes the following major changes

1. Add Bias support in the Triton kernel, for both forward and backward directions
2. Add `fp32` datatype support, and the corresponding tuning database information
3. Fix "argument list too long" error during linking
4. Improved `table_tool.py` to partially dump/load `.csv` file, allowing database merging (*)
5. Refactor the UT to use PyTorch's method to estimate ATOL/RTOL

Known limitations:
1. Gradient of Bias assumes real Rank 4 tensor (`.expand()`-ed ones are unlikely to work). No checking is performed on this requisite and failure may be silent. Bias itself is not affected since its read-only.
2. `test_forward.py` is still using the old method to estimate ATOL/RTOL

* Examples of using `table_tool.py` to merge databases

```
DB=v2python/rules/tuning_database.sqlite3

python -m v2python.table_tool -k '' --action dumpcsv \
    -f $DB --table_name 'FLASH$attn_fwd' \
    --table_file 'attn_fwd.fp32mi300.csv' \
    --select_where 'inputs$Q_dtype = "torch.float32"'

git checkout another_branch -- $DB

python -m v2python.table_tool -k '' --action loadcsv \
     -f $DB --table_name 'FLASH$attn_fwd' \
     --table_file attn_fwd.fp32mi300.csv \
     --ignore_id
```

Note: --ignored_id does not support cases that 'id' is not the first
column of the CSV file, for simplicity.
  • Loading branch information
xinyazhang committed May 3, 2024
1 parent 71bd17f commit 00ccbf3
Show file tree
Hide file tree
Showing 18 changed files with 380 additions and 357 deletions.
1 change: 1 addition & 0 deletions test/_common_test.py
2 changes: 2 additions & 0 deletions test/aotriton_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def mk_aotensor(q, if_empty_then_like=None):
assert False, f'Unsupported tensor rank {rank}, shape {q.shape}'
if q is None:
return klass(0, [0] * rank, [0] * rank, cast_dtype(if_empty_then_like.dtype))
if q is not None:
assert q.stride(-1) == 1, "AOTriton assumes the last stride of Tensors be 1"
return klass(q.data_ptr(), tuple(q.size()), q.stride(), cast_dtype(q.dtype))

def attn_fwd(q, k, v, b, sm_scale, M, o,
Expand Down
182 changes: 35 additions & 147 deletions test/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,7 @@
import torch

from attn_torch_function import attention

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)

"""
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
"""
attn_weight = query @ key.transpose(-2, -1) * scale_factor
SPARSE_HEAD_SINCE = 5
SPARSE_SEQ_SINCE = 5
# attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p > 0.0:
if dropout_mask is not None:
attn_weight.masked_fill_(dropout_mask.logical_not(), float("0.0"))
value = value / (1 - dropout_p)
else:
# assert False, "TESTING dropout_mask code path"
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
else:
# assert False, "TESTING dropout_mask code path"
pass
av = attn_weight @ value
return av, attn_weight
from _common_test import SdpaContext, SdpaParams

def _make_block_eyes(q, base=1.0, inc=0.0):
dhead = q.shape[-1]
Expand Down Expand Up @@ -69,138 +34,62 @@ def RP(x):
Note: In Flash V2 API the ... is denoted as "num_heads", serving as uniformly sized sequences
but in PyTorch API it does not present at all
'''
def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias: torch.Tensor, dtype: torch.dtype = None, device=None):
""" Clones the query, key, and value tensors and moves them to the specified dtype. """
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype=dtype, device=device).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype=dtype, device=device).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype=dtype, device=device).requires_grad_(value.requires_grad)
bias_ref = bias.clone().detach().to(dtype=dtype, device=device).requires_grad_(bias.requires_grad) if bias is not None else None
return query_ref, key_ref, value_ref, bias_ref

def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type):
if causal and seqlen_q != seqlen_k:
pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping")
if causal and bias_type is not None:
pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True")
# if BATCH > 1 and seqlen_q >= 1024 and seqlen_k >= 1024:
# torch.cuda.empty_cache()
SKIP_DK_DV = False
SKIP_DQ = False
SKIP_DB = True if bias_type is None else False
USE_AUTOTUNE = True
torch.manual_seed(20)
SPARSE_HEAD_SINCE = 1
SPARSE_SEQ_SINCE = 1
qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD)
kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD)
vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD)
bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k)
if storage_flip:
qdims = (qdims[0], qdims[2], qdims[1], qdims[3])
kdims = (kdims[0], kdims[2], kdims[1], kdims[3])
vdims = (vdims[0], vdims[2], vdims[1], vdims[3])
bdims = (bdims[0], bdims[2], bdims[1], bdims[3])
q = torch.empty(qdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5)
k = torch.empty(kdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5)
v = torch.empty(vdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5)
if bias_type is None:
b = None
elif bias_type == 'matrix':
b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5)
else:
assert False, f'Unsupported bias_type {bias_type}'
if storage_flip:
q = torch.transpose(q, 1, 2)
k = torch.transpose(k, 1, 2)
v = torch.transpose(v, 1, 2)
if b is not None:
b = torch.transpose(b, 1, 2)
if not SKIP_DQ:
q.requires_grad_()
if not SKIP_DK_DV:
k.requires_grad_()
v.requires_grad_()
if not SKIP_DB:
assert b is not None
b.requires_grad_()
transpose = (1, 2) if storage_flip else None
ctx = SdpaContext(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype,
bias_type=bias_type, storage_flip=transpose, device='cuda')
ctx.create_ref_inputs()
ctx.set_require_grads(skip_dq=SKIP_DQ, skip_dk_dv=SKIP_DK_DV, skip_db=SKIP_DB)
return_encoded_softmax = True
# q_ref_lp, k_ref_lp, v_ref_lp = query_key_value_clones(q, k, v, dtype=dtype)
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
# REF_DEVICE='cpu'
REF_DEVICE=None
q_ref, k_ref, v_ref, b_ref = query_key_value_clones(q, k, v, b, dtype=higher_precision_dtype, device=REF_DEVICE)
def TO(ref_tensor):
return ref_tensor.to(device=q.device, dtype=dtype)
q, k, v, b = ctx.dev_tensors
# autotune = True
# # triton implementation
tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, USE_AUTOTUNE)
dropout_mask = encoded_softmax >= 0
'''
ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v,
dropout_p=dropout_p,
is_causal=causal,
scale=sm_scale,
dropout_mask=dropout_mask)
'''
ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref,
dropout_p=dropout_p,
is_causal=causal,
attn_mask=b_ref,
scale=sm_scale,
dropout_mask=dropout_mask)
dout = torch.randn_like(q)
tri_out.backward(dout)
tri_dv, v.grad = None if SKIP_DK_DV else v.grad.clone(), None
tri_dk, k.grad = None if SKIP_DK_DV else k.grad.clone(), None
tri_dq, q.grad = None if SKIP_DQ else q.grad.clone(), None
if not SKIP_DB:
tri_db = b.grad.clone()
else:
tri_db = None
sdpa_params = SdpaParams(causal=causal, sm_scale=sm_scale, dropout_p=dropout_p, dropout_mask=dropout_mask)
ref_out, _ = ctx.compute_ref_forward(sdpa_params)

'''
ref_out.backward(dout, None)
ref_dv, v.grad = None if SKIP_DK_DV else v.grad.clone(), None
ref_dk, k.grad = None if SKIP_DK_DV else k.grad.clone(), None
ref_dq, q.grad = None if SKIP_DQ else q.grad.clone(), None
'''
ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype))
ref_dv, v_ref.grad = None if SKIP_DK_DV else v_ref.grad.clone(), None
ref_dk, k_ref.grad = None if SKIP_DK_DV else k_ref.grad.clone(), None
ref_dq, q_ref.grad = None if SKIP_DQ else q_ref.grad.clone(), None
if SKIP_DB:
ref_db = None
else:
ref_db, b_ref.grad = b_ref.grad.clone(), None
# compare
if dtype == torch.bfloat16:
ATOL = 1e-1 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 128.0)
else:
ATOL = 1e-2 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 128.0)
# RTOL=1e-2 if dtype==torch.float16 else 5e-2
RTOL=0.02
print(f'Forward Using ATOL={ATOL} RTOL={RTOL}')
# FIXME: Need to raise tolerance
'''
is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=RTOL)
'''
is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL)
dout = torch.rand_like(tri_out)
ctx.compute_backward(tri_out, dout)
is_allclose, adiff, grads_allclose, grads_adiff = ctx.validate_with_reference(tri_out, ctx.dout_tensors)
if not is_allclose:
import numpy as np
err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape)
print(f'{err_idx=}')
print(f'{tri_out[err_idx]=}')
print(f'{ref_out[err_idx]=}')
assert is_allclose, 'Forward pass {is_allclose=}'
if dtype == torch.bfloat16:
ATOL = 1e-1 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0)
elif dtype == torch.float32:
ATOL = 1e-3 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0)
else:
ATOL = 1e-2 * max(1.0, (RP(seqlen_q) + RP(seqlen_k) + RP(D_HEAD)) / 32.0)
print(f"Backward Using {ATOL=} {RTOL=}")

dv_allclose = SKIP_DK_DV or torch.allclose(TO(ref_dv), tri_dv, atol=ATOL, rtol=RTOL)
dq_allclose, dk_allclose, dv_allclose, db_allclose = grads_allclose
tri_dq, tri_dk, tri_dv, tri_db = ctx.dout_tensors
ref_dq, ref_dk, ref_dv, ref_db = ctx.dref_tensors
if not SKIP_DQ:
assert tri_dq is not None
assert ref_dq is not None
if not SKIP_DK_DV:
assert tri_dk is not None
assert tri_dv is not None
assert ref_dk is not None
assert ref_dv is not None
if not SKIP_DB:
assert tri_db is not None
assert ref_db is not None
def TO(ref_tensor):
return ref_tensor.to(device=q.device, dtype=dtype)
if not dv_allclose:
import numpy as np
err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_dv) - tri_dv)).cpu().numpy(), ref_dv.shape)
Expand Down Expand Up @@ -236,7 +125,6 @@ def TO(ref_tensor):
# print(f'{tri_dq[0,0]=}')
# print(f'{ref_dq[0,0]=}')

dk_allclose = SKIP_DK_DV or torch.allclose(TO(ref_dk), tri_dk, atol=ATOL, rtol=RTOL)
if dv_allclose and not dk_allclose:
print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}')
print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}')
Expand All @@ -249,20 +137,19 @@ def TO(ref_tensor):
print(f'{tri_dk[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_dk[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}')
print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}')

dq_allclose = SKIP_DQ or torch.allclose(TO(ref_dq), tri_dq, atol=ATOL, rtol=RTOL)
if dk_allclose and dv_allclose and not dq_allclose:
import numpy as np
err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_dq) - tri_dq)).cpu().numpy(), ref_dq.shape)
print(f'{err_idx=}')
print(f'{tri_dq[err_idx]=} {ref_dq[err_idx]=} error = {torch.abs(tri_dq[err_idx] - ref_dq[err_idx])}')

db_allclose = SKIP_DB or torch.allclose(TO(ref_db), tri_db, atol=ATOL, rtol=RTOL)
if dk_allclose and dv_allclose and dq_allclose and not db_allclose:
import numpy as np
err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_db) - tri_db)).cpu().numpy(), ref_db.shape)
print(f'{err_idx=}')
print(f'{tri_db[err_idx]=} {ref_db[err_idx]=} error = {torch.abs(tri_db[err_idx] - ref_db[err_idx])}')
assert dk_allclose and dv_allclose and dq_allclose and db_allclose, f'{dk_allclose=} {dv_allclose=} {dq_allclose=} {db_allclose=}'
print(f'{adiff=} {grads_adiff=}')

# @pytest.mark.parametrize('BATCH', [1])
# @pytest.mark.parametrize('N_HEADS', [1])
Expand All @@ -283,7 +170,8 @@ def TO(ref_tensor):
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0, 0.5])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32])
# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('sm_scale', [0.0, 1.2])
@pytest.mark.parametrize('storage_flip', [False, True])
# @pytest.mark.parametrize('return_encoded_softmax', [False])
Expand All @@ -309,8 +197,8 @@ def test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr
# @pytest.mark.parametrize('seqlen_k', [128, 79])
@pytest.mark.parametrize('dropout_p', [0.0, 0.5])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32])
# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('sm_scale', [0.0, 1.2])
@pytest.mark.parametrize('storage_flip', [False, True])
# @pytest.mark.parametrize('return_encoded_softmax', [False])
Expand Down
Loading

0 comments on commit 00ccbf3

Please sign in to comment.