diff --git a/backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc b/backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc index 13d293b9f2..161641a1ef 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc @@ -416,8 +416,7 @@ void FusedGateMoeKernel( const phi::DenseTensor& gate_up_weights, const phi::DenseTensor& down_weights, const paddle::optional& hidden_states_scales, - const paddle::optional>& - intermediate_hidden_states_scales, + const paddle::optional& intermediate_hidden_states_scales, const paddle::optional& gate_up_weights_scales, const paddle::optional& down_weights_scales, phi::DenseTensor* final_hidden_states, @@ -468,9 +467,7 @@ void FusedGateMoeKernel( ct.AddN(down_weights); if (intermediate_hidden_states_scales) { - for (const auto& t : intermediate_hidden_states_scales.get()) { - ct.Add(t); - } + ct.AddN(intermediate_hidden_states_scales.get()); } if (gate_up_weights_scales) { ct.AddN(gate_up_weights_scales.get()); @@ -514,8 +511,7 @@ void CallFusedGateMoeKernel( const phi::DenseTensor& gate_up_weights, const phi::DenseTensor& down_weights, const paddle::optional& hidden_states_scales, - const paddle::optional>& - intermediate_hidden_states_scales, + const paddle::optional& intermediate_hidden_states_scales, const paddle::optional& gate_up_weights_scales, const paddle::optional& down_weights_scales, phi::DenseTensor* final_hidden_states, @@ -634,7 +630,7 @@ std::vector FusedGateMoeForward( *gate_up_weights_tensor, *down_weights_tensor, paddle::optional(), /* hidden_states_scale */ - paddle::optional>(), /* intermediate */ + paddle::optional(), /* intermediate */ paddle::optional(), /* gate_up_weights_scales */ paddle::optional(), /* down_weights_scales */ final_hidden_states.get(), @@ -660,8 +656,7 @@ std::vector FusedGateMoeFP8Forward( const paddle::Tensor& gate_up_weights, const paddle::Tensor& down_weights, const paddle::optional& hidden_states_scales, - const paddle::optional>& - intermediate_hidden_states_scales, + const paddle::optional& intermediate_hidden_states_scales, const paddle::Tensor& gate_up_weights_scales, const paddle::Tensor& down_weights_scales, const int top_k, @@ -701,14 +696,16 @@ std::vector FusedGateMoeFP8Forward( paddle::optional(*hidden_states_scales_dt); } + auto intermediate_hidden_states_scales_tensor = + paddle::optional(); bool dynamic_scale = true; - std::vector scales_vec; if (intermediate_hidden_states_scales) { dynamic_scale = false; - for (const auto& t : intermediate_hidden_states_scales.get()) { - scales_vec.push_back( - *static_cast(t.impl().get())); - } + auto intermediate_hidden_states_scales_dt = static_cast( + intermediate_hidden_states_scales->impl().get()); + intermediate_hidden_states_scales_tensor = + paddle::optional( + *intermediate_hidden_states_scales_dt); } auto gate_up_weights_scales_tensor = paddle::optional(); auto gate_up_weights_scales_dt = @@ -735,7 +732,7 @@ std::vector FusedGateMoeFP8Forward( *gate_up_weights_tensor, *down_weights_tensor, hidden_states_scales_tensor, - scales_vec, + intermediate_hidden_states_scales_tensor, gate_up_weights_scales_tensor, down_weights_scales_tensor, final_hidden_states.get(), @@ -816,7 +813,7 @@ std::vector FusedGateMoeBlockWiseFP8Forward( *gate_up_weights_tensor, *down_weights_tensor, paddle::optional(), /* hidden_states_scale */ - paddle::optional>(), /* intermediate */ + paddle::optional(), /* intermediate */ gate_up_weights_scales_tensor, down_weights_scales_tensor, final_hidden_states.get(), @@ -887,7 +884,7 @@ PD_BUILD_OP(fused_gate_moe_fp8) "gate_up_weights", "down_weights", paddle::Optional("hidden_states_scales"), - paddle::Optional(paddle::Vec("intermediate_hidden_states_scales")), + paddle::Optional("intermediate_hidden_states_scales"), "gate_up_weights_scales", "down_weights_scales"}) .Outputs({"final_hidden_states"}) diff --git a/backends/intel_hpu/kernels/funcs.h b/backends/intel_hpu/kernels/funcs.h index db164fbed3..13ec591598 100644 --- a/backends/intel_hpu/kernels/funcs.h +++ b/backends/intel_hpu/kernels/funcs.h @@ -352,6 +352,12 @@ class ConvertTensors { if (addr_offset % 0x80 != 0) { PADDLE_THROW("Tensor list offset is not algined."); } + if (dims.size() == 2 && dims[0] == 1) { + // [list_size][1] is padded for 0x80 alignment as + // [list_size][1][padded_num]. + // now addr_offset = 0x80; + dims.pop_back(); + } if (is_input) { for (int64_t tensor_idx = 0; tensor_idx < num_list; tensor_idx++) { diff --git a/backends/intel_hpu/tests/unittests/test_fused_fp8_qkv_rope.py b/backends/intel_hpu/tests/unittests/test_fused_fp8_qkv_rope.py index bb2db2c794..6e94a0388f 100644 --- a/backends/intel_hpu/tests/unittests/test_fused_fp8_qkv_rope.py +++ b/backends/intel_hpu/tests/unittests/test_fused_fp8_qkv_rope.py @@ -21,6 +21,7 @@ import os intel_hpus_module_id = os.environ.get("FLAGS_selected_intel_hpus", 4) +paddle.device.set_device(f"intel_hpu:{intel_hpus_module_id}") paddle.seed(2025) diff --git a/backends/intel_hpu/tests/unittests/test_fused_gate_moe.py b/backends/intel_hpu/tests/unittests/test_fused_gate_moe.py index e0b5c4021b..c78d2a0536 100644 --- a/backends/intel_hpu/tests/unittests/test_fused_gate_moe.py +++ b/backends/intel_hpu/tests/unittests/test_fused_gate_moe.py @@ -90,13 +90,6 @@ def setup_logging(ep_rank, tp_rank, enable_logging=False): def init_distributed(ep_size=1, tp_size=1): - - if not dist.is_initialized(): - try: - dist.init_parallel_env() - except Exception as e: - raise RuntimeError("Failed to initialize distributed environment") from e - global_rank = dist.get_rank() world_size = dist.get_world_size() @@ -150,20 +143,22 @@ def check_using_cosine_similarity( if norm1 == 0 or norm2 == 0: cos_sim = 1.0 if np.array_equal(vec1, vec2) else 0.0 + mag_sim = 1.0 if norm1 == norm2 else 0.0 else: cos_sim = np.dot(vec1, vec2) / (norm1 * norm2) + mag_sim = min(norm1 / norm2, norm2 / norm1) logger.info( f"Cosine similarity: {cos_sim}, \n" + f"Euclidean similarity: {mag_sim}, \n" f"required_similarity: {required_similarity}, ", extra={"ep_rank": ep_rank, "tp_rank": tp_rank}, ) - print(f"Cosine similarity: {cos_sim}") - return cos_sim >= required_similarity + print(f"Cosine similarity: {cos_sim}, Euclidean similarity: {mag_sim}") + return cos_sim >= required_similarity and mag_sim >= required_similarity def tensorwise_cast_to_fp8(tensor, scale): - scale = paddle.to_tensor(scale, dtype=tensor.dtype) x_scaled = (tensor * scale).cast(paddle.float8_e4m3fn) return x_scaled @@ -186,7 +181,7 @@ def channelwise_quant_to_fp8(tensor): x_amax = paddle.amax(x_abs, axis=0) # shape: [N] x_amax = paddle.clip(x_amax, min=1e-4) scale = x_amax / 240.0 # shape: [N] - scale = paddle.to_tensor(scale, dtype=paddle.bfloat16) + scale = paddle.cast(scale, dtype=paddle.bfloat16) x_scaled = (tensor / scale).astype(paddle.float8_e4m3fn) return x_scaled, scale @@ -244,13 +239,14 @@ def generate_tensors( else: raise ValueError(f"Unsupported dtype: {dtype}") - hidden_states = (paddle.rand([num_tokens, hidden_dim], dtype=paddle_dtype) * 10) - 5 + hidden_states = np.random.randn(num_tokens, hidden_dim).astype(float) * 4 - 2 + hidden_states = paddle.to_tensor(hidden_states, dtype=paddle_dtype) + # hidden_states = (paddle.randn([num_tokens, hidden_dim], dtype=paddle_dtype) * 4) - 2 route_gate_weight = ( - paddle.rand([hidden_dim, num_experts], dtype=paddle.float32) * 0.6 - ) - 0.3 - gate_correction_bias = ( - paddle.rand([1, num_experts], dtype=paddle.float32) * 128 - ) - 64 + paddle.rand([hidden_dim, num_experts], dtype=paddle.float32) * 0.06 + ) - 0.03 + gate_correction_bias = paddle.rand([1, num_experts], dtype=paddle.float32) * 0.001 + up_weights = [ (paddle.rand([hidden_dim, ffn_dim], dtype=paddle_dtype) * 0.6) - 0.3 for _ in range(num_experts) @@ -269,17 +265,10 @@ def generate_tensors( gate_weights = [w.transpose([1, 0]) for w in gate_weights] down_weights = [w.transpose([1, 0]) for w in down_weights] - if fused_weights: - up_gate_weights = [ - paddle.concat((w1, w2), axis=0) - if permuted_weights - else paddle.concat((w1, w2), axis=1) - for w1, w2 in zip(up_weights, gate_weights) - ] - # fp8 scale weights handling if dtype == "bfloat16": - d_scales_up_gate = None + d_scales_up = None + d_scales_gate = None d_scales_down = None d_scales_hidden_states = None d_scales_intermediate_hidden_states = None @@ -290,21 +279,14 @@ def generate_tensors( if weight_scale_type == "channelwise" else tensorwise_quant_to_fp8 ) - if fused_weights: - up_gate_weights, d_scales_up_gate = zip( - *[weight_quant_method(w) for w in up_gate_weights] - ) - up_gate_weights = list(up_gate_weights) - d_scales_up_gate = list(d_scales_up_gate) - else: - up_weights, d_scales_up = zip(*[weight_quant_method(w) for w in up_weights]) - up_weights = list(up_weights) - d_scales_up = list(d_scales_up) - gate_weights, d_scales_gate = zip( - *[weight_quant_method(w) for w in gate_weights] - ) - gate_weights = list(gate_weights) - d_scales_gate = list(d_scales_gate) + up_weights, d_scales_up = zip(*[weight_quant_method(w) for w in up_weights]) + up_weights = list(up_weights) + d_scales_up = list(d_scales_up) + gate_weights, d_scales_gate = zip( + *[weight_quant_method(w) for w in gate_weights] + ) + gate_weights = list(gate_weights) + d_scales_gate = list(d_scales_gate) down_weights, d_scales_down = zip( *[weight_quant_method(w) for w in down_weights] ) @@ -322,9 +304,6 @@ def generate_tensors( if hidden_states_dynamic_quant is False: _, d_scales_hidden_states = tensorwise_quant_to_fp8(hidden_states) - d_scales_hidden_states = paddle.to_tensor( - d_scales_hidden_states, dtype=paddle_dtype - ) d_scales_hidden_states = 1.0 / d_scales_hidden_states else: d_scales_hidden_states = None @@ -340,11 +319,13 @@ def generate_tensors( hidden_states, gate_correction_bias, route_gate_weight, - up_gate_weights, + up_weights, + gate_weights, down_weights, d_scales_hidden_states, d_scales_intermediate_hidden_states, - d_scales_up_gate, + d_scales_up, + d_scales_gate, d_scales_down, ) return paddle_data @@ -363,13 +344,15 @@ def __init__(self, dynamic_quant, dtype): def forward_fp8( self, hidden_states, - gate_weights, + route_gate_weight, gate_correction_bias, - up_gate_weights, + up_weights, + gate_weights, down_weights, hidden_states_scales, intermediate_hidden_states_scales, - gate_up_weights_scales, + up_weights_scales, + gate_weights_scales, down_weights_scales, top_k, norm_topk_prob, @@ -378,8 +361,20 @@ def forward_fp8( experts_max, chunk_size, ): - gate_out = paddle.matmul(hidden_states.cast("float32"), gate_weights) + up_gate_weights = [ + paddle.concat((w1, w2), axis=0) + if permuted_weights + else paddle.concat((w1, w2), axis=1) + for w1, w2 in zip(up_weights, gate_weights) + ] + gate_up_weights_scales = [ + paddle.concat((w1, w2), axis=0) + if permuted_weights + else paddle.concat((w1, w2), axis=0) + for w1, w2 in zip(up_weights_scales, gate_weights_scales) + ] + gate_out = paddle.matmul(hidden_states.cast("float32"), route_gate_weight) weights = paddle.nn.functional.softmax(gate_out, axis=-1) if gate_correction_bias is not None: scores = weights + gate_correction_bias @@ -433,13 +428,15 @@ def forward_fp8( def forward_bf16( self, hidden_states, - gate_weights, + route_gate_weight, gate_correction_bias, - up_gate_weights, + up_weights, + gate_weights, down_weights, hidden_states_scales, intermediate_hidden_states_scales, - gate_up_weights_scales, + up_weights_scales, + gate_weights_scales, down_weights_scales, top_k, norm_topk_prob, @@ -448,7 +445,14 @@ def forward_bf16( experts_max, chunk_size, ): - gate_out = paddle.matmul(hidden_states.cast("float32"), gate_weights) + up_gate_weights = [ + paddle.concat((w1, w2), axis=0) + if permuted_weights + else paddle.concat((w1, w2), axis=1) + for w1, w2 in zip(up_weights, gate_weights) + ] + + gate_out = paddle.matmul(hidden_states.cast("float32"), route_gate_weight) weights = paddle.nn.functional.softmax(gate_out, axis=-1) if gate_correction_bias is not None: @@ -499,7 +503,6 @@ def __init__( dtype="fp8", intermediate_dynamic_scale=None, block_size=None, - chunk_size=0, ): self.num_experts = num_experts self.permuted_weights = permuted_weights @@ -516,7 +519,6 @@ def __init__( self.dtype = dtype self.block_size = block_size self.top_k = top_k - self.chunk_size = chunk_size if self.dtype == "bfloat16": self.fn = paddlenlp_ops.fused_gate_moe @@ -540,24 +542,56 @@ def __init__( 1, (self.experts_max - self.experts_min + 1) // self.expert_slice ) - def forward( + def prepare_inputs( self, hidden_states, - gate_weights, + route_gate_weight, gate_correction_bias, expert_weights, hidden_states_scale, intermediate_states_scales, weights_scales, - compute_amax=False, ): - common_inputs = (hidden_states, gate_weights, gate_correction_bias) + common_inputs = (hidden_states, route_gate_weight, gate_correction_bias) # final_hidden_states = paddle.zeros_like(hidden_states) - amax_per_expert = ( - paddle.zeros(self.num_experts, dtype="float32") if compute_amax else None - ) + up_weights = [ + w[ + :, + self.tp_rank + * (w.shape[1] // self.tp_size) : (self.tp_rank + 1) + * (w.shape[1] // self.tp_size), + ] + for w in expert_weights[0] + ] + gate_weights = [ + w[ + :, + self.tp_rank + * (w.shape[1] // self.tp_size) : (self.tp_rank + 1) + * (w.shape[1] // self.tp_size), + ] + for w in expert_weights[1] + ] + down_weights = [ + w[ + self.tp_rank + * (w.shape[0] // self.tp_size) : (self.tp_rank + 1) + * (w.shape[0] // self.tp_size), + :, + ] + for w in expert_weights[2] + ] + if self.fused_weights: + up_gate_weights = [ + paddle.concat((w1, w2), axis=0) + if self.permuted_weights + else paddle.concat((w1, w2), axis=1) + for w1, w2 in zip(up_weights, gate_weights) + ] + + assert self.expert_slice == 1 for idx in range(self.expert_slice): slice_experts_min = self.experts_min + (self.expert_chunk * idx) slice_experts_max = min( @@ -574,46 +608,72 @@ def forward( slice_weights = ( ( paddle.stack( - expert_weights[0][slice_experts_min : slice_experts_max + 1], + up_gate_weights[slice_experts_min : slice_experts_max + 1], axis=0, ), paddle.stack( - expert_weights[1][slice_experts_min : slice_experts_max + 1], + down_weights[slice_experts_min : slice_experts_max + 1], axis=0, ), ) if self.fused_weights else ( paddle.stack( - expert_weights[0][slice_experts_min : slice_experts_max + 1] - + expert_weights[1][slice_experts_min : slice_experts_max + 1], + up_weights[slice_experts_min : slice_experts_max + 1] + + gate_weights[slice_experts_min : slice_experts_max + 1], axis=0, ), paddle.stack( - expert_weights[2][slice_experts_min : slice_experts_max + 1], + down_weights[slice_experts_min : slice_experts_max + 1], axis=0, ), ) ) + if self.dtype == "fp8": + d_scales_up = [ + w[ + self.tp_rank + * (w.shape[0] // self.tp_size) : (self.tp_rank + 1) + * (w.shape[0] // self.tp_size) + ] + for w in weights_scales[0] + ] + d_scales_gate = [ + w[ + self.tp_rank + * (w.shape[0] // self.tp_size) : (self.tp_rank + 1) + * (w.shape[0] // self.tp_size) + ] + for w in weights_scales[1] + ] + d_scales_down = weights_scales[2] + + if self.fused_weights: + d_scales_up_gate = [ + paddle.concat((w1, w2), axis=0) + for w1, w2 in zip(d_scales_up, d_scales_gate) + ] + slice_scales = ( ( hidden_states_scale, None if self.intermediate_dynamic_scale - else intermediate_states_scales[ - slice_experts_min : slice_experts_max + 1 - ], - paddle.stack( - weights_scales[0][ + else paddle.stack( + intermediate_states_scales[ slice_experts_min : slice_experts_max + 1 ], axis=0, + ) + .unsqueeze(2) + .expand([-1, -1, 64]), + paddle.stack( + d_scales_up_gate[slice_experts_min : slice_experts_max + 1], + axis=0, ), paddle.stack( - weights_scales[1][ - slice_experts_min : slice_experts_max + 1 - ], + d_scales_down[slice_experts_min : slice_experts_max + 1], axis=0, ), ) @@ -626,16 +686,12 @@ def forward( slice_experts_min : slice_experts_max + 1 ], paddle.stack( - weights_scales[0][slice_experts_min : slice_experts_max + 1] - + weights_scales[1][ - slice_experts_min : slice_experts_max + 1 - ], + d_scales_up[slice_experts_min : slice_experts_max + 1] + + d_scales_gate[slice_experts_min : slice_experts_max + 1], axis=0, ), paddle.stack( - weights_scales[2][ - slice_experts_min : slice_experts_max + 1 - ], + d_scales_down[slice_experts_min : slice_experts_max + 1], axis=0, ), ) @@ -673,37 +729,43 @@ def forward( ), ) ) - - if self.dtype == "fp8": - slice_result = self.fn( - *common_inputs, - *slice_weights, - *slice_scales, - *common_params, - self.chunk_size, - ) - elif self.dtype == "blockwise_fp8": - slice_result = self.fn( - *common_inputs, - *slice_weights, - *slice_scales, - *common_params, - self.block_size, - self.chunk_size, - ) else: - slice_result = self.fn( - *common_inputs, - *slice_weights, - *common_params, - self.chunk_size, - ) - # paddlenlp_ops.fused_gate_moe no requirement to return amax - slice_amax = None - if compute_amax: - amax_per_expert[slice_experts_min : slice_experts_max + 1] = slice_amax + slice_scales = None + return common_inputs, slice_weights, slice_scales, common_params - final_hidden_states = slice_result + def forward( + self, + common_inputs, + slice_weights, + slice_scales, + common_params, + chunk_size=0, + ): + if self.dtype == "fp8": + slice_result = self.fn( + *common_inputs, + *slice_weights, + *slice_scales, + *common_params, + chunk_size, + ) + elif self.dtype == "blockwise_fp8": + slice_result = self.fn( + *common_inputs, + *slice_weights, + *slice_scales, + *common_params, + self.block_size, + chunk_size, + ) + else: + slice_result = self.fn( + *common_inputs, + *slice_weights, + *common_params, + chunk_size, + ) + final_hidden_states = slice_result # EP: All-reduce for final output if self.tp_size > 1: @@ -715,14 +777,6 @@ def forward( "TP All-reduce for MoE successfully.", extra={"ep_rank": self.ep_rank, "tp_rank": self.tp_rank}, ) - if compute_amax: - dist.all_reduce( - amax_per_expert, op=dist.ReduceOp.MAX, group=self.tp_group - ) - self.logger.info( - "TP All-reduce for AMax successfully.", - extra={"ep_rank": self.ep_rank, "tp_rank": self.tp_rank}, - ) except Exception as e: self.logger.error( f"Failed to perform TP All-reduce: {str(e)}", @@ -739,14 +793,6 @@ def forward( "EP All-reduce for MoE successfully.", extra={"ep_rank": self.ep_rank, "tp_rank": self.tp_rank}, ) - if compute_amax: - dist.all_reduce( - amax_per_expert, op=dist.ReduceOp.MAX, group=self.ep_group - ) - self.logger.info( - "EP All-reduce for AMax successfully.", - extra={"ep_rank": self.ep_rank, "tp_rank": self.tp_rank}, - ) except Exception as e: self.logger.error( f"Failed to perform EP All-reduce: {str(e)}", @@ -754,10 +800,10 @@ def forward( ) raise - return final_hidden_states, amax_per_expert + return final_hidden_states -DTYPES = ["bfloat16", "fp8"] # ["bfloat16", "fp8"] +DTYPES = ["bfloat16", "fp8"] NUM_TOKENS = [32] HIDDEN_DIMS = [4096] FFN_DIMS = [2560] @@ -767,8 +813,7 @@ def forward( FUSED_WEIGHTS = [True] # [True, False] ACTIVATIONS = ["silu"] # ["gelu", "relu", "silu"] PERMUTED_WEIGHTS = [False] # [True, False] -EP_SIZE = [world_size] -TP_SIZE = [1] +EP_TP_SIZE = [(world_size, 1), (1, world_size)] # for bfloat16 only COMPUTE_AMAX = [False] # [True, False] # for fp8 only @@ -808,8 +853,7 @@ class MoETest(unittest.TestCase): for fused_weights in FUSED_WEIGHTS for activation in ACTIVATIONS for permuted_weights in PERMUTED_WEIGHTS - for ep_size in EP_SIZE - for tp_size in TP_SIZE + for (ep_size, tp_size) in EP_TP_SIZE for dtype in DTYPES for intermediate_dynamic_scale in ( INTERMEDIATE_DYNAMIC_SCALE if dtype == "fp8" else [None] @@ -872,12 +916,14 @@ def test_fused_gate_moe( ( hidden_states, gate_correction_bias, + route_gate_weight, + up_weights, gate_weights, - up_gate_weights, down_weights, d_scales_hidden_states, d_scales_intermediate_hidden_states, - d_scales_up_gate, + d_scales_up, + d_scales_gate, d_scales_down, ) = out_tensors @@ -886,13 +932,15 @@ def test_fused_gate_moe( final_hidden_states_ref = mixtral_ref.forward( hidden_states, - gate_weights, + route_gate_weight, gate_correction_bias, - up_gate_weights, + up_weights, + gate_weights, down_weights, d_scales_hidden_states, d_scales_intermediate_hidden_states, - d_scales_up_gate, + d_scales_up, + d_scales_gate, d_scales_down, top_k, norm_topk_prob=True, @@ -930,24 +978,35 @@ def test_fused_gate_moe( tp_group=tp_group, dtype=dtype, block_size=None, - chunk_size=0, ) - - final_hidden_states, amax_per_expert = fused_gate_moe.forward( + ( + common_inputs, + slice_weights, + slice_scales, + common_params, + ) = fused_gate_moe.prepare_inputs( hidden_states=hidden_states, - gate_weights=gate_weights, + route_gate_weight=route_gate_weight, gate_correction_bias=gate_correction_bias, - expert_weights=(up_gate_weights, down_weights), + expert_weights=(up_weights, gate_weights, down_weights), hidden_states_scale=d_scales_hidden_states, intermediate_states_scales=d_scales_intermediate_hidden_states, - weights_scales=(d_scales_up_gate, d_scales_down), + weights_scales=(d_scales_up, d_scales_gate, d_scales_down), + ) + + final_hidden_states = fused_gate_moe.forward( + common_inputs, + slice_weights, + slice_scales, + common_params, + chunk_size=0, ) + logger.debug( "\n===== paddlenlp_ops.mixture_of_experts Output =====\n", extra={ "ep_rank": ep_rank, "tp_rank": tp_rank, - "amax_per_expert": amax_per_expert, "final_hidden_states": final_hidden_states, }, ) @@ -966,6 +1025,187 @@ def test_fused_gate_moe( assert similar, f"Cosine similarity check failed: {similar}" +def run_profile(): + + num_tokens = 128 + hidden_dim = 8192 + ffn_dim = 3584 + top_k = 8 + num_experts = 64 + slice_max_expert = 64 + TEST_EP_TP = "EP" # EP TP EPTP + if TEST_EP_TP == "EP": + ep_size = world_size + tp_size = world_size // ep_size + else: + ep_size = 1 + tp_size = world_size + + (ep_rank, ep_size, ep_group), (tp_rank, tp_size, tp_group) = init_distributed( + ep_size, tp_size + ) + logger = setup_logging(ep_rank=ep_rank, tp_rank=tp_rank) + + out_tensors_bf16 = generate_tensors( + num_tokens=num_tokens, + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + top_k=top_k, + num_experts=num_experts, + permuted_weights=False, + fused_weights=True, + intermediate_dynamic_scale=False, + hidden_states_dynamic_quant=False, + weight_scale_type="channelwise", + dtype="bfloat16", + ) + ( + hidden_states_bf16, + gate_correction_bias_bf16, + route_gate_weight_bf16, + up_weights_bf16, + gate_weights_bf16, + down_weights_bf16, + d_scales_hidden_states, + d_scales_intermediate_hidden_states, + d_scales_up, + d_scales_gate, + d_scales_down, + ) = out_tensors_bf16 + + out_tensors_fp8 = generate_tensors( + num_tokens=num_tokens, + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + top_k=top_k, + num_experts=num_experts, + permuted_weights=False, + fused_weights=True, + intermediate_dynamic_scale=False, + hidden_states_dynamic_quant=False, + weight_scale_type="channelwise", + dtype="fp8", + ) + ( + hidden_states_fp8, + gate_correction_bias_fp8, + route_gate_weight_fp8, + up_weights_fp8, + gate_weights_fp8, + down_weights_fp8, + d_scales_hidden_states, + d_scales_intermediate_hidden_states, + d_scales_up, + d_scales_gate, + d_scales_down, + ) = out_tensors_fp8 + + fused_gate_moe_bf16 = FusedGateMoE( + num_experts=num_experts, + top_k=top_k, + activation="silu", + permuted_weights=False, + fused_weights=True, + intermediate_dynamic_scale=False, + slice_max_expert=slice_max_expert, + logger=logger, + ep_rank=ep_rank, + ep_size=ep_size, + ep_group=ep_group, + tp_rank=tp_rank, + tp_size=tp_size, + tp_group=tp_group, + dtype="bfloat16", + block_size=None, + ) + + fused_gate_moe_fp8 = FusedGateMoE( + num_experts=num_experts, + top_k=top_k, + activation="silu", + permuted_weights=False, + fused_weights=True, + intermediate_dynamic_scale=False, + slice_max_expert=slice_max_expert, + logger=logger, + ep_rank=ep_rank, + ep_size=ep_size, + ep_group=ep_group, + tp_rank=tp_rank, + tp_size=tp_size, + tp_group=tp_group, + dtype="fp8", + block_size=None, + ) + + ( + common_inputs_bf16, + slice_weights_bf16, + slice_scales_bf16, + common_params_bf16, + ) = fused_gate_moe_bf16.prepare_inputs( + hidden_states=hidden_states_bf16, + route_gate_weight=route_gate_weight_bf16, + gate_correction_bias=gate_correction_bias_bf16, + expert_weights=(up_weights_bf16, gate_weights_bf16, down_weights_bf16), + hidden_states_scale=d_scales_hidden_states, + intermediate_states_scales=d_scales_intermediate_hidden_states, + weights_scales=(d_scales_up, d_scales_gate, d_scales_down), + ) + + ( + common_inputs_fp8, + slice_weights_fp8, + slice_scales_fp8, + common_params_fp8, + ) = fused_gate_moe_fp8.prepare_inputs( + hidden_states=hidden_states_fp8, + route_gate_weight=route_gate_weight_fp8, + gate_correction_bias=gate_correction_bias_fp8, + expert_weights=(up_weights_fp8, gate_weights_fp8, down_weights_fp8), + hidden_states_scale=d_scales_hidden_states, + intermediate_states_scales=d_scales_intermediate_hidden_states, + weights_scales=(d_scales_up, d_scales_gate, d_scales_down), + ) + + print("start profiling...") + + import paddle.profiler as profiler + + prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], + scheduler=(3, 5), + on_trace_ready=profiler.export_chrome_tracing("./profile"), + ) + prof.start() + for iter in range(5): + chunk_size = [0, 64, 128, 256] + for chunk in chunk_size: + with paddle.no_grad(): + final_hidden_states = fused_gate_moe_fp8.forward( + common_inputs_fp8, + slice_weights_fp8, + slice_scales_fp8, + common_params_fp8, + chunk_size=chunk, + ) + paddle.device.synchronize() + for chunk in chunk_size: + with paddle.no_grad(): + final_hidden_states = fused_gate_moe_bf16.forward( + common_inputs_bf16, + slice_weights_bf16, + slice_scales_bf16, + common_params_bf16, + chunk_size=chunk, + ) + paddle.device.synchronize() + prof.step() + prof.stop() + + print(f"profile finished. final_hidden_states.shape is {final_hidden_states.shape}") + + if __name__ == "__main__": # Set logging level to DEBUG to see debug messages logging.getLogger().setLevel(logging.WARNING) @@ -980,3 +1220,4 @@ def test_fused_gate_moe( # Run the test suite runner.run(suite) + # run_profile()