From 7fbee6f3aa1d7e65f17215421dd9021dba4c2db6 Mon Sep 17 00:00:00 2001 From: zip95297 Date: Mon, 16 Mar 2026 00:35:32 +0800 Subject: [PATCH] token parallel in paged attention kernel --- src/myvllm/layers/attention.py | 80 ++++++++++++---------------------- 1 file changed, 27 insertions(+), 53 deletions(-) diff --git a/src/myvllm/layers/attention.py b/src/myvllm/layers/attention.py index a50eb38..c4f1893 100644 --- a/src/myvllm/layers/attention.py +++ b/src/myvllm/layers/attention.py @@ -337,34 +337,19 @@ def paged_attention_decode_kernel( # Compute attention scores for this chunk qk = tl.zeros([BLOCK_N], dtype=tl.float32) - 1e10 - # Load K for each valid position and compute scores - for i in range(BLOCK_N): - token_idx = token_start + i - if token_idx < context_len: - block_num = token_idx // block_size - block_offset = token_idx % block_size - - if block_num < max_num_blocks: - # Look up physical block - block_table_offset = batch_idx * max_num_blocks + block_num - physical_block_idx = tl.load(block_tables_ptr + block_table_offset) - - if physical_block_idx != -1: - # Load K - k_offset = (physical_block_idx * block_size * num_kv_heads * head_dim + - block_offset * num_kv_heads * head_dim + - kv_head_idx * head_dim + offs_d) - k_vec = tl.load(k_cache_ptr + k_offset) - - # Compute score for this token - score = tl.sum(q * k_vec) * scale - - # Update qk array at position i using tl.where - mask_i = tl.arange(0, BLOCK_N) == i - qk = tl.where(mask_i, score, qk) - - # Apply mask to invalid positions - qk = tl.where(mask_n, qk, -1e10) + # Load K for each position and compute scores + block_idx = (token_start) // block_size + physical_block_idx = tl.load(block_tables_ptr + batch_idx * max_num_blocks + block_idx) + if physical_block_idx!=-1 : + # 物理块地址中读出BLOCK_M token的kv + k_offset = (physical_block_idx * block_size * num_kv_heads * head_dim # 块开始位置 + + offs_n[None,:] * num_kv_heads * head_dim + + kv_head_idx * head_dim + + offs_d[:,None]) + k = tl.load(k_cache_ptr + k_offset, mask=mask_n[None,:],other=0.0) + k = tl.cast(k,tl.float32) + score = tl.sum(q[:,None] * k, axis=0) * scale + qk = tl.where(mask_n, score, -1e10) # Online softmax m_ij = tl.max(qk) @@ -376,31 +361,20 @@ def paged_attention_decode_kernel( acc = acc * alpha l_i = l_i * alpha - # Load V and accumulate - for i in range(BLOCK_N): - token_idx = token_start + i - if token_idx < context_len: - block_num = token_idx // block_size - block_offset = token_idx % block_size - - if block_num < max_num_blocks: - # Look up physical block - block_table_offset = batch_idx * max_num_blocks + block_num - physical_block_idx = tl.load(block_tables_ptr + block_table_offset) - - if physical_block_idx != -1: - # Load V - v_offset = (physical_block_idx * block_size * num_kv_heads * head_dim + - block_offset * num_kv_heads * head_dim + - kv_head_idx * head_dim + offs_d) - v_vec = tl.load(v_cache_ptr + v_offset) - - # Extract weight for this token from p - mask_i = tl.arange(0, BLOCK_N) == i - weight = tl.sum(tl.where(mask_i, p, 0.0)) - - acc = acc + weight * v_vec - l_i = l_i + weight + # Load V for each position and compute result + block_idx = (token_start) // block_size + physical_block_idx = tl.load(block_tables_ptr + batch_idx * max_num_blocks + block_idx) + if physical_block_idx!=-1 : + # 物理块地址中读出BLOCK_M token的kv + v_offset = (physical_block_idx * block_size * num_kv_heads * head_dim # 块开始位置 + + offs_n[None,:] * num_kv_heads * head_dim + + kv_head_idx * head_dim + + offs_d[:,None]) + v = tl.load(v_cache_ptr + v_offset, mask=mask_n[None,:],other=0.0) + v = tl.cast(v,tl.float32) + weight = tl.where(mask_n, p, 0.0) + acc = acc + tl.sum(weight[None,:] * v,axis=1) + l_i = l_i + tl.sum(weight) m_i = m_i_new