Skip to content

[Feature, Hardware] Enable DeepseekV3 on AMD GPUs#2601

Merged
HaiShaw merged 27 commits intosgl-project:mainfrom
BruceXcluding:main
Jan 3, 2025
Merged

[Feature, Hardware] Enable DeepseekV3 on AMD GPUs#2601
HaiShaw merged 27 commits intosgl-project:mainfrom
BruceXcluding:main

Conversation

@BruceXcluding
Copy link
Copy Markdown
Contributor

@BruceXcluding BruceXcluding commented Dec 26, 2024

Motivation

  • Support DeepseekV3 on AMD Instinct MI300X GPU

Modifications

  • Add proper fix for AMD FP8 e4m3fnuz to support DeepseekV3 FP8 model
  • Bypass FlashInfer backend bmm_fp8 to cast FP8 to BF16 in MLA
  • Add AMD triton stages config

TODO

How to run

build env

cd sglang/docker

docker build –t sglang-rocm:latest –f Dockerfile.rocm .
 
docker run -it --ipc=host \ 
               --cap-add=SYS_PTRACE \
               --network=host \ 
               --device=/dev/kfd --device=/dev/dri \
               --security-opt seccomp=unconfined \ 
               --group-add video \
               --privileged \
               -w /workspace sglang-rocm:latest 

offline:

python -m sglang.bench_one_batch --batch-size 32 --input 128 --output 32 --model /data/DeepSeek-V3-Base/ --tp 8 --trust-remote-code

Warmup ...
Prefill. latency: 6.46569 s, throughput:    633.50 token/s
Decode.  latency: 2.58990 s, throughput:     12.36 token/s
Decode.  latency: 0.07421 s, throughput:    431.21 token/s
Decode.  latency: 0.07358 s, throughput:    434.90 token/s
Decode.  latency: 0.07341 s, throughput:    435.91 token/s
Decode.  latency: 0.07385 s, throughput:    433.30 token/s
Decode.  median latency: 0.07383 s, median throughput:    433.44 token/s
Total. latency:  9.498 s, throughput:    458.19 token/s
Benchmark ...
Prefill. latency: 0.54745 s, throughput:   7482.01 token/s
Decode.  latency: 0.07250 s, throughput:    441.41 token/s
Decode.  latency: 0.07399 s, throughput:    432.46 token/s
Decode.  latency: 0.07309 s, throughput:    437.84 token/s
Decode.  latency: 0.07335 s, throughput:    436.27 token/s
Decode.  latency: 0.07333 s, throughput:    436.38 token/s
Decode.  median latency: 0.07358 s, median throughput:    434.88 token/s
Total. latency:  2.828 s, throughput:   1810.38 token/s

server:

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-Base --tp 8 --trust-remote-code

python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8

Accuracy: 0.950
Invalid: 0.000

