Llama3.1: Median decode latency is high with batch size 128 on the Triton backend #1935
Unanswered
Jackycheng0808
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Docker Image: sglang 0.3.4.post2
Hardware: H200
Command: python3 -m sglang.bench_latency --batch-size 128 --input 128 --output 128 --model "amd/Meta-Llama-3.1-8B-Instruct-FP8-KV" --quantization fp8 --tp 1
Hi, I am testing Llama 3.1 on different backends (Triton & FlashInfer) and have found some unusual behavior on the Triton backend. The median decode latency and total latency significantly increase from batch size 64 to batch size 128, then drop back down at batch size 256. Compared to the FlashInfer backend, the latency is 5x slower at batch size 128. After profiling, I realized the _fwd_grouped_kernel_stage1 kernel takes ~90% of the execution time, while the BatchDecode kernel of the FlashInfer engine only takes 24%. I am wondering if there might be an issue in the fwd_grouped_kernel_stage1 Triton kernel implementation?
Beta Was this translation helpful? Give feedback.
All reactions