Skip to content

Conversation

@TianyuZhang1214
Copy link
Collaborator

@TianyuZhang1214 TianyuZhang1214 commented Oct 20, 2025

Deploying DeepSeek-R1 on H20-96G with SGLang: Best Practices

Introduction

We published an article on LMSYS titled "Together with SGLang: Best Practices for Serving DeepSeek-R1 on H20-96G", sharing our best practices for deploying the DeepSeek-R1 model on H20-96G hardware.
To facilitate reproduction of our experimental results and provide access to our code, we have released this pull request in the DeepSeek-R1 repository.

Reproduction Steps

Pulling the Docker Image

To obtain the Docker image, use the following command:

docker pull ghcr.io/antgroup/sglang:h20-blog-release

The image is hosted at: https://github.com/orgs/antgroup/packages/container/package/sglang

Checking Environment Variables

All environment variables are stored in the /root/env.sh file, configured for our H20 environment. Before launching SGLang, verify that these variables are suitable for your environment.

Launching SGLang

We recommend running four containers: two for Prefill nodes and two for Decode nodes.

1. Launching Prefill Nodes (Identical Configuration for Both Nodes)

Note:

  • Both Prefill nodes use the same launch parameters.
  • Adjust the port number if there is a conflict.
PYTHONUNBUFFERED=1 \
SGL_CHUNKED_PREFIX_CACHE_THRESHOLD=0 \
nohup python3 -m sglang.launch_server \
--trust-remote-code \
--model-path /path/to/DeepSeek-R1 \
--disaggregation-mode prefill \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \
--host 0.0.0.0 \
--port 61001 \
--tp-size 8 \
--page-size 64 \
--attention-backend fa3 \
--mem-fraction-static 0.9 \
--chunked-prefill-size 16384 \
--max-running-requests 512 \
--context-length 65535 \
--enable-cache-report \
--log-level info \
--load-balance-method round_robin \
--quantization fp8 \
--kv-cache-dtype fp8_e4m3 \
> /home/admin/logs/stdout.log 2>&1 &

2. Launching Decode Nodes

Note:

  • Set {node_rank} to 0 or 1 for the respective node.
  • Replace {decode_master_ip} with the IP address of Node 0.
  • Adjust the port number if there is a conflict.
Node-0
PYTHONUNBUFFERED=1 \
SGL_ENABLE_JIT_DEEPGEMM=1 \
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=64 \
ENABLE_SWAPAB=1 \
nohup python3 -m sglang.launch_server \
--model-path /path/to/DeepSeek-R1 \
--disaggregation-mode decode \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \
--disaggregation-bootstrap-port 9000 \
--attention-backend flashmla \
--host 0.0.0.0 \
--port 61001 \
--trust-remote-code \
--dist-init-addr {decode_master_ip}:62001 \
--nnodes 2 \
--node-rank {node_rank} \
--tp-size 16 \
--dp-size 16 \
--enable-dp-attention \
--mem-fraction-static 0.88 \
--max-running-requests 512 \
--context-length 65535 \
--log-level info \
--decode-log-interval 50 \
--page-size 64 \
--schedule-conservativeness 0.3 \
--enable-cache-report \
--moe-dense-tp-size 1 \
--enable-deepep-moe \
--enable-dp-lm-head \
--cuda-graph-max-bs 32 \
--speculative-algorithm NEXTN \
--speculative-num-steps 1 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 2 \
--init-expert-location /root/expert_workload.json \
--prefill-round-robin-balance \
--quantization fp8 \
--kv-cache-dtype fp8_e4m3 \
--deepep-mode low_latency_overlap \
--enable-single-batch-overlap \
> /home/admin/logs/stdout.log 2>&1 &

3. Launching SGLang Router

Note:

  • Replace {decode_master_ip}, {prefill_node_0_ip}, and {prefill_node_1_ip} with the respective IP addresses.
  • Adjust the port number if there is a conflict.
nohup python3 -m sglang_router.launch_router \
--pd-disaggregation \
--mini-lb \
--host 0.0.0.0 \
--decode http://{decode_master_ip}:61001 \
--port 8000 \
--prefill http://{prefill_node_0_ip}:61001 \
--prefill http://{prefill_node_1_ip}:61001 \
> /home/admin/logs/router.log 2>&1 &

Testing

1. Running the Benchmark

Note:

  • This script is designed to observe peak performance in logs. Since --request-rate is set to inf, all requests are sent at once, making TTFT and TPOT data less meaningful.
  • Replace {path-to-shareGPT} with the path to the ShareGPT dataset.
