Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions gpt_oss/triton/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _attn_fwd(
lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M
hi = tl.minimum(hi, N_KV_CTX)

for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
Expand Down Expand Up @@ -181,6 +182,7 @@ def attention_ref(
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask | (pos_keys[None, :] < start_q)
mask = mask.float().masked_fill(mask, float("-inf"))
Comment on lines 182 to 185

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mask now drops cached prefix tokens

The new mask mask = mask | (pos_keys[None, :] < start_q) causes the reference path to drop every key with index < start_q. During cached decoding we set start_q to the cache offset (see AttentionBlock.forward in gpt_oss/triton/model.py lines 218-253) while passing the full cache as k/v, so this change makes n_ctx==1 or small-context calls ignore all previously cached tokens and attend only to the current block. That silently breaks causal attention whenever offset > 0, producing wrong outputs for generation with a warm cache.

Useful? React with 👍 / 👎.

Copy link
Author

@xjmxyt xjmxyt Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solve this by matching kernel implementation to reference instead of modifying reference.


if sliding_window:
Expand Down Expand Up @@ -211,7 +213,7 @@ def attention_ref(
@pytest.mark.parametrize("head_dim", [64])
@pytest.mark.parametrize("sm_scale", [0.125])
@pytest.mark.parametrize("sliding_window", [None, 128])
@pytest.mark.parametrize("start_q", [0, 5])
@pytest.mark.parametrize("start_q", [0, 64])
def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q):
if num_queries > num_keys:
pytest.skip("too many queries")
Expand All @@ -226,4 +228,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu
o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q)
o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q)

torch.testing.assert_close(o1, o2)
torch.testing.assert_close(o1, o2)