Skip to content

Conversation

@agi-scaler
Copy link
Contributor

@agi-scaler agi-scaler commented Sep 24, 2025

The original implementation in sglang can be found in https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/decode_attention.py#L89-L94

In this PR, the xai_temperature_len feature is added to:

  • ref_ragged_paged_attention_fused
  • ref_ragged_paged_attention
  • ragged_paged_attention

Correctness verification:

  • decode attn varying seqlen, TPU vs GPU numerical difference is within 6e-2:
image
# produce TPU qkvo:
python3 python/sgl_jax/test/test_flashattention_dump.py  -k test_gqa_prefill_accuracy_page_size_1_temperature_dump 

# compare with GPU qkvo:
python3 test_sgl_baseline.py -k  test_extend_attention_dump

The result difference without temperature can be as high as 3.4688, even without this PR:

Testing prefill_32_128_8_1_128_tempNone.npy
qkv_np.shape (128, 8, 128) (128, 8, 128) (1, 32, 128) False False False False False
Diff tensor([-0.0010,  0.0010,  0.0015,  ...,  0.0015,  0.0007,  0.0000],
       device='cuda:0', dtype=torch.bfloat16) tensor(**0.0063**, device='cuda:0', dtype=torch.bfloat16)
Testing prefill_32_128_8_3_20_tempNone.npy
qkv_np.shape (20, 8, 128) (20, 8, 128) (3, 32, 128) False False False False False
Diff tensor([-0.0273,  0.0386, -0.1289,  ...,  0.0684, -0.0016,  0.0391],
       device='cuda:0', dtype=torch.bfloat16) tensor(**1.0156**, device='cuda:0', dtype=torch.bfloat16)
Testing prefill_32_128_8_64_64_tempNone.npy
qkv_np.shape (64, 8, 128) (64, 8, 128) (64, 32, 128) False False False False False
Diff tensor([-0.3555,  1.1875, -0.3477,  ...,  0.0957,  0.0972,  0.1709],
       device='cuda:0', dtype=torch.bfloat16) tensor(**3.1719**, device='cuda:0', dtype=torch.bfloat16)
Testing prefill_32_128_8_20_20_tempNone.npy
qkv_np.shape (20, 8, 128) (20, 8, 128) (20, 32, 128) False False False False False
Diff tensor([-0.3848,  1.4375, -0.1680,  ...,  0.1387,  0.0261,  0.2852],
       device='cuda:0', dtype=torch.bfloat16) tensor(**3.4688**, device='cuda:0', dtype=torch.bfloat16)
Testing prefill_32_128_8_125_125_tempNone.npy
qkv_np.shape (125, 8, 128) (125, 8, 128) (125, 32, 128) False False False False False
Diff tensor([-0.5859,  1.2656, -0.4746,  ...,  0.0796,  0.0586,  0.1050],
       device='cuda:0', dtype=torch.bfloat16) tensor(**3.0469**, device='cuda:0', dtype=torch.bfloat16)
Testing prefill_32_128_8_123_522_tempNone.npy
qkv_np.shape (522, 8, 128) (522, 8, 128) (123, 32, 128) False False False False False
Diff tensor([-0.0039,  0.0703,  0.0415,  ...,  0.0181,  0.0278,  0.0210],
       device='cuda:0', dtype=torch.bfloat16) tensor(**0.4082**, device='cuda:0', dtype=torch.bfloat16)
Testing prefill_32_128_8_1_511_tempNone.npy
qkv_np.shape (511, 8, 128) (511, 8, 128) (1, 32, 128) False False False False False
Diff tensor([ 0.0000e+00,  0.0000e+00, -5.4932e-04,  ...,  4.8828e-04,
        -1.8311e-04,  6.1035e-05], device='cuda:0', dtype=torch.bfloat16) tensor(**0.0020**, device='cuda:0', dtype=torch.bfloat16)

when temperature is enabled, the numerical difference stays at the same level.

Performance benchmark: negligible performance difference compared to main branch without temperature.

image

raw benchmark data:
attn-benchmark.xlsx

@gemini-code-assist
Copy link

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@jimoosciuc jimoosciuc requested a review from Iamleos September 24, 2025 12:07
@Iamleos
Copy link
Collaborator

Iamleos commented Sep 24, 2025

please add some tests for page_size 1 with temperature attention enabled

@Iamleos
Copy link
Collaborator

Iamleos commented Sep 24, 2025

please attach flash attention kernel benchmark results. Refer to benchmark/kernels/flash_attention/bench_flashattention.py

@agi-scaler
Copy link
Contributor Author

@Iamleos benchmark data added to PR description.

@agi-scaler
Copy link
Contributor Author

pagesize 1 test also added.

Iamleos
Iamleos previously approved these changes Sep 26, 2025
Copy link
Collaborator

@Iamleos Iamleos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/LGTM

q_batch = jnp.stack(q_heads, axis=0)

if xai_temperature_len is not None:
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here import numpy

@Iamleos
Copy link
Collaborator

Iamleos commented Sep 26, 2025

merge blocked, please fix the lint error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants