Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes
"verl/workers/rollout/replica.py", # appear in default device_name
"verl/checkpoint_engine", # checkpoint engine backend are device specific
"verl/models/mcore/patch.py", # checkpoint patch only on cuda
]

# directory or file path must contain keyword "nccl"
Expand Down
65 changes: 65 additions & 0 deletions verl/models/mcore/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,68 @@ def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_ini
return tp_group

megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none


# When using checkpoint + MoE models (like Qwen3-30B-A3B and Qwen3-VL-30B-A3B),
# input tensors and their grads will stay in gpu memory after forward_backward completes.
# see https://github.com/NVIDIA/Megatron-LM/pull/3267
def apply_patch_checkpoint():
import megatron.core.tensor_parallel.random as rd
import torch

_fork_rng = rd._fork_rng
_set_all_rng_states = rd._set_all_rng_states
detach_variable = rd.detach_variable
gather_split_1d_tensor = rd.gather_split_1d_tensor
safely_set_viewless_tensor_data = rd.safely_set_viewless_tensor_data

@staticmethod
def patch_backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))

with _fork_rng():
# Set the states to what it used to be before the forward pass.
_set_all_rng_states(*ctx.rng_states)

# Compute the forward pass.
detached_inputs = detach_variable(inputs)

with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)

# filter out non tensor outputs for backward pass
outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args, strict=False)),
strict=False,
)
torch.autograd.backward(outputs, args)
# Clone grads to return
grads = tuple(
inp.grad.clone()
if isinstance(inp, torch.Tensor) and inp.grad is not None
else inp.grad
if isinstance(inp, torch.Tensor)
else inp
for inp in detached_inputs
)
cur_stream = torch.cuda.current_stream()
# Release original input and grad tensors
for t in detached_inputs:
if isinstance(t, torch.Tensor) and t.requires_grad:
t.record_stream(cur_stream)
t.untyped_storage().resize_(0)
if t.grad is not None:
t.grad.record_stream(cur_stream)
t.grad.untyped_storage().resize_(0)
# ctx.saved_tensors = None
return (None, None) + grads

rd.CheckpointFunction.backward = patch_backward
7 changes: 7 additions & 0 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def __init__(
logger.info(f"enable_routing_replay in MegatronEngine: {self.enable_routing_replay}")
if self.enable_routing_replay:
apply_router_replay_patch()
# Apply checkpoint patch for MoE models
from verl.utils.device import is_cuda_available

if is_cuda_available:
from verl.models.mcore.patch import apply_patch_checkpoint

apply_patch_checkpoint()

def _init_device_mesh(self):
# TODO: set different parallelism for actor, critic, ref
Expand Down
Loading