[Optimize] Token parallel in paged attention kernel#65
Open
zip95297 wants to merge 1 commit intoWenyueh:mainfrom
Open
[Optimize] Token parallel in paged attention kernel#65zip95297 wants to merge 1 commit intoWenyueh:mainfrom
zip95297 wants to merge 1 commit intoWenyueh:mainfrom
Conversation
77z-zhou
reviewed
Mar 15, 2026
| block_idx = (token_start) // block_size | ||
| physical_block_idx = tl.load(block_tables_ptr + batch_idx * max_num_blocks + block_idx) | ||
| if physical_block_idx!=-1 : | ||
| # 物理块地址中读出BLOCK_M token的kv |
77z-zhou
reviewed
Mar 15, 2026
| # Apply mask to invalid positions | ||
| qk = tl.where(mask_n, qk, -1e10) | ||
| # Load K for each position and compute scores | ||
| block_idx = (token_start) // block_size |
Contributor
There was a problem hiding this comment.
您好, 我认为这一块应该需要load BLOCK_N physical_blocks, 代码如下:
block_nums = offs_n // block_size
block_offsets = offs_n % block_size
physical_blocks = tl.load(block_table_ptr + block_nums, mask=mask_n, other=-1) # (BLOCK_N)
k_offsets = (
physical_blocks[:, None] * block_size * num_kv_heads * head_dim +
block_offsets[:, None] * num_kv_heads * head_dim +
kv_head_idx * head_dim +
offs_d[None, :]
) # (BLOCK_N, head_dim)
k = tl.load(k_cache_ptr + k_offsets, mask=(physical_blocks[:, None] != -1) & mask_n[:, None], other=0.0)
Contributor
Author
There was a problem hiding this comment.
谢谢提醒,常规设置的blocksize 256 一般是大于 BLOCKN所以没有加上这个逻辑。
Owner
|
Thanks for this cool implementation! Could you provide relevant benchmark + result for comparison? Using the current benchmark for decoding, I'm not seeing the big change yet. |
Contributor
Author
|
the following scripts is the benchmark I use: import os
import time
import numpy as np
import argparse
from random import randint, seed
from tqdm.auto import tqdm
from nanovllm import LLM, SamplingParams
# --- Constants ---
MODEL_PATH = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
MAX_INPUT_LEN = 1024
MAX_OUTPUT_LEN = 1024
# --- Seed for reproducibility ---
seed(0)
np.random.seed(0)
class RequestMetrics:
"""Stores metrics for a single request."""
def __init__(self, request_id, input_len, max_output_len):
self.request_id = request_id
self.input_len = input_len
self.max_output_len = max_output_len
self.submission_time = -1
self.first_token_time = -1
self.completion_time = -1
self.output_len = -1
def record_submission(self):
self.submission_time = time.perf_counter()
def record_first_token(self):
if self.first_token_time == -1:
self.first_token_time = time.perf_counter()
def record_completion(self, output_ids):
self.completion_time = time.perf_counter()
self.output_len = len(output_ids)
@property
def ttft(self):
return self.first_token_time - self.submission_time
@property
def tpot(self):
if self.output_len > 1:
return (self.completion_time - self.first_token_time) / (self.output_len - 1)
return float('nan')
@property
def latency(self):
return self.completion_time - self.submission_time
def main():
"""Main function to run the serving benchmark."""
parser = argparse.ArgumentParser(description="Serving benchmark for nano-vllm.")
parser.add_argument("--num-requests", type=int, default=64, help="Number of requests to process.")
parser.add_argument("--request-rate", type=int, default=8, help="Request rate (requests per second).")
args = parser.parse_args()
NUM_REQUESTS = args.num_requests
REQUEST_RATE = args.request_rate
print(f"\n--- Running benchmark with --num-requests {NUM_REQUESTS} --request-rate {REQUEST_RATE} ---")
llm = LLM(MODEL_PATH, enforce_eager=False, max_model_len=4096, tensor_parallel_size=4)
engine = llm
# --- Generate random prompts ---
prompts = [[randint(0, 10000) for _ in range(randint(100, MAX_INPUT_LEN))] for _ in range(NUM_REQUESTS)]
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, MAX_OUTPUT_LEN)) for _ in range(NUM_REQUESTS)]
# --- Generate request arrival times ---
request_intervals = np.random.poisson(1.0 / REQUEST_RATE, NUM_REQUESTS)
arrival_times = np.cumsum(request_intervals)
# --- Benchmark loop ---
metrics = {}
requests_sent = 0
start_time = time.perf_counter()
completed_latencies = []
with tqdm(total=NUM_REQUESTS, desc="Processing Requests") as pbar:
while requests_sent < NUM_REQUESTS or not engine.is_finished():
# --- Send new requests ---
current_time = time.perf_counter()
while requests_sent < NUM_REQUESTS and current_time - start_time >= arrival_times[requests_sent]:
prompt = prompts[requests_sent]
sp = sampling_params[requests_sent]
engine.add_request(prompt, sp)
new_seq = engine.scheduler.waiting[-1]
seq_id = new_seq.seq_id
req_metrics = RequestMetrics(seq_id, len(prompt), sp.max_tokens)
req_metrics.record_submission()
metrics[seq_id] = req_metrics
requests_sent += 1
# --- Engine step ---
if engine.scheduler.waiting or engine.scheduler.running:
finished_outputs, _ = engine.step()
# Record first token time for all processed sequences
all_processed_seqs = list(engine.scheduler.running)
for seq in all_processed_seqs:
if seq.seq_id in metrics:
metrics[seq.seq_id].record_first_token()
for seq_id, output_ids in finished_outputs:
if seq_id in metrics:
metrics[seq_id].record_first_token() # Ensure first token time is recorded
metrics[seq_id].record_completion(output_ids)
completed_latencies.append(metrics[seq_id].latency)
avg_latency = np.mean(completed_latencies)
pbar.set_postfix({"Avg Latency": f"{avg_latency:.2f}s"})
pbar.update(1)
else:
# If no requests are running or waiting, sleep briefly
time.sleep(0.01)
end_time = time.perf_counter()
total_time = end_time - start_time
# --- Calculate and print metrics ---
total_input_tokens = sum(m.input_len for m in metrics.values())
total_output_tokens = sum(m.output_len for m in metrics.values() if m.output_len != -1)
avg_ttft = np.mean([m.ttft for m in metrics.values() if m.first_token_time != -1])
avg_tpot = np.mean([m.tpot for m in metrics.values() if not np.isnan(m.tpot)])
avg_latency = np.mean([m.latency for m in metrics.values() if m.completion_time != -1])
throughput = total_output_tokens / total_time
print("--- Benchmark Results ---")
print(f"Total time: {total_time:.2f}s")
print(f"Requests sent: {requests_sent}")
print(f"Throughput: {throughput:.2f} tokens/s")
print(f"Average TTFT: {avg_ttft * 1000:.2f} ms")
print(f"Average TPOT: {avg_tpot * 1000:.2f} ms/token")
print(f"Average latency: {avg_latency:.2f} s")
print("-------------------------\n")
if __name__ == "__main__":
main()optimized result : --- Benchmark Results ---
Total time: 21.55s
Requests sent: 64
Throughput: 1783.59 tokens/s
Average TTFT: 647.12 ms
Average TPOT: 19.05 ms/token
Average latency: 11.69 s
-------------------------origin result : --- Benchmark Results ---
Total time: 61.78s
Requests sent: 64
Throughput: 622.21 tokens/s
Average TTFT: 617.65 ms
Average TPOT: 53.08 ms/token
Average latency: 32.75 s
-------------------------then the result from benchmark_decoding.py : Optimized: ======================================================================
batch_size=16, seq_len=256, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
Time: 8.840ms
2. Testing Optimized PyTorch implementation...
Time: 5.338ms
3. Testing Triton implementation...
Time: **0.155ms** OPTIMIZED
======================================================================
batch_size=4, seq_len=2048, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
Time: 14.066ms
2. Testing Optimized PyTorch implementation...
Time: 4.812ms
3. Testing Triton implementation...
Time: **0.283ms** OPTIMIZEDOrigin: ======================================================================
batch_size=16, seq_len=256, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
Time: 8.281ms
2. Testing Optimized PyTorch implementation...
Time: 5.373ms
3. Testing Triton implementation...
Time: **0.719ms** ORIGIN
======================================================================
batch_size=4, seq_len=2048, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
Time: 13.413ms
2. Testing Optimized PyTorch implementation...
Time: 4.826ms
3. Testing Triton implementation...
Time: **2.285ms** ORIGINthanks for review~ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Performance Optimization: Token-Block Parallelism for PagedAttention
This PR optimizes the PagedAttention decode kernel in
src/layers/attentionby refactoring the BLOCKM processing logic. I replaced the original linear, token-by-token sequential processing with a token-block parallelization approach, which significantly improves hardware utilization and instruction throughput during the decoding phase.The performance gains were validated by integrating the kernel into
nanovllm(adjust a redundant scale factor inAttention.forward()), as the current repository'sforce_eager=Falsepath is unstable. Benchmarks show a ~2.5x reduction in average decode latency, effectively addressing the performance bottlenecks in the existing implementation.Closes #64