Skip to content

[Optimize] Token parallel in paged attention kernel#65

Open
zip95297 wants to merge 1 commit intoWenyueh:mainfrom
zip95297:optimize/PageAttn-TokenParallel
Open

[Optimize] Token parallel in paged attention kernel#65
zip95297 wants to merge 1 commit intoWenyueh:mainfrom
zip95297:optimize/PageAttn-TokenParallel

Conversation

@zip95297
Copy link
Contributor

Performance Optimization: Token-Block Parallelism for PagedAttention

This PR optimizes the PagedAttention decode kernel in src/layers/attention by refactoring the BLOCKM processing logic. I replaced the original linear, token-by-token sequential processing with a token-block parallelization approach, which significantly improves hardware utilization and instruction throughput during the decoding phase.

The performance gains were validated by integrating the kernel into nanovllm (adjust a redundant scale factor in Attention.forward()), as the current repository's force_eager=False path is unstable. Benchmarks show a ~2.5x reduction in average decode latency, effectively addressing the performance bottlenecks in the existing implementation.

Closes #64

block_idx = (token_start) // block_size
physical_block_idx = tl.load(block_tables_ptr + batch_idx * max_num_blocks + block_idx)
if physical_block_idx!=-1 :
# 物理块地址中读出BLOCK_M token的kv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是读出 BLOCK_N token的kv吗?

# Apply mask to invalid positions
qk = tl.where(mask_n, qk, -1e10)
# Load K for each position and compute scores
block_idx = (token_start) // block_size
Copy link
Contributor

@77z-zhou 77z-zhou Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好, 我认为这一块应该需要load BLOCK_N physical_blocks, 代码如下:

block_nums = offs_n // block_size
block_offsets = offs_n % block_size
physical_blocks = tl.load(block_table_ptr + block_nums, mask=mask_n, other=-1)  # (BLOCK_N)
k_offsets = (
    physical_blocks[:, None] * block_size * num_kv_heads * head_dim + 
    block_offsets[:, None] * num_kv_heads * head_dim +
    kv_head_idx * head_dim + 
    offs_d[None, :]
)  # (BLOCK_N,  head_dim)
k = tl.load(k_cache_ptr + k_offsets, mask=(physical_blocks[:, None] != -1) & mask_n[:, None], other=0.0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢提醒,常规设置的blocksize 256 一般是大于 BLOCKN所以没有加上这个逻辑。

@Wenyueh
Copy link
Owner

Wenyueh commented Mar 16, 2026

Thanks for this cool implementation! Could you provide relevant benchmark + result for comparison? Using the current benchmark for decoding, I'm not seeing the big change yet.

@zip95297
Copy link
Contributor Author

the following scripts is the benchmark I use:

import os
import time
import numpy as np
import argparse
from random import randint, seed
from tqdm.auto import tqdm
from nanovllm import LLM, SamplingParams

# --- Constants ---
MODEL_PATH = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
MAX_INPUT_LEN = 1024
MAX_OUTPUT_LEN = 1024

# --- Seed for reproducibility ---
seed(0)
np.random.seed(0)

class RequestMetrics:
    """Stores metrics for a single request."""
    def __init__(self, request_id, input_len, max_output_len):
        self.request_id = request_id
        self.input_len = input_len
        self.max_output_len = max_output_len
        self.submission_time = -1
        self.first_token_time = -1
        self.completion_time = -1
        self.output_len = -1

    def record_submission(self):
        self.submission_time = time.perf_counter()

    def record_first_token(self):
        if self.first_token_time == -1:
            self.first_token_time = time.perf_counter()

    def record_completion(self, output_ids):
        self.completion_time = time.perf_counter()
        self.output_len = len(output_ids)

    @property
    def ttft(self):
        return self.first_token_time - self.submission_time

    @property
    def tpot(self):
        if self.output_len > 1:
            return (self.completion_time - self.first_token_time) / (self.output_len - 1)
        return float('nan')

    @property
    def latency(self):
        return self.completion_time - self.submission_time

def main():
    """Main function to run the serving benchmark."""
    parser = argparse.ArgumentParser(description="Serving benchmark for nano-vllm.")
    parser.add_argument("--num-requests", type=int, default=64, help="Number of requests to process.")
    parser.add_argument("--request-rate", type=int, default=8, help="Request rate (requests per second).")
    args = parser.parse_args()

    NUM_REQUESTS = args.num_requests
    REQUEST_RATE = args.request_rate

    print(f"\n--- Running benchmark with --num-requests {NUM_REQUESTS} --request-rate {REQUEST_RATE} ---")
    llm = LLM(MODEL_PATH, enforce_eager=False, max_model_len=4096, tensor_parallel_size=4)
    engine = llm

    # --- Generate random prompts ---
    prompts = [[randint(0, 10000) for _ in range(randint(100, MAX_INPUT_LEN))] for _ in range(NUM_REQUESTS)]
    sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, MAX_OUTPUT_LEN)) for _ in range(NUM_REQUESTS)]

    # --- Generate request arrival times ---
    request_intervals = np.random.poisson(1.0 / REQUEST_RATE, NUM_REQUESTS)
    arrival_times = np.cumsum(request_intervals)

    # --- Benchmark loop ---
    metrics = {}
    requests_sent = 0
    start_time = time.perf_counter()
    completed_latencies = []

    with tqdm(total=NUM_REQUESTS, desc="Processing Requests") as pbar:
        while requests_sent < NUM_REQUESTS or not engine.is_finished():
            # --- Send new requests ---
            current_time = time.perf_counter()
            while requests_sent < NUM_REQUESTS and current_time - start_time >= arrival_times[requests_sent]:
                prompt = prompts[requests_sent]
                sp = sampling_params[requests_sent]
                
                engine.add_request(prompt, sp)
                
                new_seq = engine.scheduler.waiting[-1]
                seq_id = new_seq.seq_id
                req_metrics = RequestMetrics(seq_id, len(prompt), sp.max_tokens)
                req_metrics.record_submission()
                metrics[seq_id] = req_metrics
                
                requests_sent += 1

            # --- Engine step ---
            if engine.scheduler.waiting or engine.scheduler.running:
                finished_outputs, _ = engine.step()

                # Record first token time for all processed sequences
                all_processed_seqs = list(engine.scheduler.running)
                for seq in all_processed_seqs:
                    if seq.seq_id in metrics:
                        metrics[seq.seq_id].record_first_token()

                for seq_id, output_ids in finished_outputs:
                    if seq_id in metrics:
                        metrics[seq_id].record_first_token() # Ensure first token time is recorded
                        metrics[seq_id].record_completion(output_ids)
                        
                        completed_latencies.append(metrics[seq_id].latency)
                        avg_latency = np.mean(completed_latencies)
                        pbar.set_postfix({"Avg Latency": f"{avg_latency:.2f}s"})
                        pbar.update(1)
            else:
                # If no requests are running or waiting, sleep briefly
                time.sleep(0.01)

    end_time = time.perf_counter()
    total_time = end_time - start_time

    # --- Calculate and print metrics ---
    total_input_tokens = sum(m.input_len for m in metrics.values())
    total_output_tokens = sum(m.output_len for m in metrics.values() if m.output_len != -1)
    
    avg_ttft = np.mean([m.ttft for m in metrics.values() if m.first_token_time != -1])
    avg_tpot = np.mean([m.tpot for m in metrics.values() if not np.isnan(m.tpot)])
    avg_latency = np.mean([m.latency for m in metrics.values() if m.completion_time != -1])
    throughput = total_output_tokens / total_time

    print("--- Benchmark Results ---")
    print(f"Total time: {total_time:.2f}s")
    print(f"Requests sent: {requests_sent}")
    print(f"Throughput: {throughput:.2f} tokens/s")
    print(f"Average TTFT: {avg_ttft * 1000:.2f} ms")
    print(f"Average TPOT: {avg_tpot * 1000:.2f} ms/token")
    print(f"Average latency: {avg_latency:.2f} s")
    print("-------------------------\n")

