Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def ref_ragged_paged_attention_fused(
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
Expand Down Expand Up @@ -84,6 +85,16 @@ def ref_ragged_paged_attention_fused(
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
if soft_cap is not None:
attn = soft_cap * jnp.tanh(attn / soft_cap)

# xai_temperature_scale: ref implementation from sgl-project/sglang
# python/sglang/srt/layers/attention/triton_ops/decode_attention.py
if xai_temperature_len is not None:
xai_temperature_scale = 1.0 / jnp.log2(float(xai_temperature_len))
qidx = jnp.arange(q_start, q_end) - 1
_qtemp = jnp.log2(qidx.astype(jnp.float32)) * xai_temperature_scale
xai_temperature_reg = jnp.where(qidx > xai_temperature_len, _qtemp, 1.0)
attn = attn * xai_temperature_reg[:, None]

attn += jnp.where(mask, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
Expand All @@ -107,6 +118,7 @@ def ref_ragged_paged_attention(
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
Expand Down Expand Up @@ -141,6 +153,16 @@ def ref_ragged_paged_attention(
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
if soft_cap is not None:
attn = soft_cap * jnp.tanh(attn / soft_cap)

# xai_temperature_scale: ref implementation from sgl-project/sglang
# python/sglang/srt/layers/attention/triton_ops/decode_attention.py
if xai_temperature_len is not None:
xai_temperature_scale = 1.0 / jnp.log2(float(xai_temperature_len))
qidx = jnp.arange(q_start, q_end) - 1
_qtemp = jnp.log2(qidx.astype(jnp.float32)) * xai_temperature_scale
xai_temperature_reg = jnp.where(qidx > xai_temperature_len, _qtemp, 1.0)
attn = attn * xai_temperature_reg[:, None]

attn += jnp.where(mask, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
Expand Down Expand Up @@ -258,6 +280,7 @@ def _ragged_paged_attention_kernel(
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
chunk_prefill_size: int | None = None,
bkv_p,
bq_sz,
Expand Down Expand Up @@ -661,12 +684,21 @@ def batch_prepare_queries():
for head_idx in range(actual_num_kv_heads):
bq = load_bq(bq_sem_idx, head_idx, actual_bq_sz=actual_bq_sz)
q_heads.append(bq)
q_batch = jnp.stack(q_heads, axis=0)

if xai_temperature_len is not None:
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here import numpy


return jnp.stack(q_heads, axis=0)
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_batch_shape = q_batch.shape
base = (q_len_start - 1) + lax.iota(jnp.int32, q_batch_shape[1])
offs_qidx = jnp.tile(base, (q_batch_shape[0], 1))
return q_batch, offs_qidx
return q_batch, None

# Load batched data
k_batch, v_batch = batch_load_all_heads_kv()
q_batch = batch_prepare_queries()
q_batch, offs_qidx_batch = batch_prepare_queries()

def flash_attention(q_batch, k_batch, v_batch):
q_batch_f32 = q_batch.astype(jnp.float32)
Expand Down Expand Up @@ -709,6 +741,22 @@ def flash_attention(q_batch, k_batch, v_batch):
if soft_cap is not None:
s = soft_cap * jnp.tanh(s / soft_cap)

# xai_temperature_scale: ref implementation from sgl-project/sglang
# python/sglang/srt/layers/attention/triton_ops/decode_attention.py
if xai_temperature_len is not None:
xai_temperature_scale = 1.0 / jnp.log2(
float(xai_temperature_len)
)
_qtemp = (
jnp.log2(offs_qidx_batch.astype(jnp.float32))
* xai_temperature_scale
)
xai_temperature_reg = jnp.where(
offs_qidx_batch > xai_temperature_len, _qtemp, 1.0
)

s = s * xai_temperature_reg[:, :, None]

s += jnp.where(mask, mask_value, 0.0)

for head_idx in range(actual_num_kv_heads):
Expand Down Expand Up @@ -964,6 +1012,7 @@ def static_validate_inputs_fused(
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
# Kernel optimization params.
chunk_prefill_size: int | None = None,
# Kernel tuning params.
Expand Down Expand Up @@ -1058,6 +1107,8 @@ def static_validate_inputs_fused(
raise ValueError(f"{sliding_window=} must be positive.")
if soft_cap is not None and soft_cap == 0.0:
raise ValueError(f"{soft_cap=} must not be 0.0.")
if xai_temperature_len is not None and xai_temperature_len <= 0:
raise ValueError(f"{xai_temperature_len=} must be positive.")
if chunk_prefill_size is not None and chunk_prefill_size <= 0:
raise ValueError(f"{chunk_prefill_size=} must be positive.")
if num_kv_pages_per_block is not None:
Expand All @@ -1074,6 +1125,7 @@ def static_validate_inputs_fused(
del q_scale
del k_scale
del v_scale
del xai_temperature_len


@functools.partial(
Expand All @@ -1086,6 +1138,7 @@ def static_validate_inputs_fused(
"q_scale",
"k_scale",
"v_scale",
"xai_temperature_len",
"chunk_prefill_size",
"num_kv_pages_per_block",
"num_queries_per_block",
Expand All @@ -1111,6 +1164,7 @@ def ragged_paged_attention(
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
# Kernel optimization params.
chunk_prefill_size: int | None = None,
# Kernel tuning params.
Expand Down Expand Up @@ -1140,6 +1194,8 @@ def ragged_paged_attention(
mask_value: mask value for causal mask.
k_scale: the scale for the key cache.
v_scale: the scale for the value cache.
xai_temperature_len: the length-based temperature term used by xai grok.
reference: sgl-project/sglang: python/sglang/srt/layers/attention/triton_ops/decode_attention.py
num_kv_pages_per_block: number of kv pages to be processed in one flash
attention block in the pallas kernel.
num_queries_per_block: number of kv pages to be processed in one flash
Expand Down Expand Up @@ -1300,6 +1356,7 @@ def ragged_paged_attention(
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
xai_temperature_len=xai_temperature_len,
chunk_prefill_size=chunk_prefill_size,
bq_sz=bq_sz,
bkv_p=bkv_p,
Expand Down
2 changes: 2 additions & 0 deletions python/sgl_jax/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __call__(
scale = 1.0 / jnp.sqrt(layer.head_dim)
else:
scale = layer.scaling
xai_temperature_len = getattr(layer, "xai_temperature_len", None)

# Prepare fused KV cache for paged format: [num_pages, page_size, num_kv_heads * 2, head_dim]
total_tokens = kv_cache_fused.shape[0]
Expand Down Expand Up @@ -249,6 +250,7 @@ def _ragged_paged_attention_with_fused_kv(*args):
sm_scale=scale,
sliding_window=None,
soft_cap=None,
xai_temperature_len=xai_temperature_len,
vmem_limit_bytes=self.vmem_limit_bytes,
)

Expand Down
1 change: 1 addition & 0 deletions python/sgl_jax/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
self.scaling = scaling
self.layer_id = layer_id
self.attn_type = attn_type
self.xai_temperature_len = None

def __call__(
self,
Expand Down
60 changes: 59 additions & 1 deletion python/sgl_jax/test/test_flashattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def create_test_data(
):
"""Create a real ForwardBatch for testing."""
assert mode in ["prefill", "decode"]
print("model_config", model_config)
batch_size = len(lens)
# Create sequence lengths array
seq_lens = jnp.array([kv_len for _, kv_len in lens], dtype=jnp.int32)
Expand Down Expand Up @@ -231,6 +232,8 @@ def align_to_size(l, size, value=0):
attention_backend = FlashAttention(
num_heads, num_kv_heads, head_dim, page_size=page_size
)
if model_config.get("xai_temperature_len", None):
attention_backend.xai_temperature_len = model_config["xai_temperature_len"]
forward_mode = ForwardMode.EXTEND if mode == "prefill" else ForwardMode.DECODE

mwb = ModelWorkerBatch(
Expand Down Expand Up @@ -286,7 +289,7 @@ def setUp(self):
self.rng_key = jax.random.PRNGKey(42)
np.random.seed(42)

def run_test(self, mode, lens, mode_args):
def run_test(self, mode, lens, mode_args, **mode_kwargs):
# Create mock forward_batch
num_heads, head_dim, num_kv_heads, page_size, dtype = mode_args

Expand All @@ -307,6 +310,7 @@ def run_test(self, mode, lens, mode_args):
"head_dim": head_dim,
"num_hidden_layers": 1,
"bf16": is_bf16,
"xai_temperature_len": mode_kwargs.get("xai_temperature_len", None),
},
)

Expand Down Expand Up @@ -373,6 +377,7 @@ def run_test(self, mode, lens, mode_args):
# forward_batch.attn_backend.forward_metadata.cu_kv_lens,
forward_batch.attn_backend.forward_metadata.num_seqs,
sm_scale=head_dim**-0.5,
**mode_kwargs,
)
jax.block_until_ready(expected)

Expand All @@ -381,6 +386,8 @@ def jit_attn(q, k, v, forward_batch):
out = attn(q, k, v, forward_batch)
return out

attn.xai_temperature_len = mode_kwargs.get("xai_temperature_len", None)

# run
jax_output, _ = jit_attn(q_shard, extend_k, extend_v, forward_batch)
jax.block_until_ready(jax_output)
Expand Down Expand Up @@ -600,6 +607,57 @@ def test_gqa_decode_accuracy_page_size_64(self):
"decode", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16)
)

def test_gqa_prefill_accuracy_page_size_64_temperature(self):
"""Test JAX attention accuracy against PyTorch reference
Testcase (1024, 1024) fails on token 607, possible precision issue?
Token 607: max_diff=0.023438, jax_mean=-0.011597, expected_mean=-0.011597, jax_std=0.048096, expected_std=0.047607
"""
# Parameters
num_heads = 32
num_kv_heads = 8
head_dim = 128
lens = [
(1, 128),
(3, 20),
(64, 64),
(20, 20),
(125, 125),
# (1024, 1024),
(123, 522),
(1, 511),
]
self.run_test(
"prefill",
lens,
(num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16),
xai_temperature_len=512,
)

def test_gqa_decode_accuracy_page_size_64_temperature(self):
"""Test JAX attention accuracy against native fa"""
# Parameters
num_heads = 32
num_kv_heads = 8
head_dim = 128
lens = [
(1, 119),
(1, 127),
(1, 128),
(1, 129),
(1, 133),
(1, 1001),
(1, 1023),
(1, 1024),
(1, 1025),
]

self.run_test(
"decode",
lens,
(num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16),
xai_temperature_len=512,
)


if __name__ == "__main__":
unittest.main()