Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support varlen FlashAttention #14

Merged
merged 10 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 158 additions & 176 deletions test/test_flash_attention_backward.py
Original file line number Diff line number Diff line change
@@ -1,179 +1,161 @@
import sys
import unittest
import os
import pytest

import numpy as np
import torch
import flash_attn_2_cuda as flash_attn_cuda

import torch_xla
import torch_xla.core.xla_model as xm

BATCH_SIZE = 4
SEQ_LEN = 256
DIMS = 32
N_HEADS = 8
DROPOUT = 0.8
SOFTMAX_SCALE = 0.25
ZERO_TENSORS = False
IS_CAUSAL = True
NUM_SPLITS = 0
GEN = None
WINDOW_SIZE = (-1, -1)
DETERMINISTIC = False


class FlashAttentionBackwardTest(unittest.TestCase):

def _backward_internal(self,
tensor_dtype,
n_heads_kv=N_HEADS,
enable_alibi_slopes=False):
# original flash attention
torch.manual_seed(101)
device = 'cuda:0'
q_cuda = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * N_HEADS * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, N_HEADS,
DIMS).to(tensor_dtype)
dq_cuda = torch.zeros_like(q_cuda)
k_cuda = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * n_heads_kv * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, n_heads_kv,
DIMS).to(tensor_dtype)
dk_cuda = torch.zeros_like(k_cuda)
v_cuda = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * n_heads_kv * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, n_heads_kv,
DIMS).to(tensor_dtype)
dv_cuda = torch.zeros_like(v_cuda)
out_cuda = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * N_HEADS * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, N_HEADS,
DIMS).to(tensor_dtype)
dout_cuda = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * N_HEADS * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, N_HEADS,
DIMS).to(tensor_dtype)
softmax_lse_cuda = torch.linspace(
5, 6, BATCH_SIZE * N_HEADS * SEQ_LEN,
device=device).reshape(BATCH_SIZE, N_HEADS, SEQ_LEN).to(torch.float32)
cu_seqlens_q_cuda = torch.arange(
0, (BATCH_SIZE + 1) * SEQ_LEN,
step=SEQ_LEN,
dtype=torch.int32,
device=q_cuda.device)
cu_seqlens_k_cuda = cu_seqlens_q_cuda
alibi_slopes_cuda = torch.linspace(
-0.5, 0.5, N_HEADS, device=device,
dtype=torch.float32) * 0.3 if enable_alibi_slopes else None
rng_state_cuda = torch.Tensor([101, 102]).to(torch.int64).to(device)
dq_cuda, dk_cuda, dv_cuda, softmax_d_cuda = flash_attn_cuda.varlen_bwd(
dout_cuda, q_cuda, k_cuda, v_cuda, out_cuda, softmax_lse_cuda, dq_cuda,
dk_cuda, dv_cuda, cu_seqlens_q_cuda, cu_seqlens_k_cuda,
alibi_slopes_cuda, SEQ_LEN, SEQ_LEN, DROPOUT, SOFTMAX_SCALE,
ZERO_TENSORS, IS_CAUSAL, WINDOW_SIZE[0], WINDOW_SIZE[1], DETERMINISTIC,
GEN, rng_state_cuda) # must set rng_state when dropout > 0

# TorchXLA flash attention
torch.manual_seed(101)
device = torch_xla.core.xla_model.xla_device()
q_xla = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * N_HEADS * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, N_HEADS,
DIMS).to(tensor_dtype)
dq_xla = torch.zeros_like(q_xla)
k_xla = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * n_heads_kv * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, n_heads_kv,
DIMS).to(tensor_dtype)
dk_xla = torch.zeros_like(k_xla)
v_xla = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * n_heads_kv * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, n_heads_kv,
DIMS).to(tensor_dtype)
dv_xla = torch.zeros_like(v_xla)
out_xla = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * N_HEADS * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, N_HEADS,
DIMS).to(tensor_dtype)
dout_xla = torch.linspace(
-0.5, 0.5, SEQ_LEN * BATCH_SIZE * N_HEADS * DIMS,
device=device).reshape(SEQ_LEN * BATCH_SIZE, N_HEADS,
DIMS).to(tensor_dtype)
softmax_lse_xla = torch.linspace(
5, 6, BATCH_SIZE * N_HEADS * SEQ_LEN,
device=device).reshape(BATCH_SIZE, N_HEADS, SEQ_LEN).to(torch.float32)
cu_seqlens_q_xla = torch.arange(
0, (BATCH_SIZE + 1) * SEQ_LEN,
step=SEQ_LEN,
dtype=torch.int32,
device=q_xla.device)
cu_seqlens_k_xla = cu_seqlens_q_xla
alibi_slopes_xla = torch.linspace(
-0.5, 0.5, N_HEADS, device=device,
dtype=torch.float32) * 0.3 if enable_alibi_slopes else None
rng_state_xla = torch.Tensor([101, 102]).to(torch.int64).to(device)
dq_xla, dk_xla, dv_xla, softmax_d_xla = torch_xla._XLAC._flash_attention_backward(
dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla,
cu_seqlens_q_xla, cu_seqlens_k_xla, alibi_slopes_xla, SEQ_LEN, SEQ_LEN,
DROPOUT, SOFTMAX_SCALE, ZERO_TENSORS, IS_CAUSAL, WINDOW_SIZE[0],
WINDOW_SIZE[1], DETERMINISTIC, GEN, rng_state_xla)
xm.mark_step()
torch.cuda.synchronize()

assert torch.allclose(
softmax_d_cuda.cpu().detach(),
softmax_d_xla.cpu().detach(),
rtol=1e-3,
atol=1e-3)

