We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I try to use fp8 prefill attention kernel and i found that in L20, there has no speed up than fp16, even slow. is it correct?
from flash_attn.utils.benchmark import benchmark_forward from flash_attn_interface import flash_attn_func as flash_attn_func_v3 import flashinfer import argparse parser = argparse.ArgumentParser(description='Benchmark FlashInfer') parser.add_argument('--batch_size', type=int, default=4, help='Batch size') parser.add_argument('--num_kv_heads', type=int, default=32, help='Number of heads') parser.add_argument('--num_qo_heads', type=int, default=32, help='Number of heads') parser.add_argument('--head_dim', type=int, default=128, help='Head dimension') args = parser.parse_args() qo_head = args.num_qo_heads kv_head = args.num_kv_heads batch = args.batch_size headdim = args.head_dim print(f"FlashInfer Benchmark") print(f"batch: {batch}, qo_head: {qo_head}, kv_head: {kv_head}, headdim: {headdim}") kv_layout = "NHD" workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) is_causal = False print(f"is_causal: {is_causal}") for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}: flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda") k = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda") v = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda") q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len wrapper.plan( q_indptr, kv_indptr, qo_head, kv_head, headdim, causal=is_causal, ) o = wrapper.run(q, k, v) for i in range(5): wrapper.run(q, k, v) torch.cuda.synchronize() _, time = benchmark_forward(wrapper.run, q, k, v, repeats=100, verbose=False, desc='Flashinfer') print(f'{seq_len} flops:{flops/time.mean*1e-12}') is_causal = True print(f"is_causal: {is_causal}") for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}: flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda") k = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda") v = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda") q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len wrapper.plan( q_indptr, kv_indptr, qo_head, kv_head, headdim, causal=is_causal, ) o = wrapper.run(q, k, v) for i in range(5): wrapper.run(q, k, v) torch.cuda.synchronize() _, time = benchmark_forward(wrapper.run, q, k, v, repeats=100, verbose=False, desc='Flashinfer') print(f'{seq_len} flops:{flops/time.mean*1e-12}') kv_layout = "NHD" workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) is_causal = True print(f"is_causal: {is_causal}") for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}: flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda") k = torch.randn(batch, seq_len, kv_head, headdim, dtype=torch.float16, device="cuda") v = torch.randn(batch, seq_len, kv_head, headdim, dtype=torch.float16, device="cuda") dtype = torch.float8_e5m2 k_scale = k.amax().item() / 256 v_scale = v.amax().item() / 256 k_fp8 = (k / k_scale).to(dtype).transpose(0, 1) v_fp8 = (v / v_scale).to(dtype).transpose(0, 1) page_size = 1 q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len kv_indices = torch.arange(0, seq_len).to(0).int() kv_last_page_len = torch.full( (batch,), (seq_len - 1) % page_size + 1, dtype=torch.int32 ).to(0) wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_last_page_len, qo_head, kv_head, headdim, page_size, causal=is_causal, q_data_type=torch.float16, kv_data_type=dtype, use_fp16_qk_reduction=False, ) o = wrapper.run(q, (k_fp8, v_fp8), k_scale, v_scale) for i in range(5): wrapper.run(q, (k_fp8, v_fp8), k_scale, v_scale) torch.cuda.synchronize() _, time = benchmark_forward(wrapper.run, q, (k_fp8, v_fp8), k_scale, v_scale, repeats=100, verbose=False, desc='Flashinfer FP8') print(f'fp8 {seq_len} flops:{flops/time.mean*1e-12}')```
The Result is as follows:
is_causal: False 1024 flops:91.40299148402637 2048 flops:107.07696479819627 4096 flops:108.43416398791643 8192 flops:109.00693589495772 16384 flops:109.32962903647683 32768 flops:109.45145192913175 is_causal: True 1024 flops:68.55602643189259 2048 flops:81.8157771148776 4096 flops:95.28764360811904 8192 flops:102.80001359824226 16384 flops:106.81962644318786 32768 flops:108.20337146264986 is_causal: True fp8 1024 flops:60.862217383540795 fp8 2048 flops:77.45613669425248 fp8 4096 flops:90.72903913564886 fp8 8192 flops:98.91800674555174 fp8 16384 flops:101.95377017553793 fp8 32768 flops:103.3322699046888
The text was updated successfully, but these errors were encountered:
No branches or pull requests
I try to use fp8 prefill attention kernel and i found that in L20, there has no speed up than fp16, even slow. is it correct?
The Result is as follows:
The text was updated successfully, but these errors were encountered: