Skip to content

Commit 42135d6

Browse files
[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (vllm-project#32414)
1 parent e14467b commit 42135d6

82 files changed

Lines changed: 2699 additions & 1552 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.buildkite/test-pipeline.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,46 @@ steps:
634634
- pip install helion
635635
- pytest -v -s kernels/helion/
636636

637+
638+
- label: Kernels FP8 MoE Test (1 H100)
639+
timeout_in_minutes: 90
640+
gpu: h100
641+
num_gpus: 1
642+
optional: true
643+
commands:
644+
- pytest -v -s kernels/moe/test_cutlass_moe.py
645+
- pytest -v -s kernels/moe/test_flashinfer.py
646+
- pytest -v -s kernels/moe/test_gpt_oss_triton_kernels.py
647+
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py
648+
- pytest -v -s kernels/moe/test_moe.py
649+
# - pytest -v -s kernels/moe/test_block_fp8.py - failing on main
650+
- pytest -v -s kernels/moe/test_block_int8.py
651+
- pytest -v -s kernels/moe/test_triton_moe_no_act_mul.py
652+
- pytest -v -s kernels/moe/test_triton_moe_ptpc_fp8.py
653+
654+
- label: Kernels FP8 MoE Test (2 H100s)
655+
timeout_in_minutes: 90
656+
gpu: h100
657+
num_gpus: 2
658+
optional: true
659+
commands:
660+
- pytest -v -s kernels/moe/test_deepep_deepgemm_moe.py
661+
- pytest -v -s kernels/moe/test_deepep_moe.py
662+
- pytest -v -s kernels/moe/test_pplx_cutlass_moe.py
663+
# - pytest -v -s kernels/moe/test_pplx_moe.py - failing on main
664+
665+
- label: Kernels Fp4 MoE Test (B200)
666+
timeout_in_minutes: 60
667+
gpu: b200
668+
num_gpus: 1
669+
optional: true
670+
commands:
671+
- pytest -v -s kernels/moe/test_cutedsl_moe.py
672+
- pytest -v -s kernels/moe/test_flashinfer_moe.py
673+
- pytest -v -s kernels/moe/test_nvfp4_moe.py
674+
- pytest -v -s kernels/moe/test_ocp_mx_moe.py
675+
676+
637677
- label: Model Executor Test # 23min
638678
timeout_in_minutes: 35
639679
torch_nightly: true

benchmarks/kernels/benchmark_cutlass_moe_fp8.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12+
from tests.kernels.moe.utils import make_dummy_moe_config
1213
from vllm import _custom_ops as ops
1314
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
1415
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
@@ -138,12 +139,13 @@ def bench_run(
138139
fn = mk.FusedMoEModularKernel(
139140
MoEPrepareAndFinalizeNoEP(),
140141
CutlassExpertsFp8(
141-
out_dtype=a.dtype,
142-
e=num_experts,
143-
n=n,
144-
k=k,
142+
moe_config=make_dummy_moe_config(
143+
num_experts=num_experts,
144+
hidden_dim=k,
145+
intermediate_size_per_partition=n,
146+
in_dtype=a.dtype,
147+
),
145148
quant_config=quant_config,
146-
device=w1.device,
147149
),
148150
)
149151

benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.utils.benchmark as benchmark
1313

1414
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
15+
from tests.kernels.moe.utils import make_dummy_moe_config
1516
from vllm import _custom_ops as ops
1617
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
1718
from vllm.model_executor.layers.fused_moe.config import (
@@ -198,8 +199,7 @@ def run_cutlass_moe_fp4(
198199
kernel = mk.FusedMoEModularKernel(
199200
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
200201
CutlassExpertsFp4(
201-
out_dtype=dtype,
202-
max_experts_per_worker=e,
202+
make_dummy_moe_config(),
203203
quant_config=quant_config,
204204
),
205205
)
@@ -244,8 +244,7 @@ def run_cutlass_from_graph(
244244
kernel = mk.FusedMoEModularKernel(
245245
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
246246
CutlassExpertsFp4(
247-
out_dtype=dtype,
248-
max_experts_per_worker=e,
247+
make_dummy_moe_config(),
249248
quant_config=quant_config,
250249
),
251250
)

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from benchmark_shapes import WEIGHT_SHAPES_MOE
77

88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9+
from tests.kernels.moe.utils import make_dummy_moe_config
910
from vllm import _custom_ops as ops
1011
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
1112
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
@@ -134,13 +135,13 @@ def run_cutlass_moe(
134135
fn = mk.FusedMoEModularKernel(
135136
MoEPrepareAndFinalizeNoEP(),
136137
CutlassExpertsFp8(
137-
out_dtype=a.dtype,
138-
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
139-
e=w2.shape[0],
140-
n=w2.shape[2],
141-
k=w2.shape[1],
138+
moe_config=make_dummy_moe_config(
139+
num_experts=w2.shape[0],
140+
hidden_dim=w2.shape[1],
141+
intermediate_size_per_partition=w2.shape[2],
142+
in_dtype=a.dtype,
143+
),
142144
quant_config=quant_config,
143-
device=w1.device,
144145
),
145146
)
146147

@@ -166,13 +167,13 @@ def run_cutlass_from_graph(
166167
fn = mk.FusedMoEModularKernel(
167168
MoEPrepareAndFinalizeNoEP(),
168169
CutlassExpertsFp8(
169-
out_dtype=a.dtype,
170-
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
171-
e=w2.shape[0],
172-
n=w2.shape[2],
173-
k=w2.shape[1],
170+
moe_config=make_dummy_moe_config(
171+
num_experts=w2.shape[0],
172+
hidden_dim=w2.shape[1],
173+
intermediate_size_per_partition=w2.shape[2],
174+
in_dtype=a.dtype,
175+
),
174176
quant_config=quant_config,
175-
device=w1.device,
176177
),
177178
)
178179

benchmarks/kernels/benchmark_moe.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@
1616
from ray.experimental.tqdm_ray import tqdm
1717

1818
from vllm.model_executor.layers.fused_moe.config import (
19+
FusedMoEConfig,
20+
FusedMoEParallelConfig,
1921
FusedMoEQuantConfig,
22+
RoutingMethodType,
2023
_get_config_dtype_str,
2124
)
2225
from vllm.model_executor.layers.fused_moe.fused_moe import *
26+
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
27+
TritonOrDeepGemmExperts,
28+
)
2329
from vllm.platforms import current_platform
2430
from vllm.transformers_utils.config import get_config
2531
from vllm.triton_utils import triton
@@ -194,10 +200,33 @@ def run():
194200
block_shape=block_quant_shape,
195201
)
196202

203+
deep_gemm_experts = mk.FusedMoEModularKernel(
204+
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
205+
fused_experts=TritonOrDeepGemmExperts(
206+
moe_config=FusedMoEConfig(
207+
num_experts=num_experts,
208+
experts_per_token=topk,
209+
hidden_dim=hidden_size,
210+
intermediate_size_per_partition=shard_intermediate_size,
211+
num_local_experts=num_experts,
212+
activation="silu",
213+
parallel_config=FusedMoEParallelConfig.make_no_parallel(),
214+
in_dtype=init_dtype,
215+
routing_method=RoutingMethodType.TopK,
216+
),
217+
quant_config=quant_config,
218+
),
219+
)
220+
197221
with override_config(config):
198222
topk_weights, topk_ids, token_expert_indices = fused_topk(
199223
x, input_gating, topk, renormalize=not use_deep_gemm
200224
)
225+
226+
if use_deep_gemm:
227+
return deep_gemm_experts(
228+
x, w1, w2, topk_weights, topk_ids, inplace=True
229+
)
201230
return fused_experts(
202231
x,
203232
w1,
@@ -206,7 +235,6 @@ def run():
206235
topk_ids,
207236
inplace=True,
208237
quant_config=quant_config,
209-
allow_deep_gemm=use_deep_gemm,
210238
)
211239

212240
# JIT compilation & warmup

docs/design/moe_kernel_features.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
8585
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
8686
| triton | standard | all<sup>1</sup> | G,A,T | silu, gelu,</br>swigluoai,</br>silu_no_mul,</br>gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],</br>[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
8787
| triton (batched) | batched | all<sup>1</sup> | G,A,T | silu, gelu | <sup>6</sup> | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
88-
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
88+
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | </br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
8989
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
9090
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
91-
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
91+
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
9292
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
9393
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
9494
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |

tests/compile/test_fusion_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.model_executor.layers.quantization.utils.quant_utils import (
4444
QuantKey,
4545
kFp8StaticTensorSym,
46-
kNvfp4Quant,
46+
kNvfp4Dynamic,
4747
)
4848
from vllm.platforms import current_platform
4949
from vllm.utils.flashinfer import has_flashinfer
@@ -215,7 +215,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
215215
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
216216
"""Test model for AttentionNvfp4QuantPattern fusion."""
217217

218-
quant_key = kNvfp4Quant
218+
quant_key = kNvfp4Dynamic
219219

220220
def __init__(self, *args, **kwargs):
221221
super().__init__(*args, **kwargs)
@@ -468,7 +468,7 @@ def test_attention_quant_pattern(
468468

469469
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
470470
# Only output quant ops are fused into attention.
471-
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
471+
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
472472

473473
# access the underlying `AttnFusionPass` on the `LazyInitPass`
474474
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from vllm.model_executor.layers.quantization.utils.quant_utils import (
4545
GroupShape,
4646
kFp8StaticTensorSym,
47-
kNvfp4Quant,
47+
kNvfp4Dynamic,
4848
)
4949
from vllm.platforms import current_platform
5050

@@ -134,11 +134,11 @@ def forward(self, x):
134134
def ops_in_model_before(self):
135135
return [
136136
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
137-
QUANT_OPS[kNvfp4Quant],
137+
QUANT_OPS[kNvfp4Dynamic],
138138
]
139139

140140
def ops_in_model_after(self):
141-
return [FUSED_OPS[kNvfp4Quant]]
141+
return [FUSED_OPS[kNvfp4Dynamic]]
142142

143143

144144
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):

tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ accuracy_threshold: 0.92
33
num_questions: 1319
44
num_fewshot: 5
55
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
6+
env:
7+
VLLM_USE_FLASHINFER_MOE_FP8: "0"
8+
VLLM_USE_DEEP_GEMM: "0"

tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enab
66
env:
77
VLLM_USE_DEEP_GEMM: "1"
88
VLLM_USE_DEEP_GEMM_MOE: "1"
9-
VLLM_USE_DEEP_GEMM_E8M0: "0"

0 commit comments

Comments
 (0)