Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions verl/models/mcore/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,67 @@ def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_ini
return tp_group

megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none


# When using checkpoint + MoE models (like Qwen3-30B-A3B and Qwen3-VL-30B-A3B),
# input tensors and their grads will stay in gpu memory after forward_backward completes.
def apply_patch_checkpoint():
import megatron.core.tensor_parallel.random as rd
import torch

_fork_rng = rd._fork_rng
_set_all_rng_states = rd._set_all_rng_states
detach_variable = rd.detach_variable
gather_split_1d_tensor = rd.gather_split_1d_tensor
safely_set_viewless_tensor_data = rd.safely_set_viewless_tensor_data

@staticmethod
def patch_backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))

with _fork_rng():
# Set the states to what it used to be before the forward pass.
_set_all_rng_states(*ctx.rng_states)

# Compute the forward pass.
detached_inputs = detach_variable(inputs)

with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)

# filter out non tensor outputs for backward pass
outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args, strict=False)),
strict=False,
)
torch.autograd.backward(outputs, args)
# Clone grads to return
grads = tuple(
inp.grad.clone()
if isinstance(inp, torch.Tensor) and inp.grad is not None
else inp.grad
if isinstance(inp, torch.Tensor)
else inp
for inp in detached_inputs
)
cur_stream = torch.cuda.current_stream()
# Release original input and grad tensors
for t in detached_inputs:
if isinstance(t, torch.Tensor) and t.requires_grad:
t.record_stream(cur_stream)
t.untyped_storage().resize_(0)
if t.grad is not None:
t.grad.record_stream(cur_stream)
t.grad.untyped_storage().resize_(0)
# ctx.saved_tensors = None
return (None, None) + grads

rd.CheckpointFunction.backward = patch_backward
4 changes: 4 additions & 0 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(
logger.info(f"enable_routing_replay in MegatronEngine: {self.enable_routing_replay}")
if self.enable_routing_replay:
apply_router_replay_patch()
# Apply checkpoint patch for MoE models
from verl.models.mcore.patch import apply_patch_checkpoint

apply_patch_checkpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment "Apply checkpoint patch for MoE models" indicates this patch is specific to Mixture-of-Experts models. However, it's being applied unconditionally. This could introduce risks or unintended side effects for non-MoE models that also use checkpointing. It would be safer to apply this patch conditionally, only when an MoE model is detected. You can check for this using self.engine_config.expert_model_parallel_size > 1.

Suggested change
# Apply checkpoint patch for MoE models
from verl.models.mcore.patch import apply_patch_checkpoint
apply_patch_checkpoint()
# Apply checkpoint patch for MoE models
if self.engine_config.expert_model_parallel_size > 1:
from verl.models.mcore.patch import apply_patch_checkpoint
apply_patch_checkpoint()


def _init_device_mesh(self):
# TODO: set different parallelism for actor, critic, ref
Expand Down
Loading