Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,14 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
multi_modal_data["images"] = agent_data.image_data
if agent_data.video_data is not None:
multi_modal_data["videos"] = agent_data.video_data

routed_experts = getattr(agent_data, "routed_experts", None)
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
response_mask=agent_data.response_mask[: self.response_length],
multi_modal_data=multi_modal_data,
routed_experts=routed_experts,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为何之前也没传入,是bugfix么还是npu特有的

response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
Expand Down
57 changes: 52 additions & 5 deletions verl/utils/megatron/router_replay_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from enum import Enum

import torch
import types
import inspect
from functools import wraps

try:
from megatron.core.transformer.moe.moe_utils import (
Expand Down Expand Up @@ -236,7 +239,27 @@

return routing_probs, routing_map


def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float:
"""获取给定辅助损失类型的系数。"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释最好统一写成英文,无关注释建议去掉

# 逻辑保持不变
if isinstance(_self.routing_type, str):
if _self.routing_type == aux_loss_type:
return _self.config.moe_aux_loss_coeff
if isinstance(_self.routing_type, list):
try:
idx = _self.routing_type.index(aux_loss_type)
return _self.config.moe_aux_loss_coeff[idx]
except (ValueError, IndexError):
return 0.0
return 0.0

def _is_aux_loss_enabled(_self) -> bool:
"""检查是否启用了任何辅助损失。"""
for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]:
# 注意这里调用的是同在模块级别的另一个辅助函数
if _get_aux_loss_coeff(_self, aux_loss_type) > 0:
return True
return False
def patched_routing(self, logits: torch.Tensor, *args, **kwargs):
"""Top-k routing function

Expand Down Expand Up @@ -282,6 +305,8 @@
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
)

if not hasattr(self, "is_aux_loss_enabled"):
self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self)
# Apply each aux loss type and attach aux loss autograd function to probs
if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled():
# Calculate scores and routing_map for aux loss
Expand Down Expand Up @@ -311,26 +336,48 @@
# Clear router instances to avoid state leakage between model initializations.
RouterReplay.router_instances.clear()
# Step 1: Patch TransformerConfig to include the feature flag
if not hasattr(TransformerConfig, "enable_routing_replay"):
# Add class attribute with default value
TransformerConfig.enable_routing_replay = False
try:
global_args = get_args()

Check failure on line 340 in verl/utils/megatron/router_replay_patch.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (F821)

verl/utils/megatron/router_replay_patch.py:340:23: F821 Undefined name `get_args`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verl/utils/megatron/router_replay_patch.py:340:23: F821 Undefined name get_args

except Exception:
global_args = None

try:
sig = inspect.signature(TransformerConfig.__init__)
native_params = sig.parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先保存一份原始的TransfoermerConfig参数

except Exception:
native_params = []

ext_attrs = ["enable_routing_replay", "moe_router_fusion"]

for attr in ext_attrs:
val = getattr(global_args, attr, False) if global_args else False

if not hasattr(TransformerConfig, attr):
setattr(TransformerConfig, attr, val)

if not hasattr(TransformerConfig, "_verl_router_patched"):
# Store original __init__ method
original_tf_config_init = TransformerConfig.__init__

# Define new __init__ method that safely handles enable_routing_replay parameter
@wraps(original_tf_config_init)
def patched_tf_config_init(self, *args, **kwargs):
# Simple solution: remove the unknown parameter before calling original constructor
enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay)
if "enable_routing_replay" not in native_params:
enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果原始的TransformerConfig中没有当前这个参数,则用新的注入的参数值代替


if "moe_router_fusion" not in native_params:
moe_router_fusion = kwargs.pop("moe_router_fusion", TransformerConfig.moe_router_fusion)
# Call original constructor with remaining kwargs
original_tf_config_init(self, *args, **kwargs)

# Set the instance attribute
self.enable_routing_replay = enable_routing_replay
self.moe_router_fusion = moe_router_fusion

# Apply the patch
TransformerConfig.__init__ = patched_tf_config_init
TransformerConfig._verl_router_patched = True

# Step 2: Patch TopKRouter only once to ensure idempotency.
if hasattr(TopKRouter, "_router_replay_patched"):
Expand Down
9 changes: 8 additions & 1 deletion verl/utils/megatron/router_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Optional

import torch
import inspect

try:
from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage
Expand Down Expand Up @@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None):
if vp_rank is None:
vp_rank = 0
num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank)
offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段为何要单独判断一下是否有vp_state入参

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前npu配套中megatron还只支持到0.12.0,从megatron中导入的get_num_layers_to_build函数没有vp_stage参数,暂时无法通过其他方式使用该参数


sig = inspect.signature(get_transformer_layer_offset)

if 'vp_stage' in sig.parameters:
offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank)
else:
offset = get_transformer_layer_offset(tf_config)
local = {}
local["start"] = offset
local["end"] = offset + num_layers_to_build
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ async def generate(
max_tokens = sampling_params.pop("max_new_tokens")
else:
# Default to a calculation that considers configured lengths
max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids)
max_tokens = self.config.response_length

# Clamp max_tokens to the valid range [0, max_possible_tokens]
max_tokens = max(0, min(max_tokens, max_possible_tokens))
Expand Down
Loading