Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 27 additions & 53 deletions src/myvllm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,34 +337,19 @@ def paged_attention_decode_kernel(
# Compute attention scores for this chunk
qk = tl.zeros([BLOCK_N], dtype=tl.float32) - 1e10

# Load K for each valid position and compute scores
for i in range(BLOCK_N):
token_idx = token_start + i
if token_idx < context_len:
block_num = token_idx // block_size
block_offset = token_idx % block_size

if block_num < max_num_blocks:
# Look up physical block
block_table_offset = batch_idx * max_num_blocks + block_num
physical_block_idx = tl.load(block_tables_ptr + block_table_offset)

if physical_block_idx != -1:
# Load K
k_offset = (physical_block_idx * block_size * num_kv_heads * head_dim +
block_offset * num_kv_heads * head_dim +
kv_head_idx * head_dim + offs_d)
k_vec = tl.load(k_cache_ptr + k_offset)

# Compute score for this token
score = tl.sum(q * k_vec) * scale

# Update qk array at position i using tl.where
mask_i = tl.arange(0, BLOCK_N) == i
qk = tl.where(mask_i, score, qk)

# Apply mask to invalid positions
qk = tl.where(mask_n, qk, -1e10)
# Load K for each position and compute scores
block_idx = (token_start) // block_size
Copy link
Contributor

@77z-zhou 77z-zhou Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好, 我认为这一块应该需要load BLOCK_N physical_blocks, 代码如下:

block_nums = offs_n // block_size
block_offsets = offs_n % block_size
physical_blocks = tl.load(block_table_ptr + block_nums, mask=mask_n, other=-1)  # (BLOCK_N)
k_offsets = (
    physical_blocks[:, None] * block_size * num_kv_heads * head_dim + 
    block_offsets[:, None] * num_kv_heads * head_dim +
    kv_head_idx * head_dim + 
    offs_d[None, :]
)  # (BLOCK_N,  head_dim)
k = tl.load(k_cache_ptr + k_offsets, mask=(physical_blocks[:, None] != -1) & mask_n[:, None], other=0.0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢提醒,常规设置的blocksize 256 一般是大于 BLOCKN所以没有加上这个逻辑。

physical_block_idx = tl.load(block_tables_ptr + batch_idx * max_num_blocks + block_idx)
if physical_block_idx!=-1 :
# 物理块地址中读出BLOCK_M token的kv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是读出 BLOCK_N token的kv吗?

k_offset = (physical_block_idx * block_size * num_kv_heads * head_dim # 块开始位置
+ offs_n[None,:] * num_kv_heads * head_dim
+ kv_head_idx * head_dim
+ offs_d[:,None])
k = tl.load(k_cache_ptr + k_offset, mask=mask_n[None,:],other=0.0)
k = tl.cast(k,tl.float32)
score = tl.sum(q[:,None] * k, axis=0) * scale
qk = tl.where(mask_n, score, -1e10)

# Online softmax
m_ij = tl.max(qk)
Expand All @@ -376,31 +361,20 @@ def paged_attention_decode_kernel(
acc = acc * alpha
l_i = l_i * alpha

# Load V and accumulate
for i in range(BLOCK_N):
token_idx = token_start + i
if token_idx < context_len:
block_num = token_idx // block_size
block_offset = token_idx % block_size

if block_num < max_num_blocks:
# Look up physical block
block_table_offset = batch_idx * max_num_blocks + block_num
physical_block_idx = tl.load(block_tables_ptr + block_table_offset)

if physical_block_idx != -1:
# Load V
v_offset = (physical_block_idx * block_size * num_kv_heads * head_dim +
block_offset * num_kv_heads * head_dim +
kv_head_idx * head_dim + offs_d)
v_vec = tl.load(v_cache_ptr + v_offset)

# Extract weight for this token from p
mask_i = tl.arange(0, BLOCK_N) == i
weight = tl.sum(tl.where(mask_i, p, 0.0))

acc = acc + weight * v_vec
l_i = l_i + weight
# Load V for each position and compute result
block_idx = (token_start) // block_size
physical_block_idx = tl.load(block_tables_ptr + batch_idx * max_num_blocks + block_idx)
if physical_block_idx!=-1 :
# 物理块地址中读出BLOCK_M token的kv
v_offset = (physical_block_idx * block_size * num_kv_heads * head_dim # 块开始位置
+ offs_n[None,:] * num_kv_heads * head_dim
+ kv_head_idx * head_dim
+ offs_d[:,None])
v = tl.load(v_cache_ptr + v_offset, mask=mask_n[None,:],other=0.0)
v = tl.cast(v,tl.float32)
weight = tl.where(mask_n, p, 0.0)
acc = acc + tl.sum(weight[None,:] * v,axis=1)
l_i = l_i + tl.sum(weight)

m_i = m_i_new

Expand Down