diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index bf689055..018b59d0 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -59,9 +59,10 @@ def _attn_fwd( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: - lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M + hi = tl.minimum(hi, N_KV_CTX) for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -226,4 +227,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q) o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q) - torch.testing.assert_close(o1, o2) + torch.testing.assert_close(o1, o2) \ No newline at end of file