You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This function computes a Mixture of Experts (MoE) layer using two sets of
304
+
weights, w1 and w2, and top-k gating mechanism.
305
+
306
+
Parameters:
307
+
- hidden_states [num_tokens, hidden_dim] (torch.Tensor): The input tensor to the MoE layer.
308
+
- w1 [num_experts, hidden_dim, output_channel] (torch.Tensor): The first set of expert weights.
309
+
- w2 [num_experts, output_channel, hidden_dim] (torch.Tensor): The second set of expert weights.
310
+
- topk_weights [num_tokens, topk] (torch.Tensor): The top-k output of the experts.
311
+
- topk_ids [num_tokens, topk] (torch.Tensor): The top-k indices of the experts.
312
+
- b1 (Optional[torch.Tensor]): Optional bias for w1.
313
+
- b2 (Optional[torch.Tensor]): Optional bias for w2.
314
+
- inplace (bool): If True, perform operations in-place to save memory. Defaults to False.
315
+
- activation (str): The activation function to use ('silu' or 'gelu'). Defaults to 'silu'.
316
+
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
317
+
products for w1 and w2. Defaults to False.
318
+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
319
+
w1.
320
+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
321
+
w2.
322
+
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
323
+
a1.
324
+
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
325
+
a2.
326
+
- block_shape: (Optional[List[int]]): Optional block size for block-wise
327
+
quantization.
328
+
- no_combine (bool): If True, skip the combine step. Defaults to False.
329
+
- routed_scaling_factor (Optional[float]): Optional scaling factor for routed tokens, used by Llama4 only.
330
+
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
331
+
function.
332
+
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
333
+
function.
334
+
335
+
Returns:
336
+
- torch.Tensor: The output tensor after applying the MoE layer.
337
+
"""
338
+
339
+
assertuse_fp8_w8a8isFalse, "current MoE does not support use_fp8_w8a8"
340
+
assertw1_scaleisNone, "current MoE does not support w1_scale"
341
+
assertw2_scaleisNone, "current MoE does not support w2_scale"
342
+
asserta1_scaleisNone, "current MoE does not support a1_scale"
343
+
asserta2_scaleisNone, "current MoE does not support a2_scale"
344
+
assertblock_shapeisNone, "current MoE does not support block_shape"
345
+
346
+
# type check
347
+
asserthidden_states.dtype==torch.bfloat16, "hidden_states must be bfloat16"
348
+
assertw1.dtype==torch.bfloat16, "w1 must be bfloat16"
349
+
assertw2.dtype==torch.bfloat16, "w2 must be bfloat16"
350
+
351
+
# Shape check
352
+
asserthidden_states.ndim==2, "hidden_states must be 2D"
353
+
assert (
354
+
hidden_states.shape[-1] ==w1.shape[-2]
355
+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
356
+
assert (
357
+
2*w2.shape[1] ==w1.shape[2]
358
+
), f"w2 shape[1] {w2.shape[1]} must be half of w1 shape[2] {w1.shape[2]}"
359
+
assert (topk_ids.shape==topk_weights.shape) and (
360
+
topk_ids.shape[0] ==hidden_states.shape[0]
361
+
), f"topk_ids shape {topk_ids.shape} and topk_weights shape {topk_weights.shape} must be equal and match hidden_states shape[0] {hidden_states.shape[0]}"
0 commit comments