diff --git a/verl/models/mcore/qat_patch.py b/verl/models/mcore/qat_patch.py new file mode 100644 index 00000000000..ec381a0f5de --- /dev/null +++ b/verl/models/mcore/qat_patch.py @@ -0,0 +1,541 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime patches for QAT (Quantization-Aware Training) with Megatron-Core. + +This module provides four independent monkey-patches that fix issues in older +versions of megatron-core / megatron-bridge when running QAT workflows: + +1. **SwiGLU sharded-state-dict patch** (``apply_swiglu_sharded_factory_patch``) + Older megatron-core raises ``NotImplementedError`` inside + ``apply_swiglu_sharded_factory`` when ``singleton_local_shards=True``. + The patch adds correct handling by splitting the sharded tensor key into + separate ``{key}_w`` / ``{key}_v`` entries. + +2. **EP gather_from_ep_ranks patch** (``apply_ep_gather_patch``) + The original ``MegatronParamMapping.gather_from_ep_ranks`` only supports + the TEGroupedMLP naming pattern (``weight`` / ``bias``). The patch + additionally supports the SequentialMLP pattern (``local_experts.``) + and adds better error handling. + +3. **extract_sort_key patch** (``apply_extract_sort_key_patch``) + The original ``extract_sort_key`` in megatron-bridge utils only recognises + expert numbers in TEGroupedMLP format (``weight`` / ``bias``). The + patch adds fallback support for the SequentialMLP pattern + (``local_experts.``). + +4. **build_conversion_tasks patch** (``apply_build_conversion_tasks_patch``) + The original ``MegatronModelBridge.build_conversion_tasks`` may return + ``None`` entries in the task list (for PP ranks that don't own certain + parameters and have no mapping). The patch filters out ``None`` entries + before returning so that callers never need to guard against them. + +Convenience entry-point:: + + from verl.models.mcore.qat_patch import apply_qat_patch + apply_qat_patch() # applies all patches at once +""" + +import gc +import logging +import re +from typing import Dict, Iterable, List, Optional + +import torch + +logger = logging.getLogger(__name__) + +# ====================================================================== +# 1. SwiGLU sharded-state-dict patch +# ====================================================================== + + +def apply_swiglu_sharded_factory_patch(): + """Patch ``megatron.core.transformer.mlp.apply_swiglu_sharded_factory`` + to support ``singleton_local_shards`` for SwiGLU MLP tensors. + + Idempotent – safe to call multiple times. + """ + import megatron.core.transformer.mlp as mlp_module + from megatron.core.dist_checkpointing import ShardedTensor + from megatron.core.dist_checkpointing.mapping import ( + ReplicaId, + ShardedTensorFactory, + ) + + if getattr(mlp_module, "_swiglu_patched", False): + return + mlp_module._swiglu_patched = True + mlp_module._original_apply_swiglu_sharded_factory = mlp_module.apply_swiglu_sharded_factory + + def patched_apply_swiglu_sharded_factory( + original_sh_ten, sharded_offsets, singleton_local_shards: bool = False + ): + swiglu_shard_axis = 0 + prepend_axis_num = len(sharded_offsets) + original_shape = original_sh_ten.local_shape + local_axis_size = original_shape[swiglu_shard_axis] + assert ( + original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] + % local_axis_size + == 0 + ) + rank_offset = ( + original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] + // local_axis_size + ) + axis_frag = original_sh_ten.axis_fragmentations[ + swiglu_shard_axis + prepend_axis_num + ] + + @torch.no_grad() + def sh_ten_build_fn( + key: str, + t: torch.Tensor, + replica_id: ReplicaId, + flattened_range: Optional[slice], + ): + if singleton_local_shards: + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag) + offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag) + w_key = f"{key}_w" + v_key = f"{key}_v" + else: + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2) + offset_v = ( + swiglu_shard_axis + prepend_axis_num, + rank_offset + axis_frag, + axis_frag * 2, + ) + w_key = key + v_key = key + + tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis) + return [ + ShardedTensor.from_rank_offsets( + w_key, tensor_w, *sharded_offsets, offset_w, + replica_id=replica_id, prepend_axis_num=prepend_axis_num, + ), + ShardedTensor.from_rank_offsets( + v_key, tensor_v, *sharded_offsets, offset_v, + replica_id=replica_id, prepend_axis_num=prepend_axis_num, + ), + ] + + def sh_ten_merge_fn(sub_state_dict): + with torch.no_grad(): + try: + return torch.cat(sub_state_dict) + except (RuntimeError, torch.cuda.OutOfMemoryError) as e: + logger.warning( + "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)", e, + ) + merged = torch.cat([t.cpu() for t in sub_state_dict]) + gc.collect() + torch.cuda.empty_cache() + return merged + + return ShardedTensorFactory( + original_sh_ten.key, + original_sh_ten.data, + sh_ten_build_fn, + sh_ten_merge_fn, + original_sh_ten.replica_id, + flattened_range=original_sh_ten.flattened_range, + ) + + mlp_module.apply_swiglu_sharded_factory = patched_apply_swiglu_sharded_factory + logger.info("Applied QAT patch: apply_swiglu_sharded_factory now supports singleton_local_shards.") + + +def revert_swiglu_sharded_factory_patch(): + """Revert :func:`apply_swiglu_sharded_factory_patch`.""" + import megatron.core.transformer.mlp as mlp_module + + if not getattr(mlp_module, "_swiglu_patched", False): + return + mlp_module.apply_swiglu_sharded_factory = mlp_module._original_apply_swiglu_sharded_factory + mlp_module._swiglu_patched = False + logger.info("Reverted QAT patch: apply_swiglu_sharded_factory.") + + +# ====================================================================== +# 2. EP gather_from_ep_ranks patch +# ====================================================================== + + +def apply_ep_gather_patch(): + """Patch ``MegatronParamMapping.gather_from_ep_ranks`` in megatron-bridge + to support both SequentialMLP (``local_experts.``) and TEGroupedMLP + (``weight`` / ``bias``) naming patterns. + + Idempotent – safe to call multiple times. + """ + from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping + + if getattr(MegatronParamMapping, "_ep_gather_patched", False): + return + MegatronParamMapping._ep_gather_patched = True + MegatronParamMapping._original_gather_from_ep_ranks = MegatronParamMapping.gather_from_ep_ranks + + def _patched_gather_from_ep_ranks( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module, # Optional[MegatronModule] + hf_param_name: Optional[str], + ) -> Dict[str, torch.Tensor]: + """Gather expert weights across EP ranks (supports SequentialMLP + TEGroupedMLP).""" + if megatron_module is None: + num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") + else: + model_config = self._get_config(megatron_module) + num_experts = model_config.num_moe_experts + num_experts_per_rank = num_experts // self.ep_size + num_experts_per_rank = self.broadcast_obj_from_pp_rank( + num_experts_per_rank, "num_experts_per_rank" + ) + + # --- Extract the local expert index from the Megatron param name --- + local_expert_number = None + + # Try SequentialMLP pattern first: local_experts. + local_experts_match = re.search(r"local_experts\.(\d+)", self.megatron_param) + if local_experts_match: + global_expert_number = int(local_experts_match.group(1)) + local_expert_number = global_expert_number % num_experts_per_rank + else: + # Fallback: TEGroupedMLP pattern – weight or bias + for key in (".weight", ".bias"): + if key in self.megatron_param: + suffix = self.megatron_param.split(key)[-1] + if suffix: # only if there is actually a number after the suffix + global_expert_number = int(suffix) + local_expert_number = global_expert_number % num_experts_per_rank + break + + if local_expert_number is None: + raise ValueError( + f"Could not extract expert number from parameter name: {self.megatron_param}. " + f"Expected either TEGroupedMLP pattern (weight/bias) or " + f"SequentialMLP pattern (local_experts.)." + ) + + # Build HF param names for every EP rank + gathered_expert_param_names = [ + re.sub( + r"experts\.(\d+)", + f"experts.{int(local_expert_number) + num_experts_per_rank * i}", + str(hf_param_name), + ) + for i in range(self.ep_size) + ] + assert str(hf_param_name) in gathered_expert_param_names, ( + f"hf_param_name {hf_param_name} not in gathered_expert_param_names " + f"{gathered_expert_param_names}" + ) + + # All-gather across the EP group + gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] + torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) + + # Assemble the result dict (handles duplicate names via concatenation) + weights_dict: Dict[str, torch.Tensor] = {} + for i, param_name in enumerate(gathered_expert_param_names): + if param_name in weights_dict: + weights_dict[param_name] = torch.cat( + [weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0 + ) + else: + weights_dict[param_name] = gathered_weights[i].unsqueeze(0) + for param_name in weights_dict: + weights_dict[param_name] = weights_dict[param_name].squeeze() + + return weights_dict + + MegatronParamMapping.gather_from_ep_ranks = _patched_gather_from_ep_ranks + logger.info( + "Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks " + "now supports SequentialMLP pattern." + ) + + +def revert_ep_gather_patch(): + """Revert :func:`apply_ep_gather_patch`.""" + from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping + + if not getattr(MegatronParamMapping, "_ep_gather_patched", False): + return + MegatronParamMapping.gather_from_ep_ranks = MegatronParamMapping._original_gather_from_ep_ranks + MegatronParamMapping._ep_gather_patched = False + logger.info("Reverted QAT patch: MegatronParamMapping.gather_from_ep_ranks.") + + +# ====================================================================== +# 3. extract_sort_key patch +# ====================================================================== + + +def apply_extract_sort_key_patch(): + """Patch ``megatron.bridge.models.conversion.utils.extract_sort_key`` + to support the SequentialMLP naming pattern (``local_experts.``) in + addition to the original TEGroupedMLP pattern (``weight`` / ``bias``). + + Idempotent – safe to call multiple times. + """ + import megatron.bridge.models.conversion.utils as utils_module + + if getattr(utils_module, "_sort_key_patched", False): + return + utils_module._sort_key_patched = True + utils_module._original_extract_sort_key = utils_module.extract_sort_key + + def _patched_extract_sort_key(param_name: str): + """Extract sorting key based on layer and expert numbers.""" + numbers = [] + + # Find layer number + layer_match = re.search(r"layers\.(\d+)", param_name) + if layer_match: + numbers.append(int(layer_match.group(1))) + + # Find expert number – try multiple patterns + expert_number = None + + # Pattern 1: TEGroupedMLP format (e.g., weight15, bias15) + expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) + if expert_match: + expert_number = int(expert_match.group(1)) + + # Pattern 2: SequentialMLP format (e.g., local_experts.15) + if expert_number is None: + local_experts_match = re.search(r"local_experts\.(\d+)", param_name) + if local_experts_match: + expert_number = int(local_experts_match.group(1)) + + if expert_number is not None: + numbers.append(expert_number) + + # Pad to ensure consistent comparison (max 2 numbers) + while len(numbers) < 2: + numbers.append(-1) + numbers = numbers[:2] + return numbers, param_name + + utils_module.extract_sort_key = _patched_extract_sort_key + logger.info( + "Applied QAT patch: extract_sort_key now supports SequentialMLP pattern." + ) + + +def revert_extract_sort_key_patch(): + """Revert :func:`apply_extract_sort_key_patch`.""" + import megatron.bridge.models.conversion.utils as utils_module + + if not getattr(utils_module, "_sort_key_patched", False): + return + utils_module.extract_sort_key = utils_module._original_extract_sort_key + utils_module._sort_key_patched = False + logger.info("Reverted QAT patch: extract_sort_key.") + + +# ====================================================================== +# 4. build_conversion_tasks patch +# ====================================================================== + + +def apply_build_conversion_tasks_patch(): + """Patch ``MegatronModelBridge.build_conversion_tasks`` to filter out + ``None`` entries before returning the task list. + + The original implementation can leave ``None`` slots for PP ranks that + don't own certain parameters and have no mapping. Downstream code that + iterates over the returned list may break on ``None``. This patch + ensures only valid :class:`WeightConversionTask` objects are returned. + + Idempotent – safe to call multiple times. + """ + import itertools + + from megatron.bridge.models.conversion.model_bridge import ( + MegatronModelBridge, + WeightConversionTask, + _megatron_local_name_to_global, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + persistent_buffers, + ) + from megatron.bridge.utils.common_utils import print_rank_0 + from megatron.core import parallel_state + from megatron.core.utils import unwrap_model + + if getattr(MegatronModelBridge, "_build_tasks_patched", False): + return + MegatronModelBridge._build_tasks_patched = True + MegatronModelBridge._original_build_conversion_tasks = ( + MegatronModelBridge.build_conversion_tasks + ) + + def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): + """Construct conversion tasks between HF and Megatron (``None``-free). + + Returns a list of :class:`WeightConversionTask` objects — ``None`` + entries are filtered out before the list is returned so that callers + never need to guard against them. + """ + # Ensure hf_pretrained has the required state structure + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() + + mapping_registry = self.mapping_registry() + unwrapped_model = unwrap_model(megatron_model)[0] + model_config = unwrapped_model.config + embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks( + megatron_model + ) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name + ] + + global_names_index_dict = { + name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks) + } + + tasks = [None] * len(sorted_global_param_names_all_pp_ranks) + for vp_stage, model in enumerate(megatron_model): + for local_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global( + megatron_model, model_config, local_name, vp_stage + ) + if global_name not in global_names_index_dict: + print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") + continue + global_name_idx = global_names_index_dict[global_name] + mapping = mapping_registry.megatron_to_hf_lookup( + self._get_lora_unwrapped_name(global_name) + ) + + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + + # Ensure HF weights exist + if not mapping.allow_hf_name_mismatch: + if isinstance(mapping.hf_param, str): + if mapping.hf_param not in hf_keys: + logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") + continue + else: + missing_params = [ + hf_param + for hf_param in mapping.hf_param.values() + if hf_param not in hf_keys + ] + if missing_params: + logger.warning( + f"WARNING: Can't find the following HF parameters in hf_keys: " + f"{missing_params}" + ) + continue + + local_module, local_weights = get_module_and_param_from_name( + megatron_model, local_name, vp_stage + ) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + + tasks[global_name_idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + + # Fill the remaining slots for PP communications + for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): + if tasks[idx] is None: + mapping = mapping_registry.megatron_to_hf_lookup( + self._get_lora_unwrapped_name(global_name) + ) + if mapping is None: + continue + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + global_param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + tasks = [task for task in tasks if task is not None] + return tasks + + MegatronModelBridge.build_conversion_tasks = _patched_build_conversion_tasks + logger.info( + "Applied QAT patch: MegatronModelBridge.build_conversion_tasks " + "now filters out None entries." + ) + + +def revert_build_conversion_tasks_patch(): + """Revert :func:`apply_build_conversion_tasks_patch`.""" + from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge + + if not getattr(MegatronModelBridge, "_build_tasks_patched", False): + return + MegatronModelBridge.build_conversion_tasks = ( + MegatronModelBridge._original_build_conversion_tasks + ) + MegatronModelBridge._build_tasks_patched = False + logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.") + + +# ====================================================================== +# Convenience: apply / revert all QAT patches at once +# ====================================================================== + + +def apply_qat_patch(): + """Apply **all** QAT-related patches. Idempotent.""" + apply_swiglu_sharded_factory_patch() + apply_ep_gather_patch() + apply_extract_sort_key_patch() + apply_build_conversion_tasks_patch() + + +def revert_qat_patch(): + """Revert **all** QAT-related patches.""" + revert_swiglu_sharded_factory_patch() + revert_ep_gather_patch() + revert_extract_sort_key_patch() + revert_build_conversion_tasks_patch() diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index b588a96c1b3..cbf2c53c733 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -72,6 +72,12 @@ override_transformer_config: # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl attention_backend: flash +# # Quantization method. None for no quantization, "nvfp4" for NVFP4 quantization +quantization: null + +# Whether to enable Quantization-Aware Training (QAT). Default False. +enable_qat: False + override_mcore_model_config: {} # oc.select: default val for ref.megatron.use_mbridge diff --git a/verl/utils/modelopt_qat_utils.py b/verl/utils/modelopt_qat_utils.py new file mode 100644 index 00000000000..7e8b63d401b --- /dev/null +++ b/verl/utils/modelopt_qat_utils.py @@ -0,0 +1,1056 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + QUANTIZATION_NVFP4, + get_quantization_format, + get_weight_block_size, + to_quantized_weight, +) +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +from verl.utils.megatron_utils import unwrap_model + +# --------------------------------------------------------------------------- +# NVFP4 quantization config +# --------------------------------------------------------------------------- + +NVFP4_WEIGHT_ONLY_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + "nn.BatchNorm1d": {"*": {"enable": False}}, + "nn.BatchNorm2d": {"*": {"enable": False}}, + "nn.BatchNorm3d": {"*": {"enable": False}}, + "nn.LeakyReLU": {"*": {"enable": False}}, + "*lm_head*": {"enable": False}, + "*proj_out.*": {"enable": False}, # Whisper: lm_head has key name proj_out + "*block_sparse_moe.gate*": {"enable": False}, # Skip MOE router + "*router*": {"enable": False}, # Skip MOE router + "*mlp.gate.*": {"enable": False}, # Skip MOE router + "*mlp.shared_expert_gate.*": {"enable": False}, # Skip MOE router + "*linear_attn.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, + "*output_layer*": {"enable": False}, + "output.*": {"enable": False}, + "default": {"enable": False}, + }, + "algorithm": "max", +} + +# --------------------------------------------------------------------------- +# QAT application +# --------------------------------------------------------------------------- + + +def apply_qat(model: nn.Module, quant_method: str): + """Apply Quantization-Aware Training to the model. + + Args: + model: The Megatron model to apply QAT to. + quant_method: Quantization method (currently only ``"nvfp4"`` is supported). + + Returns: + The quantized model. + """ + if quant_method != "nvfp4": + raise ValueError(f"Only 'nvfp4' is supported, got: {quant_method}") + + mtq.quantize(model, NVFP4_WEIGHT_ONLY_CFG) + return model + + +@dataclass +class QuantizationMetadata: + """Metadata for a quantized module.""" + + qformat: str + weight_quantizer: Any + input_quantizer: Any + module: torch.nn.Module + vpp_idx: int + block_size: int = 16 # Default NVFP4 block size + # Fields for EP synchronization - store amax values for non-local experts + weight_amax: Optional[torch.Tensor] = None + input_amax: Optional[torch.Tensor] = None + is_local: bool = True # Whether this expert is local to current EP rank + global_expert_idx: Optional[int] = None # Global expert index for MoE experts + local_expert_idx: Optional[int] = None # Local expert index on this EP rank + + +class QATWeightPostProcessor: + """ + Post-processor for extracting quantization info from QAT trained modules + and converting bf16 weights to quantized formats (e.g., NVFP4). + + Key Design: + 1. Collect quantization metadata (quantizers, amax, block_size) from QAT modules + 2. Process all_gathered bf16 weights to compute quantized weights and scaling factors + 3. The scaling factors are computed on the merged (all_gathered) weights to ensure + correct block boundaries for per-block quantization (NVFP4) + + Note on TP (Tensor Parallelism): + - For NVFP4, weight_scale_2 (global scale) should ideally be computed from the full + (all_gathered) weight to ensure consistency across TP ranks. + - If use_calibrated_scale_2=True (default), we use the QAT calibrated amax which may + only reflect the local shard's statistics. + - If use_calibrated_scale_2=False, we recompute weight_scale_2 from the merged weight. + Note on EP (Expert Parallelism): + - When EP is enabled, each rank only holds a subset of experts (local_experts) + - We synchronize metadata across all EP ranks to ensure complete metadata for all experts + - Local expert indices are converted to global expert indices for proper mapping + """ + + def __init__( + self, + actor_module: list, + quantization_method: str = "nvfp4", + dtype: torch.dtype = torch.bfloat16, + use_calibrated_scale_2: bool = False, + ): + """ + Initialize the QAT weight post-processor. + + Args: + actor_module: List of QAT trained model chunks (vpp chunks) + quantization_method: Quantization method (nvfp4, fp8, etc.) + dtype: Original data type (bf16) + use_calibrated_scale_2: If True, use QAT calibrated amax for weight_scale_2. + If False, recompute weight_scale_2 from merged weights. Recommended to set + False when using TP to ensure consistent global scale. + """ + self.actor_module = actor_module + self.quantization_method = quantization_method + self.dtype = dtype + self.use_calibrated_scale_2 = use_calibrated_scale_2 + self.quant_metadata: dict[str, QuantizationMetadata] = {} + self.ep_size, self.ep_rank, self.ep_group = self._get_ep_info() + self.pp_size, self.pp_rank, self.pp_group = self._get_pp_info() + self.num_local_experts = 0 # Will be determined during metadata building + + self._build_quantization_metadata() + + global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + # Synchronize metadata across EP ranks if EP is enabled + if self.ep_size > 1: + print(f"[QAT PostProcessor][Rank {global_rank}] Starting EP metadata sync...") + self._sync_quantization_metadata_across_ep() + print(f"[QAT PostProcessor][Rank {global_rank}] After EP sync: metadata_count={len(self.quant_metadata)}") + + # Synchronize metadata across PP ranks if PP is enabled + # This ensures all PP ranks have complete metadata for all layers + if self.pp_size > 1: + print(f"[QAT PostProcessor][Rank {global_rank}] Starting PP metadata sync...") + self._sync_quantization_metadata_across_pp() + print(f"[QAT PostProcessor][Rank {global_rank}] After PP sync: metadata_count={len(self.quant_metadata)}") + else: + print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_size={self.pp_size}") + + self._log_initialization_info() + + def _get_ep_info(self) -> tuple[int, int, Any]: + """ + Get Expert Parallel information from Megatron parallel state. + + Returns: + (ep_size, ep_rank, ep_group): EP world size, rank, and process group + """ + try: + from megatron.core import parallel_state as mpu + + ep_size = mpu.get_expert_model_parallel_world_size() + if ep_size > 1: + ep_rank = mpu.get_expert_model_parallel_rank() + ep_group = mpu.get_expert_model_parallel_group() + return ep_size, ep_rank, ep_group + except Exception: + # EP not enabled or mpu not available + pass + return 1, 0, None + + def _get_pp_info(self) -> tuple[int, int, Any]: + """ + Get Pipeline Parallel information from Megatron parallel state. + + Returns: + (pp_size, pp_rank, pp_group): PP world size, rank, and process group + """ + try: + from megatron.core import parallel_state as mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_group = mpu.get_pipeline_model_parallel_group() + + if torch.distributed.get_rank() == 0: + print(f"[QAT PostProcessor] PP info: pp_size={pp_size}, pp_rank={pp_rank}, pp_group={pp_group}") + + if pp_size > 1: + return pp_size, pp_rank, pp_group + else: + return pp_size, pp_rank, None + except Exception as e: + if torch.distributed.get_rank() == 0: + print(f"[QAT PostProcessor] Warning: Failed to get PP info: {e}") + pass + return 1, 0, None + + def _extract_layer_index(self, name: str) -> Optional[int]: + """ + Extract layer index from parameter name. + + For mcore format: decoder.layers.{layer_idx}.xxx + + Returns: + Layer index or None if not a layer parameter + """ + match = re.search(r"layers\.(\d+)\.", name) + if match: + return int(match.group(1)) + return None + + def _get_num_layers_per_pp_stage(self) -> int: + """ + Get the number of layers per PP stage from local metadata. + + This is calculated as max(local_layer_indices) + 1 + """ + max_layer_idx = -1 + for name in self.quant_metadata.keys(): + layer_idx = self._extract_layer_index(name) + if layer_idx is not None and layer_idx > max_layer_idx: + max_layer_idx = layer_idx + return max_layer_idx + 1 if max_layer_idx >= 0 else 0 + + def _convert_local_to_global_layer_name(self, name: str, source_pp_rank: int, num_layers_per_stage: int) -> str: + """ + Convert parameter name from local layer index to global layer index. + + Args: + name: Parameter name with local layer index (e.g., decoder.layers.0.xxx) + source_pp_rank: The PP rank this name came from + num_layers_per_stage: Number of layers per PP stage + + Returns: + Parameter name with global layer index + """ + local_layer_idx = self._extract_layer_index(name) + if local_layer_idx is None: + return name + + global_layer_idx = source_pp_rank * num_layers_per_stage + local_layer_idx + return re.sub(r"layers\.(\d+)\.", f"layers.{global_layer_idx}.", name, count=1) + + def _extract_local_expert_index(self, name: str) -> Optional[int]: + """ + Extract local expert index from parameter name. + + For SequentialMLP structure, the pattern is: + decoder.layers.{layer}.mlp.experts.local_experts.{local_idx}.linear_fc1/fc2.weight + + Args: + name: Parameter name in mcore format + + Returns: + Local expert index or None if not an expert parameter + """ + match = re.search(r"local_experts\.(\d+)\.", name) + if match: + return int(match.group(1)) + return None + + def _local_to_global_expert_index(self, local_idx: int) -> int: + """ + Convert local expert index to global expert index. + + Global index = ep_rank * num_local_experts + local_idx + + Args: + local_idx: Local expert index on this EP rank + + Returns: + Global expert index + """ + return self.ep_rank * self.num_local_experts + local_idx + + def _convert_name_to_global_index(self, name: str, local_idx: int, global_idx: int) -> str: + """ + Convert parameter name from local to global expert index. + + Args: + name: Original parameter name with local index + local_idx: Local expert index + global_idx: Global expert index + + Returns: + Parameter name with global expert index + """ + return name.replace(f"local_experts.{local_idx}.", f"local_experts.{global_idx}.") + + def _build_quantization_metadata(self): + """ + Extract quantization metadata from all modules in actor_module. + Stores: {param_name: QuantizationMetadata} + + For EP training with SequentialMLP: + - Detects local expert indices and computes global indices + - Stores metadata with global expert indices as keys + """ + # First pass: collect all local expert indices to determine num_local_experts + local_expert_indices = set() + + for vpp_idx, module in enumerate(self.actor_module): + model = unwrap_model(module) + for name, submodule in model.named_modules(): + local_idx = self._extract_local_expert_index(name) + if local_idx is not None: + local_expert_indices.add(local_idx) + + if local_expert_indices: + self.num_local_experts = max(local_expert_indices) + 1 + if torch.distributed.get_rank() == 0: + print(f"[QAT PostProcessor] Detected {self.num_local_experts} local experts per EP rank") + + # Second pass: build metadata with global indices + for vpp_idx, module in enumerate(self.actor_module): + model = unwrap_model(module) + + for name, submodule in model.named_modules(): + # Check if this module is quantized + qformat = get_quantization_format(submodule) + if qformat == QUANTIZATION_NONE: + continue + + block_size = get_weight_block_size(submodule) + if block_size == 0: + continue + + weight_quantizer = getattr(submodule, "weight_quantizer", None) + input_quantizer = getattr(submodule, "input_quantizer", None) + + # Extract amax values for synchronization + weight_amax = None + input_amax = None + if weight_quantizer is not None and hasattr(weight_quantizer, "_amax"): + weight_amax = weight_quantizer._amax.clone().cpu() if weight_quantizer._amax is not None else None + if input_quantizer is not None and hasattr(input_quantizer, "_amax"): + input_amax = input_quantizer._amax.clone().cpu() if input_quantizer._amax is not None else None + + # Determine global expert index for MoE experts + local_expert_idx = self._extract_local_expert_index(name) + global_expert_idx = None + if local_expert_idx is not None and self.ep_size > 1: + global_expert_idx = self._local_to_global_expert_index(local_expert_idx) + + metadata = QuantizationMetadata( + qformat=qformat, + weight_quantizer=weight_quantizer, + input_quantizer=input_quantizer, + module=submodule, + vpp_idx=vpp_idx, + block_size=block_size, + weight_amax=weight_amax, + input_amax=input_amax, + is_local=True, + global_expert_idx=global_expert_idx, + local_expert_idx=local_expert_idx, + ) + + for param_name, _ in submodule.named_parameters(recurse=False): + full_name = f"{name}.{param_name}" if name else param_name + + # For EP training, store with global expert index as key + if local_expert_idx is not None and self.ep_size > 1: + global_name = self._convert_name_to_global_index(full_name, local_expert_idx, global_expert_idx) + self.quant_metadata[global_name] = metadata + else: + self.quant_metadata[full_name] = metadata + + def _log_initialization_info(self): + """Log initialization information for debugging.""" + global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + print( + f"[QAT PostProcessor][Rank {global_rank}] Initialized with quantization method: {self.quantization_method}" + ) + print(f"[QAT PostProcessor][Rank {global_rank}] Found {len(self.quant_metadata)} quantized parameters") + if self.ep_size > 1: + print( + f"[QAT PostProcessor][Rank {global_rank}] EP enabled: ep_size={self.ep_size}, ep_rank={self.ep_rank}, " + f"num_local_experts={self.num_local_experts}" + ) + if self.pp_size > 1: + local_count = sum(1 for m in self.quant_metadata.values() if m.is_local) + remote_count = sum(1 for m in self.quant_metadata.values() if not m.is_local) + print( + f"[QAT PostProcessor][Rank {global_rank}] PP enabled: pp_size={self.pp_size}, pp_rank={self.pp_rank}, " + f"local_params={local_count}, remote_params={remote_count}" + ) + + # Log all metadata entries for debugging + for name, metadata in self.quant_metadata.items(): + extra_info = "" + if metadata.global_expert_idx is not None: + extra_info = f", global_expert_idx={metadata.global_expert_idx}" + if not metadata.is_local: + extra_info += ", is_local=False" + print( + f"[QAT PostProcessor][Rank {global_rank}] Metadata: {name}, qformat={metadata.qformat}, " + f"block_size={metadata.block_size}{extra_info}" + ) + + def _sync_quantization_metadata_across_ep(self): + """ + Synchronize quantization metadata across all EP (Expert Parallel) ranks. + + When EP is enabled, each rank only holds metadata for its local experts. + This method gathers metadata from all EP ranks and merges them so that + every rank has complete metadata for all experts. + + For SequentialMLP structure: + - Local expert indices are converted to global indices + - Metadata is gathered and merged using global indices as keys + - Non-local experts have is_local=False and module/quantizers set to None + """ + if self.ep_size <= 1 or self.ep_group is None: + return + + # Prepare serializable metadata info for all_gather + # We can't send module/quantizer objects, so we extract necessary info + local_metadata_info = {} + for name, metadata in self.quant_metadata.items(): + # Only sync MoE expert metadata (containing "local_experts") + if "local_experts" not in name: + continue + + local_metadata_info[name] = { + "qformat": metadata.qformat, + "block_size": metadata.block_size, + "vpp_idx": metadata.vpp_idx, + "weight_amax": metadata.weight_amax, + "input_amax": metadata.input_amax, + "global_expert_idx": metadata.global_expert_idx, + "local_expert_idx": metadata.local_expert_idx, + } + + # Also send num_local_experts for validation + sync_data = { + "metadata": local_metadata_info, + "num_local_experts": self.num_local_experts, + "ep_rank": self.ep_rank, + } + + # Gather metadata from all EP ranks + all_sync_data = [None] * self.ep_size + torch.distributed.all_gather_object(all_sync_data, sync_data, group=self.ep_group) + + # Validate that all ranks have the same num_local_experts + for rank_idx, data in enumerate(all_sync_data): + if data is not None and data["num_local_experts"] != self.num_local_experts: + print( + f"[QAT PostProcessor] Warning: EP rank {rank_idx} has " + f"{data['num_local_experts']} local experts, expected {self.num_local_experts}" + ) + + # Merge metadata from all ranks + for rank_idx, data in enumerate(all_sync_data): + if rank_idx == self.ep_rank: + # Skip local metadata (already have it) + continue + + if data is None: + continue + + rank_metadata = data["metadata"] + for name, info in rank_metadata.items(): + if name in self.quant_metadata: + # Already have this metadata (shouldn't happen with proper global indices) + continue + + # Create metadata entry for non-local experts + # Note: module and quantizers are not available for non-local experts + metadata = QuantizationMetadata( + qformat=info["qformat"], + weight_quantizer=None, # Not available for non-local + input_quantizer=None, # Not available for non-local + module=None, # Not available for non-local + vpp_idx=info["vpp_idx"], + block_size=info["block_size"], + weight_amax=info["weight_amax"], + input_amax=info["input_amax"], + is_local=False, # Mark as non-local + global_expert_idx=info["global_expert_idx"], + local_expert_idx=info["local_expert_idx"], + ) + self.quant_metadata[name] = metadata + + # Count local vs non-local experts + num_local = sum(1 for m in self.quant_metadata.values() if m.is_local and m.global_expert_idx is not None) + num_remote = sum(1 for m in self.quant_metadata.values() if not m.is_local and m.global_expert_idx is not None) + + if torch.distributed.get_rank() == 0: + print( + f"[QAT PostProcessor] EP metadata sync complete. " + f"EP size: {self.ep_size}, Local expert params: {num_local}, " + f"Remote expert params: {num_remote}, Total metadata entries: {len(self.quant_metadata)}" + ) + + def _sync_quantization_metadata_across_pp(self): + """ + Synchronize quantization metadata across all PP (Pipeline Parallel) ranks. + + When PP is enabled, each rank only holds layers for its pipeline stage. + This method gathers metadata from all PP ranks and merges them so that + every rank has complete metadata for all layers. + + IMPORTANT: In Megatron's PP mode, each PP rank uses LOCAL layer indices + (starting from 0), not global layer indices. For example: + - PP rank 0 has decoder.layers.0 (globally layer 0) + - PP rank 1 has decoder.layers.0 (globally layer 1) + + This method converts local layer indices to global layer indices during sync. + + For MoE SequentialMLP structure with PP: + - Different PP ranks hold different decoder layers + - Each PP rank builds metadata only for its local layers + - We gather and merge metadata from all PP ranks + - Layer indices are converted from local to global during merge + - Non-local layers have is_local=False and module/quantizers set to None + """ + global_rank = torch.distributed.get_rank() + + print( + f"[QAT PostProcessor][Rank {global_rank}] PP sync starting: " + f"pp_size={self.pp_size}, pp_rank={self.pp_rank}, pp_group={self.pp_group}, " + f"local_metadata_count={len(self.quant_metadata)}" + ) + + if self.pp_size <= 1: + print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_size <= 1") + return + + if self.pp_group is None: + print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_group is None") + return + + # Verify PP group size matches expected pp_size + actual_pp_group_size = torch.distributed.get_world_size(group=self.pp_group) + print( + f"[QAT PostProcessor][Rank {global_rank}] PP group size verification: " + f"expected={self.pp_size}, actual={actual_pp_group_size}" + ) + + # Calculate number of layers per PP stage (needed for global layer index conversion) + num_layers_per_stage = self._get_num_layers_per_pp_stage() + print(f"[QAT PostProcessor][Rank {global_rank}] Detected {num_layers_per_stage} layers per PP stage") + + # First, convert our local metadata to use global layer indices + # This is needed so we can properly merge with other PP ranks + local_metadata_with_global_indices = {} + for name, metadata in self.quant_metadata.items(): + global_name = self._convert_local_to_global_layer_name(name, self.pp_rank, num_layers_per_stage) + local_metadata_with_global_indices[global_name] = metadata + + # Update our metadata dict to use global layer indices + self.quant_metadata = local_metadata_with_global_indices + + # Prepare serializable metadata info for all_gather + # We can't send module/quantizer objects, so we extract necessary info + local_metadata_info = {} + for name, metadata in self.quant_metadata.items(): + local_metadata_info[name] = { + "qformat": metadata.qformat, + "block_size": metadata.block_size, + "vpp_idx": metadata.vpp_idx, + "weight_amax": metadata.weight_amax, + "input_amax": metadata.input_amax, + "global_expert_idx": metadata.global_expert_idx, + "local_expert_idx": metadata.local_expert_idx, + "is_local": metadata.is_local, + } + + # Include PP rank info and num_layers_per_stage for global index conversion + sync_data = { + "metadata": local_metadata_info, + "pp_rank": self.pp_rank, + "num_local_experts": self.num_local_experts, + "num_layers_per_stage": num_layers_per_stage, + "global_rank": global_rank, + } + + print( + f"[QAT PostProcessor][Rank {global_rank}] Preparing to sync {len(local_metadata_info)} metadata entries, " + f"sample keys (global indices): {list(local_metadata_info.keys())[:3]}" + ) + + # Gather metadata from all PP ranks + all_sync_data = [None] * actual_pp_group_size + torch.distributed.all_gather_object(all_sync_data, sync_data, group=self.pp_group) + + # Debug: print what we received + print(f"[QAT PostProcessor][Rank {global_rank}] Received data from {len(all_sync_data)} PP ranks") + for i, data in enumerate(all_sync_data): + if data is not None: + sample_keys = list(data.get("metadata", {}).keys())[:2] + print( + f"[QAT PostProcessor][Rank {global_rank}] PP rank {i}: " + f"received from global_rank={data.get('global_rank', 'unknown')}, " + f"pp_rank={data.get('pp_rank', 'unknown')}, " + f"metadata_count={len(data.get('metadata', {}))}, " + f"sample_keys={sample_keys}" + ) + + # Merge metadata from all PP ranks + local_metadata_before = len(self.quant_metadata) + for rank_idx, data in enumerate(all_sync_data): + if data is None: + print(f"[QAT PostProcessor][Rank {global_rank}] Skipping rank_idx={rank_idx}: data is None") + continue + + source_pp_rank = data.get("pp_rank") + + # Skip our own data - compare by pp_rank from the data, not by index + if source_pp_rank == self.pp_rank: + print( + f"[QAT PostProcessor][Rank {global_rank}] Skipping rank_idx={rank_idx}: same pp_rank={self.pp_rank}" + ) + continue + + rank_metadata = data["metadata"] + added_count = 0 + skipped_existing = 0 + + for name, info in rank_metadata.items(): + # The name already has global layer indices (converted by the sender) + if name in self.quant_metadata: + # Already have this metadata (shouldn't happen with correct global indices) + existing = self.quant_metadata[name] + if existing.is_local: + skipped_existing += 1 + continue + # If both are non-local, just keep existing + skipped_existing += 1 + continue + + # Create metadata entry for layers from other PP ranks + # Note: module and quantizers are not available for non-local layers + metadata = QuantizationMetadata( + qformat=info["qformat"], + weight_quantizer=None, # Not available for non-local PP rank + input_quantizer=None, # Not available for non-local PP rank + module=None, # Not available for non-local PP rank + vpp_idx=info["vpp_idx"], + block_size=info["block_size"], + weight_amax=info["weight_amax"], + input_amax=info["input_amax"], + is_local=False, # Mark as non-local (from other PP rank) + global_expert_idx=info["global_expert_idx"], + local_expert_idx=info["local_expert_idx"], + ) + self.quant_metadata[name] = metadata + added_count += 1 + + print( + f"[QAT PostProcessor][Rank {global_rank}] From pp_rank={source_pp_rank}: " + f"added {added_count} metadata entries, skipped {skipped_existing} existing" + ) + + # Log statistics + metadata_added = len(self.quant_metadata) - local_metadata_before + local_count = sum(1 for m in self.quant_metadata.values() if m.is_local) + remote_count = sum(1 for m in self.quant_metadata.values() if not m.is_local) + + print( + f"[QAT PostProcessor][Rank {global_rank}] PP metadata sync complete. " + f"PP size: {self.pp_size}, PP rank: {self.pp_rank}, " + f"Local params: {local_count}, Remote params: {remote_count}, " + f"Metadata added from other PP ranks: {metadata_added}, " + f"Total metadata entries: {len(self.quant_metadata)}" + ) + + def _find_matching_metadata(self, param_name: str) -> QuantizationMetadata | None: + """ + Find matching quantization metadata for a parameter name. + Handles potential name variations between training and export. + """ + # Direct match + if param_name in self.quant_metadata: + return self.quant_metadata[param_name] + + # Try removing common prefixes/suffixes + variations = [ + param_name, + param_name.replace("module.", ""), + param_name.replace("model.", ""), + ] + + for var in variations: + if var in self.quant_metadata: + return self.quant_metadata[var] + + return None + + def _quantize_weight( + self, + name: str, + weight: torch.Tensor, + metadata: QuantizationMetadata, + ) -> Iterator[tuple[str, torch.Tensor]]: + """ + Quantize a single weight parameter. + + Args: + name: Parameter name + weight: The all_gathered bf16 weight tensor + metadata: Quantization metadata + + Yields: + (param_name, param_tensor) for quantized weight and scaling factors + """ + qformat = metadata.qformat + + if qformat == QUANTIZATION_NVFP4: + yield from self._quantize_nvfp4(name, weight, metadata) + else: + # Unknown format, pass through with warning + print(f"[QAT PostProcessor] Warning: Unknown qformat {qformat} for {name}, passing through") + yield (name, weight) + + def _quantize_nvfp4( + self, + name: str, + weight: torch.Tensor, + metadata: QuantizationMetadata, + ) -> Iterator[tuple[str, torch.Tensor]]: + """ + NVFP4 quantization implementation. + + NVFP4 uses two-level scaling: + - weight_scale_2 (global): per-tensor scale = amax / (6.0 * 448.0) + - weight_scale (per-block): per-block scale in FP8 format + + The weight is packed into uint8 format (2 x FP4 values per byte). + + Yields: + (name, quantized_weight): Packed uint8 weight + (name + "_scale", weight_scale): Per-block FP8 scaling factors + (name + "_scale_2", weight_scale_2): Global scaling factor + (name + "_input_scale", input_scale): Input activation scale (if available) + """ + weight_quantizer = metadata.weight_quantizer + input_quantizer = metadata.input_quantizer + block_size = metadata.block_size + qformat = metadata.qformat + + # # Ensure weight is in float for quantization computation + # weight_float = weight.float() + + # Step 1: Compute weight_scale_2 (global scale) + # For TP sharding, we should recompute weight_scale_2 from merged weight + # to ensure consistent global scale across all TP ranks. + if self.use_calibrated_scale_2 and weight_quantizer is not None and hasattr(weight_quantizer, "_amax"): + # Use QAT calibrated amax (may only reflect local shard statistics) + # weight_scale_2 = amax / (6.0 * 448.0) + weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + elif metadata.weight_amax is not None: + # Non-local expert (EP): Use synchronized amax from metadata + weight_amax = metadata.weight_amax.to(weight.device) + weight_scale_2 = weight_amax.float() / (6.0 * 448.0) + else: + # Compute from all_gathered weight directly (recommended for TP) + # weight_scale_2 = max(abs(weight)) / (6.0 * 448.0) + weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2(weight) + + # Step 2: Compute weight_scale (per-block scale) + # This MUST be computed on the all_gathered (merged) weight to ensure + # correct block boundaries + # weight_scale shape: [out_dim, in_dim / block_size], dtype: float8_e4m3fn + weight_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + block_size, + weights_scaling_factor_2=weight_scale_2.to(weight.device), + )[0] + + # Step 3: Quantize weight to NVFP4 packed format + quantized_weight = to_quantized_weight( + weight, + weight_scale, + qformat, + weight_scale_2, + block_size, + ) + + # Yield quantized weight + yield (name, quantized_weight) + + # Yield scaling factors + # Note: Use consistent naming convention with ModelOpt export + scale_name = name.replace(".weight", ".weight_scale") + if scale_name == name: + scale_name = name + "_scale" + yield (scale_name, weight_scale) + + scale_2_name = name.replace(".weight", ".weight_scale_2") + if scale_2_name == name: + scale_2_name = name + "_scale_2" + yield (scale_2_name, weight_scale_2) + + # Step 4: Export input_scale (activation quantization) if available + if input_quantizer is not None: + input_scale = self._get_input_scale(input_quantizer) + if input_scale is not None: + input_scale_name = name.replace(".weight", ".input_scale") + if input_scale_name == name: + input_scale_name = name + "_input_scale" + yield (input_scale_name, input_scale) + + def _get_input_scale(self, input_quantizer) -> torch.Tensor | None: + """ + Get input activation scaling factor from quantizer. + + Args: + input_quantizer: The input quantizer from the module + + Returns: + Input scaling factor tensor or None + """ + if input_quantizer is None: + return None + + if not hasattr(input_quantizer, "_amax"): + return None + + amax = input_quantizer._amax + if amax is None: + return None + + # For NVFP4, use the NVFP4QTensor method + if hasattr(NVFP4QTensor, "get_activation_scaling_factor"): + return NVFP4QTensor.get_activation_scaling_factor(input_quantizer) + + return amax.float() / (6.0 * 448.0) + + def process_weights_iterator( + self, + per_tensor_param: Iterator[tuple[str, torch.Tensor]], + ) -> Iterator[tuple[str, torch.Tensor]]: + """ + Process an iterator of weights and yield quantized results. + + This method wraps per_tensor_generator output and applies quantization + to each weight, yielding the quantized weights and scaling factors. + + Args: + per_tensor_param: Iterator of (name, bf16_weight) from per_tensor_generator + + Yields: + (name, tensor): Quantized weight and associated scaling factors + """ + for name, param in per_tensor_param: + # quantize_single_tensor returns a list of (name, tensor) tuples + # For NVFP4: [(name, quant_weight), (name_scale, scale), (name_scale_2, scale_2), ...] + # For non-quantized: [(name, original_weight)] + quantized_results = self.quantize_single_tensor(name, param) + for q_name, q_tensor in quantized_results: + yield (q_name, q_tensor) + + def quantize_single_tensor( + self, + name: str, + weight: torch.Tensor, + ) -> list[tuple[str, torch.Tensor]]: + """ + Quantize a single tensor and return all related tensors as a list. + + This method is designed to be called AFTER weight_converter.convert_param, + so the name should already be in HF format (e.g., 'model.layers.0.self_attn.q_proj.weight'). + + Args: + name: Parameter name in HF format + weight: Single tensor to quantize + + Returns: + List of (param_name, param_tensor) tuples: + - (name, quantized_weight) + - (name.replace('.weight', '.weight_scale'), weight_scale) # for NVFP4 + - (name.replace('.weight', '.weight_scale_2'), weight_scale_2) # for NVFP4 + """ + # Find matching metadata using the original mcore name pattern + # Since name is now in HF format, we need to check if this layer type should be quantized + metadata = self._find_matching_metadata_by_hf_name(name) + + if metadata is None: + # Not quantized, return original tensor + return [(name, weight)] + + # Quantize this tensor + return list(self._quantize_weight(name, weight, metadata)) + + def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetadata | None: + """ + Find matching quantization metadata for an HF-format parameter name. + + This maps HF names back to the original mcore names to find metadata. + E.g., 'model.layers.0.self_attn.q_proj.weight' -> check if qkv layer is quantized + + The mapping logic: + - HF q_proj/k_proj/v_proj.weight -> mcore linear_qkv.weight + - HF o_proj.weight -> mcore linear_proj.weight + - HF gate_proj/up_proj.weight -> mcore linear_fc1.weight + - HF down_proj.weight -> mcore linear_fc2.weight + - MoE experts: model.layers.X.mlp.experts.Y.gate_proj/up_proj/down_proj.weight + - MoE router (gate): model.layers.X.mlp.gate.weight -> NOT quantized (returns None) + """ + + # Only process weight parameters + if not hf_name.endswith(".weight") or hf_name.endswith("._amax") or "norm" in hf_name: + return None + + # Check for MoE router (gate) - should NOT be quantized + # HF formats: model.layers.X.mlp.gate.weight (Qwen) + # model.layers.X.block_sparse_moe.gate.weight (Mixtral) + if self._is_moe_router(hf_name): + return None + + # Extract layer number from HF name + layer_match = re.search(r"layers?\.(\d+)\.", hf_name) + if not layer_match: + # Not a layer parameter (e.g., embed_tokens, lm_head, norm) + # Check for direct matches + return self._find_non_layer_metadata(hf_name) + + layer_num = layer_match.group(1) + + # Determine the mcore module name based on HF name pattern + mcore_patterns = [] + + if "self_attn" in hf_name: + if any(proj in hf_name for proj in ["q_proj", "k_proj", "v_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_qkv.weight") + elif "o_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_proj.weight") + elif "mlp" in hf_name: + # Check for MoE experts first + # HF format: model.layers.X.mlp.experts.Y.gate_proj/up_proj/down_proj.weight + # HF Mixtral format: model.layers.X.block_sparse_moe.experts.Y.w1/w2/w3.weight + expert_match = re.search(r"\.experts\.(\d+)\.", hf_name) + if expert_match: + expert_id = expert_match.group(1) # This is the global expert ID in HF format + # MoE expert layers - use global expert ID for SequentialMLP + if any(proj in hf_name for proj in ["gate_proj", "up_proj", "w1", "w3"]): + # Try TEGroupedMLP pattern first (all experts share same linear layer) + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.experts.linear_fc1.weight") + # Try SequentialMLP pattern with global expert index + mcore_patterns.append( + f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight" + ) + elif any(proj in hf_name for proj in ["down_proj", "w2"]): + # Try TEGroupedMLP pattern first + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.experts.linear_fc2.weight") + # Try SequentialMLP pattern with global expert index + mcore_patterns.append( + f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc2.weight" + ) + # Check for shared_expert (Qwen2 MoE) + elif "shared_expert" in hf_name: + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc1.weight") + elif "down_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc2.weight") + else: + # Dense MLP + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") + elif "down_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") + + # Try to find matching metadata + for pattern in mcore_patterns: + if pattern in self.quant_metadata: + return self.quant_metadata[pattern] + + # # If no exact match, try to find any metadata from the same layer + # # This handles cases where the exact name might be slightly different + # for mcore_name, metadata in self.quant_metadata.items(): + # if f"layers.{layer_num}." in mcore_name: + # # Found a quantized module in the same layer + # # Skip router metadata - router should not be used for other layers + # if ".router." in mcore_name: + # continue + # # For QAT, if any module in the layer is quantized, all Linear layers should be + # if ".weight" in mcore_name: + # return metadata + + return None + + def _is_moe_router(self, hf_name: str) -> bool: + """ + Check if the HF parameter name corresponds to a MoE router (gate). + + MoE router should NOT be quantized to maintain routing precision. + + Router naming patterns: + - Qwen/Qwen2/Qwen3 MoE: model.layers.X.mlp.gate.weight + - Mixtral: model.layers.X.block_sparse_moe.gate.weight + - Shared expert gate (Qwen2 MoE): model.layers.X.mlp.shared_expert_gate.weight + + Note: gate_proj is NOT the router, it's part of the MLP expert. + """ + + # Pattern 1: Qwen/Qwen3 MoE router - model.layers.X.mlp.gate.weight + # Must be exactly ".mlp.gate.weight" not ".mlp.gate_proj.weight" + if re.search(r"\.mlp\.gate\.weight$", hf_name): + return True + + # Pattern 2: Mixtral router - model.layers.X.block_sparse_moe.gate.weight + if re.search(r"\.block_sparse_moe\.gate\.weight$", hf_name): + return True + + # Pattern 3: Qwen2 MoE shared expert gate - model.layers.X.mlp.shared_expert_gate.weight + if re.search(r"\.mlp\.shared_expert_gate\.weight$", hf_name): + return True + + return False + + def _find_non_layer_metadata(self, hf_name: str) -> QuantizationMetadata | None: + """Find metadata for non-layer parameters (embed_tokens, lm_head, etc.).""" + # Map HF names to mcore names for non-layer parameters + name_mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "lm_head.weight": "output_layer.weight", + "model.norm.weight": "decoder.final_layernorm.weight", + } + + mcore_name = name_mapping.get(hf_name) + if mcore_name and mcore_name in self.quant_metadata: + return self.quant_metadata[mcore_name] + + return None \ No newline at end of file diff --git a/verl/utils/modelopt_vllm_utils.py b/verl/utils/modelopt_vllm_utils.py new file mode 100644 index 00000000000..18bbaca5ced --- /dev/null +++ b/verl/utils/modelopt_vllm_utils.py @@ -0,0 +1,841 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Callable, Optional +from unittest.mock import patch + +import torch + +logger = logging.getLogger(__name__) +from torch.nn import Parameter + + +def generate_nvfp4_ignore_list(num_layers: int, is_moe: bool) -> list[str]: + """ + Generate the ignore list for NVFP4 quantization based on model configuration. + + Args: + num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers) + is_moe: Whether the model is a Mixture of Experts model + + Returns: + List of layer names to ignore during quantization + """ + ignore_list = [] + + # For MoE models, ignore the gate layers (routing layers) + if is_moe: + for layer_idx in range(num_layers): + ignore_list.append(f"model.layers.{layer_idx}.mlp.gate") + + # Always ignore lm_head for stability + ignore_list.append("lm_head") + + return ignore_list + + +def get_nvfp4_block_quant_kwargs(num_layers: int, is_moe: bool) -> dict: + """ + Generate complete NVFP4 quantization configuration based on model properties. + Args: + num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers) + is_moe: Whether the model is a Mixture of Experts model + + Returns: + Complete quantization configuration dictionary compatible with ModelOpt + """ + ignore_list = generate_nvfp4_ignore_list(num_layers, is_moe) + + return { + "config_groups": { + "group_0": { + "input_activations": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "weights": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "targets": [ + "Linear" + ] + } + }, + "ignore": ignore_list, + "quant_algo": "NVFP4", + "producer": { + "name": "modelopt", + }, + "quant_method": "modelopt" + } + + + +def _create_param_from_subclass_attributes(custom_data: torch.Tensor, custom_weight) -> Parameter: + """ + Helper to preserve custom attributes from ModelWeightParameter and + PerTensorScaleParameter when creating new Parameters. + """ + param = Parameter(custom_data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_weight_dir = dir(custom_weight) + # Find the attributes that are unique to the custom parameter + custom_attributes = [attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__")] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_weight, attr)) + return param + + +def process_weights_after_loading_modelopt(self, layer: torch.nn.Module) -> None: + import vllm._custom_ops as ops + from torch.nn import Parameter + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + mxfp4_marlin_process_scales, + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale + + def _create_param_from_subclass_attributes(custom_data, custom_weight): + param = Parameter(custom_data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_weight_dir = dir(custom_weight) + # Find the attributes that are unique to the custom parameter + custom_attributes = [ + attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__") + ] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_weight, attr)) + + return param + + def prepare_fp4_layer_for_marlin(layer: torch.nn.Module, weight_scale_2_max: torch.Tensor) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + if getattr(layer, "workspace", None) is None: + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + layer.marlin_weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) + + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.marlin_weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = weight_scale_2_max.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.marlin_weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.marlin_weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + return + + # global scales: + input_scale_2 = layer.input_scale.data + layer.input_scale = _create_param_from_subclass_attributes(input_scale_2, layer.input_scale) + input_scale_2_max = input_scale_2.max().to(torch.float32) + + weight_scale_2 = layer.weight_scale_2.data + layer.weight_scale_2 = _create_param_from_subclass_attributes(weight_scale_2, layer.weight_scale_2) + weight_scale_2_max = weight_scale_2.max().to(torch.float32) + + layer.alpha = Parameter(input_scale_2_max * weight_scale_2_max, requires_grad=False) + + # Calculate `1 / input_scale` so that we don't need to do so at runtime + layer.input_scale_inv = Parameter((1 / layer.input_scale).to(torch.float32), requires_grad=False) + + # Swizzle the weight blockscale. + # contracting dimension is input dimension + # block_size = 16; + assert layer.weight_scale.dtype == torch.float8_e4m3fn, "Weight Block scale must be represented as FP8-E4M3" + + if self.backend == "marlin": + weight = layer.weight.data + weight_scale = layer.weight_scale.data + layer.weight = _create_param_from_subclass_attributes(weight, layer.weight) + layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) + prepare_fp4_layer_for_marlin(layer, weight_scale_2_max) + + del layer.alpha + # del layer.input_scale + elif self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) + + layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) + layer.weight = _create_param_from_subclass_attributes(weight, layer.weight) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = _create_param_from_subclass_attributes(swizzled_weight_scale, layer.weight_scale) + layer.weight = _create_param_from_subclass_attributes(layer.weight.data, layer.weight) + +def apply_modelopt( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import apply_fp4_marlin_linear + from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm + + if self.backend == "marlin": + return apply_fp4_marlin_linear( + input=x, + weight=layer.marlin_weight, + weight_scale=layer.marlin_weight_scale, + weight_scale_2=layer.marlin_weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + + # validate dtypes of quantized input, input block scale, + # weight and weight_blockscale + assert x_fp4.dtype == torch.uint8 + assert layer.weight.dtype == torch.uint8 + assert x_blockscale.dtype == torch.float8_e4m3fn + assert layer.weight_scale.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 + + mm_args = ( + x_fp4, + layer.weight, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + ) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + + if bias is not None: + out = out + bias + return out.view(*output_shape) + + +# ============================================================================= +# ModelOptNvFp4FusedMoE Patches +# ============================================================================= + + +def process_weights_after_loading_moe(self, layer: torch.nn.Module) -> None: + """ + Patched process_weights_after_loading for ModelOptNvFp4FusedMoE. + + Key modifications compared to original: + 1. Preserves original weights in separate attributes (marlin_w13_weight, etc.) + 2. Uses _create_param_from_subclass_attributes to preserve parameter metadata + 3. Computes weight_scale_2_max before processing for Marlin + """ + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + prepare_static_weights_for_trtllm_fp4_moe, + reorder_w1w3_to_w3w1, + ) + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + is_flashinfer_supporting_global_sf, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale + + def prepare_moe_fp4_layer_for_marlin_patched( + layer: torch.nn.Module, + w13_weight_scale_2_per_expert: torch.Tensor, + w2_weight_scale_2_per_expert: torch.Tensor, + ) -> None: + """ + Modified prepare_moe_fp4_layer_for_marlin that: + 1. Takes per-expert weight_scale_2 values (not max!) + 2. Saves to marlin_* attributes instead of overwriting originals + + Args: + w13_weight_scale_2_per_expert: shape (num_experts,) - per-expert scales + w2_weight_scale_2_per_expert: shape (num_experts,) - per-expert scales + """ + logger.warning("Using patched prepare_moe_fp4_layer_for_marlin for NVFP4 MoE") + + group_size = 16 # NVFP4 uses group_size=16 + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + device = layer.w13_weight.device + param_dtype = layer.params_dtype + + # WORKSPACE + if getattr(layer, "workspace", None) is None: + layer.workspace = marlin_make_workspace_new(device, 4) + + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT - Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2), ( + f"Weight shape mismatch for {name}: expected {(e, size_n, size_k // 2)}, got {weight.shape}" + ) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + tensor_list.append(marlin_qweight) + + marlin_weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + marlin_weight = Parameter(marlin_weight, requires_grad=False) + + # Save to marlin_* attribute instead of overwriting original + marlin_attr_name = "marlin_" + name + setattr(layer, marlin_attr_name, marlin_weight) + + # WEIGHT SCALES - Permute scales + for name, weight_scale_2_per_expert in [ + ("w13", w13_weight_scale_2_per_expert), + ("w2", w2_weight_scale_2_per_expert), + ]: + scales = getattr(layer, name + "_weight_scale") + scales = scales.to(param_dtype) + + # Convert per-expert global scale to param_dtype + global_scale = weight_scale_2_per_expert.to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + scale = scales[i].T + + marlin_scales = marlin_permute_scales( + s=scale, + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + marlin_scales_combined = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + marlin_scales_combined = Parameter(marlin_scales_combined, requires_grad=False) + + # Save to marlin_* attribute + setattr(layer, "marlin_" + name + "_weight_scale", marlin_scales_combined) + + # Process per-expert global scale (shape: num_experts) + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = Parameter(global_scale, requires_grad=False) + setattr(layer, "marlin_" + name + "_weight_scale_2", global_scale) + + # ========== Main processing logic ========== + + # GEMM 1 processing + gemm1_weight = layer.w13_weight.data + gemm1_weight_scale = layer.w13_weight_scale.data + + if ( + self.allow_flashinfer + and ( + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ) + and self.moe.is_act_and_mul + ): + gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(gemm1_weight, gemm1_weight_scale, dim=-2) + + layer.w13_weight = _create_param_from_subclass_attributes(gemm1_weight, layer.w13_weight) + layer.w13_weight_scale = _create_param_from_subclass_attributes(gemm1_weight_scale, layer.w13_weight_scale) + + # Common processing for w13_weight_scale_2 + # IMPORTANT: Keep the original shape (num_experts, 2) for subsequent weight loading + # Only compute the max value for Marlin, but don't modify the original parameter shape + if self.moe.is_act_and_mul and not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): + logger.warning("w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.") + + # Keep original data and shape - DO NOT reduce dimension! + w13_weight_scale_2_data = layer.w13_weight_scale_2.data # Keep original shape: (num_experts, 2) + layer.w13_weight_scale_2 = _create_param_from_subclass_attributes(w13_weight_scale_2_data, layer.w13_weight_scale_2) + # Get per-expert scales (shape: num_experts) for Marlin - NOT the max! + # This is what the original code uses after reducing [:, 0] + w13_weight_scale_2_per_expert = layer.w13_weight_scale_2[:, 0].clone() + # Also keep a 1D version for g1_alphas calculation (following original logic) + w13_weight_scale_2_1d = layer.w13_weight_scale_2[:, 0] + + # Common processing for input scales and alphas + # IMPORTANT: Keep original input_scale shapes for subsequent weight loading + use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(self.flashinfer_moe_backend) + + # Keep original w13_input_scale data and shape + w13_input_scale_data = layer.w13_input_scale.data + layer.w13_input_scale = _create_param_from_subclass_attributes(w13_input_scale_data, layer.w13_input_scale) + + # Compute derived values for runtime use + if use_global_sf: + w13_input_scale_for_alpha = layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) + else: + w13_input_scale_for_alpha = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + + layer.g1_alphas = Parameter( + (w13_input_scale_for_alpha * w13_weight_scale_2_1d).to(torch.float32), + requires_grad=False, + ) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter((1 / w13_input_scale_for_alpha).to(torch.float32), requires_grad=False) + + # GEMM 2 processing + # Keep original w2_weight_scale_2 data and shape + w2_weight_scale_2_data = layer.w2_weight_scale_2.data + layer.w2_weight_scale_2 = _create_param_from_subclass_attributes(w2_weight_scale_2_data, layer.w2_weight_scale_2) + # Get per-expert scales (shape: num_experts) for Marlin - NOT the max! + w2_weight_scale_2_per_expert = layer.w2_weight_scale_2.clone() + + # Keep original w2_input_scale data and shape + w2_input_scale_data = layer.w2_input_scale.data + layer.w2_input_scale = _create_param_from_subclass_attributes(w2_input_scale_data, layer.w2_input_scale) + + if use_global_sf: + w2_input_scale_for_alpha = layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) + else: + w2_input_scale_for_alpha = layer.w2_input_scale + layer.g2_alphas = Parameter( + (w2_input_scale_for_alpha * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter((1 / w2_input_scale_for_alpha).to(torch.float32), requires_grad=False) + + # ========== Backend-specific processing ========== + + if self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # TensorRT-LLM specific processing + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) + logger.debug("Finished shuffling weights for TRT-LLM MOE") + + layer.gemm1_weights_fp4_shuffled = Parameter(gemm1_weights_fp4_shuffled, requires_grad=False) + layer.gemm2_weights_fp4_shuffled = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.gemm1_scales_fp4_shuffled = Parameter(gemm1_scales_fp4_shuffled, requires_grad=False) + layer.gemm2_scales_fp4_shuffled = Parameter(gemm2_scales_fp4_shuffled, requires_grad=False) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + + elif self.use_marlin: + # Marlin processing - use patched version + # Pass per-expert scales (shape: num_experts), NOT scalar max values! + prepare_moe_fp4_layer_for_marlin_patched(layer, w13_weight_scale_2_per_expert, w2_weight_scale_2_per_expert) + # Delete attributes not needed for Marlin + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + + else: + # Non-TRT-LLM processing (Cutlass or non-flashinfer) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) + layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, requires_grad=False) + + w13_weight = layer.w13_weight + intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) + if intermediate_size_pad: + # padding gated activations will require to split w1 and w3 + # and pad them individually + assert not self.moe.is_act_and_mul, ( + "The intermediate size required padding, but padding is not implemented for gated activations" + ) + + layer.w13_weight = Parameter( + torch.nn.functional.pad(w13_weight, (0, 0, 0, intermediate_size_pad)), + requires_grad=False, + ) + layer.w2_weight = Parameter( + torch.nn.functional.pad(layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)), + requires_grad=False, + ) + layer.w2_weight_scale = Parameter( + torch.nn.functional.pad(layer.w2_weight_scale, (0, intermediate_size_pad // 16)), + requires_grad=False, + ) + + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, requires_grad=False) + + +def apply_moe( + self, + layer, # FusedMoE + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Patched apply method for ModelOptNvFp4FusedMoE. + + Key modification for Marlin: Uses marlin_* attributes instead of originals. + """ + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + flashinfer_trtllm_fp4_moe, + ) + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + ) + from vllm.scalar_type import scalar_types + + if not self.moe.is_act_and_mul: + assert self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS, ( + "Non-gated activations are only supported by the flashinfer CUTLASS backend for modelopt checkpoints" + ) + + if self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if enable_eplb: + raise NotImplementedError("EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + return flashinfer_trtllm_fp4_moe( + layer=layer, + x=x, + router_logits=router_logits, + top_k=top_k, + global_num_experts=global_num_experts, + num_expert_group=num_expert_group, + topk_group=topk_group, + custom_routing_function=custom_routing_function, + e_score_correction_bias=e_score_correction_bias, + ) + + topk_weights, topk_ids, _ = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + ) + + if self.use_marlin: + # Use marlin_* attributes instead of original attributes + return fused_marlin_moe( + x, + layer.marlin_w13_weight, + layer.marlin_w2_weight, + None, # bias1 + None, # bias2 + layer.marlin_w13_weight_scale, + layer.marlin_w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + global_scale1=layer.marlin_w13_weight_scale_2, + global_scale2=layer.marlin_w2_weight_scale_2, + workspace=layer.workspace, + input_dtype=self.marlin_input_dtype, + ) + + elif self.allow_flashinfer: + assert self.flashinfer_moe_backend in ( + FlashinferMoeBackend.CUTLASS, + FlashinferMoeBackend.CUTEDSL, + ) + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + flashinfer_cutlass_moe_fp4, + ) + + flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4 + else: + from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + flashinfer_cutedsl_moe_fp4, + ) + + flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4 + + assert self.moe_quant_config is not None + return flashinfer_fn_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 + + assert self.moe_quant_config is not None + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + ) + + +def process_weights_after_loading_kv(self, layer) -> None: + """Modified version of BaseKVCacheMethod.process_weights_after_loading. + + Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow + for dynamic updates during refit. + """ + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + from vllm.platforms import current_platform + + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") + + if layer.q_scale < 0.0: + layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = ( + lambda x: isinstance(x, float) or isinstance(x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError("Only support per-tensor scaling factorfor fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + + layer._prob_scale.copy_(prob_scale) + + +def apply_vllm_modelopt_patches(): + func1_path = ( + "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4LinearMethod.process_weights_after_loading" + ) + patcher1 = patch(func1_path, process_weights_after_loading_modelopt) + patcher1.start() + func2_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4LinearMethod.apply" + patcher2 = patch(func2_path, apply_modelopt) + patcher2.start() + # Patch ModelOptNvFp4FusedMoE + func3_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4FusedMoE.process_weights_after_loading" + patcher3 = patch(func3_path, process_weights_after_loading_moe) + patcher3.start() + func4_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4FusedMoE.apply" + patcher4 = patch(func4_path, apply_moe) + patcher4.start() + # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates + func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" + patcher5 = patch(func5_path, process_weights_after_loading_kv) + patcher5.start() \ No newline at end of file diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index e09dfb20a7f..86828322c81 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -102,6 +102,8 @@ class McoreEngineConfig(EngineConfig): override_transformer_config (dict[str, Any]): Override configuration for transformer. use_mbridge (bool): Whether to use MBridge for communication. dtype (str): Mixed precision training param dtype, default "bfloat16" + quantization (Optional[str]): Quantization method to use. None for no quantization, "nvfp4" for NVFP4 quantization. + enable_qat (bool): Whether to enable Quantization-Aware Training (QAT). Default False. """ # sequence_parallel is not listed as a frozen field for auto-correction purpose @@ -124,6 +126,8 @@ class McoreEngineConfig(EngineConfig): use_mbridge: bool = True vanilla_mbridge: bool = True strategy: str = "megatron" + quantization: Optional[str] = None + enable_qat: bool = False def __post_init__(self) -> None: super().__post_init__() diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index aa7613fbc78..c6182f168c1 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -73,6 +73,7 @@ simple_timer, ) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max +from verl.utils.modelopt_qat_utils import apply_qat from verl.utils.ray_utils import get_event_loop from verl.utils.torch_functional import use_original_torch_compile from verl.workers.actor.megatron_actor import MegatronPPOActor @@ -218,6 +219,20 @@ def _init_hf_config_and_tf_config( provider.moe_token_dispatcher_type = "alltoall" provider.moe_router_load_balancing_type = "none" + enable_qat = self.config.actor.megatron.get("enable_qat", False) + if enable_qat: + from megatron.bridge.models.gpt_provider import quantization_layer_spec + provider.transformer_layer_spec = quantization_layer_spec + + # Patch megatron-core MLP to support singleton_local_shards + # in SwiGLU sharded state dict (required for QAT checkpointing) + from verl.models.mcore.qat_patch import apply_qat_patch + apply_qat_patch() + + from megatron.bridge.models.conversion.param_mapping import AutoMapping + AutoMapping.register_module_type('QuantColumnParallelLinear', 'column') + AutoMapping.register_module_type('QuantRowParallelLinear', 'row') + # Apply transformer config overrides for key, value in override_transformer_config.items(): setattr(provider, key, value) @@ -442,6 +457,11 @@ def _build_model_optimizer( if self.rank == 0: print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) + quantization = self.config.actor.megatron.get("quantization", None) + enable_qat = self.config.actor.megatron.get("enable_qat", False) + if quantization is not None and enable_qat: + for i in range(len(actor_module)): + actor_module[i] = apply_qat(actor_module[i], quantization) elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model @@ -717,6 +737,14 @@ async def rollout_mode(self): self.tf_config, self.layer_name_mapping, ) + if self.config.actor.megatron.get("enable_qat", False): + from verl.utils.modelopt_qat_utils import QATWeightPostProcessor + + qat_weight_post_processor = QATWeightPostProcessor( + self.actor.actor_module, "nvfp4" + ) + per_tensor_param = qat_weight_post_processor.process_weights_iterator(per_tensor_param) + if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index cbfedb879fa..8f7d2173f27 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -29,7 +29,7 @@ from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights - +from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -138,6 +138,8 @@ def __new__(cls, **kwargs): # 2. patch online fp8 quant if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1": apply_vllm_fp8_patches() + elif os.environ.get("VERL_VLLM_NVFP4_QUANT_ENABLED", "0") == "1": + apply_vllm_modelopt_patches() # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. @@ -225,6 +227,12 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: logger.info("Loading standard weights (non-FP8, async)") self.model_runner.model.load_weights(weights) + from vllm.model_executor.model_loader.utils import process_weights_after_loading + model_config = self.model_runner.vllm_config.model_config + device = next(self.model_runner.model.parameters()).device + process_weights_after_loading(self.model_runner.model, model_config, device) + # from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4LinearMethod + def _get_zmq_handle(self) -> str: """Get ZMQ handle for communication.""" if not hasattr(self, "device_uuid") or not self.device_uuid: diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 196c72bc378..e32019d562b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Optional import numpy as np +from numpy.random import f import ray import vllm.entrypoints.cli.serve from packaging import version @@ -60,9 +61,10 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser if _VLLM_VERSION == version.parse("0.12.0"): - from vllm.entrypoints.harmony_utils import get_encoding + pass + # from vllm.entrypoints.harmony_utils import get_encoding - get_encoding() + # get_encoding() elif _VLLM_VERSION >= version.parse("0.13.0"): from vllm.entrypoints.openai.parser.harmony_utils import get_encoding @@ -225,7 +227,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non quantization = self.config.quantization if quantization is not None: - _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] + _SUPPORTED_QUANTIZATION = ["fp8", "torchao", "nvfp4"] if quantization not in _SUPPORTED_QUANTIZATION: raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}") @@ -242,6 +244,21 @@ async def launch_server(self, master_address: str = None, master_port: int = Non apply_vllm_fp8_patches() # for subprocesses patching os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" + elif quantization == "nvfp4": + from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches, get_nvfp4_block_quant_kwargs + + num_layers = getattr(self.model_config.hf_config, "num_hidden_layers") + + is_moe = ( + hasattr(self.model_config.hf_config, "num_experts") or + hasattr(self.model_config.hf_config, "num_local_experts") or + hasattr(self.model_config.hf_config, "moe_intermediate_size") + ) + + fp4_block_quant_kwargs = get_nvfp4_block_quant_kwargs(num_layers, is_moe) + + apply_vllm_modelopt_patches() + os.environ["VERL_VLLM_NVFP4_QUANT_ENABLED"] = "1" hf_overrides = {} if quantization is not None and self.config.quantization_config_file is not None: @@ -249,6 +266,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non if quantization == "fp8": hf_overrides["quantization_config"] = fp8_block_quant_kwargs + elif quantization == "nvfp4": + hf_overrides["quantization_config"] = fp4_block_quant_kwargs + quantization = "modelopt" args = { "dtype": self.config.dtype, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index ebbb6e19e48..3880d2756a2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -169,6 +169,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # model parameters are in fp32 full precision weight = weight.to(dtype, non_blocking=True) + # fill the tensor bucket if offset + weight.nbytes > bucket_size: get_torch_device().synchronize()