Issues

  • If you get the error like raise OutOfResources(self.metadata.shared, max_shared, "shared memory"), same with [Bug] Deepseek-v2-lite AMD MI300 run failed #2384
    Solved with python/sglang/srt/layers/attention/triton_ops/decode_attention.py +410
  • If you get an error like ImportError: cannot import name 'build_regex_from_schema' from 'outlines.fsm.json_schema', same with [Bug] SGLang v0.4.0 with AMD MI300X #2530
    Solved with downgrade vllm
  • If you get an error like `RuntimeError: [enforce fail at /app/pytorch/third_party/gloo/gloo/transport/tcp/device.cc:83] ifa != nullptr. Unable to find address for: eth0'
    Solved with ifconfig check your eth number and export GLOO_SOCKET_IFNAME=your eth

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@zhyncs zhyncs added bug Something isn't working amd labels Dec 26, 2024
@carlushuang
Copy link
Copy Markdown

@HaiShaw

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Dec 26, 2024

@BruceXcluding Can we just add the fix to unlock v3 from the triton kernel config error first?

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Dec 26, 2024

@BruceXcluding Can we just add the fix to unlock v3 from the triton kernel config error first?

That would be nice. I plan to release v0.4.1.post1 soon to enable users to use AMD MI300X initially.

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Dec 26, 2024

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BruceXcluding
Some to address, thanks!

Comment thread docker/Dockerfile.rocm Outdated
ENV NCCL_MIN_NCHANNELS=112

ENV MOE_PADDING=1
ENV MOE_PADDING=0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep MOE_PADDING on for performance, to error it incurs we need to fix it.

Comment thread docker/Dockerfile.rocm Outdated
logit_cap,
):
BLOCK = 32
BLOCK = 16 if is_hip() else 32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not cut by half for HIP globally here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't work well in latest vllm with BLOCK 32

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part we can not take as it is - it will cost performance of all other models in large margin.

# WEIGHT
weight_dtype = (
torch.float8_e4m3fn
torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have this, serialized weight is always OCP (torch.float8_e4m3fn)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would encounter the error "python/sglang/srt/layers/quantization/fp8_kernel.py:176:33: error: Unsupported conversion from 'f8E4M3FN' to 'f16'
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]" with torch.float8_e4m3fn at w8a8_block_fp8_matmul

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check how normalize_e4m3fn_to_e4m3fnuz is used.
Basically - we do not expected non-OCP/e4m3fn dtype in the quantized model.

is_marlin: bool,
) -> Dict[str, int]:
if dtype == "fp8_w8a8":
if dtype == "fp8_w8a8" and not is_hip():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following block isn't a breaker to HIP

Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py Outdated
Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BruceXcluding
Also see this error below with your version of pyproject.toml:

  File "/dockerx/1226/HS/sglang/python/sglang/srt/constrained/outlines_backend.py", line 23, in <module>
    from outlines.fsm.json_schema import build_regex_from_schema
ImportError: cannot import name 'build_regex_from_schema' from 'outlines.fsm.json_schema' (/usr/local/lib/python3.12/dist-packages/outlines/fsm/json_schema.py)

@ZJLi2013
Copy link
Copy Markdown

the CI failure: PR Test / unit-test-backend-2-gpu, used a lite model 'deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct', which doesn't has fp8 block-level quant feature

@BruceXcluding BruceXcluding marked this pull request as ready for review December 27, 2024 05:56
# WEIGHT
weight_dtype = (
torch.float8_e4m3fn
torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check how normalize_e4m3fn_to_e4m3fnuz is used.
Basically - we do not expected non-OCP/e4m3fn dtype in the quantized model.


if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
params_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem here - check out the previous usage from normalize_e4m3fn_to_e4m3fnuz

@BruceXcluding BruceXcluding marked this pull request as draft December 31, 2024 01:36
@BruceXcluding
Copy link
Copy Markdown
Contributor Author

@AdjectiveAllison we are targeted to fix accuracy issue with fp8, do you see garbled output with bf16 as well? We will tune performance with config.json provided soon. Are you using MI308?

No, output on full bf16 works perfectly. I'm on an 8x mi300x machine. 192GB of vram each.

@AdjectiveAllison Can you try with the latest instruction

Comment thread python/sglang/srt/server.py Outdated
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
if "GLOO_SOCKET_IFNAME" not in os.environ:
os.environ["GLOO_SOCKET_IFNAME"] = "eth0"
# TODO(fix socket error with gpu backend)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used for cpu backend or specific workstation? get RuntimeError: [enforce fail at pytorch/third_party/gloo/gloo/transport/tcp/device.cc:83] ifa != nullptr. Unable to find address for: eth0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used for multi-node tensor parallelism. Instead of using comments, we suggest adding an is_hip flag.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the value set for the GLOO_SOCKET_IFNAME environment variable should depend on the name of the network interface card in each user's system and should not be hard-coded as eth0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wufann If the user's value is not eth0, they should specify it explicitly, this applies only when no setting is provided, with eth0 as the default.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs Different network interface ( "ens" ) may be used. Also they may test in a single node envrionment where IB is not configured. In that case IB should be disabled

Comment thread sgl-kernel/amd/CMakeLists.txt Outdated
@@ -0,0 +1,51 @@
cmake_minimum_required(VERSION 3.18)
Copy link
Copy Markdown
Collaborator

@zhyncs zhyncs Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this, we only use CMakeLists.txt for clangd indexing, so it's not necessary.

Comment thread sgl-kernel/amd/pyproject.toml Outdated
build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refer to the setup of flash-attention or vllm compatible with NVIDIA and AMD?
https://github.com/Dao-AILab/flash-attention/blob/main/setup.py
https://github.com/vllm-project/vllm/blob/main/setup.py

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Jan 2, 2025

Hi @BruceXcluding @HaiShaw
#2712
You can now try using moe_align_block_size_triton on AMD.

@BruceXcluding
Copy link
Copy Markdown
Contributor Author

Hi @BruceXcluding @HaiShaw #2712 You can now try using moe_align_block_size_triton on AMD.

Tested and works well. We could build sgl-kernel-amd after we add ck kernels

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Jan 2, 2025

Hi @BruceXcluding @HaiShaw #2712 You can now try using moe_align_block_size_triton on AMD.

Tested and works well. We could build sgl-kernel-amd after we add ck kernels

@BruceXcluding, How was the performance comparing to sgl-kernel-amd?

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Jan 2, 2025

Hi @BruceXcluding @HaiShaw
Before releasing v0.4.1.post4 #2713, I hope the main branch has a version compatible with AMD MI300X. What minimal changes are needed to achieve this? The requirement is just to get it running, performance optimization can be done later.

@zhyncs zhyncs marked this pull request as ready for review January 2, 2025 18:42
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Jan 2, 2025

@zhyncs I am expecting @BruceXcluding to do the final update.
@BruceXcluding can you confirm the decode_attention.py change?

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BruceXcluding thanks!

@HaiShaw HaiShaw dismissed merrymercy’s stale review January 3, 2025 00:23

had been address above

@HaiShaw HaiShaw merged commit c7ae474 into sgl-project:main Jan 3, 2025
@BruceXcluding
Copy link
Copy Markdown
Contributor Author

Hi @BruceXcluding @HaiShaw Before releasing v0.4.1.post4 #2713, I hope the main branch has a version compatible with AMD MI300X. What minimal changes are needed to achieve this? The requirement is just to get it running, performance optimization can be done later.

Thanks @zhyncs @HaiShaw. we will keep the TODO list on track for performance improvement.

XiaotongJiang pushed a commit to XiaotongJiang/sglang that referenced this pull request Jan 3, 2025
Co-authored-by: root <root@banff-cyxtera-s83-5.amd.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Bruce Xue <yigex@xilinx.com>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor

yiakwy-xpu-ml-framework-team commented Jan 3, 2025

Hi @BruceXcluding @HaiShaw Before releasing v0.4.1.post4 #2713, I hope the main branch has a version compatible with AMD MI300X. What minimal changes are needed to achieve this? The requirement is just to get it running, performance optimization can be done later.

Thanks @zhyncs @HaiShaw. we will keep the TODO list on track for performance improvement.

Yes theoretical throughput is

4800 (memory transaction speed) / 37 * 1.8 (MTP multiplier) ~ 233 tok/gpu/sec, arrond 1868 toks/sec for 8 cards

There are spaces to improve.

timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
Co-authored-by: root <root@banff-cyxtera-s83-5.amd.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Bruce Xue <yigex@xilinx.com>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd bug Something isn't working high priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.