Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flashinfer fp8 prefill has no speed up than fp16 in L20 #914

Open
yongchaoding opened this issue Mar 6, 2025 · 0 comments
Open

flashinfer fp8 prefill has no speed up than fp16 in L20 #914

yongchaoding opened this issue Mar 6, 2025 · 0 comments

Comments

@yongchaoding
Copy link

I try to use fp8 prefill attention kernel and i found that in L20, there has no speed up than fp16, even slow. is it correct?

from flash_attn.utils.benchmark import benchmark_forward
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
import flashinfer

import argparse

parser = argparse.ArgumentParser(description='Benchmark FlashInfer')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--num_kv_heads', type=int, default=32, help='Number of heads')
parser.add_argument('--num_qo_heads', type=int, default=32, help='Number of heads')
parser.add_argument('--head_dim', type=int, default=128, help='Head dimension')
args = parser.parse_args()

qo_head = args.num_qo_heads
kv_head = args.num_kv_heads
batch = args.batch_size
headdim = args.head_dim

print(f"FlashInfer Benchmark")
print(f"batch: {batch}, qo_head: {qo_head}, kv_head: {kv_head}, headdim: {headdim}")

kv_layout = "NHD"
workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
    workspace_buffer, kv_layout
)

is_causal = False
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
    q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda")
    k = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")
    v = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")

    q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    wrapper.plan(
            q_indptr,
            kv_indptr,
            qo_head,
            kv_head,
            headdim,
            causal=is_causal,
        )
    o = wrapper.run(q, k, v)
    for i in range(5): wrapper.run(q, k, v)
    torch.cuda.synchronize()
    _, time = benchmark_forward(wrapper.run, q, k, v, repeats=100, verbose=False, desc='Flashinfer')
    print(f'{seq_len} flops:{flops/time.mean*1e-12}')


is_causal = True
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
    q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda")
    k = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")
    v = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")

    q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    wrapper.plan(
            q_indptr,
            kv_indptr,
            qo_head,
            kv_head,
            headdim,
            causal=is_causal,
        )
    o = wrapper.run(q, k, v)
    for i in range(5): wrapper.run(q, k, v)
    torch.cuda.synchronize()
    _, time = benchmark_forward(wrapper.run, q, k, v, repeats=100, verbose=False, desc='Flashinfer')
    print(f'{seq_len} flops:{flops/time.mean*1e-12}')

kv_layout = "NHD"
workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, kv_layout
)

is_causal = True
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
    q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda")
    k = torch.randn(batch, seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")
    v = torch.randn(batch, seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")

    dtype = torch.float8_e5m2
    k_scale = k.amax().item() / 256
    v_scale = v.amax().item() / 256

    k_fp8 = (k / k_scale).to(dtype).transpose(0, 1)
    v_fp8 = (v / v_scale).to(dtype).transpose(0, 1)

    page_size = 1
    q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indices = torch.arange(0, seq_len).to(0).int()
    kv_last_page_len = torch.full(
        (batch,), (seq_len - 1) % page_size + 1, dtype=torch.int32
    ).to(0)

    wrapper.plan(
            q_indptr,
            kv_indptr,
            kv_indices,
            kv_last_page_len,
            qo_head,
            kv_head,
            headdim,
            page_size,
            causal=is_causal,
            q_data_type=torch.float16,
            kv_data_type=dtype,
            use_fp16_qk_reduction=False,
        )
    o = wrapper.run(q, (k_fp8, v_fp8), k_scale, v_scale)
    for i in range(5): wrapper.run(q, (k_fp8, v_fp8), k_scale, v_scale)
    torch.cuda.synchronize()
    _, time = benchmark_forward(wrapper.run, q, (k_fp8, v_fp8), k_scale, v_scale, repeats=100, verbose=False, desc='Flashinfer FP8')
    print(f'fp8 {seq_len} flops:{flops/time.mean*1e-12}')```

The Result is as follows:

is_causal: False
1024 flops:91.40299148402637
2048 flops:107.07696479819627
4096 flops:108.43416398791643
8192 flops:109.00693589495772
16384 flops:109.32962903647683
32768 flops:109.45145192913175

is_causal: True
1024 flops:68.55602643189259
2048 flops:81.8157771148776
4096 flops:95.28764360811904
8192 flops:102.80001359824226
16384 flops:106.81962644318786
32768 flops:108.20337146264986

is_causal: True
fp8 1024 flops:60.862217383540795
fp8 2048 flops:77.45613669425248
fp8 4096 flops:90.72903913564886
fp8 8192 flops:98.91800674555174
fp8 16384 flops:101.95377017553793
fp8 32768 flops:103.3322699046888
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant