- 
                Notifications
    You must be signed in to change notification settings 
- Fork 23
[grok, attention]: support xai_temperature_len feature in attention for grok #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! | 
| please add some tests for page_size 1 with temperature attention enabled | 
| please attach flash attention kernel benchmark results. Refer to benchmark/kernels/flash_attention/bench_flashattention.py | 
| @Iamleos benchmark data added to PR description. | 
| pagesize 1 test also added. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/LGTM
| q_batch = jnp.stack(q_heads, axis=0) | ||
|  | ||
| if xai_temperature_len is not None: | ||
| import numpy as np | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why here import numpy
| merge blocked, please fix the lint error | 
The original implementation in sglang can be found in https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/decode_attention.py#L89-L94
In this PR, the xai_temperature_len feature is added to:
Correctness verification:
for prefill (extend attention), the difference is larger
extend attention without temperature is compared using this script: https://github.com/agi-scaler/sglang-jax/blob/temp-comp-baseline/test_sgl_baseline.py and https://github.com/agi-scaler/sglang-jax/blob/temp-comp-baseline/python/sgl_jax/test/test_flashattention_dump.py#L396-L409
The result difference without temperature can be as high as 3.4688, even without this PR:
when temperature is enabled, the numerical difference stays at the same level.
Performance benchmark: negligible performance difference compared to main branch without temperature.
raw benchmark data:
attn-benchmark.xlsx