# TODO(wenting.swt): alibi_slopes cause Nan and -inf both for cuda and xla in this case,
# we just compare the non Nan nor -inf elements
assert torch.allclose(
dq_cuda.cpu().detach()[0:int(BATCH_SIZE // 2)],
dq_xla.cpu().detach()[0:int(BATCH_SIZE // 2)],
rtol=1e-3,
atol=1e-3)
assert torch.allclose(
dk_cuda.cpu().detach()[0:int(BATCH_SIZE // 2)],
dk_xla.cpu().detach()[0:int(BATCH_SIZE // 2)],
rtol=1e-3,
atol=1e-3)
assert torch.allclose(
dv_cuda.cpu().detach()[0:int(BATCH_SIZE // 2)],
dv_xla.cpu().detach()[0:int(BATCH_SIZE // 2)],
rtol=1e-3,
atol=1e-3)

def test_flash_attn_gqa_backward_fp16(self):
self._backward_internal(torch.float16, n_heads_kv=int(N_HEADS // 2))

def test_flash_attn_gqa_backward_bf16(self):
self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2))

def test_flash_attn_backward_fp16(self):
self._backward_internal(torch.float16, n_heads_kv=N_HEADS)

def test_flash_attn_backward_bf16(self):
self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS)

def test_flash_attn_gqa_backward_fp16_alibi(self):
self._backward_internal(
torch.float16, n_heads_kv=int(N_HEADS // 2), enable_alibi_slopes=True)

def test_flash_attn_gqa_backward_bf16_alibi(self):
self._backward_internal(
torch.bfloat16, n_heads_kv=int(N_HEADS // 2), enable_alibi_slopes=True)

def test_flash_attn_backward_fp16_alibi(self):
self._backward_internal(
torch.float16, n_heads_kv=N_HEADS, enable_alibi_slopes=True)

def test_flash_attn_backward_bf16_alibi(self):
self._backward_internal(
torch.bfloat16, n_heads_kv=N_HEADS, enable_alibi_slopes=True)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)

from flash_attn import flash_attn_func
import flash_attn_2_cuda as flash_attn_cuda
import torchacc as ta


@pytest.fixture(autouse=True, scope="module")
def setup_env():
orign_env = os.getenv('PJRT_ALLOCATOR_FRACTION')
os.environ['PJRT_ALLOCATOR_FRACTION'] = '0.5'
yield
if orign_env is None:
os.environ.pop('PJRT_ALLOCATOR_FRACTION', None)
else:
os.environ['PJRT_ALLOCATOR_FRACTION'] = orign_env


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [32])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(128, 113),
(128, 128),
(256, 256),
],
)
@pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_backward(seqlen_q, seqlen_k, d, dropout_p, causal, local,
alibi, deterministic, mha_type, dtype):
if d % 8 != 0:
pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")

device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)

assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else tuple(
torch.randint(0, seqlen_k, (2,)).tolist())
torch.cuda.synchronize()
q = torch.randn(
batch_size,
seqlen_q,
nheads,
d,
device=device,
dtype=dtype,
requires_grad=True)
softmax_scale = q.shape[-1]**(-0.5)
k = torch.randn(
batch_size,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype,
requires_grad=True)
v = torch.randn(
batch_size,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype,
requires_grad=True)
do = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)

rng_state = torch.Tensor([0, 0]).to(torch.int64).to(device)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)

if alibi:
alibi_slopes = torch.rand(
batch_size, nheads, device=device, dtype=torch.float32) * 0.3
else:
alibi_slopes = None

o, softmax_lse, _ = flash_attn_func(
q,
k,
v,
dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

dq, dk, dv, softmax_d = flash_attn_cuda.bwd(do, q, k, v, o, softmax_lse, dq,
dk, dv, alibi_slopes, dropout_p,
softmax_scale, causal,
window_size[0], window_size[1],
deterministic, None, rng_state)

torch.random.manual_seed(0)
q = q.cpu().detach()
k = k.cpu().detach()
v = v.cpu().detach()
o = o.cpu().detach()
do = do.cpu().detach()
rng_state = rng_state.cpu().detach()
softmax_lse = softmax_lse.cpu().detach()

dq = dq.cpu().detach()
dk = dk.cpu().detach()
dv = dv.cpu().detach()
softmax_d = softmax_d.cpu().detach()
torch.cuda.synchronize()

device = ta.lazy_device()
torch.random.manual_seed(0)
q_xla = q.to(device)
k_xla = k.to(device)
v_xla = v.to(device)
do_xla = do.to(device)

softmax_d_xla = softmax_d.to(device)
q_xla.requires_grad = True
k_xla.requires_grad = True
v_xla.requires_grad = True
if alibi:
alibi_slopes = alibi_slopes.cpu().to(device)
softmax_lse_xla, o_xla, rng_state_xla = torch_xla._XLAC._flash_attention_forward(
q_xla, k_xla, v_xla, None, alibi_slopes, dropout_p, softmax_scale, False,
causal, window_size[0], window_size[1], True, None)

dq_xla, dk_xla, dv_xla, softmax_d_xla = torch_xla._XLAC._flash_attention_backward(
do_xla, q_xla, k_xla, v_xla, o_xla, softmax_lse_xla, None, None,
alibi_slopes, dropout_p, softmax_scale, False, causal, window_size[0],
window_size[1], deterministic, None, rng_state_xla)

ta.mark_step(wait=True)
torch.cuda.synchronize()

dq_xla = dq_xla.cpu().detach()
dk_xla = dk_xla.cpu().detach()
dv_xla = dv_xla.cpu().detach()
softmax_d_xla = softmax_d_xla.cpu().detach()

assert torch.allclose(dq, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dk, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dv, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(
softmax_d, softmax_d_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
Loading
Loading