-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Megatron] Support routing replay on NPU with performance and compatibility enhancements #5298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
59c2206
2b5acf5
77af5c7
a9fcbb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,9 @@ | |
| from enum import Enum | ||
|
|
||
| import torch | ||
| import types | ||
| import inspect | ||
| from functools import wraps | ||
|
|
||
| try: | ||
| from megatron.core.transformer.moe.moe_utils import ( | ||
|
|
@@ -236,7 +239,27 @@ | |
|
|
||
| return routing_probs, routing_map | ||
|
|
||
|
|
||
| def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: | ||
| """获取给定辅助损失类型的系数。""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 注释最好统一写成英文,无关注释建议去掉 |
||
| # 逻辑保持不变 | ||
| if isinstance(_self.routing_type, str): | ||
| if _self.routing_type == aux_loss_type: | ||
| return _self.config.moe_aux_loss_coeff | ||
| if isinstance(_self.routing_type, list): | ||
| try: | ||
| idx = _self.routing_type.index(aux_loss_type) | ||
| return _self.config.moe_aux_loss_coeff[idx] | ||
| except (ValueError, IndexError): | ||
| return 0.0 | ||
| return 0.0 | ||
|
|
||
| def _is_aux_loss_enabled(_self) -> bool: | ||
| """检查是否启用了任何辅助损失。""" | ||
| for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: | ||
| # 注意这里调用的是同在模块级别的另一个辅助函数 | ||
| if _get_aux_loss_coeff(_self, aux_loss_type) > 0: | ||
| return True | ||
| return False | ||
| def patched_routing(self, logits: torch.Tensor, *args, **kwargs): | ||
| """Top-k routing function | ||
|
|
||
|
|
@@ -282,6 +305,8 @@ | |
| pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, | ||
| ) | ||
|
|
||
| if not hasattr(self, "is_aux_loss_enabled"): | ||
| self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self) | ||
| # Apply each aux loss type and attach aux loss autograd function to probs | ||
| if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): | ||
| # Calculate scores and routing_map for aux loss | ||
|
|
@@ -311,26 +336,48 @@ | |
| # Clear router instances to avoid state leakage between model initializations. | ||
| RouterReplay.router_instances.clear() | ||
| # Step 1: Patch TransformerConfig to include the feature flag | ||
| if not hasattr(TransformerConfig, "enable_routing_replay"): | ||
| # Add class attribute with default value | ||
| TransformerConfig.enable_routing_replay = False | ||
| try: | ||
| global_args = get_args() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. verl/utils/megatron/router_replay_patch.py:340:23: F821 Undefined name |
||
| except Exception: | ||
| global_args = None | ||
|
|
||
| try: | ||
| sig = inspect.signature(TransformerConfig.__init__) | ||
| native_params = sig.parameters | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 先保存一份原始的TransfoermerConfig参数 |
||
| except Exception: | ||
| native_params = [] | ||
|
|
||
| ext_attrs = ["enable_routing_replay", "moe_router_fusion"] | ||
|
|
||
| for attr in ext_attrs: | ||
| val = getattr(global_args, attr, False) if global_args else False | ||
|
|
||
| if not hasattr(TransformerConfig, attr): | ||
| setattr(TransformerConfig, attr, val) | ||
|
|
||
| if not hasattr(TransformerConfig, "_verl_router_patched"): | ||
| # Store original __init__ method | ||
| original_tf_config_init = TransformerConfig.__init__ | ||
|
|
||
| # Define new __init__ method that safely handles enable_routing_replay parameter | ||
| @wraps(original_tf_config_init) | ||
755651978 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def patched_tf_config_init(self, *args, **kwargs): | ||
| # Simple solution: remove the unknown parameter before calling original constructor | ||
| enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) | ||
| if "enable_routing_replay" not in native_params: | ||
| enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果原始的TransformerConfig中没有当前这个参数,则用新的注入的参数值代替 |
||
|
|
||
| if "moe_router_fusion" not in native_params: | ||
| moe_router_fusion = kwargs.pop("moe_router_fusion", TransformerConfig.moe_router_fusion) | ||
| # Call original constructor with remaining kwargs | ||
| original_tf_config_init(self, *args, **kwargs) | ||
|
|
||
| # Set the instance attribute | ||
| self.enable_routing_replay = enable_routing_replay | ||
| self.moe_router_fusion = moe_router_fusion | ||
|
|
||
| # Apply the patch | ||
| TransformerConfig.__init__ = patched_tf_config_init | ||
| TransformerConfig._verl_router_patched = True | ||
|
|
||
| # Step 2: Patch TopKRouter only once to ensure idempotency. | ||
| if hasattr(TopKRouter, "_router_replay_patched"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from typing import Optional | ||
|
|
||
| import torch | ||
| import inspect | ||
|
|
||
| try: | ||
| from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage | ||
|
|
@@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None): | |
| if vp_rank is None: | ||
| vp_rank = 0 | ||
| num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank) | ||
| offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一段为何要单独判断一下是否有vp_state入参
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前npu配套中megatron还只支持到0.12.0,从megatron中导入的get_num_layers_to_build函数没有vp_stage参数,暂时无法通过其他方式使用该参数 |
||
|
|
||
| sig = inspect.signature(get_transformer_layer_offset) | ||
|
|
||
| if 'vp_stage' in sig.parameters: | ||
| offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) | ||
| else: | ||
| offset = get_transformer_layer_offset(tf_config) | ||
| local = {} | ||
| local["start"] = offset | ||
| local["end"] = offset + num_layers_to_build | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为何之前也没传入,是bugfix么还是npu特有的