Skip to content

Commit eae76db

Browse files
committed
init moe interface
1 parent 5036ca3 commit eae76db

File tree

6 files changed

+220
-11
lines changed

6 files changed

+220
-11
lines changed

include/sgl_kernel_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ void fp8_blockwise_scaled_grouped_mm(
255255
const torch::Tensor& expert_offsets,
256256
const torch::Tensor& workspace);
257257

258-
void moe_grouped_mm_nt(
258+
void moe_grouped_mm_nn(
259259
torch::Tensor& output,
260260
const torch::Tensor& activations,
261261
const torch::Tensor& weights,

python/sgl_kernel/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@
5151
apply_shuffle_mul_sum,
5252
cutlass_fp4_group_mm,
5353
fp8_blockwise_scaled_grouped_mm,
54+
fused_experts,
5455
moe_align_block_size,
56+
moe_align_block_size_impl,
5557
moe_fused_gate,
56-
moe_grouped_mm_nt,
5758
moe_sum,
5859
moe_sum_reduce,
5960
prepare_moe_input,

python/sgl_kernel/moe.py

Lines changed: 208 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55

6-
def moe_align_block_size(
6+
def moe_align_block_size_impl(
77
topk_ids,
88
num_experts,
99
block_size,
@@ -25,6 +25,43 @@ def moe_align_block_size(
2525
)
2626

2727

28+
def moe_align_block_size(
29+
topk_ids,
30+
num_experts,
31+
block_size,
32+
pad_sorted_token_ids=False,
33+
):
34+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
35+
36+
sorted_ids_xpu = torch.empty(
37+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
38+
)
39+
if not pad_sorted_token_ids:
40+
sorted_ids_xpu.fill_(topk_ids.numel())
41+
max_num_m_blocks = max_num_tokens_padded // block_size
42+
expert_ids_xpu = torch.zeros(
43+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
44+
)
45+
num_tokens_post_pad_xpu = torch.empty(
46+
(1), dtype=torch.int32, device=topk_ids.device
47+
)
48+
cumsum_buffer = torch.empty(
49+
num_experts + 2, dtype=torch.int32, device=topk_ids.device
50+
)
51+
moe_align_block_size_impl(
52+
topk_ids,
53+
num_experts + 1,
54+
block_size,
55+
sorted_ids_xpu,
56+
expert_ids_xpu,
57+
num_tokens_post_pad_xpu,
58+
cumsum_buffer,
59+
pad_sorted_token_ids,
60+
)
61+
62+
return sorted_ids_xpu, expert_ids_xpu, num_tokens_post_pad_xpu
63+
64+
2865
def topk_softmax(
2966
topk_weights: torch.Tensor,
3067
topk_ids: torch.Tensor,
@@ -219,7 +256,7 @@ def cutlass_fp4_group_mm(
219256
return c.to(dtype=out_dtype)
220257

221258

222-
def moe_grouped_mm_nt(activations, weights, total_rows_for_experts, n_experts):
259+
def moe_grouped_mm_nn(activations, weights, total_rows_for_experts, n_experts):
223260
"""
224261
BF16/FP16 grouped GEMM for MoE with non-transposed weights.
225262
activations: (total_tokens, hidden_dim)
@@ -233,7 +270,175 @@ def moe_grouped_mm_nt(activations, weights, total_rows_for_experts, n_experts):
233270
device=activations.device,
234271
dtype=activations.dtype,
235272
)
236-
torch.ops.sgl_kernel.moe_grouped_mm_nt(
273+
torch.ops.sgl_kernel.moe_grouped_mm_nn(
237274
output, activations, weights, total_rows_for_experts, n_experts
238275
)
239276
return output
277+
278+
279+
def fused_experts(
280+
hidden_states: torch.Tensor,
281+
w1: torch.Tensor,
282+
w2: torch.Tensor,
283+
topk_weights: torch.Tensor,
284+
topk_ids: torch.Tensor,
285+
b1: Optional[torch.Tensor] = None,
286+
b2: Optional[torch.Tensor] = None,
287+
inplace: bool = False,
288+
activation: str = "silu",
289+
use_fp8_w8a8: bool = False,
290+
w1_scale: Optional[torch.Tensor] = None,
291+
w2_scale: Optional[torch.Tensor] = None,
292+
w1_zp: Optional[torch.Tensor] = None,
293+
w2_zp: Optional[torch.Tensor] = None,
294+
a1_scale: Optional[torch.Tensor] = None,
295+
a2_scale: Optional[torch.Tensor] = None,
296+
block_shape: Optional[list[int]] = None,
297+
no_combine: bool = False,
298+
routed_scaling_factor: Optional[float] = None,
299+
gemm1_alpha: Optional[float] = None,
300+
gemm1_limit: Optional[float] = None,
301+
) -> torch.Tensor:
302+
"""
303+
This function computes a Mixture of Experts (MoE) layer using two sets of
304+
weights, w1 and w2, and top-k gating mechanism.
305+
306+
Parameters:
307+
- hidden_states [num_tokens, hidden_dim] (torch.Tensor): The input tensor to the MoE layer.
308+
- w1 [num_experts, hidden_dim, output_channel] (torch.Tensor): The first set of expert weights.
309+
- w2 [num_experts, output_channel, hidden_dim] (torch.Tensor): The second set of expert weights.
310+
- topk_weights [num_tokens, topk] (torch.Tensor): The top-k output of the experts.
311+
- topk_ids [num_tokens, topk] (torch.Tensor): The top-k indices of the experts.
312+
- b1 (Optional[torch.Tensor]): Optional bias for w1.
313+
- b2 (Optional[torch.Tensor]): Optional bias for w2.
314+
- inplace (bool): If True, perform operations in-place to save memory. Defaults to False.
315+
- activation (str): The activation function to use ('silu' or 'gelu'). Defaults to 'silu'.
316+
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
317+
products for w1 and w2. Defaults to False.
318+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
319+
w1.
320+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
321+
w2.
322+
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
323+
a1.
324+
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
325+
a2.
326+
- block_shape: (Optional[List[int]]): Optional block size for block-wise
327+
quantization.
328+
- no_combine (bool): If True, skip the combine step. Defaults to False.
329+
- routed_scaling_factor (Optional[float]): Optional scaling factor for routed tokens, used by Llama4 only.
330+
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
331+
function.
332+
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
333+
function.
334+
335+
Returns:
336+
- torch.Tensor: The output tensor after applying the MoE layer.
337+
"""
338+
339+
assert use_fp8_w8a8 is False, "current MoE does not support use_fp8_w8a8"
340+
assert w1_scale is None, "current MoE does not support w1_scale"
341+
assert w2_scale is None, "current MoE does not support w2_scale"
342+
assert a1_scale is None, "current MoE does not support a1_scale"
343+
assert a2_scale is None, "current MoE does not support a2_scale"
344+
assert block_shape is None, "current MoE does not support block_shape"
345+
346+
# type check
347+
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
348+
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
349+
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
350+
351+
# Shape check
352+
assert hidden_states.ndim == 2, "hidden_states must be 2D"
353+
assert (
354+
hidden_states.shape[-1] == w1.shape[-2]
355+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
356+
assert (
357+
2 * w2.shape[1] == w1.shape[2]
358+
), f"w2 shape[1] {w2.shape[1]} must be half of w1 shape[2] {w1.shape[2]}"
359+
assert (topk_ids.shape == topk_weights.shape) and (
360+
topk_ids.shape[0] == hidden_states.shape[0]
361+
), f"topk_ids shape {topk_ids.shape} and topk_weights shape {topk_weights.shape} must be equal and match hidden_states shape[0] {hidden_states.shape[0]}"
362+
363+
num_tokens, _ = hidden_states.shape
364+
E, K, _ = w1.shape
365+
N, OutK = w2.shape[1]
366+
367+
M = num_tokens
368+
TopK = topk_ids.shape[1]
369+
370+
cache = torch.empty(
371+
M * TopK * max(2 * N, OutK),
372+
device=hidden_states.device,
373+
dtype=hidden_states.dtype,
374+
)
375+
intermediate_cache1 = cache[:, M * TopK * 2 * N].view((M, TopK, 2 * N))
376+
intermediate_cache2 = torch.empty(
377+
(M * TopK, N // 2),
378+
device=hidden_states.device,
379+
dtype=hidden_states.dtype,
380+
)
381+
intermediate_cache3 = cache[:, M * TopK * OutK].view((M, TopK, OutK))
382+
383+
if no_combine:
384+
assert not inplace
385+
out_hidden_states = torch.empty(
386+
(num_tokens, OutK),
387+
device=hidden_states.device,
388+
dtype=hidden_states.dtype,
389+
)
390+
elif inplace:
391+
out_hidden_states = hidden_states
392+
else:
393+
out_hidden_states = torch.empty_like(hidden_states)
394+
395+
idxs = topk_ids.argsort()
396+
counts = topk_ids.to(torch.long).bincount().cpu().numpy()
397+
tokens_per_expert = counts.cumsum()
398+
num_per_tok = TopK
399+
token_idxs = idxs // num_per_tok
400+
offset = []
401+
input_A = torch.empty(
402+
(num_tokens * TopK, K), device=hidden_states.device, dtype=hidden_states.dtype
403+
)
404+
for expert_id, end_idx in enumerate(tokens_per_expert):
405+
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
406+
offset.append(end_idx - start_idx)
407+
if start_idx == end_idx:
408+
continue
409+
exp_token_idxs = token_idxs[start_idx:end_idx]
410+
# expert_tokens = hidden_states[exp_token_idxs]
411+
# grouped_input_A.append(expert_tokens)
412+
input_A[start_idx:end_idx, :].copy_(hidden_states[exp_token_idxs])
413+
offset = torch.tensor(offset, device="cpu", dtype=torch.int32)
414+
415+
torch.ops.sglang.moe_grouped_mm_nn(
416+
intermediate_cache1,
417+
input_A,
418+
w1,
419+
offset,
420+
)
421+
422+
gate, up_ = torch.split(intermediate_cache1, N, dim=1)
423+
act = torch.nn.SiLU()
424+
intermediate_cache2 = act(gate) * up_
425+
426+
torch.ops.sglang.moe_grouped_mm_nn(
427+
intermediate_cache3,
428+
intermediate_cache2.contiguous(),
429+
w2,
430+
offset,
431+
)
432+
for expert_id, end_idx in enumerate(tokens_per_expert):
433+
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
434+
if start_idx == end_idx:
435+
continue
436+
437+
exp_token_idxs = token_idxs[start_idx:end_idx]
438+
expert_out = intermediate_cache3[start_idx:end_idx]
439+
expert_out.mul_(topk_weights[idxs[start_idx:end_idx]])
440+
out_hidden_states.scatter_reduce_(
441+
0, exp_token_idxs.view(-1, 1).repeat(1, OutK), expert_out, reduce="sum"
442+
)
443+
444+
return out_hidden_states

src/sycl/GroupGemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ struct MoERunner {
171171
}
172172
};
173173

174-
void moe_grouped_mm_nt(
174+
void moe_grouped_mm_nn(
175175
torch::Tensor& output,
176176
const torch::Tensor& activations,
177177
const torch::Tensor& weights,
@@ -195,7 +195,7 @@ void moe_grouped_mm_nt(
195195
activations.scalar_type() == weights.scalar_type(), "activations and weights must have the same data type");
196196
TORCH_CHECK(
197197
activations.scalar_type() == at::ScalarType::Half || activations.scalar_type() == at::ScalarType::BFloat16,
198-
"Only float16 and bfloat16 are supported in moe_grouped_mm_nt");
198+
"Only float16 and bfloat16 are supported in moe_grouped_mm_nn");
199199

200200
if (activations.scalar_type() == at::ScalarType::BFloat16) {
201201
auto stream = at::xpu::getCurrentXPUStream();

src/torch_extension_sycl.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
6363
m.impl("moe_sum", torch::kXPU, &moe_sum);
6464

6565
m.def(
66-
"moe_grouped_mm_nt(Tensor output, Tensor activations, Tensor weights, Tensor total_rows_for_experts, int "
66+
"moe_grouped_mm_nn(Tensor output, Tensor activations, Tensor weights, Tensor total_rows_for_experts, int "
6767
"n_experts) -> ()");
68-
m.impl("moe_grouped_mm_nt", torch::kXPU, &moe_grouped_mm_nt);
68+
m.impl("moe_grouped_mm_nn", torch::kXPU, &moe_grouped_mm_nn);
6969

7070
// m.def(
7171
// "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype,

tests/test_moe_align.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import triton
66
import triton.language as tl
7-
from sgl_kernel import moe_align_block_size, moe_sum
7+
from sgl_kernel import moe_align_block_size_impl, moe_sum
88

99

1010
def ceil_div(a, b):
@@ -180,7 +180,7 @@ def test_moe_align_block_size_compare_implementations(
180180
expert_ids_triton = torch.zeros_like(expert_ids_xpu)
181181
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_xpu)
182182

183-
moe_align_block_size(
183+
moe_align_block_size_impl(
184184
topk_ids,
185185
num_experts + 1,
186186
block_size,
@@ -233,6 +233,9 @@ def test_moe_align_block_size_compare_implementations(
233233
block_sorted_start:block_sorted_end
234234
].sort()[0]
235235

236+
import pdb
237+
238+
pdb.set_trace()
236239
assert torch.allclose(
237240
selected_sorted_ids_xpu,
238241
selected_sorted_ids_triton,

0 commit comments

Comments
 (0)