diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index f98485a6781..e96076d88fb 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -193,11 +193,14 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu multi_modal_data["images"] = agent_data.image_data if agent_data.video_data is not None: multi_modal_data["videos"] = agent_data.video_data + + routed_experts = getattr(agent_data, "routed_experts", None) output = AgentLoopOutput( prompt_ids=prompt_ids, response_ids=response_ids[: self.response_length], response_mask=agent_data.response_mask[: self.response_length], multi_modal_data=multi_modal_data, + routed_experts=routed_experts, response_logprobs=agent_data.response_logprobs[: self.response_length] if agent_data.response_logprobs else None, diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 50eb4a128d3..bec7282682a 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -15,6 +15,9 @@ from enum import Enum import torch +import types +import inspect +from functools import wraps try: from megatron.core.transformer.moe.moe_utils import ( @@ -236,7 +239,27 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return routing_probs, routing_map - +def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: + """获取给定辅助损失类型的系数。""" + # 逻辑保持不变 + if isinstance(_self.routing_type, str): + if _self.routing_type == aux_loss_type: + return _self.config.moe_aux_loss_coeff + if isinstance(_self.routing_type, list): + try: + idx = _self.routing_type.index(aux_loss_type) + return _self.config.moe_aux_loss_coeff[idx] + except (ValueError, IndexError): + return 0.0 + return 0.0 + +def _is_aux_loss_enabled(_self) -> bool: + """检查是否启用了任何辅助损失。""" + for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: + # 注意这里调用的是同在模块级别的另一个辅助函数 + if _get_aux_loss_coeff(_self, aux_loss_type) > 0: + return True + return False def patched_routing(self, logits: torch.Tensor, *args, **kwargs): """Top-k routing function @@ -282,6 +305,8 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs): pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, ) + if not hasattr(self, "is_aux_loss_enabled"): + self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self) # Apply each aux loss type and attach aux loss autograd function to probs if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): # Calculate scores and routing_map for aux loss @@ -311,26 +336,48 @@ def apply_router_replay_patch(): # Clear router instances to avoid state leakage between model initializations. RouterReplay.router_instances.clear() # Step 1: Patch TransformerConfig to include the feature flag - if not hasattr(TransformerConfig, "enable_routing_replay"): - # Add class attribute with default value - TransformerConfig.enable_routing_replay = False + try: + global_args = get_args() + except Exception: + global_args = None + + try: + sig = inspect.signature(TransformerConfig.__init__) + native_params = sig.parameters + except Exception: + native_params = [] + + ext_attrs = ["enable_routing_replay", "moe_router_fusion"] + + for attr in ext_attrs: + val = getattr(global_args, attr, False) if global_args else False + + if not hasattr(TransformerConfig, attr): + setattr(TransformerConfig, attr, val) + if not hasattr(TransformerConfig, "_verl_router_patched"): # Store original __init__ method original_tf_config_init = TransformerConfig.__init__ # Define new __init__ method that safely handles enable_routing_replay parameter + @wraps(original_tf_config_init) def patched_tf_config_init(self, *args, **kwargs): # Simple solution: remove the unknown parameter before calling original constructor - enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + if "enable_routing_replay" not in native_params: + enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + if "moe_router_fusion" not in native_params: + moe_router_fusion = kwargs.pop("moe_router_fusion", TransformerConfig.moe_router_fusion) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) # Set the instance attribute self.enable_routing_replay = enable_routing_replay + self.moe_router_fusion = moe_router_fusion # Apply the patch TransformerConfig.__init__ = patched_tf_config_init + TransformerConfig._verl_router_patched = True # Step 2: Patch TopKRouter only once to ensure idempotency. if hasattr(TopKRouter, "_router_replay_patched"): diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index 3aec85c24b7..6b8c12ec07e 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -21,6 +21,7 @@ from typing import Optional import torch +import inspect try: from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage @@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None): if vp_rank is None: vp_rank = 0 num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank) - offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + + sig = inspect.signature(get_transformer_layer_offset) + + if 'vp_stage' in sig.parameters: + offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + else: + offset = get_transformer_layer_offset(tf_config) local = {} local["start"] = offset local["end"] = offset + num_layers_to_build diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index cf5ab342888..7bbc3075cae 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -507,7 +507,7 @@ async def generate( max_tokens = sampling_params.pop("max_new_tokens") else: # Default to a calculation that considers configured lengths - max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids) + max_tokens = self.config.response_length # Clamp max_tokens to the valid range [0, max_possible_tokens] max_tokens = max(0, min(max_tokens, max_possible_tokens))