Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Implement Cascade Attention #11635

Merged
merged 23 commits into from
Jan 1, 2025
Merged
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
182 changes: 182 additions & 0 deletions tests/kernels/test_cascade_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import List, Optional, Tuple

import pytest
import torch

from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states)
from vllm.vllm_flash_attn import flash_attn_varlen_func

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.float16, torch.bfloat16]


@pytest.mark.parametrize("num_tokens", [1, 39, 16912])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_merge_kernel(
num_tokens: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0

# Prepare inputs.
prefix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
suffix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)

# Run the kernel.
output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype)
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)

# Reference implementation.
max_lse = torch.maximum(prefix_lse, suffix_lse)
p_lse = torch.exp(prefix_lse - max_lse)
s_lse = torch.exp(suffix_lse - max_lse)
p_scale = p_lse / (p_lse + s_lse)
s_scale = s_lse / (p_lse + s_lse)
p_scale = p_scale.transpose(0, 1).unsqueeze(2)
s_scale = s_scale.transpose(0, 1).unsqueeze(2)
ref_output = p_scale * prefix_output + s_scale * suffix_output
ref_output = ref_output.to(dtype)

# Compare the results.
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)


CASES = [
# Case 1. A general case.
([(129, 871), (18, 280), (37, 988), (1023, 2304), (1, 257)], 256),
# Case 2. Flash-decoding case.
([(1, 1023), (1, 879), (1, 778), (1, 1777)] * 100, 512),
]


@pytest.mark.parametrize("seq_lens_and_common_prefix", CASES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)

window_size = (-1, -1)
scale = head_size**-0.5
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)

seq_lens, common_prefix_len = seq_lens_and_common_prefix
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)

total_num_query_tokens = sum(query_lens)
query = torch.randn(total_num_query_tokens,
num_query_heads,
head_size,
dtype=dtype)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

assert common_prefix_len > 0
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
# Make sure the first `num_common_kv_blocks` blocks are the same.
block_tables[:, :num_common_kv_blocks] = \
block_tables[0, :num_common_kv_blocks]

# Run the regular attention.
ref_output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)

# Run cascade attention.
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
cu_suffix_kv_lens = (
cu_kv_lens -
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
output = torch.empty_like(query)
cascade_attention(
output=output,
query=query,
key_cache=key_cache,
value_cache=value_cache,
cu_query_lens=cu_query_lens,
max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens,
max_kv_len=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
sliding_window=window_size,
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
)

# Compare the results.
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
Loading
Loading