Skip to content

Commit

Permalink
Fix the performance regression due to support of irregular shapes. (#10)
Browse files Browse the repository at this point in the history
This is a squashed changeset. See branch `xinyazhang/perf-irregular`
for the breakdown of performance impacts on MI210 for each added
irregular shape support.
  • Loading branch information
xinyazhang committed Mar 19, 2024
1 parent 0168ad8 commit e537881
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 241 deletions.
31 changes: 23 additions & 8 deletions test/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask
print(f'{av[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}')
return av, attn_weight

def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None, device=None):
""" Clones the query, key, and value tensors and moves them to the specified dtype. """
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype=dtype, device=device).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype=dtype, device=device).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype=dtype, device=device).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref

'''
Flash Attention is batch operator that evaluates sm(QK')V
Q = batch_size x ... x seqlen_q x head_size
Expand All @@ -72,16 +81,16 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask
@pytest.mark.parametrize('D_HEAD', [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256])
# @pytest.mark.parametrize('seqlen_q', [16,32,64,128,256,512,1024])
# @pytest.mark.parametrize('seqlen_k', [16,32,64,128,256,512,1024])
@pytest.mark.parametrize('seqlen_q', [4,8,16,17,32,64,128,143,256,512,1024,2048])
@pytest.mark.parametrize('seqlen_k', [4,8,16,23,32,64,128,256,512,587,1024,2048])
@pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048])
@pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048])
# @pytest.mark.parametrize('seqlen_q', [32, 128])
# @pytest.mark.parametrize('seqlen_k', [32, 128])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0, 0.5])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('sm_scale', [0.0, 1.2])
@pytest.mark.parametrize('storage_flip', [True, False])
@pytest.mark.parametrize('storage_flip', [False, True])
# @pytest.mark.parametrize('return_encoded_softmax', [False])
def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip):
if causal and seqlen_q != seqlen_k:
Expand Down Expand Up @@ -160,11 +169,16 @@ def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr
ref_out = torch.matmul(p, v)
'''
return_encoded_softmax = dropout_p > 0.0
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
REF_DEVICE=None
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v, dtype=higher_precision_dtype, device=REF_DEVICE)
def TO(ref_tensor):
return ref_tensor.to(device=q.device, dtype=dtype)
tri_out, encoded_softmax, _ = attention(q, k, v, causal, sm_scale, dropout_p, return_encoded_softmax)

dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None
# assert torch.allclose(dropout_mask, dropout_mask_naive)
ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v,
ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref,
dropout_p=dropout_p,
is_causal=causal,
scale=sm_scale,
Expand Down Expand Up @@ -192,17 +206,18 @@ def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr
ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0)
else:
ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0)
print(f'Using ATOL={ATOL}')
is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=0)
RTOL = 0.0
print(f'Using ATOL={ATOL} RTOL={RTOL}')
is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL)
if not is_allclose:
import numpy as np
err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape)
err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape)
print(f'{err_idx=}')
print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}')
# if not is_allclose:
if False:
import numpy as np
err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape)
err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape)
print(f'{tri_out[0][0][0][:]=}')
print(f'{ref_out[0][0][0][:]=}')
print(f'{mref_out[0][0][0][:]=}')
Expand Down
Loading

0 comments on commit e537881

Please sign in to comment.