Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 76 additions & 41 deletions tests/kernels/attention/test_flashinfer_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
BLOCK_SIZE = [16]
WINDOW_LEFT = [-1, 127]
SOFT_CAP = [None, 50.0]
HAS_SINKS = [True, False]

NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.

Expand All @@ -63,6 +64,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
Expand All @@ -77,9 +79,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
block_size: int,
window_left: int,
soft_cap: Optional[float],
has_sinks: bool,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
current_platform.seed_everything(42)

q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
Expand All @@ -101,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline(
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")

query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
# max_q_len = 1
q_lens = torch.ones((batch_size,), dtype=torch.int32)
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)

query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
Expand All @@ -112,7 +124,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len

seq_lens = kv_lens
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()

kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
Expand Down Expand Up @@ -148,27 +160,36 @@ def test_flashinfer_trtllm_decode_with_baseline(
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)

# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True
)
if has_sinks:
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)
else:
sinks = None
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)

wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
qo_indptr=q_indptr,
paged_kv_indptr=kv_indptr,
paged_kv_indices=kv_indices,
paged_kv_last_page_len=kv_last_page_lens,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_size,
page_size=block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
q_data_type=dtype,
kv_data_type=dtype,
)

output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)

o_scale = 1.0
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
Expand Down Expand Up @@ -202,6 +223,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
sinks=sinks,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
Expand All @@ -217,11 +239,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])

if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
rtol, atol = 7e-2, 9e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 2e-2, 4e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 1e-2, 2e-2
else:
rtol, atol = 1e-2, 1e-2

(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
Expand All @@ -239,6 +263,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
Expand All @@ -253,9 +278,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
block_size: int,
window_left: int,
soft_cap: Optional[float],
has_sinks: bool,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
current_platform.seed_everything(42)

q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
Expand Down Expand Up @@ -297,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
q_scale = 1.0
ref_query = query

kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len

seq_lens = kv_lens + q_lens
Expand Down Expand Up @@ -336,28 +362,36 @@ def test_flashinfer_trtllm_prefill_with_baseline(
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)

# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
if has_sinks:
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)
else:
sinks = None
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)

wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
qo_indptr=q_indptr,
paged_kv_indptr=kv_indptr,
paged_kv_indices=kv_indices,
paged_kv_last_page_len=kv_last_page_lens,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_size,
page_size=block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
q_data_type=dtype,
kv_data_type=dtype,
)

output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)

o_scale = 1.0
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
Expand Down Expand Up @@ -395,6 +429,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
sinks=sinks,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
Expand All @@ -410,11 +445,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])

if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
rtol, atol = 1e-1, 2e-1
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 4e-2, 6e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 2e-2, 3e-2
else:
rtol, atol = 1e-2, 1e-2

Expand Down
5 changes: 0 additions & 5 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,6 @@ def use_trtllm_attention(

# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
if has_sinks:
raise RuntimeError(
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
"Use kv_cache_dtype=auto for now."
)
logger.info_once("Using TRTLLM attention (query is quantized).")
return True

Expand Down