nohup python3 -m sglang.bench_serving \
--host 0.0.0.0 \
--port 8000 \
--dataset-path {path-to-shareGPT} \
--num-prompt 4096 \
--random-input 4096 \
--random-output 1536 \
--request-rate "inf" \
--max-concurrency 2048 \
--warmup-requests 0 \
--backend sglang \
--dataset-name random \
--random-range-ratio 1 \
> /home/local/workspace/bench.log 2>&1 &

2. Observing Logs

To monitor peak performance, filter logs for entries with running-req: 32:

grep -E 'Decode batch.*running-req: 32' /home/admin/logs/sglang.log

Example Output (for batch size = 32):

2025-10-20 03:02:22 INFO 31223 [DP3 TP3 EP3 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 157952, token usage: 0.21, accept len: 1.93, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 693.45, #queue-req: 0
2025-10-20 03:02:22 INFO 31225 [DP5 TP5 EP5 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 164224, token usage: 0.22, accept len: 1.92, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 674.19, #queue-req: 1
2025-10-20 03:02:22 INFO 31222 [DP2 TP2 EP2 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 162112, token usage: 0.22, accept len: 1.90, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 655.17, #queue-req: 1
2025-10-20 03:02:22 INFO 31224 [DP4 TP4 EP4 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 168768, token usage: 0.22, accept len: 1.93, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 679.00, #queue-req: 2
2025-10-20 03:02:22 INFO 31227 [DP7 TP7 EP7 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 157696, token usage: 0.21, accept len: 1.92, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 673.31, #queue-req: 0
2025-10-20 03:02:26 INFO 31222 [DP2 TP2 EP2 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 159488, token usage: 0.21, accept len: 1.92, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 679.66, #queue-req: 0
2025-10-20 03:02:27 INFO 31224 [DP4 TP4 EP4 scheduler_metrics_mixin.py:222] Decode batch. #running-req: 32, #token: 160320, token usage: 0.21, accept len: 1.94, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 673.26, #queue-req: 0

Related PRs

Summary by Sourcery

Enable a one-shot multi-head attention path and TMA-based MoE optimizations across SGLang’s DeepSeek and fused MoE kernels, and add a script for auto-tuning Triton MoE configurations.

New Features:

  • Introduce MHA_ONE_SHOT method in DeepseekV2 to perform prefix and extended KV attention in one pass for short sequences
  • Add Triton TensorDescriptor (TMA) support in fused MoE kernels to enable efficient A/B descriptors

Enhancements:

  • Parameterize fused MoE Python bindings with filter_expert, a_use_tma, b_use_tma and c_sorted flags and extend config loader for down_moe
  • Extend memory pool caching to support separate KV buffers and fast retrieval for one-shot MHA
  • Update flashinfer and flashattention backends to handle the new one-shot attention mode

Chores:

  • Add a standalone benchmarking and auto-tuning script for fused MoE Triton kernels along with example JSON config files

@sourcery-ai
Copy link

sourcery-ai bot commented Oct 20, 2025

Reviewer's Guide

This PR implements a new one-shot multi-head attention mode for DeepSeek-V2, enriches fused MoE Triton kernels with descriptor/TMA/filtering support, introduces Triton-based KV buffer operations in the memory pool and utils, updates config generation for down-MoE scenarios, and adds a comprehensive benchmark/tuning script for the fused MoE kernels.

Sequence diagram for one-shot MHA attention path in DeepSeek-V2

sequenceDiagram
    participant FB as ForwardBatch
    participant Attn as DeepseekV2AttentionMLA
    participant KVPool as MLATokenToKVPool
    FB->>Attn: forward_prepare(...)
    Attn->>FB: _support_mha_one_shot(...)
    alt MHA_ONE_SHOT supported
        Attn->>Attn: forward_normal_one_shot_prepare(...)
        Attn->>FB: fetch_mha_one_shot_kv_indices()
        Attn->>KVPool: get_mla_kv_buffer(...)
        KVPool-->>Attn: (kv_a, k_pe)
        Attn->>Attn: forward_normal_one_shot_core(...)
    else fallback
        Attn->>Attn: forward_normal_chunked_kv_prepare(...)
    end
Loading

Sequence diagram for fused MoE Triton kernel invocation with TMA/descriptor support

sequenceDiagram
    participant Worker as BenchmarkWorker
    participant FusedMoE as FusedMoE
    participant Kernel as TritonKernel
    Worker->>FusedMoE: benchmark(...)
    FusedMoE->>Kernel: invoke_fused_moe_kernel(..., a_desc, b_desc, filter_expert)
    Kernel-->>FusedMoE: (results)
    FusedMoE-->>Worker: (latency results)
Loading

Class diagram for new and updated DeepSeek-V2 attention and MoE classes

classDiagram
    class AttnForwardMethod {
        +MHA_CHUNKED_KV
        +MHA_ONE_SHOT
        +MLA_FUSED_ROPE
    }
    class DeepseekV2AttentionMLA {
        +kv_cache_dtype
        +forward_normal_one_shot_prepare()
        +forward_normal_one_shot_core()
        +_set_mla_kv_buffer()
        +_get_mla_kv_buffer()
        +_concat_and_cast_mha_k()
    }
    class ForwardBatch {
        +mha_one_shot_kv_indices
        +mha_one_shot
        +fetch_mha_one_shot_kv_indices()
    }
    class MLATokenToKVPool {
        +get_mla_kv_buffer()
    }
    AttnForwardMethod <|-- DeepseekV2AttentionMLA
    DeepseekV2AttentionMLA <.. ForwardBatch
    ForwardBatch <.. MLATokenToKVPool
Loading

Class diagram for Fused MoE Triton kernel and config changes

classDiagram
    class BenchmarkWorker {
        +benchmark()
        +tune()
    }
    class BestConfigTrace {
        +update()
        +total_time
        +config_dict()
    }
    class MoeRunnerConfig {
        +inplace
        +num_experts
        +num_local_experts
    }
    class FusedMoE {
        +fused_experts_impl(..., filter_expert)
    }
    class FusedMoEConfig {
        +get_config_file_name(..., down_moe)
        +get_moe_configs(..., down_moe)
        +try_get_optimal_moe_config(..., return_down_config)
    }
    BenchmarkWorker <.. BestConfigTrace
    FusedMoEConfig <.. FusedMoE
Loading

File-Level Changes

Change Details Files
Add one-shot MHA method to DeepSeekV2 attention pipeline
  • Introduce AttnForwardMethod.MHA_ONE_SHOT and support predicate
  • Update backend dispatch to select one-shot mode when capacity allows
  • Implement forward_normal_one_shot_prepare/core and KV buffer set/get helpers
  • Extend ForwardBatch with mha_one_shot fields and fetch_mha_one_shot_kv_indices
  • Adjust flashinfer and flashattention backends for one-shot execution
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
Enhance fused MoE Triton kernels with tensor descriptor and TMA support
  • Add TensorDescriptor import path check and filter_expert flag
  • Extend invoke_fused_moe_kernel signature with a_desc, b_desc, c_sorted, filter_expert, a_use_tma, b_use_tma
  • Update fused_moe API to propagate new parameters and filter_expert logic
  • Augment config naming and try_get_optimal_moe_config to handle down_moe scenarios
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
Implement low-level Triton kernels for KV buffer operations
  • Add get_mla_kv_buffer_triton kernel and wrapper in memory_pool
  • Add concat_and_cast_mha_k_kernel and triton wrapper in attention utils
  • Integrate MLATokenToKVPool.get_mla_kv_buffer method
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/layers/attention/utils.py
Add benchmark and tuning script for fused MoE Triton kernels
  • Introduce tuning_fused_moe_triton_sep.py under benchmark/kernels
  • Implement ray-based distributed tuning, search-space generation and performance logging
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@justinSmileDate
Copy link

justinSmileDate commented Oct 22, 2025

Sorry to bother you, is this optimization only for the DeepSeek-R1 model? Is the DeepSeek-V3 model also available?

@TianyuZhang1214
Copy link
Collaborator Author

Sorry to bother you, is this optimization only for the DeepSeek-R1 model? Is the DeepSeek-V3 model also available?

Yes, V3 is also available.

@justinSmileDate
Copy link

Sorry to bother you, is this optimization only for the DeepSeek-R1 model? Is the DeepSeek-V3 model also available?

Yes, V3 is also available.

Thanks for your reply. When I try to reproduce your work, it shows that there is a problem with the DeepGEMM library. Can you tell me the link of the Deepgemm library you use? In this repository, I tried the sbo.v2.public branch to load DeepSeekV3 for inference, but it prompts that the deep_gemm.get_compile_mode() function cannot be found. In fact, the sbo.v2.public branch does not register the get_compile_mode() function. Although the sbo.v2.sgl branch registers the get_compile_mode() function, it fails to run the basic tests/test_fp8.py test. Could you tell me the correct steps to use it?

@JoyFuture
Copy link

Hello, I have successfully run it according to your configuration, and the performance is very good. However, there is an issue: during the operation, the sglang logs are not printed. Even when I set --log-level debug, there are no sglang-related logs, only some logs from nccl, deepgemm, and the transfer engine. How can I configure it to properly print the sglang logs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants