Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/advance/mtp.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@

**Author**: `https://github.com/meituan-search`

Last updated: 01/30/2026
Last updated: 02/15/2026

# 1. Scope of Support

Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek series models based on the MTP architecture. The support rules for training and inference engines are as follows:

- **Training Engine**: Only supports the `mbridge + megatron` combination; other training engines are not compatible at this time;
- **Training Engine**: Only supports the `mbridge/Megatron-Bridge + megatron` combination; other training engines are not compatible at this time;

- **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list;

- **Dependency Versions**:

- mbridge: Use the specified branch: [https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp](https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp) (will be merged into the main branch in the future);
- mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future);

- Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future);

- megatron: Use the latest dev version (commit: [23e092f41ec8bc659020e401ddac9576c1cfed7e](https://github.com/NVIDIA/Megatron-LM/tree/23e092f41ec8bc659020e401ddac9576c1cfed7e)), which supports MTP + CP training methods.

Expand Down
167 changes: 107 additions & 60 deletions verl/models/mcore/mtp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.multi_token_prediction import (
MTPLossAutoScaler,
MTPLossLoggingHelper,
roll_tensor,
)
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor

try:
from megatron.core.utils import unwrap_model
Expand Down Expand Up @@ -78,19 +74,45 @@ def _megatron_gptmodel_postprocess(
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
**kwargs,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.
"""Compatibility patch for GPTModel._postprocess.

Applies Multi-Token Prediction if enabled, generates output logits through
the output layer, and computes language model loss when labels are provided.
For inference (`labels is None`), delegate to the upstream implementation to stay
aligned with Megatron-Core updates.

For training (`labels is not None`), keep VERL's MTP behavior and always return
logits (instead of CE loss) so PPO paths can compute custom losses from logits.
"""
# Keep inference path aligned with whatever upstream Megatron currently expects.
if labels is None:
return self._postprocess_backup(
hidden_states=hidden_states,
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
mtp_in_postprocess=mtp_in_postprocess,
loss_mask=loss_mask,
decoder_input=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
runtime_gather_output=runtime_gather_output,
extra_block_kwargs=extra_block_kwargs,
inference_context=inference_context,
**kwargs,
)

# logits and loss
# Training path: keep logits for external loss computation.
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()

if mtp_in_postprocess and labels is not None:
if mtp_in_postprocess:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
Expand All @@ -109,60 +131,85 @@ def _megatron_gptmodel_postprocess(
if not self.post_process:
return hidden_states

# Skip when mtp_num_layers is None or 0
if self.config.mtp_num_layers and labels is not None:
mtp_labels = labels.clone()

hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
# Skip when mtp_num_layers is None or 0.
if self.config.mtp_num_layers:
cp_group = None
if getattr(self, "pg_collection", None) is not None:
cp_group = self.pg_collection.cp
elif hasattr(self, "cp_group"):
cp_group = self.cp_group

# Prefer upstream helper when available (newer Megatron-LM).
try:
from megatron.core.transformer.multi_token_prediction import process_mtp_loss

hidden_states = process_mtp_loss(
hidden_states=hidden_states,
labels=labels,
loss_mask=loss_mask,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
is_training=self.training,
compute_language_model_loss=self.compute_language_model_loss,
config=self.config,
cp_group=cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)

# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
)
except (ImportError, AttributeError, TypeError):
# Fallback for older Megatron-LM versions without process_mtp_loss API.
mtp_labels = labels.clone()

hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)

mtp_loss = loss_mask * mtp_loss
if self.training:
# TODO(shifangx): remove the use of parallel_state here
# after moving loss logging to loss_func in pretrain_gpt.py
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)

mtp_loss = loss_mask * mtp_loss
if self.training:
# TODO(shifangx): remove the use of parallel_state here
# after moving loss logging to loss_func in pretrain_gpt.py
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)

logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
# [s b h] => [b s h]
Expand Down
5 changes: 5 additions & 0 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def __init__(
assert self.mtp_config.enable, "MTP requires mtp_config.enable to be True"

self.use_fused_kernels = self.config.get("use_fused_kernels", False)
if getattr(self.mtp_config, "enable", False) and self.use_fused_kernels:
self.use_fused_kernels = False
logger.warning_once(
"MTP is not compatible with fused kernels for now. Automatically disable use_fused_kernels."
)
if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False):
# do not patch if overlap_moe_expert_parallel_comm is enabled
logger.warning_once(
Expand Down
1 change: 0 additions & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def _init_hf_config_and_tf_config(
if enable_mtp:
assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer"
assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True"
assert megatron_config.vanilla_mbridge, "MTP requires vanilla_mbridge to be True"
override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor
else:
if hasattr(hf_config, "num_nextn_predict_layers"):
Expand Down