Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes
"verl/workers/rollout/replica.py", # appear in default device_name
"verl/checkpoint_engine", # checkpoint engine backend are device specific
"verl/models/mcore/patch.py", # checkpoint patch only on cuda
]

# directory or file path must contain keyword "nccl"
Expand Down
64 changes: 64 additions & 0 deletions verl/models/mcore/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,67 @@ def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_ini
return tp_group

megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none


# When using checkpoint + MoE models (like Qwen3-30B-A3B and Qwen3-VL-30B-A3B),
# input tensors and their grads will stay in gpu memory after forward_backward completes.
def apply_patch_checkpoint():
import megatron.core.tensor_parallel.random as rd
import torch

_fork_rng = rd._fork_rng
_set_all_rng_states = rd._set_all_rng_states
detach_variable = rd.detach_variable
gather_split_1d_tensor = rd.gather_split_1d_tensor
safely_set_viewless_tensor_data = rd.safely_set_viewless_tensor_data

@staticmethod
def patch_backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))

with _fork_rng():
# Set the states to what it used to be before the forward pass.
_set_all_rng_states(*ctx.rng_states)

# Compute the forward pass.
detached_inputs = detach_variable(inputs)

with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)

# filter out non tensor outputs for backward pass
outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args, strict=False)),
strict=False,
)
torch.autograd.backward(outputs, args)
# Clone grads to return
grads = tuple(
inp.grad.clone()
if isinstance(inp, torch.Tensor) and inp.grad is not None
else inp.grad
if isinstance(inp, torch.Tensor)
else inp
for inp in detached_inputs
)
cur_stream = torch.cuda.current_stream()
# Release original input and grad tensors
for t in detached_inputs:
if isinstance(t, torch.Tensor) and t.requires_grad:
t.record_stream(cur_stream)
t.untyped_storage().resize_(0)
if t.grad is not None:
t.grad.record_stream(cur_stream)
t.grad.untyped_storage().resize_(0)
# ctx.saved_tensors = None
return (None, None) + grads

rd.CheckpointFunction.backward = patch_backward
7 changes: 7 additions & 0 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def __init__(
logger.info(f"enable_routing_replay in MegatronEngine: {self.enable_routing_replay}")
if self.enable_routing_replay:
apply_router_replay_patch()
# Apply checkpoint patch for MoE models
from verl.utils.device import is_cuda_available

if is_cuda_available:
from verl.models.mcore.patch import apply_patch_checkpoint

apply_patch_checkpoint()

def _init_device_mesh(self):
# TODO: set different parallelism for actor, critic, ref
Expand Down
Loading