if __name__ == "__main__":
    main()

optimized result :

--- Benchmark Results ---
Total time: 21.55s
Requests sent: 64
Throughput: 1783.59 tokens/s
Average TTFT: 647.12 ms
Average TPOT: 19.05 ms/token
Average latency: 11.69 s
-------------------------

origin result :

--- Benchmark Results ---
Total time: 61.78s
Requests sent: 64
Throughput: 622.21 tokens/s
Average TTFT: 617.65 ms
Average TPOT: 53.08 ms/token
Average latency: 32.75 s
-------------------------

then the result from benchmark_decoding.py :

Optimized:

======================================================================
batch_size=16, seq_len=256, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
   Time: 8.840ms
2. Testing Optimized PyTorch implementation...
   Time: 5.338ms
3. Testing Triton implementation...
   Time: **0.155ms** OPTIMIZED
======================================================================
batch_size=4, seq_len=2048, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
   Time: 14.066ms
2. Testing Optimized PyTorch implementation...
   Time: 4.812ms
3. Testing Triton implementation...
   Time: **0.283ms** OPTIMIZED

Origin:

======================================================================
batch_size=16, seq_len=256, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
   Time: 8.281ms
2. Testing Optimized PyTorch implementation...
   Time: 5.373ms
3. Testing Triton implementation...
   Time: **0.719ms** ORIGIN
======================================================================
batch_size=4, seq_len=2048, num_heads=32
num_kv_heads=8, head_dim=128, block_size=16
======================================================================
1. Testing Naive PyTorch implementation...
   Time: 13.413ms
2. Testing Optimized PyTorch implementation...
   Time: 4.826ms
3. Testing Triton implementation...
   Time: **2.285ms** ORIGIN

thanks for review~

@Wenyueh Wenyueh self-requested a review March 16, 2026 13:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Optimize] PageAttention Decode Kernel in BLOCKM Processing with TokenBlock Parallelization

3 participants