diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index fbf9cf7e75a..bf3ba132262 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -48,6 +48,7 @@ "verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes "verl/workers/rollout/replica.py", # appear in default device_name "verl/checkpoint_engine", # checkpoint engine backend are device specific + "verl/models/mcore/patch.py", # checkpoint patch only on cuda ] # directory or file path must contain keyword "nccl" diff --git a/verl/models/mcore/patch.py b/verl/models/mcore/patch.py index 2968b3daace..fe232ef575b 100644 --- a/verl/models/mcore/patch.py +++ b/verl/models/mcore/patch.py @@ -362,3 +362,68 @@ def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_ini return tp_group megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none + + +# When using checkpoint + MoE models (like Qwen3-30B-A3B and Qwen3-VL-30B-A3B), +# input tensors and their grads will stay in gpu memory after forward_backward completes. +# see https://github.com/NVIDIA/Megatron-LM/pull/3267 +def apply_patch_checkpoint(): + import megatron.core.tensor_parallel.random as rd + import torch + + _fork_rng = rd._fork_rng + _set_all_rng_states = rd._set_all_rng_states + detach_variable = rd.detach_variable + gather_split_1d_tensor = rd.gather_split_1d_tensor + safely_set_viewless_tensor_data = rd.safely_set_viewless_tensor_data + + @staticmethod + def patch_backward(ctx, *args): + """Backward pass.""" + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") + inputs = ctx.saved_tensors + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data(inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) + + with _fork_rng(): + # Set the states to what it used to be before the forward pass. + _set_all_rng_states(*ctx.rng_states) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # filter out non tensor outputs for backward pass + outputs, args = zip( + *filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args, strict=False)), + strict=False, + ) + torch.autograd.backward(outputs, args) + # Clone grads to return + grads = tuple( + inp.grad.clone() + if isinstance(inp, torch.Tensor) and inp.grad is not None + else inp.grad + if isinstance(inp, torch.Tensor) + else inp + for inp in detached_inputs + ) + cur_stream = torch.cuda.current_stream() + # Release original input and grad tensors + for t in detached_inputs: + if isinstance(t, torch.Tensor) and t.requires_grad: + t.record_stream(cur_stream) + t.untyped_storage().resize_(0) + if t.grad is not None: + t.grad.record_stream(cur_stream) + t.grad.untyped_storage().resize_(0) + # ctx.saved_tensors = None + return (None, None) + grads + + rd.CheckpointFunction.backward = patch_backward diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 0e3f7ff6a29..070695d6ca5 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -98,6 +98,13 @@ def __init__( logger.info(f"enable_routing_replay in MegatronEngine: {self.enable_routing_replay}") if self.enable_routing_replay: apply_router_replay_patch() + # Apply checkpoint patch for MoE models + from verl.utils.device import is_cuda_available + + if is_cuda_available: + from verl.models.mcore.patch import apply_patch_checkpoint + + apply_patch_checkpoint() def _init_device_mesh(self): # TODO: set different parallelism for actor, critic, ref