From d82146d00e5bb93490bc5be6aefcee6187f1345a Mon Sep 17 00:00:00 2001 From: chai-xiaonan <3072824838@qq.com> Date: Fri, 9 Jan 2026 09:52:16 +0800 Subject: [PATCH 1/3] add nemo_bridge --- .../nemo_bridge/.requirements.txt.swp | Bin 0 -> 12288 bytes .../train/megatron/nemo_bridge/__init__.py | 8 + .../megatron/nemo_bridge/models/__init__.py | 99 + .../megatron/nemo_bridge/models/config.py | 340 ++++ .../nemo_bridge/models/conversion/__init__.py | 32 + .../models/conversion/auto_bridge.py | 572 ++++++ .../models/conversion/mapping_registry.py | 266 +++ .../models/conversion/model_bridge.py | 1032 ++++++++++ .../models/conversion/param_mapping.py | 1785 +++++++++++++++++ .../nemo_bridge/models/conversion/utils.py | 287 +++ .../nemo_bridge/models/decorators/__init__.py | 9 + .../nemo_bridge/models/decorators/dispatch.py | 348 ++++ .../nemo_bridge/models/decorators/torchrun.py | 42 + .../nemo_bridge/models/deepseek/__init__.py | 31 + .../nemo_bridge/models/deepseek/common.py | 137 ++ .../models/deepseek/deepseek_provider.py | 309 +++ .../models/deepseek/deepseek_v2_bridge.py | 48 + .../models/deepseek/deepseek_v3_bridge.py | 64 + .../models/gpt_full_te_layer_autocast_spec.py | 347 ++++ .../nemo_bridge/models/gpt_provider.py | 430 ++++ .../models/hf_pretrained/__init__.py | 8 + .../nemo_bridge/models/hf_pretrained/base.py | 237 +++ .../models/hf_pretrained/causal_lm.py | 657 ++++++ .../hf_pretrained/safe_config_loader.py | 136 ++ .../nemo_bridge/models/hf_pretrained/state.py | 850 ++++++++ .../nemo_bridge/models/hf_pretrained/vlm.py | 603 ++++++ .../nemo_bridge/models/model_provider.py | 710 +++++++ .../nemo_bridge/models/qwen/__init__.py | 56 + .../nemo_bridge/models/qwen/qwen2_bridge.py | 110 + .../nemo_bridge/models/qwen/qwen3_bridge.py | 106 + .../models/qwen/qwen3_moe_bridge.py | 113 ++ .../nemo_bridge/models/qwen/qwen_provider.py | 393 ++++ .../nemo_bridge/models/transformer_config.py | 96 + .../megatron/nemo_bridge/utils/__init__.py | 3 + .../nemo_bridge/utils/common_utils.py | 147 ++ .../megatron/nemo_bridge/utils/decorators.py | 26 + .../megatron/nemo_bridge/utils/fusions.py | 175 ++ .../nemo_bridge/utils/import_utils.py | 409 ++++ .../nemo_bridge/utils/instantiate_utils.py | 418 ++++ .../megatron/nemo_bridge/utils/path_utils.py | 10 + .../megatron/nemo_bridge/utils/vocab_utils.py | 64 + .../megatron/nemo_bridge/utils/yaml_utils.py | 203 ++ .../train/megatron/training/arguments.py | 20 + .../train/megatron/training/checkpointing.py | 29 + .../train/megatron/training/yaml_arguments.py | 4 +- 45 files changed, 11768 insertions(+), 1 deletion(-) create mode 100644 flagscale/train/megatron/nemo_bridge/.requirements.txt.swp create mode 100644 flagscale/train/megatron/nemo_bridge/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/config.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/utils.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/common.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/gpt_provider.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/model_provider.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py create mode 100755 flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py create mode 100644 flagscale/train/megatron/nemo_bridge/models/transformer_config.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/__init__.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/common_utils.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/decorators.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/fusions.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/import_utils.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/path_utils.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py create mode 100644 flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py diff --git a/flagscale/train/megatron/nemo_bridge/.requirements.txt.swp b/flagscale/train/megatron/nemo_bridge/.requirements.txt.swp new file mode 100644 index 0000000000000000000000000000000000000000..ba527123a01890cc58a79cd9f2a4743f4f9a8587 GIT binary patch literal 12288 zcmeI%u}%Up9LMqEZZz@&sB`IoA>j#39Gr~q?WmdWSMfoVjn4 zw*P;+^!M4C&a$Vw`@wWL7R@+sm!I+R-WSpPwC@k?L59UvzHX(_8&?%_nPizuJoA%rKnw{X3?3#s%Y2GF4~o? zQawzYf8FWOeK|v*n*yieW<1K;?AKS9gZ{;1HyKEc00IagfB*srAb>ze1ybmVwD!0tg_0 Q00IagfB*srAb T: + """Load a pretrained model configuration from a directory or file.""" + ... + + def save_hf_pretrained( + self, + save_directory: Union[str, Path], + config_format: ConfigFormat | None = None, + config_name: Optional[str] = None, + **kwargs, + ) -> None: + """Save the model configuration to a directory.""" + ... + + +def from_hf_pretrained( + cls: Type[T], + pretrained_model_name_or_path: Union[str, Path], + trust_remote_code: bool = False, + mode: InstantiationMode = InstantiationMode.LENIENT, + config_name: str = "config", + **kwargs, +) -> T: + """ + Load a pretrained model configuration from a directory or file. + + Args: + cls: The class to instantiate + pretrained_model_name_or_path: Path to a directory containing a config file, + or direct path to a config file (yaml/json/toml) + trust_remote_code: Whether to trust and execute code references (classes/functions) + found in the configuration. Required to be True if the config + contains any class or function references. Default: False + mode: Instantiation mode (STRICT or LENIENT) for the instantiate function + config_name: Base name of the config file (without extension) + **kwargs: Additional keyword arguments to override loaded configuration + + Returns: + Instance of the class with loaded configuration + + Example: + ```python + # Load from directory (looks for config.yaml, config.json, or config.toml) + model = from_hf_pretrained(MyModel, "./saved_model/") + + # Load from specific file + model = from_hf_pretrained(MyModel, "./saved_model/config.yaml") + + # With code references + model = from_pretrained(MyModel, "./saved_model/", trust_remote_code=True) + + # Override configuration values + model = from_pretrained(MyModel, "./saved_model/", temperature=0.8) + ``` + """ + path = Path(pretrained_model_name_or_path) + + # Determine the config file path + if path.is_dir(): + # Look for config files in order of preference + config_file = None + for ext in [".yaml", ".yml", ".json", ".toml"]: + candidate = path / f"{config_name}{ext}" + if candidate.exists(): + config_file = candidate + break + + if config_file is None: + raise FileNotFoundError( + f"No configuration file found in {path}. " + f"Expected {config_name}.yaml, {config_name}.json, or {config_name}.toml" + ) + else: + config_file = path + + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found at {config_file}") + + # Load the configuration based on file extension + file_ext = config_file.suffix.lower() + + if file_ext in [".yaml", ".yml"]: + with open(config_file, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + elif file_ext == ".json": + with open(config_file, "r", encoding="utf-8") as f: + config_dict = json.load(f) + elif file_ext == ".toml": + if not HAS_TOML: + raise ImportError( + "TOML support requires the 'toml' package. Install it with: pip install toml" + ) + with open(config_file, "r", encoding="utf-8") as f: + config_dict = toml.load(f) + else: + raise ValueError( + f"Unsupported file format: {file_ext}. Supported formats: .yaml, .yml, .json, .toml" + ) + + # Check for trust_remote_code requirement + if not trust_remote_code and _contains_code_references(config_dict): + raise ValueError( + "This configuration contains class or function references. " + "Loading it requires trust_remote_code=True to prevent arbitrary code execution." + ) + + # Convert to OmegaConf for compatibility with instantiate + omega_conf = OmegaConf.create(config_dict) + + # Merge with kwargs + if kwargs: + override_conf = OmegaConf.create(kwargs) + omega_conf = OmegaConf.merge(omega_conf, override_conf) + + # Add _target_ if not present + if "_target_" not in omega_conf: + omega_conf["_target_"] = f"{cls.__module__}.{cls.__qualname__}" + + # Convert back to container for instantiate + final_config = OmegaConf.to_container(omega_conf, resolve=True) + + # Use instantiate to create the object + return instantiate(final_config, mode=mode) + + +def save_hf_pretrained( + obj: Any, + save_directory: Union[str, Path], + config_format: ConfigFormat = "json", + config_name: str = "config", + **kwargs, +) -> None: + """ + Save the model configuration to a directory. + + Args: + obj: The object to save + save_directory: Directory where to save the configuration + config_format: Format to save in ("yaml", "json", or "toml"). Default: "json" + config_name: Name for the config file (without extension) + **kwargs: Additional metadata to save alongside the configuration + + Example: + ```python + # Save as JSON (default) + save_hf_pretrained(model, "./saved_model/") + + # Save as YAML + save_hf_pretrained(model, "./saved_model/", config_format="yaml") + + # Save with custom name + save_hf_pretrained(model, "./saved_model/", config_name="my_config") + ``` + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Determine file extension + format_to_ext = {"yaml": ".yaml", "yml": ".yaml", "json": ".json", "toml": ".toml"} + + config_format = config_format.lower() + if config_format not in format_to_ext: + raise ValueError( + f"Unsupported format: {config_format}. Supported formats: {list(format_to_ext.keys())}" + ) + + if config_format == "toml" and not HAS_TOML: + raise ImportError( + "TOML support requires the 'toml' package. Install it with: pip install toml" + ) + + config_file = save_path / f"{config_name}{format_to_ext[config_format]}" + + # Get the configuration dictionary + config_dict = _to_dict(obj) + + # Add any additional metadata + if kwargs: + config_dict.update(kwargs) + + # Save based on format + if config_format in ["yaml", "yml"]: + with safe_yaml_representers(): + with open(config_file, "w", encoding="utf-8") as f: + yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) + elif config_format == "json": + # First convert to YAML string to use the custom representers + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(config_dict, default_flow_style=False) + # Then parse and save as JSON + yaml_dict = yaml.safe_load(yaml_str) + with open(config_file, "w", encoding="utf-8") as f: + json.dump(yaml_dict, f, indent=2, ensure_ascii=False) + elif config_format == "toml": + # First convert to YAML string to use the custom representers + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(config_dict, default_flow_style=False) + # Then parse and save as TOML + yaml_dict = yaml.safe_load(yaml_str) + with open(config_file, "w", encoding="utf-8") as f: + toml.dump(yaml_dict, f) + + print(f"Configuration saved to {config_file}") + + +def _to_dict(obj: Any) -> Dict[str, Any]: + """ + Convert an object to a dictionary representation. + + Args: + obj: The object to convert + + Returns: + Dictionary representation of the object + """ + # Check if this is a ConfigContainer (has to_dict method) + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return obj.to_dict() + + # Otherwise, build dict from dataclass fields or attributes + result = {} + result["_target_"] = f"{obj.__class__.__module__}.{obj.__class__.__qualname__}" + + if is_dataclass(obj): + # Handle dataclass + for field in dataclass_fields(obj): + if field.name.startswith("_"): + continue + value = getattr(obj, field.name) + result[field.name] = _convert_value_to_dict(value) + else: + # Handle regular class + for key, value in obj.__dict__.items(): + if not key.startswith("_"): + result[key] = _convert_value_to_dict(value) + + return result + + +def _convert_value_to_dict(value: Any) -> Any: + """ + Recursively convert a value to a dictionary representation. + + Args: + value: The value to convert + + Returns: + The converted value + """ + if hasattr(value, "_to_dict"): + return value._to_dict() + elif hasattr(value, "to_dict") and callable(value.to_dict): + return value.to_dict() + elif is_dataclass(value) and not isinstance(value, type): + # Handle regular dataclasses + result = {"_target_": f"{value.__class__.__module__}.{value.__class__.__qualname__}"} + for field in dataclass_fields(value): + if not field.name.startswith("_"): + result[field.name] = _convert_value_to_dict(getattr(value, field.name)) + return result + elif isinstance(value, (list, tuple)): + return [_convert_value_to_dict(item) for item in value] + elif isinstance(value, dict): + return {k: _convert_value_to_dict(v) for k, v in value.items()} + else: + return value + + +def _contains_code_references(config_dict: Dict[str, Any]) -> bool: + """ + Check if a configuration dictionary contains code references. + + Args: + config_dict: The configuration dictionary to check + + Returns: + True if code references are found, False otherwise + """ + if isinstance(config_dict, dict): + for key, value in config_dict.items(): + # Check for _target_ that's not a built-in type + if key == "_target_" and isinstance(value, str): + # Consider it a code reference if it's not a basic type + if not value.startswith( + ("builtins.", "str", "int", "float", "bool", "list", "dict", "tuple") + ): + return True + # Check for _call_ = False which indicates a code reference + if key == "_call_" and value is False: + return True + # Recursively check nested structures + if _contains_code_references(value): + return True + elif isinstance(config_dict, (list, tuple)): + for item in config_dict: + if _contains_code_references(item): + return True + + return False diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py b/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py new file mode 100644 index 0000000000..e7a20e1f97 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +# Import model providers for easy access +from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import ( + AutoMapping, + ColumnParallelMapping, + GatedMLPMapping, + MegatronParamMapping, + QKVMapping, + ReplicatedMapping, + RowParallelMapping, +) +from megatron.nemo_bridge.models.conversion.utils import weights_verification_table + +__all__ = [ + "AutoBridge", + "MegatronMappingRegistry", + "MegatronModelBridge", + "ColumnParallelMapping", + "GatedMLPMapping", + "MegatronParamMapping", + "QKVMapping", + "ReplicatedMapping", + "RowParallelMapping", + "AutoMapping", + "weights_verification_table", +] diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py b/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py new file mode 100644 index 0000000000..88a9a7f9b9 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py @@ -0,0 +1,572 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import dataclasses + +from functools import cached_property, partial +from pathlib import Path +from typing import Any, Generic, Iterable, List, Optional, Type, TypeVar, Union + +import torch.distributed as dist +import transformers + +from transformers import AutoModelForCausalLM +from transformers.configuration_utils import PretrainedConfig +from typing_extensions import Unpack + +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig + +from megatron.nemo_bridge.models.conversion import model_bridge +from megatron.nemo_bridge.models.conversion.model_bridge import ( + HFWeightTuple, + MegatronModelBridge, + WeightConversionTask, +) +from megatron.nemo_bridge.models.conversion.utils import get_causal_lm_class_via_auto_map + +# from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import ( + safe_load_config_with_retry, +) +from megatron.nemo_bridge.models.hf_pretrained.state import SafeTensorsStateSource + +# from megatron.nemo_bridge.models.model_provider import GetModelKwargs, ModelParallelKwargs, ModelProviderMixin + + +MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule) +DataclassT = TypeVar("DataclassT") + + +class AutoBridge(Generic[MegatronModelT]): + """ + Automatically select and instantiate the appropriate bridge for a model. + + This unified bridge class combines automatic model detection with full bridge + functionality for converting models between HuggingFace and Megatron formats. + It handles the conversion of causal language models (e.g., GPT, Llama, Phi) + between HuggingFace's transformers library format and Megatron-Core's distributed + training format. It manages weight mapping, tensor parallelism distribution, and + configuration translation. + + The bridge supports both directions of conversion: + - HuggingFace → Megatron: For training or inference with Megatron + - Megatron → HuggingFace: For saving trained models in HF format + + Args: + hf_pretrained: Either a PreTrainedCausalLM instance with loaded model, + or a PretrainedConfig for configuration-only operations + + Example: + >>> # Load and convert a model to Megatron format + >>> bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B") + >>> provider = bridge.to_megatron_provider() + >>> megatron_model = provider.provide_distributed_model(wrap_with_ddp=False) + + >>> # Export a Megatron model back to HuggingFace format + >>> bridge.save_hf_pretrained(megatron_model, "./exported_model") + + >>> # Convert weights with custom settings + >>> for name, weight in bridge.export_hf_weights( + ... megatron_model, + ... cpu=True + ... ): + ... print(f"Exported {name}: {weight.shape}") + + >>> # Check if a model is supported before loading + >>> if AutoBridge.can_handle("microsoft/phi-2"): + ... bridge = AutoBridge.from_hf_pretrained("microsoft/phi-2") + + Note: + The bridge automatically detects the model architecture and applies + the appropriate weight mappings. Custom architectures require implementing + a MegatronModelBridge subclass. + """ + + def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): + if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): + raise ValueError( + "hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance" + ) + self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained + + @classmethod + def list_supported_models(cls) -> list[str]: + """ + List all model architectures currently supported by the bridge system. + + Returns: + List of supported HuggingFace model architecture names + """ + # Get all registered implementations from the dispatch system + supported = [] + + # Access the dispatch registry to find all registered types + + if hasattr(model_bridge.get_model_bridge, "_exact_types"): + for arch_type in model_bridge.get_model_bridge._exact_types.keys(): + # Support both type and string registrations + if isinstance(arch_type, str): + supported.append(arch_type) + elif hasattr(arch_type, "__name__"): + supported.append(arch_type.__name__) + + return sorted(supported) + + @classmethod + def supports(cls, config: Any) -> bool: + """ + Check if this bridge supports the given model configuration. + + A model is supported if it has at least one architecture ending with 'ForCausalLM' or 'ForConditionalGeneration' + or 'NemotronH_Nano_VL_V2'. + + Args: + config: HuggingFace model config object + + Returns: + True if this bridge can handle the model, False otherwise + """ + architectures = getattr(config, "architectures", []) + if not architectures: + return False + return any( + arch.endswith(("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2")) + for arch in architectures + ) + + @classmethod + def from_hf_config(cls, config: PretrainedConfig) -> "AutoBridge": + """ + Create an AutoBridge from a HuggingFace configuration. + + This method creates a bridge instance from just a model configuration, + without loading any weights. This is useful for: + - Creating Megatron models with random initialization + - Working with model architectures without downloading weights + - Testing and development scenarios + + Args: + config: HuggingFace PretrainedConfig instance containing model + architecture information + + Returns: + AutoBridge: Bridge instance configured for the architecture + + Raises: + ValueError: If the configuration is not for a supported CausalLM model + + Example: + >>> from transformers import AutoConfig + >>> + >>> # Load just the configuration + >>> config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B") + >>> + >>> # Create bridge from config (no weights) + >>> bridge = AutoBridge.from_hf_config(config) + >>> + >>> # Create Megatron model with random initialization + >>> provider = bridge.to_megatron_provider(load_weights=False) + >>> model = provider.provide_distributed_model(wrap_with_ddp=False) + + >>> # Or use for architecture exploration + >>> transformer_config = bridge.transformer_config + >>> print(f"Hidden size: {transformer_config.hidden_size}") + >>> print(f"Num layers: {transformer_config.num_layers}") + + See Also: + from_hf_pretrained: Create bridge with loaded weights + transformer_config: Access the Megatron TransformerConfig + """ + cls._validate_config(config) + model = PreTrainedCausalLM() + model.config = config + import torch + + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + + with init_empty_weights(): + hf_model = AutoModelForCausalLM.from_config(model.config) + + for name, param in hf_model.named_parameters(): + set_module_tensor_to_device( + hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype) + ) + model.model = hf_model + return cls(model) + + @classmethod + def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge": + """ + Load an AutoBridge from a pretrained model, automatically detecting the model type. + + This method loads a model from HuggingFace Hub or a local directory and + creates a bridge instance ready for conversion operations. The model + architecture is validated to ensure compatibility. + + Args: + path: HuggingFace model ID or path to model directory + Examples: "meta-llama/Meta-Llama-3-8B", "./my_model" + **kwargs: Additional arguments passed to HuggingFace from_hf_pretrained + Common options include: + - torch_dtype: Model precision (torch.float16, torch.bfloat16) + - device_map: Device placement strategy ("auto", "cuda:0", etc.) + - trust_remote_code: Allow custom model code execution + - attn_implementation: Attention implementation ("flash_attention_2", etc.) + + Returns: + AutoBridge: Bridge instance with loaded model + + Raises: + ValueError: If the model architecture is not supported + + Example: + >>> # Basic loading + >>> bridge = AutoBridge.from_hf_pretrained("gpt2") + + >>> # Load with specific settings + >>> bridge = AutoBridge.from_hf_pretrained( + ... "meta-llama/Meta-Llama-3-8B", + ... torch_dtype=torch.float16, + ... device_map="auto" + ... ) + + >>> # Works with local paths too + >>> bridge = AutoBridge.from_hf_pretrained("/path/to/model") + """ + # First load just the config to check architecture support + # Use thread-safe config loading to prevent race conditions + config = safe_load_config_with_retry( + path, trust_remote_code=kwargs.get("trust_remote_code", False) + ) + + cls._validate_config(config, str(path)) + + try: + return cls(PreTrainedCausalLM.from_pretrained(path, **kwargs)) + except Exception as e: + raise ValueError(f"Failed to load model with AutoBridge: {e}") from e + + def load_hf_weights( + self, model: list[MegatronModelT], hf_path: str | Path | None = None + ) -> None: + """ + Load HuggingFace weights into a Megatron model. + + This method handles the conversion and distribution of weights from + HuggingFace format to Megatron's distributed format, including proper + tensor parallel and pipeline parallel distribution. + + Args: + model: List of Megatron model instances (one per virtual pipeline stage) + hf_path: Optional path to load weights from. If None, uses weights + from the bridge's hf_pretrained instance + + Returns: + The input model with loaded weights + + Raises: + ValueError: If hf_path is None and bridge was created without weights + + Example: + >>> # Load weights from bridge's pretrained model + >>> bridge = AutoBridge.from_hf_pretrained("gpt2") + >>> megatron_model = create_megatron_model() # Your model creation + >>> bridge.load_hf_weights(megatron_model) + + >>> # Load weights from a different checkpoint + >>> bridge.load_hf_weights(megatron_model, "./finetuned_model") + """ + if hf_path is None: + if not isinstance(self.hf_pretrained, PreTrainedCausalLM): + raise ValueError( + "hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance" + ) + pre_trained = self.hf_pretrained + else: + pre_trained = PreTrainedCausalLM.from_pretrained(hf_path) + # Preserve trust_remote_code setting from the original bridge instance + trust_remote_code = getattr(self.hf_pretrained, 'trust_remote_code', False) + pre_trained = PreTrainedCausalLM.from_pretrained( + hf_path, trust_remote_code=trust_remote_code + ) + # self._model_bridge.load_weights_hf_to_megatron(model, pre_trained) + self._model_bridge.load_weights_hf_to_megatron(pre_trained, model) + + return model + + def save_hf_pretrained( + self, + model: list[MegatronModelT], + path: str | Path, + show_progress: bool = True, + strict: bool = True, + ) -> None: + """ + Save a Megatron model in HuggingFace format. + + This method exports the complete model including configuration, tokenizer, + and weights to a directory that can be loaded with HuggingFace's + from_pretrained methods. + + If the original model was loaded with trust_remote_code=True, any custom + modeling files (e.g., modeling_*.py, configuration_*.py) will be preserved + to ensure the saved model can be loaded properly. + + Args: + model: Megatron model instance or list of instances + path: Directory path to save the model + show_progress: Display progress bar during weight export + + Example: + >>> # Save model after training + >>> bridge.save_hf_pretrained(megatron_model, "./my_finetuned_model") + + >>> # Load the saved model with HuggingFace + >>> from transformers import AutoModelForCausalLM + >>> hf_model = AutoModelForCausalLM.from_pretrained("./my_finetuned_model") + + Note: + This method is collective - all ranks must call it. Only rank 0 + saves the configuration files, while weight saving is coordinated + across all ranks. + """ + if dist.is_available() and dist.is_initialized(): + # Distributed training, only rank 0 saves artifacts + if dist.get_rank() == 0: + self.hf_pretrained.save_artifacts(path) + else: + # No distributed training, save artifacts + self.hf_pretrained.save_artifacts(path) + self.save_hf_weights(model, path, show_progress, strict) + + def save_hf_weights( + self, + model: list[MegatronModelT], + path: str | Path, + show_progress: bool = True, + strict: bool = True, + ) -> None: + """ + Save Megatron model weights in HuggingFace safetensors format. + + This method exports only the model weights (not configuration or tokenizer) + to safetensors files compatible with HuggingFace. It uses streaming save + to handle large models efficiently without requiring all weights in memory + at once. + + The weights are gathered from distributed ranks and saved in the standard + HuggingFace sharded format when the model is large. + + Args: + model: Megatron model instance or list of instances + path: Directory path where weight files will be saved + show_progress: Display progress bar during export + + Raises: + ValueError: If the state source doesn't support streaming save + + Example: + >>> # Save just the weights + >>> bridge.save_hf_weights(megatron_model, "./model_weights") + + >>> # Save without progress bar (useful in scripts) + >>> bridge.save_hf_weights(megatron_model, "./weights", show_progress=False) + + Note: + - This method is collective and must be called by all ranks + - Uses safetensors format for efficient loading and security + - Automatically handles model sharding for large models + - The saved weights can be loaded with HuggingFace's from_pretrained + """ + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) + generator = model_bridge.stream_weights_megatron_to_hf( + dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress + ) + source = SafeTensorsStateSource(path) + # Check if the state source is SafeTensorsStateSource for streaming save. + if ( + hasattr(self.hf_pretrained, "state") + and hasattr(self.hf_pretrained.state, "source") + # and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource) + ): + # self.hf_pretrained.state.source.save_generator(generator, path, strict=strict) + source.save_generator(generator, path, strict=strict) + else: + raise ValueError( + "The state source is not a SafeTensorsStateSource, cannot save in streaming mode." + ) + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + @property + def _model_bridge(self) -> "MegatronModelBridge": + return model_bridge.get_model_bridge(self._causal_lm_architecture) + + @cached_property + def _causal_lm_architecture(self): + """Resolve the model's CausalLM architecture for dispatch. + + Behavior: + - If the model can be imported from transformers directly, return the actual transformers class object. + - Otherwise, if the model uses HuggingFace auto_map, return the architecture's class name as a string (e.g., + "DeepseekV2ForCausalLM"). + + Returns: + str | type: The transformers class for the CausalLM architecture or the architecture's class name as a + string for auto_map models. + + Raises: + ValueError: If no CausalLM architecture is found or cannot be resolved. + """ + if isinstance(self.hf_pretrained, PreTrainedCausalLM): + config = self.hf_pretrained.config + model_name_or_path = getattr(config, "_name_or_path", None) or getattr( + self.hf_pretrained, "model_name_or_path", None + ) + else: + config = self.hf_pretrained + model_name_or_path = getattr(config, "_name_or_path", None) + + architectures = getattr(config, "architectures", []) + + if not architectures: + raise ValueError( + "\n✗ No architectures found in model config\n\n" + "The model configuration does not specify any architectures.\n" + "This is required for determining the model type." + ) + + causal_lm_arch = None + for architecture_name in architectures: + # TODO: Can we improve this? + if architecture_name.endswith( + ("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2") + ): + causal_lm_arch = architecture_name + break + + if not causal_lm_arch: + raise ValueError( + f"\n✗ No CausalLM architecture found\n\n" + f"Model architectures: {architectures}\n\n" + f"None of the architectures end with 'ForCausalLM' or 'ForConditionalGeneration' or" + f"'NemotronH_Nano_VL_V2'.\n" + f"This bridge only supports causal language models.\n" + f"For other model types, use a different bridge class." + ) + + # Try auto_map first + cls = get_causal_lm_class_via_auto_map(model_name_or_path=model_name_or_path, config=config) + if cls is not None: + # For auto_map models, return the class name as a string + return getattr(cls, "__name__", str(cls)) + + try: + return getattr(transformers, causal_lm_arch) + except AttributeError: + raise ValueError( + f"\n✗ Architecture class '{causal_lm_arch}' not found in transformers\n\n" + f"This could mean:\n" + f"1. The model requires a newer version of transformers\n" + f"2. The model uses a custom modeling file not in the standard library\n" + f"3. There's a typo in the architecture name\n\n" + f"Please verify your transformers installation and the model requirements." + ) + + @classmethod + def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> None: + # Check if this is a causal LM model + if not cls.supports(config): + architectures = getattr(config, "architectures", []) + raise ValueError( + f"\n✗ Model architecture not supported by AutoBridge\n\n" + f"Model: {path}\n" + f"Architectures: {architectures}\n\n" + f"AutoBridge only supports models with architectures ending in 'ForCausalLM' or" + f"'ForConditionalGeneration' or 'NemotronH_Nano_VL_V2'.\n" + f"Found architectures that don't match this pattern.\n\n" + f"If this is a different model type (e.g., Vision, Sequence-to-Sequence),\n" + f"you may need to use a different bridge class." + ) + + # Check if we have an implementation for this specific architecture + architecture = None + for arch_name in config.architectures: + if arch_name.endswith( + ("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2") + ): + architecture = arch_name + break + + if architecture: + # Try auto_map first + arch_class = ( + get_causal_lm_class_via_auto_map(model_name_or_path=path, config=config) + if path + else None + ) + if arch_class is not None: + # For auto_map models, use class-name string + arch_key = getattr(arch_class, "__name__", str(arch_class)) + else: + try: + arch_class = getattr(transformers, architecture) + arch_key = arch_class + except AttributeError: + # Fall back to name-based registration + arch_key = architecture + + # Test if we have a registered implementation (type or class-name string) + has_implementation = False + if hasattr(model_bridge.get_model_bridge, "_exact_types"): + registry = model_bridge.get_model_bridge._exact_types + if isinstance(arch_key, str): + has_implementation = arch_key in registry + else: + has_implementation = (arch_key in registry) or ( + getattr(arch_key, "__name__", None) in registry + ) + + if not has_implementation: + # Get list of supported models + supported_models = cls.list_supported_models() + + raise ValueError( + f"\n✗ Model architecture '{architecture}' is not yet supported\n\n" + f"Model: {path}\n" + f"Architecture: {architecture}\n\n" + f"Currently supported architectures:\n" + + "\n".join(f" • {model}" for model in supported_models) + + f"\n\nTo add support for {architecture}, you need to:\n" + f"1. Create a new bridge class that inherits from MegatronModelBridge\n" + f"2. Implement the required methods (provider_bridge, mapping_registry)\n" + f"3. Register it with @MegatronModelBridge.register_bridge decorator\n\n" + f"Example implementation:\n" + f" from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge\n" + f" from transformers import {architecture}\n" + f" from megatron.core.models.gpt import GPTModel\n\n" + f" @MegatronModelBridge.register_bridge(source={architecture}, target=GPTModel)\n" + f" class Megatron{architecture.replace('ForCausalLM', '')}Bridge(MegatronModelBridge):\n" + f" def provider_bridge(self, hf_pretrained):\n" + f" # Return a ModelProvider instance\n" + f" ...\n\n" + f" def mapping_registry(self):\n" + f" # Return a MegatronMappingRegistry with weight mappings\n" + f" ...\n\n" + f"For reference implementations, see:\n" + f" • src/megatron/bridge/models/llama/llama_bridge.py\n" + f" • src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py" + ) from None + + def _get_model_instance(self, model: list[MegatronModelT]) -> MegatronModelT: + model_instance = model[0] + while hasattr(model_instance, "module"): + model_instance = model_instance.module + return model_instance diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py b/flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py new file mode 100644 index 0000000000..58e154eb3c --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py @@ -0,0 +1,266 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import re + +from typing import List, Optional + +from megatron.nemo_bridge.models.conversion.param_mapping import MegatronParamMapping + + +class MegatronMappingRegistry: + """ + Registry for weight mappings between model formats with pattern matching support. + + This class serves as a registry of weight mappings between Megatron and external + (typically HuggingFace) model formats. It provides efficient pattern matching + for parameter names using glob-like wildcards (*) and supports both forward + (Megatron → HF) and reverse (HF → Megatron) lookups. + + The registry pre-compiles regex patterns for efficient repeated lookups and + handles the resolution of wildcards in parameter names. + + Args: + *mappings: Variable number of MegatronParamMapping objects defining + the individual weight mappings + + Example: + >>> # Create a mapping registry with various mappings + >>> mapping_registry = MegatronMappingRegistry( + ... AutoMapping( + ... megatron_param="embedding.word_embeddings.weight", + ... hf_param="model.embed_tokens.weight", + ... ), + ... QKVMapping( + ... megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + ... q="model.layers.*.self_attn.q_proj.weight", + ... k="model.layers.*.self_attn.k_proj.weight", + ... v="model.layers.*.self_attn.v_proj.weight", + ... ), + ... ) + + >>> # Query for a specific layer (wildcards are resolved) + >>> mapping = mapping_registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.weight") + >>> print(mapping.hf_param) # Will show resolved HF names for layer 0 + + >>> # Reverse lookup from HF name + >>> mapping = mapping_registry.hf_to_megatron_lookup("model.layers.5.self_attn.q_proj.weight") + >>> print(mapping.megatron_param) # Shows corresponding Megatron name + + >>> # Build from a list + >>> mappings = [bridge1, bridge2, bridge3] + >>> mapping_registry = MegatronMappingRegistry(*mappings) + + Note: + Wildcard patterns support: + - '*' matches any sequence of digits (0-9) - designed for layer indices + - '**' matches any sequence of characters - designed for nested paths + """ + + def _convert_pattern_to_regex(self, pattern: str) -> str: + """Convert a pattern with wildcards to regex pattern. + + Args: + pattern: Pattern string with * and ** wildcards + + Returns: + Regex pattern string + + Note: + ** must be processed before * to avoid conflicts. + ** becomes (.*) - matches any characters including dots + * becomes (\\d+) - matches digits only for layer indices + """ + # Escape the pattern first + regex_pattern = re.escape(pattern) + + # Process ** before * to avoid conflicts + # Replace \*\* with (.*) + regex_pattern = regex_pattern.replace(r"\*\*", r"(.*)") + + # Replace remaining \* with (\d+) + regex_pattern = regex_pattern.replace(r"\*", r"(\d+)") + + return regex_pattern + + def __init__(self, *mappings: MegatronParamMapping): + """ + Initialize MegatronMappingRegistry with weight mappings. + + Args: + *mappings: MegatronParamMapping objects + """ + self.mappings = list(mappings) + + # Pre-compile patterns for efficiency + self._compiled_patterns = [] + self._reverse_patterns = [] # For hf_param -> megatron lookups + + for mapping in mappings: + # Compile source patterns + if "*" in mapping.megatron_param: + # Convert glob pattern to regex with support for * and ** + pattern = self._convert_pattern_to_regex(mapping.megatron_param) + self._compiled_patterns.append((re.compile(f"^{pattern}$"), mapping)) + else: + self._compiled_patterns.append((None, mapping)) + + # Compile destination patterns for reverse lookups + if isinstance(mapping.hf_param, str): + if "*" in mapping.hf_param: + pattern = self._convert_pattern_to_regex(mapping.hf_param) + self._reverse_patterns.append((re.compile(f"^{pattern}$"), mapping)) + else: + self._reverse_patterns.append((None, mapping)) + else: + # For dict destinations, compile patterns for each value + reverse_dict_patterns = {} + for key, hf_pattern in mapping.hf_param.items(): + if "*" in hf_pattern: + pattern = self._convert_pattern_to_regex(hf_pattern) + reverse_dict_patterns[key] = re.compile(f"^{pattern}$") + else: + reverse_dict_patterns[key] = None + self._reverse_patterns.append((reverse_dict_patterns, mapping)) + + def megatron_to_hf_lookup(self, megatron_param_name: str) -> Optional[MegatronParamMapping]: + """ + Get mapping for a Megatron parameter name. + + This method performs efficient lookups by first checking for exact matches, + then falling back to pattern matching using pre-compiled regex patterns. + When a pattern match is found, wildcards are automatically resolved. + + Args: + megatron_param_name: Megatron parameter name to look up + Example: "decoder.layers.0.self_attention.linear_qkv.weight" + + Returns: + MegatronParamMapping: Bridge instance with resolved wildcards, or None + if no matching mapping is found. The returned bridge will have + all wildcards replaced with actual values. + + Example: + >>> # Query with exact layer number + >>> bridge = state_map.megatron_to_hf_lookup("decoder.layers.5.mlp.linear_fc1.weight") + >>> if bridge: + ... print(f"Maps to: {bridge.hf_param}") # Shows HF name for layer 5 + """ + for pattern, mapping in self._compiled_patterns: + if pattern is None: + # Direct match + if mapping.megatron_param == megatron_param_name: + return mapping + else: + # Pattern match + match = pattern.match(megatron_param_name) + if match: + # Return resolved mapping with wildcards replaced + return mapping.resolve(match.groups()) + return None + + def hf_to_megatron_lookup(self, hf_param_name: str) -> Optional[MegatronParamMapping]: + """ + Get mapping for a destination parameter name (reverse lookup). + + This is useful when you have a destination name and want to find + the corresponding megatron name. + + Args: + hf_param_name: Destination parameter name to look up + + Returns: + MegatronParamMapping with resolved wildcards, or None if no match found + """ + for pattern_info, mapping in self._reverse_patterns: + if isinstance(mapping.hf_param, str): + # Simple string destination + pattern = pattern_info + if pattern is None: + # Direct match + if mapping.hf_param == hf_param_name: + return mapping + else: + # Pattern match + match = pattern.match(hf_param_name) + if match: + return mapping.resolve(match.groups()) + else: + # Dict destination - check each pattern + patterns_dict = pattern_info + for key, pattern in patterns_dict.items(): + if pattern is None: + # Direct match + if mapping.hf_param[key] == hf_param_name: + # Create a simplified mapping for this specific key + return mapping.resolve(()) + else: + # Pattern match + match = pattern.match(hf_param_name) + if match: + return mapping.resolve(match.groups()) + return None + + def get_all_mappings(self) -> List[MegatronParamMapping]: + """Get all mappings in this MegatronMappingRegistry.""" + return self.mappings.copy() + + def get_mappings_by_pattern(self, pattern: str) -> List[MegatronParamMapping]: + """ + Get all mappings that match a given pattern. + + Args: + pattern: Pattern to match (supports * and ** wildcards) + + Returns: + List of matching MegatronParamMapping objects + """ + # Convert pattern to regex using the same logic as _convert_pattern_to_regex + # but for this method we want both * and ** to match anything for search purposes + regex_pattern = re.escape(pattern) + regex_pattern = regex_pattern.replace(r"\*\*", r".*") + regex_pattern = regex_pattern.replace(r"\*", r".*") + compiled_pattern = re.compile(f"^{regex_pattern}$") + + matches = [] + for mapping in self.mappings: + if compiled_pattern.match(mapping.megatron_param): + matches.append(mapping) + + return matches + + def __len__(self) -> int: + """Return number of mappings.""" + return len(self.mappings) + + def __iter__(self): + """Iterate over mappings.""" + return iter(self.mappings) + + def __repr__(self) -> str: + """String representation of MegatronMappingRegistry.""" + return f"MegatronMappingRegistry({len(self.mappings)} mappings)" + + def describe(self) -> str: + """ + Get a human-readable description of all mappings. + + Returns: + Formatted string describing all weight mappings + """ + lines = [f"MegatronMappingRegistry with {len(self.mappings)} mappings:"] + for i, mapping in enumerate(self.mappings): + lines.append(f"\n{i + 1}. {mapping.megatron_param}") + if isinstance(mapping.hf_param, str): + lines.append(f" → {mapping.hf_param}") + else: + lines.append(" → {") + for key, value in mapping.hf_param.items(): + lines.append(f" {key}: {value}") + lines.append(" }") + + # Show bridge type + lines.append(f" bridge: {type(mapping).__name__}") + + return "\n".join(lines) diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py b/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py new file mode 100644 index 0000000000..6858337ef4 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py @@ -0,0 +1,1032 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import abc +import itertools +import logging +import re + +from dataclasses import dataclass +from typing import ( + Callable, + Generic, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Type, + TypeVar, + Union, +) + +import torch + +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn +from transformers.modeling_utils import PreTrainedModel + +from megatron.core import parallel_state +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import get_pg_size, unwrap_model + +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.param_mapping import MegatronParamMapping +from megatron.nemo_bridge.models.conversion.utils import ( + extract_sort_key, + get_module_and_param_from_name, + persistent_buffers, +) +from megatron.nemo_bridge.models.decorators.dispatch import dispatch +from megatron.nemo_bridge.utils.common_utils import print_rank_0 + +logger = logging.getLogger(__name__) + +MappingT = TypeVar("MappingT", bound=MegatronParamMapping) +HFPreTrained = TypeVar("HFPreTrained") +MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) +_BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") + + +def padding_embedd_size(mcore_weight: torch.Tensor, hf_vocab_size: int): + hf_size = hf_vocab_size + mcore_size = mcore_weight.shape[0] + full_word = {} + is_rank0 = torch.distributed.get_rank() == 0 + # Cut out extra padding we don't need + if mcore_size > hf_size: + full_word = mcore_weight[0:hf_size, :] + if is_rank0: + print(f"> padding embedding size mcore {mcore_size} to hf {hf_size}") + + # Expanding embedding to larger size by replicating final entry + elif mcore_size < hf_size: + padding_size = hf_size - mcore_size + + full_word = torch.cat( + (mcore_weight, mcore_weight[-1].unsqueeze(0).expand(padding_size, -1)) + ) + if is_rank0: + print(f"> padding embedding size mcore {mcore_size} to hf {hf_size}") + # Same size! + else: + full_word = mcore_weight + return full_word + + +class MegatronWeightTuple(NamedTuple): + """Tuple representing a Megatron model weight with its metadata.""" + + param_name: str + weight: torch.Tensor + vp_stage: int + + +class HFWeightTuple(NamedTuple): + """Tuple representing a HuggingFace model weight with its metadata.""" + + param_name: str + weight: torch.Tensor + + +@dataclass(frozen=True) +class WeightConversionTask(Generic[MappingT]): + """A unified task for converting weights between HuggingFace and Megatron formats. + + This class combines both HF->Megatron and Megatron->HF conversion tasks since they + have different method names (hf_to_megatron vs megatron_to_hf) and can coexist safely. + + The task encapsulates all information needed for weight conversion in either direction, + with different fields being relevant depending on the conversion type. + + Attributes: + param_name (str): *unwrapped, local* parameter name (no ``module.`` prefixes). + mapping (MappingT): Concrete :pyclass:`MegatronParamMapping` instance responsible + for weight transformation and distribution. + + pp_rank (Optional[int]): Pipeline-parallel rank that owns the parameter (required for saves). + vp_stage (Optional[int]): Virtual-pipeline stage index (required for loads). + megatron_module (Optional[torch.nn.Module]): Reference to the Megatron model or + sub-module that owns the parameter (required for loads). + param_weight (Optional[torch.Tensor]): The actual parameter tensor that will + receive the converted weight (required for loads). + + """ + + param_name: str + mapping: MappingT + pp_rank: Optional[int] = None + vp_stage: Optional[int] = None + megatron_module: Optional[torch.nn.Module] = None + param_weight: Optional[torch.Tensor] = None + + +def _megatron_local_name_to_global( + models: MegatronModule | List[MegatronModule], + config: TransformerConfig, + param_name: str, + vp_stage: Optional[int] = None, +) -> str: + """Adjust layer number and expert number from local to global numbering.""" + # PP + pp_group = parallel_state.get_pipeline_model_parallel_group() + if "layers." in param_name and get_pg_size(pp_group) > 1: + match = re.match(r"^(.+?\.layers\.\d+)", param_name) + assert match is not None + layer_prefix = match.group(1) + _, layer_module = get_module_and_param_from_name( + models=models, param_name=layer_prefix, vp_stage=vp_stage + ) + + local_layer_number = int(param_name.split("layers.")[1].split(".")[0]) + global_layer_number = layer_module.layer_number - 1 + param_name = param_name.replace( + f"layers.{local_layer_number}.", f"layers.{global_layer_number}." + ) + + # EP + ep_group = parallel_state.get_expert_model_parallel_group() + if ".mlp.experts.linear_fc" in param_name and get_pg_size(ep_group) > 1: + num_experts = config.num_moe_experts + num_experts_per_rank = num_experts // ep_group.size() + + def _update_expert_number(param_name: str, param_type: str) -> str: + """Update expert number from local to global for weight or bias parameters.""" + local_expert_number = int(param_name.split(f".{param_type}")[-1]) + global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number + return param_name.replace( + f".{param_type}{local_expert_number}", f".{param_type}{global_expert_number}" + ) + + # Handle weight and bias parameters + if ".weight" in param_name: + param_name = _update_expert_number(param_name, "weight") + elif ".bias" in param_name: + param_name = _update_expert_number(param_name, "bias") + return param_name + + +# class MegatronModelBridge(Generic[HFPreTrained, ModelProviderTarget, MegatronModel]): +class MegatronModelBridge(Generic[HFPreTrained, MegatronModel]): + """ + High-level orchestrator for HuggingFace ↔ Megatron model conversions. + + This abstract base class provides the framework for converting models between + HuggingFace and Megatron formats. It acts as an orchestrator that coordinates + the conversion process without directly handling the complex details of + tensor parallelism or weight transformations. + + The bridge pattern separates concerns: + - MegatronModelBridge: Orchestrates the overall conversion process + - MegatronMappingRegistry: Manages parameter name mappings + - MegatronParamMapping: Handles actual weight transformations and distribution + + Key responsibilities: + 1. Build conversion tasks that map each parameter to its appropriate bridge + 2. Execute tasks with proper error handling and progress tracking + 3. Provide utilities for configuration translation + 4. Handle virtual pipeline parallelism (VP) complexities + + To implement a bridge for a new model architecture: + + 1. Create a subclass decorated with @MegatronModelBridge.register_bridge: + + .. code-block:: python + + @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) + class MegatronCausalLlamaBridge(MegatronModelBridge): + pass + + 2. Implement provider_bridge to create Megatron configurations: + + .. code-block:: python + + def provider_bridge(self, hf_pretrained) -> LlamaModelProvider: + return LlamaModelProvider( + num_layers=hf_pretrained.config.num_hidden_layers, + hidden_size=hf_pretrained.config.hidden_size, + ... + ) + + 3. Implement mapping_registry to define weight mappings: + + .. code-block:: python + + def mapping_registry(self) -> MegatronMappingRegistry: + return MegatronMappingRegistry( + AutoMapping( + megatron_param="embedding.word_embeddings.weight", + hf_param="model.embed_tokens.weight" + ), + ... + ) + + Example: + .. code-block:: python + + # The bridge is typically not instantiated directly + # Instead, use AutoBridge or AutoBridge which handle this + bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B") + provider = bridge.to_megatron_provider() + + Note: + This class uses generic type parameters to ensure type safety: + - HFPreTrained: The HuggingFace model type + - ModelProviderTarget: The Megatron model provider type + - MegatronModel: The Megatron model type + """ + + @abc.abstractmethod + def mapping_registry(self) -> MegatronMappingRegistry: + """Define weight mappings between HuggingFace and Megatron formats. + + This abstract method must be implemented by subclasses to specify how + parameters map between the two formats. The returned MegatronMappingRegistry + contains all param mappings needed for the model architecture. + + Returns: + MegatronMappingRegistry: MegatronMappingRegistry containing all weight + mapping definitions. + + Example: + .. code-block:: python + + def mapping_registry(self): + return MegatronMappingRegistry( + AutoMapping( + megatron_param="embedding.word_embeddings.weight", + hf_param="model.embed_tokens.weight" + ), + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight" + ), + # ... more param mappings + ) + """ + raise NotImplementedError("Subclass must implement mapping_registry method") + + def _megatron_global_param_names_all_pp_ranks( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> List[str]: + """Get all parameter names across all pipeline parallel ranks.""" + # Cache the result after first call + if hasattr(self, "_cached_param_names"): + return self._cached_param_names + + # Compute the result + pp_group = parallel_state.get_pipeline_model_parallel_group() + model_config = unwrap_model(megatron_model)[0].config + global_param_names = [] + + # Ensure megatron_model is a list for consistent handling + models_list = megatron_model if isinstance(megatron_model, list) else [megatron_model] + + for vp_stage, model in enumerate(models_list): + # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately + for local_param_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_param_name: + continue + local_param_name = self._unwrap_name(local_param_name) + global_param_name = _megatron_local_name_to_global( + models_list, model_config, local_param_name, vp_stage + ) + global_param_names.append(global_param_name) + + gathered_global_param_names = [None] * pp_group.size() + torch.distributed.all_gather_object( + gathered_global_param_names, global_param_names, group=pp_group + ) + + # flatten the list, sort it and remove duplicates + # the order matters here, casually re-order will cause a hang. + # e.g. decoder.layers.0.mlp.experts.linear_fc1.weight100 + flattened_names = list(set(sum(gathered_global_param_names, []))) + + # the order cannot be changed, this sync for all ranks for conversion + # change this might cause a hang + gathered_global_param_names = sorted(flattened_names, key=extract_sort_key) + + # Cache the result + self._cached_param_names = gathered_global_param_names + + return self._cached_param_names + + def _with_progress_tracking(self, tasks, description: str, show_progress: bool = True): + """Helper method to wrap an iterable with progress tracking. + + Args: + tasks: Iterable of tasks to process + description: Description for the progress bar + show_progress: Whether to show progress (defaults to True) + + Yields: + Items from the tasks iterable while updating progress + """ + is_main_rank = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + bridge_name = self.__class__.__name__ + + if show_progress: + with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + TextColumn("({task.completed}/{task.total})"), + TextColumn("{task.fields[bridge]}"), + disable=not is_main_rank, + ) as progress: + task_id = progress.add_task(description, total=len(tasks), bridge=bridge_name) + + for task in tasks: + yield task + progress.update(task_id, advance=1) + else: + # not using disable above because we notice it will dump some empty progress bar, + # even when disable is set to True + for task in tasks: + yield task + + def load_weights_hf_to_megatron( + self, hf_pretrained: HFPreTrained, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> List[MegatronModel]: + """Load HuggingFace weights into Megatron models. + + This method orchestrates the complete weight loading process from HuggingFace + format to Megatron's distributed format. It builds a conversion task and + executes it with proper progress tracking and error handling. + + The actual weight transformations and distribution are delegated to the + appropriate MegatronParamMapping instances based on the state mappings. + + Args: + hf_pretrained (HFPreTrained): HuggingFace model or state source containing the + weights to load. + megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance + or list of model instances (one per virtual pipeline stage). + + Returns: + List[MegatronModel]: The input megatron_model as a list with loaded weights. + + Process: + 1. Build a task mapping each Megatron parameter to its source + 2. For each parameter in the task: + - Fetch source weights from HuggingFace state + - Apply format transformation via the param mapping + - Distribute to appropriate TP/PP ranks + - Copy into the Megatron parameter + + Example: + .. code-block:: python + + hf_model = PreTrainedCausalLM.from_pretrained("gpt2") + megatron_model = create_megatron_model() # Single model or list + bridge.load_weights_hf_to_megatron(hf_model, megatron_model) + + Note: + Progress is shown only on rank 0 to avoid cluttered output in + distributed environments. + + Raises: + ValueError: If hf_pretrained doesn't have state attribute or if weight shapes don't match. + AttributeError: If required HF weights are missing. + """ + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + + hf_to_megatron_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + hf_state_dict: Mapping[str, torch.Tensor] = ( + hf_pretrained.state if hasattr(hf_pretrained, "state") else {} + ) + + description = f"Loading from {hf_pretrained.model_name_or_path}" + for task in self._with_progress_tracking(hf_to_megatron_tasks, description): + # None means megatron module not on current rank, skip if this task is not going to happen + if task.megatron_module is None: + continue + # 1) Fetch source tensor(s) from HF state dict + if isinstance(task.mapping.hf_param, str): + hf_weights = hf_state_dict[task.mapping.hf_param] + else: + hf_weights = {k: hf_state_dict[v] for k, v in task.mapping.hf_param.items()} + + # 2) Delegate conversion & distribution to the bridge + converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) + + # 3) Copy into Megatron param if this rank received a shard + if converted_weights is not None: + # Assert that param_weight is not None for HF->Megatron tasks + assert ( + task.param_weight is not None + ), "param_weight is required for HF->Megatron conversion" + + # Check shape compatibility before copying + if converted_weights.shape != task.param_weight.shape: + raise ValueError( + f"Shape mismatch for megatron param {task.mapping.megatron_param}:\n" + f" Expected shape: {task.param_weight.shape}\n" + f" Got shape: {converted_weights.shape}\n" + f" Bridge type: {type(task.mapping).__name__}\n" + f" HF mapping: {task.mapping.hf_param}" + ) + task.param_weight.data.copy_(converted_weights) + + self._broadcast_shared_embeddings(megatron_model) + return megatron_model + + def stream_weights_hf_to_megatron( + self, + hf_pretrained: HFPreTrained, + megatron_model: Union[MegatronModel, List[MegatronModel]], + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[MegatronWeightTuple]: + """Generator variant of load_weights_hf_to_megatron for streaming weight conversion. + + This method provides a memory-efficient way to convert weights by yielding + them one at a time instead of loading all at once. Useful for processing + very large models or when implementing custom weight handling logic. + + Args: + hf_pretrained (HFPreTrained): HuggingFace model or state source containing + the weights. + megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance + or list of model instances to extract configuration from. + conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks. + If not provided, tasks will be built automatically from the models. + + Yields: + MegatronWeightTuple: Named tuples containing: + - vp_stage: Index of the model in megatron_model list + - param_name: Name of the parameter + - weight: Transformed weight tensor for this rank + + Example: + .. code-block:: python + + # Process weights one by one + for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model): + print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") + # Custom processing logic here + + # Or use pre-built conversion tasks + tasks = bridge.build_conversion_tasks(hf_model, megatron_model) + for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model, tasks): + print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") + + Note: + Only yields weights that belong to the current rank after TP/PP distribution. + + Raises: + ValueError: If input parameters are invalid. + """ + + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + + # Use provided conversion tasks or build them + if conversion_tasks is None: + conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + + for task in conversion_tasks: + # None means megatron module not on current rank, skip if this task is not going to happen + if task.megatron_module is None: + continue + hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state + if isinstance(task.mapping.hf_param, str): + hf_weights = hf_state_dict[task.mapping.hf_param] + else: + hf_weights = {k: hf_state_dict[v] for k, v in task.mapping.hf_param.items()} + + converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) + if converted_weights is not None: + # Assert that vp_stage is not None for HF->Megatron tasks + yield MegatronWeightTuple(task.param_name, converted_weights, task.vp_stage) + + def stream_weights_megatron_to_hf( + self, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[HFWeightTuple]: + """Export Megatron weights to HuggingFace format. + + This method orchestrates the conversion of weights from Megatron's distributed + format back to HuggingFace format. It handles gathering from tensor parallel + ranks, broadcasting across pipeline parallel ranks, and format conversions. + All ranks receive the full tensors. + + The export order is determined automatically: + - First tries safetensors order (if key_to_filename_map is available) + - Falls back to HuggingFace state dict order + + Args: + megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance + or list of model instances (one per virtual pipeline stage). + hf_pretrained (HFPreTrained): HuggingFace model/config for metadata + and mapping info. + cpu (bool, optional): Whether to move tensors to CPU before yielding. + Defaults to True. + show_progress (bool, optional): Display progress bar during export. + Defaults to True. + conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks. + If not provided, tasks will be built automatically from the models. + + Yields: + HFWeightTuple: Named tuples of (param_name, weight_tensor) in HF format. + + Example: + .. code-block:: python + + # Export weights + for name, weight in bridge.stream_weights_megatron_to_hf(megatron_model, hf_config): + print(f"Exported {name}: {weight.shape}") + + # Or use pre-built conversion tasks + tasks = bridge.build_conversion_tasks(hf_config, megatron_model) + for name, weight in bridge.stream_weights_megatron_to_hf( + megatron_model, hf_config, conversion_tasks=tasks + ): + print(f"Exported {name}: {weight.shape}") + + Raises: + ValueError: If input parameters are invalid. + + Note: + All ranks yield the full tensors after gathering from distributed format. + """ + + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + # Use provided conversion tasks or build them + if conversion_tasks is None: + conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + + megatron_to_hf_tasks = conversion_tasks + model_config = unwrap_model(megatron_model)[0].config + # embeddings_are_tied = model_config.share_embeddings_and_output_weights + embeddings_are_tied = not model_config.untie_embeddings_and_output_weights + for task in self._with_progress_tracking( + megatron_to_hf_tasks, "Converting to HuggingFace", show_progress + ): + converted_weights_dict = task.mapping.megatron_to_hf( + task.param_weight, task.megatron_module + ) + + # All ranks get the full tensor + for hf_name, tensor in converted_weights_dict.items(): + final_tensor = tensor.cpu() + + if hf_name == "model.embed_tokens.weight" or hf_name == "lm_head.weight": + final_tensor = padding_embedd_size( + final_tensor, hf_pretrained.config.vocab_size + ) + + # Handle tied embeddings case + # TODO(yuya): fix this hard coded naming + if embeddings_are_tied and hf_name == "model.embed_tokens.weight": + # Yield the embedding weight + yield HFWeightTuple(hf_name, final_tensor) + + # Also yield as lm_head.weight if it's expected + if hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source"): + expected_keys = hf_pretrained.state.source.get_all_keys() + if "lm_head.weight" in expected_keys: + final_tensor = final_tensor.detach().clone() + yield HFWeightTuple("lm_head.weight", final_tensor) + elif embeddings_are_tied and hf_name == "lm_head.weight": + # This should not happen when embeddings are tied - assert error + raise ValueError( + "Encountered lm_head.weight when embeddings are tied. This indicates a mapping error." + ) + else: + # Regular case - yield the tensor normally + yield HFWeightTuple(hf_name, final_tensor) + + def dtype_from_hf(self, config, default=None): + """Extract torch dtype from a HuggingFace config. + + This utility method handles the conversion of dtype specifications in + HuggingFace configs to PyTorch dtype objects. Supports both direct + torch.dtype objects and string representations. + + Args: + config: HuggingFace configuration object with a torch_dtype attribute. + default (Any, optional): Default value to return if torch_dtype is + not str or torch.dtype. Defaults to None. + + Returns: + torch.dtype: The corresponding PyTorch dtype. + + Raises: + AssertionError: If config doesn't have torch_dtype attribute. + ValueError: If torch_dtype is neither a string nor torch.dtype. + + Example: + .. code-block:: python + + dtype = bridge.dtype_from_hf(hf_config) + print(dtype) # torch.float16 + """ + assert hasattr(config, "torch_dtype"), "Expected config to have attr `torch_dtype`" + torch_dtype = config.torch_dtype + if isinstance(torch_dtype, torch.dtype): + return torch_dtype + elif isinstance(torch_dtype, str): + return self.dtype_from_str(torch_dtype) + elif default is not None: + return default + + raise ValueError("torch_dtype is not of type str/torch.dtype") + + def dtype_from_str(self, dtype: str) -> torch.dtype: + """Convert a string precision identifier to equivalent torch dtype. + + This utility method handles various string representations of PyTorch + data types, including common abbreviations and mixed precision formats. + + Args: + dtype (str): String representation of dtype (e.g., "float16", "fp16", + "bf16-mixed"). + + Returns: + torch.dtype: Corresponding PyTorch dtype (defaults to float32 if unknown). + + Supported formats: + - float16/fp16/16/16-mixed → torch.float16 + - bfloat16/bf16-mixed → torch.bfloat16 + - Others → torch.float32 (default) + + Example: + .. code-block:: python + + dtype = bridge.dtype_from_str("fp16") + print(dtype) # torch.float16 + + dtype = bridge.dtype_from_str("bf16-mixed") + print(dtype) # torch.bfloat16 + """ + assert isinstance(dtype, str) + if dtype in ["float16", "fp16", "16", "16-mixed"]: + return torch.float16 + elif dtype in ["bfloat16", "bf16-mixed"]: + return torch.bfloat16 + else: + return torch.float32 + + def make_vocab_size_divisible_by(self, vocab_size: int) -> int: + """Calculate an appropriate divisor for vocabulary size padding. + + Megatron requires vocabulary sizes to be divisible by certain values for + efficient tensor parallelism. This method finds the largest power of 2 + (up to 128) that evenly divides the vocabulary size. + + Args: + vocab_size (int): Original vocabulary size from the model. + + Returns: + int: Largest power of 2 (≤ 128) that divides vocab_size. + + Example: + .. code-block:: python + + # For vocab_size=50257 (GPT-2) + divisor = bridge.make_vocab_size_divisible_by(50257) + print(divisor) # 1 (50257 is prime) + + # For vocab_size=32000 (Llama) + divisor = bridge.make_vocab_size_divisible_by(32000) + print(divisor) # 128 + + Note: + The returned value is used by Megatron to potentially pad the + vocabulary to ensure efficient parallelization. + """ + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + # def _get_provider_from_model(self, model: MegatronModule) -> ModelProviderTarget: + # """Extract provider/config from model.""" + # model = unwrap_model(model) + # return model.config + + def _unwrap_name(self, name: str) -> str: + """Unwrap name from DDP or other wrappers. + + Args: + name: Parameter name that may have 'module.' prefixes + + Returns: + Unwrapped parameter name with 'module.' prefixes removed + + Example: + 'module.module.decoder.weight' -> 'decoder.weight' + """ + if not isinstance(name, str): + raise ValueError(f"name must be a string, got {type(name)}") + + while name.startswith("module."): + name = name[len("module.") :] + return name + + def _broadcast_shared_embeddings( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> None: + """Broadcast shared embeddings and output weights across embedding group. + + When embeddings and output weights are shared and pipeline parallelism is enabled, + this method ensures all ranks in the embedding group have the same weights by + broadcasting from rank 0. + + Args: + megatron_model: Megatron model instance or list of model instances. + """ + unwrapped_model = unwrap_model(megatron_model)[0] + # hack for vlm to work properly + if ( + hasattr(unwrapped_model, "language_model") + and unwrapped_model.language_model is not None + ): + unwrapped_model = unwrapped_model.language_model + model_config = unwrapped_model.config + if ( + not model_config.untie_embeddings_and_output_weights + and model_config.pipeline_model_parallel_size > 1 + ): + # Broadcast embeddings and output weights from rank 0 to embedding group + embd_group = parallel_state.get_embedding_group() + embd_group_ranks = torch.distributed.get_process_group_ranks(embd_group) + if embd_group is not None and torch.distributed.get_rank() in embd_group_ranks: + # Get embeddings and output weights from rank 0 + if hasattr(unwrapped_model, "embedding") and hasattr( + unwrapped_model.embedding, "word_embeddings" + ): + embd_weights = unwrapped_model.embedding.word_embeddings.weight.data + else: + assert hasattr(unwrapped_model, "output_layer"), "Output layer not found" + embd_weights = torch.empty_like(unwrapped_model.output_layer.weight.data) + torch.distributed.broadcast(embd_weights, src=embd_group_ranks[0], group=embd_group) + if hasattr(unwrapped_model, "output_layer"): + unwrapped_model.output_layer.weight.data.copy_(embd_weights) + + def build_conversion_tasks( + self, hf_pretrained: HFPreTrained, megatron_model: List[MegatronModel] + ) -> List[None | WeightConversionTask]: + """Construct the conversion tasks between HF and megatron. + + The algorithm walks over every parameter of every destination model, + asks the :class:`MegatronMappingRegistry` whether it has a mapping for that + parameter, and – if the corresponding HF weights actually exist – yields + an :class:`_HFLoadTask` describing exactly how that parameter will be + populated. + """ + + # Ensure hf_pretrained has the required state structure + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() + mapping_registry = self.mapping_registry() + model_config = unwrap_model(megatron_model)[0].config + # embeddings_are_tied = model_config.share_embeddings_and_output_weights + embeddings_are_tied = not model_config.untie_embeddings_and_output_weights + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks( + megatron_model + ) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name + for name in sorted_global_param_names_all_pp_ranks + if "output_layer" not in name + ] + + global_names_index_dict = { + name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks) + } + + tasks = [None] * len(sorted_global_param_names_all_pp_ranks) + for vp_stage, model in enumerate(megatron_model): + # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately + for local_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_name: + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global( + megatron_model, model_config, local_name, vp_stage + ) + # if name removed due to some reason, continue. e.g. embeddings_are_tied + if global_name not in global_names_index_dict: + print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") + continue + global_name_idx = global_names_index_dict[global_name] + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + # ensure hf weights exist + if isinstance(mapping.hf_param, str): + if mapping.hf_param not in hf_keys: + prefix = '.'.join(mapping.hf_param.split('.')[:-2]) + if not (('q_proj.weight' in mapping.hf_param) and ( + f'{prefix}.q_a_layernorm.weight' in hf_keys + and f'{prefix}.q_a_proj.weight' in hf_keys + and f'{prefix}.q_b_proj.weight' in hf_keys + )): + logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") + continue + else: + missing_params = [ + hf_param + for hf_param in mapping.hf_param.values() + if hf_param not in hf_keys + ] + if missing_params: + logger.warning( + f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}" + ) + continue + + local_module, local_weights = get_module_and_param_from_name( + megatron_model, local_name, vp_stage + ) + tasks[global_name_idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + + # Fill the remaining ones for pp communications + for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if tasks[idx] is None: + # This is an exception here we pass in global name + # we are not using global_name to extract module and weights + # only use it for param mapping auto dispatch checks + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + return tasks + + @classmethod + def register_bridge( + cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] + ) -> Callable[[_BridgeImplClass], _BridgeImplClass]: + """Class decorator for registering bridge implementations. + + This decorator registers a MegatronModelBridge subclass with the dispatch + system, enabling automatic routing of conversions based on the source + HuggingFace model type and target Megatron model type. + + Args: + source (Type[PreTrainedModel] | str): HuggingFace PreTrainedModel class + (e.g., LlamaForCausalLM) or the class name as a string. Using a + string allows registering bridges for architectures that are only + available via auto_map. + target (Type[MegatronModel]): Megatron model class (e.g., GPTModel). + + Returns: + Callable[[_BridgeImplClass], _BridgeImplClass]: Decorator function + that registers the bridge implementation. + + Example: + .. code-block:: python + + @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) + class MegatronCausalLlamaBridge(MegatronModelBridge): + def provider_bridge(self, hf_pretrained): + # Implementation + pass + + def mapping_registry(self): + # Implementation + pass + + String-based registration is also supported: + + .. code-block:: python + + @MegatronModelBridge.register_bridge(source="DeepseekV3ForCausalLM", target=GPTModel) + class MegatronDeepseekV3Bridge(MegatronModelBridge): + ... + + Note: + The decorated class is registered with multiple dispatchers to handle + different conversion scenarios. The registration is automatic when the + class is defined. + """ + + return create_bridge_decorator(source=source, target=target) + + +def is_tensor_parallel(param) -> bool: + """Check if a parameter is tensor parallel distributed.""" + return hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel + + +# Core dispatch functions +@dispatch +def get_model_bridge(hf_architecture) -> "MegatronModelBridge": + """Get the appropriate model bridge for a given HuggingFace architecture.""" + ... + + +@dispatch +def stream_weights_megatron_to_hf( + dispatch_instance: MegatronModel, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, +) -> Iterable[HFWeightTuple]: + """Bridge Megatron model state to HuggingFace format.""" + ... + + +def register_bridge_implementation( + *, + source: Type["PreTrainedModel"] | str, + target: Type["MegatronModule"], + bridge_class: Type["MegatronModelBridge"], +) -> None: + """Register a bridge implementation with the dispatch system. + + Args: + source: HuggingFace PreTrainedModel class or the class name as a string. + Using a string allows registering bridges for architectures that are + available only via auto_map. + target: Megatron model class (e.g., GPTModel) + bridge_class: MegatronModelBridge implementation class + """ + bridge_class_name = bridge_class.__name__ + + @get_model_bridge.impl(source) + def _get_model_bridge_impl(_) -> "MegatronModelBridge": + bridge = bridge_class() + return bridge + + @stream_weights_megatron_to_hf.impl((source, target)) + def _megatron_to_hf_registered_impl( + _, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[HFWeightTuple]: + bridge = bridge_class() + return bridge.stream_weights_megatron_to_hf( + megatron_model, + hf_pretrained, + cpu=cpu, + show_progress=show_progress, + conversion_tasks=conversion_tasks, + ) + + # Set meaningful names for debugging + _get_model_bridge_impl.__name__ = f"_bridge_with_{bridge_class_name}" + _megatron_to_hf_registered_impl.__name__ = f"_megatron_to_hf_with_{bridge_class_name}" + + +def create_bridge_decorator( + *, source: Type["PreTrainedModel"] | str, target: Type["MegatronModule"] +) -> Callable[[Type["MegatronModelBridge"]], Type["MegatronModelBridge"]]: + """Create a decorator for registering bridge implementations. + + Args: + source: HuggingFace PreTrainedModel class or the class name as a string + (useful for auto_map architectures) + target: Megatron model class + + Returns: + Decorator function that registers the bridge implementation + """ + + def decorator(bridge_class: Type["MegatronModelBridge"]) -> Type["MegatronModelBridge"]: + register_bridge_implementation(source=source, target=target, bridge_class=bridge_class) + return bridge_class + + return decorator diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py b/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py new file mode 100644 index 0000000000..b5bb8f2a2b --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py @@ -0,0 +1,1785 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import json +import re + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union + +import torch +import torch.distributed +import torch.nn as nn + +from megatron.core import mpu +from megatron.core.fp8_utils import FP8_TENSOR_CLASS, HAVE_TE_FP8_TENSOR_CLASS +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import get_pg_rank, get_pg_size + +from megatron.nemo_bridge.models.conversion.utils import ( + get_module_and_param_from_name, + remove_non_pickleables, +) + +WeightType = TypeVar("WeightType", torch.Tensor, Dict[str, torch.Tensor]) + +import logging + +logger = logging.getLogger(__name__) + + +def col_padding_size(hf_weight: torch.Tensor, mcore_weight: torch.Tensor, tp_size: int): + hf_size = hf_weight.shape[0] + mcore_size = mcore_weight.shape[0] * tp_size + full_word = {} + is_rank0 = torch.distributed.get_rank() == 0 + # Cut out extra padding we don't need + if hf_size > mcore_size: + full_word = hf_weight[0:mcore_size, :] + if is_rank0: + print(f"> padding TP-ColumnParallelfrom {hf_size} to {mcore_size}") + + # Expanding embedding to larger size by replicating final entry + elif hf_size < mcore_size: + padding_size = mcore_size - hf_size + + full_word = torch.cat((hf_weight, hf_weight[-1].unsqueeze(0).expand(padding_size, -1))) + if is_rank0: + print(f"> padding TP-ColumnParallelfrom {hf_size} to {mcore_size}") + # Same size! + else: + full_word = hf_weight + return full_word + + +class MegatronParamMapping(ABC, Generic[WeightType]): + """ + Abstract base class for weight conversion between Megatron and external formats. + + This class provides the foundation for all weight mappings, handling the complex + conversions between Megatron-Core's distributed tensor formats and standard + (typically HuggingFace) formats. Each concrete mapping implements specific + transformation logic while inheriting common parallel communication patterns. + + Key responsibilities: + - Format transformation (e.g., QKV merging/splitting, gated MLP handling) + - Tensor parallel (TP) distribution and gathering across GPUs + - Pipeline parallel (PP) broadcasting between pipeline stages + - Wildcard pattern resolution for layer-wise mappings + + The mapping abstraction ensures that higher-level code doesn't need to know + about the parallel topology or format differences - it just requests a + conversion and the mapping handles all the complexity. + + Public helper methods for subclasses: + - broadcast_from_pp_rank: Broadcast tensors across pipeline stages + - broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks + - broadcast_tensor_to_tp_ranks: Broadcast within TP group + - scatter_to_tp_ranks: Distribute tensor shards to TP ranks + - gather_from_tp_ranks: Collect tensor shards from TP ranks + + Example: + .. code-block:: python + + class MyCustomMapping(MegatronParamMapping[torch.Tensor]): + def hf_to_megatron(self, hf_weights, megatron_module): + # Custom transformation logic + transformed = hf_weights.t() # Example: transpose + # Use helpers for distribution + return self.scatter_to_tp_ranks(...) + + def megatron_to_hf(self, megatron_weights, megatron_module): + # Broadcast from owning PP rank + weight = self.broadcast_from_pp_rank(megatron_weights) + # Gather from TP ranks and transform + gathered = self.gather_from_tp_ranks(weight) + return {"custom_weight": gathered[0].t()} + """ + + def __init__(self, megatron_param: str, hf_param: Union[str, Dict[str, str]]): + """Initialize the weight mapping. + + Args: + megatron_param (str): Megatron parameter name pattern (supports * + wildcards). + hf_param (Union[str, Dict[str, str]]): External format name pattern(s). + """ + self.megatron_param = megatron_param + self.hf_param = hf_param + self._validate_patterns() + + # Cache for metadata and tensor_spec_output + self._broadcast_obj_cache = {} + self._tensor_spec_output_cache = {} + + if mpu.is_initialized(): + self.pp_group = mpu.get_pipeline_model_parallel_group() + self.ep_group = mpu.get_expert_model_parallel_group() + self._tp_group = mpu.get_tensor_model_parallel_group() + self._etp_group = mpu.get_expert_tensor_parallel_group() + else: + self.pp_group = None + self.ep_group = None + self._tp_group = None + self._etp_group = None + + @property + def tp_group(self): + """Get the tensor model parallel group.""" + if self.is_expert: + return self._etp_group + return self._tp_group + + @property + def tp_rank(self) -> int: + """Get the tensor model parallel rank.""" + return get_pg_rank(self.tp_group) + + @property + def tp_size(self) -> int: + """Get the tensor model parallel size.""" + return get_pg_size(self.tp_group) + + @property + def pp_rank(self) -> int: + """Get the pipeline model parallel rank.""" + return get_pg_rank(self.pp_group) + + @property + def pp_size(self) -> int: + """Get the pipeline model parallel size.""" + return get_pg_size(self.pp_group) + + @property + def ep_rank(self) -> int: + """Get the expert model parallel rank.""" + return get_pg_rank(self.ep_group) + + @property + def ep_size(self) -> int: + """Get the expert model parallel size.""" + return get_pg_size(self.ep_group) + + @property + def etp_rank(self) -> int: + """Get the expert tensor parallel rank.""" + return get_pg_rank(self.etp_group) + + @property + def etp_size(self) -> int: + """Get the expert tensor parallel size.""" + return get_pg_size(self.etp_group) + + @property + def is_expert(self) -> bool: + """Check if this mapping is for an expert parameter.""" + return ".mlp.experts.linear_fc" in self.megatron_param + + def _resolve_names(self, captures: Tuple[str, ...]) -> Tuple[str, Union[str, Dict[str, str]]]: + """Resolve wildcard patterns with captured values. + + Handles both ** (any characters) and * (digits) wildcards in order. + ** patterns are processed before * patterns to avoid conflicts. + """ + resolved_megatron_param = self.megatron_param + capture_index = 0 + + # First pass: resolve ** wildcards + while "**" in resolved_megatron_param and capture_index < len(captures): + resolved_megatron_param = resolved_megatron_param.replace( + "**", captures[capture_index], 1 + ) + capture_index += 1 + + # Second pass: resolve * wildcards + while "*" in resolved_megatron_param and capture_index < len(captures): + resolved_megatron_param = resolved_megatron_param.replace( + "*", captures[capture_index], 1 + ) + capture_index += 1 + + if isinstance(self.hf_param, str): + resolved_hf_param = self.hf_param + capture_index = 0 + + # First pass: resolve ** wildcards + while "**" in resolved_hf_param and capture_index < len(captures): + resolved_hf_param = resolved_hf_param.replace("**", captures[capture_index], 1) + capture_index += 1 + + # Second pass: resolve * wildcards + while "*" in resolved_hf_param and capture_index < len(captures): + resolved_hf_param = resolved_hf_param.replace("*", captures[capture_index], 1) + capture_index += 1 + else: + resolved_hf_param = {} + for k, v in self.hf_param.items(): + resolved_v = v + capture_index = 0 + + # First pass: resolve ** wildcards + while "**" in resolved_v and capture_index < len(captures): + resolved_v = resolved_v.replace("**", captures[capture_index], 1) + capture_index += 1 + + # Second pass: resolve * wildcards + while "*" in resolved_v and capture_index < len(captures): + resolved_v = resolved_v.replace("*", captures[capture_index], 1) + capture_index += 1 + + resolved_hf_param[k] = resolved_v + + return resolved_megatron_param, resolved_hf_param + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Create a new mapping with resolved wildcards. + + This default implementation works for mappings with a + (megatron_param, hf_param) constructor. + + Args: + captures (Tuple[str, ...]): Captured wildcard values. + + Returns: + MegatronParamMapping: A new mapping instance with resolved names. + """ + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + return type(self)(resolved_megatron_param, resolved_hf_param) + + @abstractmethod + def hf_to_megatron(self, hf_weights: WeightType, megatron_module: nn.Module) -> torch.Tensor: + """Convert hf_weights TO Megatron format. + + This method handles: + 1. Format transformation (if needed) + 2. Tensor parallel distribution (if self.tp_size > 1) + + Args: + hf_weights (WeightType): Source hf_weights in external format. + megatron_module (nn.Module): Target Megatron module (for config + access). + + Returns: + torch.Tensor: Weight tensor ready for the current TP rank. + """ + ... + + @abstractmethod + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Convert weights FROM Megatron format. + + This method handles: + 1. Pipeline parallel broadcasting (if weight is on different PP rank) + 2. Tensor parallel gathering (if needed) + 3. Format transformation + + Args: + megatron_weights (Optional[torch.Tensor]): Weight tensor from current + rank (None if on different PP rank). + megatron_module (Optional[nn.Module]): Module for config access + (None if on different PP rank). + + Returns: + Dict[str, torch.Tensor]: Converted weights (empty dict if not on + TP rank 0). + """ + ... + + def broadcast_from_pp_rank( + self, tensor: Optional[torch.Tensor], cache_key: Optional[str] = None + ) -> Optional[torch.Tensor]: + """Broadcast a tensor from the pipeline-parallel rank that owns it. + + Broadcasts to **all** PP ranks. This mirrors the behaviour of + `broadcast_from_megatron_pp` in the original MMapping implementation and + additionally keeps the tensor-parallel metadata (`tensor_model_parallel`, + `partition_dim`) consistent on every rank. + + Args: + tensor (Optional[torch.Tensor]): The local tensor if the current PP + rank owns it. ``None`` otherwise. + + Returns: + Optional[torch.Tensor]: The broadcasted tensor on every PP rank, or + ``None`` if *no* PP rank owned the tensor (which indicates a bug + in the calling code). + """ + + # Fast-path when we are not using pipeline parallelism. + if self.pp_size == 1: + return tensor + + # ------------------------------------------------------------------ + # 1. Gather (shape, dtype, tensor_parallel flag, partition_dim) from + # every PP rank so that we can find the source rank. + # ------------------------------------------------------------------ + if cache_key is not None and cache_key in self._tensor_spec_output_cache: + tensor_spec_output = self._tensor_spec_output_cache[cache_key] + else: + if tensor is not None: + shape = tensor.shape + dtype = tensor.dtype + tensor_parallel = getattr(tensor, "tensor_model_parallel", None) + partition_dim = getattr(tensor, "partition_dim", None) + tensor_spec = (shape, dtype, tensor_parallel, partition_dim) + else: + tensor_spec = None + + tensor_spec_output: list[Optional[tuple]] = [None] * self.pp_size + torch.distributed.all_gather_object( + tensor_spec_output, tensor_spec, group=self.pp_group + ) + self._tensor_spec_output_cache[cache_key] = tensor_spec_output + + # ------------------------------------------------------------------ + # 2. Identify the owning rank (the only rank with a non-None spec). + # ------------------------------------------------------------------ + target_tensor_spec = None + src_rank = None # Rank *inside* the PP group. + for rank, spec in enumerate(tensor_spec_output): + if spec is not None: + if target_tensor_spec is not None: + raise ValueError( + f"Tensor exists on more than one PP rank. Found on ranks {src_rank} and {rank}." + ) + target_tensor_spec = spec + src_rank = rank + + if target_tensor_spec is None: + # No rank had the tensor – this is an error in the caller. + raise ValueError("Object must exist on at least one PP rank") + + # ------------------------------------------------------------------ + # 3. Ensure every rank has an allocated tensor with the right shape + # and dtype before the broadcast. + # ------------------------------------------------------------------ + if tensor is None: + shape, dtype, tensor_parallel, partition_dim = target_tensor_spec + # Use CPU by default, unless CUDA is available + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tensor = torch.empty(shape, dtype=dtype, device=device) + if tensor_parallel is not None: + tensor.tensor_model_parallel = tensor_parallel + if partition_dim is not None: + tensor.partition_dim = partition_dim + + # ------------------------------------------------------------------ + # 4. Broadcast from the source PP rank to all other PP ranks. + # ------------------------------------------------------------------ + global_src = torch.distributed.get_global_rank(group=self.pp_group, group_rank=src_rank) + torch.distributed.broadcast(tensor, src=global_src, group=self.pp_group) + + return tensor + + def broadcast_obj_from_pp_rank( + self, obj: Optional[Any], cache_key: Optional[str] = None + ) -> Any: + """Broadcast any Python object from the PP rank that owns it. + + This method is useful for broadcasting configuration objects or + other metadata across pipeline parallel ranks. Results are cached + after the first call to avoid redundant broadcasts. + + Args: + obj (Optional[Any]): Object to broadcast (None on non-owning ranks). + cache_key (Optional[str]): Optional cache key. If not provided, + no caching will be performed. + + Returns: + Any: Broadcasted object on all ranks. + + Raises: + ValueError: If object exists on multiple ranks or no ranks. + """ + if self.pp_size == 1: + return obj + + # Check if we already have a cached result (only if cache_key is provided) + if cache_key is not None and cache_key in self._broadcast_obj_cache: + return self._broadcast_obj_cache[cache_key] + + # ------------------------------------------------------------------ + # 1. Gather presence flags from all PP ranks to find the source rank + # ------------------------------------------------------------------ + has_obj = obj is not None + obj_flags = [None] * self.pp_size + torch.distributed.all_gather_object(obj_flags, has_obj, group=self.pp_group) + + # ------------------------------------------------------------------ + # 2. Identify the owning rank (the only rank with True flag) + # ------------------------------------------------------------------ + src_rank = None # Rank *inside* the PP group + for rank, flag in enumerate(obj_flags): + if flag: + src_rank = rank + + if src_rank is None: + raise ValueError("Object must exist on at least one PP rank") + + # ------------------------------------------------------------------ + # 3. Broadcast the object from the source rank to all ranks + # ------------------------------------------------------------------ + if src_rank is None: + raise ValueError("Could not determine source rank") + + # Use broadcast_object_list which is more robust than all_gather_object + obj_list = [obj] + pp_ranks = torch.distributed.get_process_group_ranks(self.pp_group) + global_src = pp_ranks[src_rank] + torch.distributed.broadcast_object_list(obj_list, src=global_src, group=self.pp_group) + + result = obj_list[0] + + # Cache the result for future calls (only if cache_key is provided) + if cache_key is not None: + self._broadcast_obj_cache[cache_key] = result + + return result + + def clear_broadcast_cache(self): + """Clear the broadcast object cache. + + This can be useful for testing or if the objects being broadcast + might change during the lifetime of the mapping. + """ + self._broadcast_obj_cache.clear() + + def clear_tensor_spec_output_cache(self): + """Clear the tensor spec output cache. + + This can be useful for testing or if the tensor spec output + might change during the lifetime of the mapping. + """ + self._tensor_spec_output_cache.clear() + + def broadcast_tensor_to_tp_ranks(self, tensor: torch.Tensor, src_rank: int = 0) -> torch.Tensor: + """Broadcast a tensor to all TP ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast. + src_rank (int, optional): The source rank within the TP group. + Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor. + """ + if self.tp_size == 1: + return tensor + + global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) + torch.distributed.broadcast(tensor, src=global_src, group=self.tp_group) + return tensor + + def scatter_to_tp_ranks( + self, + splits: Optional[List[torch.Tensor]], + output_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + src_rank: int = 0, + ) -> torch.Tensor: + """Scatter tensor splits to TP ranks. + + Args: + splits (Optional[List[torch.Tensor]]): A list of tensor shards to + scatter. Only rank `src_rank` needs this. + output_shape (torch.Size): The shape of the output tensor on each rank. + dtype (torch.dtype): The data type of the output tensor. + device (torch.device): The device for the output tensor. + src_rank (int, optional): The source rank for the scatter operation. + Defaults to 0. + + Returns: + torch.Tensor: The scattered tensor shard on the current rank. + """ + if self.tp_size == 1: + return splits[0].to(device=device) if splits else None + + output = torch.empty(output_shape, dtype=dtype, device=device) + global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) + + scatter_list = None + if self.tp_rank == src_rank and splits: + scatter_list = [s.to(device=device) for s in splits] + + torch.distributed.scatter(output, scatter_list, src=global_src, group=self.tp_group) + return output + + def gather_from_tp_ranks(self, tensor: torch.Tensor) -> List[torch.Tensor]: + """Gather tensors from all TP ranks. + + Args: + tensor (torch.Tensor): The tensor shard to be gathered from the + current rank. + + Returns: + List[torch.Tensor]: A list of tensor shards from all TP ranks. + """ + if self.tp_size == 1: + return [tensor] + + gathered = [torch.empty_like(tensor) for _ in range(self.tp_size)] + torch.distributed.all_gather(gathered, tensor, group=self.tp_group) + return gathered + + def _count_wildcard_groups(self, pattern: str) -> int: + """Count the number of wildcard capture groups in a pattern. + + Args: + pattern: Pattern string with * and ** wildcards + + Returns: + Number of capture groups that will be generated + + Note: + ** counts as 1 group, * counts as 1 group + ** must be counted before * to avoid double-counting + """ + count = 0 + remaining = pattern + + # Count ** patterns first + while "**" in remaining: + count += 1 + remaining = remaining.replace("**", "", 1) + + # Count remaining * patterns + count += remaining.count("*") + + return count + + def _validate_patterns(self): + """Validate wildcard consistency between patterns.""" + megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param) + if isinstance(self.hf_param, str): + hf_param_wildcards = self._count_wildcard_groups(self.hf_param) + if megatron_param_wildcards != hf_param_wildcards: + raise ValueError( + f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has " + f"{megatron_param_wildcards} wildcards, hf_param='{self.hf_param}' has {hf_param_wildcards}" + ) + else: + for key, pattern in self.hf_param.items(): + hf_param_wildcards = self._count_wildcard_groups(pattern) + if megatron_param_wildcards != hf_param_wildcards: + raise ValueError( + f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has " + f"{megatron_param_wildcards} wildcards, hf_param['{key}']='{pattern}' has {hf_param_wildcards}" + ) + + def _normalize_expert_param_name(self, param_name: str) -> str: + """Normalize expert parameter name by replacing trailing numbers with 0. + e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0 + + Args: + param_name (str): Parameter name that may end with a number. + + Returns: + str: Parameter name with trailing number replaced by 0. + """ + # Use regex to replace any trailing number with 0 + return re.sub(r"\d+$", "0", param_name) + + def _get_config(self, module: nn.Module) -> Any: + """Extract configuration from module hierarchy.""" + current = module + while current is not None: + if hasattr(current, "config"): + return current.config + # Try parent module + if hasattr(current, "_parent"): + current = current._parent + else: + # Walk up the module tree + for parent_module in module.modules(): + for child_name, child_module in parent_module.named_children(): + if child_module is current: + current = parent_module + break + else: + continue + break + else: + current = None + + raise ValueError( + f"Could not find config in module hierarchy for {module.__class__.__name__}. " + f"Ensure the module or its parent has a 'config' attribute." + ) + + def gather_from_ep_ranks( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[MegatronModule], + hf_param_name: Optional[str], + ) -> Dict[str, torch.Tensor]: + """Handle expert parallel weight gathering for MoE models. + + This method gathers expert weights across expert-parallel (EP) ranks and + returns a mapping from HF parameter names to the corresponding tensors + from each EP rank. Call this only for confirmed expert parameters + (self.is_expert is True), typically after TP gathering/concatenation in + the export path (Megatron → HF). + + Behavior and notation: + - Let E be the total number of experts (e.g., config.num_moe_experts) and + S be the expert-parallel size (ep_size). We assume E % S == 0. + - Each EP rank owns E/S experts. For a given parameter name, we infer a + local expert index L (0 ≤ L < E/S) on the current EP rank from the + global expert id embedded in the name (works for both .weight and .bias). + - The set of global expert ids that correspond to this local index L + across all EP ranks is: {L + k * (E/S) | k ∈ [0, S-1]}. + + Communication and outputs: + - We perform an all_gather over the EP group to collect the tensor from + every EP rank into a list ordered by EP rank id. + - For each EP rank k, we construct the HF parameter name by replacing the + expert id in `hf_param_name` with (L + k * (E/S)), preserving the rest + of the path, and map that name to the gathered tensor from rank k. + + Example: + - E = 8, S = 2 → E/S = 4. Experts are distributed as: + Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7]. + If the local index L = 0 (derived from the param name), this returns: + {"...experts.0.weight": tensor_from_rank0, "...experts.4.weight": tensor_from_rank1} + + Args: + megatron_weights (Optional[torch.Tensor]): The local expert weight tensor + (after any TP handling) on this EP rank. + megatron_module (Optional[MegatronModule]): The Megatron module containing + configuration (used to determine E and E/S). Can be None on non-owning PP + ranks; values will be broadcast across PP. + hf_param_name (Optional[str]): HF parameter name template for the current + (local) expert on this rank. The expert id within this string is replaced + with the appropriate global expert ids for each EP rank. + + Returns: + Dict[str, torch.Tensor]: Mapping from HF parameter names (one per EP rank) + to the corresponding expert tensors gathered from each EP rank. + """ + if megatron_module is None: + num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") + else: + model_config = self._get_config(megatron_module) + num_experts = model_config.num_moe_experts + num_experts_per_rank = num_experts // self.ep_size + num_experts_per_rank = self.broadcast_obj_from_pp_rank( + num_experts_per_rank, "num_experts_per_rank" + ) + + # Extract local expert number from parameter name + # Handle both .weight and .bias suffixes + local_expert_number = None + for key in (".weight", ".bias"): + if key in self.megatron_param: + global_expert_number = int(self.megatron_param.split(key)[-1]) + local_expert_number = global_expert_number % num_experts_per_rank + + # Compute global expert numbers for all EP ranks + # use regex to replace the local expert number with the global expert number + gathered_expert_param_names = [ + re.sub( + r"experts\.(\d+)", + f"experts.{int(local_expert_number) + num_experts_per_rank * i}", + str(hf_param_name), + ) + for i in range(self.ep_size) + ] + assert ( + hf_param_name in gathered_expert_param_names + ), f"hf_param_name {hf_param_name} not in gathered_expert_param_names {gathered_expert_param_names}" + + # Gather weights from all EP ranks + gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] + torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) + + # Return dictionary mapping HF parameter names to weights + return { + param_name: gathered_weights[i] + for i, param_name in enumerate(gathered_expert_param_names) + } + + def maybe_dequantize(self, tensor: torch.Tensor) -> torch.Tensor: + """Dequantize FP8 tensor if needed.""" + if HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS): + return tensor.dequantize(dtype=tensor.dtype) + return tensor + + +class DirectMapping(MegatronParamMapping[torch.Tensor]): + """Direct 1:1 weight mapping with no transformation or tensor parallelism.""" + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Direct copy - no transformation or distribution.""" + return hf_weights + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Direct copy with PP broadcast.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + return {str(self.hf_param): megatron_weights} + + +class ColumnParallelMapping(MegatronParamMapping[torch.Tensor]): + """ + Mapping for column-parallel linear and embedding weights. + + Column-parallel layers in Megatron split the output dimension across tensor + parallel ranks. This is used for layers where each rank computes a portion + of the output features independently, such as: + - Embedding layers (split vocabulary) + - Linear layers producing hidden states (e.g., QKV projections, MLP up projections) + + The weight matrix is partitioned along dimension 0 (rows), so each TP rank + holds a subset of output features while maintaining all input features. + + **Sharding pattern** + - Original weight: `[output_features, input_features]` + - Rank 0: `[output_features/tp_size, input_features]` + - Rank 1: `[output_features/tp_size, input_features]` + - ... + + **Forward path (HuggingFace → Megatron)** + 1. Validate divisibility: output dimension must be divisible by tp_size + 2. Split: Chunk tensor along dim 0 into tp_size equal parts + 3. Scatter: Distribute chunks to respective TP ranks + + **Reverse path (Megatron → HuggingFace)** + 1. Broadcast: Ensure all PP ranks have the tensor + 2. Gather: Collect chunks from all TP ranks + 3. Concatenate: Reassemble along dim 0 on rank 0 + + Example: + .. code-block:: python + + # For a weight of shape [4096, 1024] with tp_size=4: + # Each rank gets [1024, 1024] after column-parallel split + mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight") + megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module) + # megatron_weights.shape = [1024, 1024] on each rank + + Note: + This mapping also handles bias terms, which are 1D tensors split + along their only dimension following the same pattern. + """ + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Split weight along dim 0 and distribute to TP ranks.""" + # if self.tp_size == 1: + # return hf_weights + + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + if self.tp_size == 1: + full_weight = col_padding_size(hf_weights, target_param, self.tp_size) + return full_weight + + # On rank 0, check for divisibility and split + if self.tp_rank == 0: + if hf_weights is None: + raise ValueError("hf_weights should not be None on rank 0") + + # For MCore MambaMixer, A_log is initialized in FP32 but cast to BF16 when + # saving ckpts, including the ckpt uploaded to HF. Without this cast, + # self.scatter_to_tp_ranks will try to scatter the HF A_log weights in BF16 to + # the Megatron tensor which is in FP32. This will error. So we cast before the scatter. + if hf_weights.dtype != target_param.dtype: + logger.warning( + f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. " + f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. " + f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. " + ) + hf_weights = hf_weights.to(target_param.dtype) + + # For bias (1D), we still split along dim 0 + # For weight (2D), we split along dim 0 (output dimension) + # full_size = hf_weights.shape[0] + # if full_size % self.tp_size != 0: + # raise ValueError( + # f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + # ) + # splits = torch.chunk(hf_weights, self.tp_size, dim=0) + full_weight = col_padding_size(hf_weights, target_param, self.tp_size) + full_size = full_weight.shape[0] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(full_weight, self.tp_size, dim=0) + else: + splits = None + + # Scatter to all ranks. Each rank gets its sharded shape from its module. + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather from all TP ranks and concatenate.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + if self.tp_size == 1: + full_weights = megatron_weights + else: + # Gather from all TP ranks + gathered = self.gather_from_tp_ranks(megatron_weights) + full_weights = torch.cat(gathered, dim=0) + + if self.is_expert: + return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) + + return {str(self.hf_param): full_weights} + + +class RowParallelMapping(MegatronParamMapping[torch.Tensor]): + """Mapping for **row-parallel** linear weights. + + Megatron shards row-parallel tensors along **dimension 1** (the *input* + dimension of a linear layer). + + **Forward path (external → Megatron)** + 1. Rank 0 validates that the *second* dimension is divisible by `tp_size`. + 2. Rank 0 splits the tensor with `torch.chunk(..., dim=1)` producing + `tp_size` equally-sized shards. + 3. The shards are **scattered** so that every TP rank receives exactly one + shard matching the shape of its local Megatron parameter. + + **Reverse path (Megatron → external)** + 1. The local Megatron parameter (which may live on any PP rank) is + broadcast to all PP ranks so that the gather step can be collective. + 2. All TP ranks **gather** their shard. + 3. Rank 0 concatenates the gathered list along dim 1 to reconstruct the + original unsharded weight and emits it under the external (HF) name. + """ + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Split weight along dim 1 and distribute to TP ranks.""" + if self.tp_size == 1: + return hf_weights + + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + # On rank 0, check for divisibility and split + if self.tp_rank == 0: + if hf_weights is None: + raise ValueError("hf_weights should not be None on rank 0") + + # For bias (1D), we still split along dim 0 + # For weight (2D), we split along dim 1 + if hf_weights.ndim == 1: + full_size = hf_weights.shape[0] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(hf_weights, self.tp_size, dim=0) + else: + assert hf_weights.ndim == 2 + full_size = hf_weights.shape[1] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(hf_weights, self.tp_size, dim=1) + + else: + splits = None + + # Scatter to all ranks. Each rank gets its sharded shape from its module. + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather from all TP ranks and concatenate.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + if self.tp_size == 1: + full_weights = megatron_weights + else: + gathered = self.gather_from_tp_ranks(megatron_weights) + full_weights = torch.cat(gathered, dim=1) + + if self.is_expert: + return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) + + return {str(self.hf_param): full_weights} + + +class ReplicatedMapping(MegatronParamMapping[torch.Tensor]): + """Mapping for weights that are **fully replicated** across TP ranks. + + Examples: layer-norm scales, biases, router weights in MoE, etc. + + These tensors exist in exactly the same form on *every* TP rank, so the + mapping logic is trivial – but we still need to broadcast across TP ranks + during *load* (HF → Megatron) and ensure we do **not** emit duplicates + during *export* (Megatron → HF). + """ + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Replicate weight to all TP ranks.""" + try: + target_device = megatron_module.weight.device + except AttributeError: + # the parameter may not be called "weight" + target_device = next(megatron_module.parameters()).device + hf_weights = hf_weights.to(device=target_device) + if self.tp_size == 1: + return hf_weights + + # TODO(yuya): router.weight is on device cpu, need to check. + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + # All ranks need the full weight + if self.tp_rank > 0: + # Create empty tensor of correct shape + hf_weights = torch.empty_like(hf_weights) + + # Broadcast from rank 0 to all TP ranks + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Return weight only from rank 0 to avoid duplication.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + if self.is_expert: + return self.gather_from_ep_ranks(megatron_weights, megatron_module, self.hf_param) + + return {str(self.hf_param): megatron_weights} + + +class AutoMapping(MegatronParamMapping[torch.Tensor]): + """ + Smart mapping that automatically detects and applies the correct parallelism strategy. + + This mapping eliminates the need to manually specify whether a layer is + column-parallel, row-parallel, or replicated. It examines the Megatron + module at runtime and delegates to the appropriate specialized mapping. + + **Detection strategy** + 1. Check module class name against a registry of known types + 2. If unknown, examine module attributes (tensor_model_parallel, partition_dim) + 3. Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated + + This abstraction is particularly useful for model-agnostic code where you + don't know the parallelism type ahead of time, or when working with models + that mix different parallelism strategies. + + **Built-in module recognition** + - Column-parallel: `ColumnParallelLinear`, `VocabParallelEmbedding`, etc. + - Row-parallel: `RowParallelLinear`, `TERowParallelLinear` + - Replicated: `LayerNorm`, `RMSNorm`, and other normalization layers + + Example: + .. code-block:: python + + # Automatically handles any weight type + mapping = AutoMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + hf_param="model.layers.*.mlp.gate_proj.weight" + ) + + # Works with column-parallel layers + megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module) + + # Also works with normalization layers + norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module) + + # Register custom module types + AutoMapping.register_module_type("MyCustomLinear", "column") + + Note: + If the parallelism type cannot be determined, the mapping will raise + a descriptive error suggesting how to fix the issue. + """ + + # Module type registry + _MODULE_TYPE_REGISTRY: Dict[str, set] = { + "column": { + "ColumnParallelLinear", + "TEColumnParallelLinear", + "TELayerNormColumnParallelLinear", + "TEColumnParallelGroupedLinear", + "VocabParallelEmbedding", + }, + "row": {"RowParallelLinear", "TERowParallelLinear", "TERowParallelGroupedLinear"}, + "replicated": { + # Normalization layers + "TENorm", + "FusedLayerNorm", + "WrappedTorchNorm", + "LayerNorm", + "RMSNorm", + "L2Norm", + # Other non-parallel modules + "IdentityOp", + "DotProductAttention", + "TEDotProductAttention", + "TopKRouter", + }, + } + + @classmethod + def register_module_type(cls, module_name: str, parallelism_type: str): + """Register a new module type for automatic parallelism detection. + + Args: + module_name (str): The name of the module class (e.g., + 'MyColumnLinear'). + parallelism_type (str): One of 'column', 'row', or 'replicated'. + """ + if parallelism_type not in cls._MODULE_TYPE_REGISTRY: + raise ValueError( + f"Invalid parallelism_type '{parallelism_type}'. " + f"Must be one of {list(cls._MODULE_TYPE_REGISTRY.keys())}" + ) + cls._MODULE_TYPE_REGISTRY[parallelism_type].add(module_name) + + def __init__(self, megatron_param: str, hf_param: str): + """Initialize TP-aware mapping.""" + super().__init__(megatron_param, hf_param) + + # Cache for detected parallelism type and delegate mapping + self._detected_type: Optional[str] = None + self._mapping: Optional[MegatronParamMapping[torch.Tensor]] = None + + def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[torch.Tensor]: + """Get or create the appropriate mapping for the given type.""" + if parallelism_type == "column": + return ColumnParallelMapping(self.megatron_param, self.hf_param) + elif parallelism_type == "row": + return RowParallelMapping(self.megatron_param, self.hf_param) + elif parallelism_type == "replicated": + return ReplicatedMapping(self.megatron_param, self.hf_param) + else: + raise ValueError(f"Unknown parallelism type: {parallelism_type}") + + def _detect_parallelism_type(self, module: nn.Module) -> str: + """Detect parallelism type from module.""" + module_type = type(module).__name__ + + # Handle fused modules like TELayerNormColumnParallelLinear + # These modules have both column-parallel weights (weight, bias) + # and replicated layer norm weights (layer_norm_weight, layer_norm_bias) + if module_type == "TELayerNormColumnParallelLinear": + # Check the actual parameter name to determine the correct parallelism type + if self.megatron_param and ( + self.megatron_param.endswith("layer_norm_weight") + or self.megatron_param.endswith("layer_norm_bias") + ): + return "replicated" + # All other parameters (weight, bias) are column-parallel + return "column" + + # Check registry first + for parallelism, types in self._MODULE_TYPE_REGISTRY.items(): + if module_type in types: + return parallelism + + # Fallback to inspecting module attributes + if hasattr(module, "tensor_model_parallel"): + if not module.tensor_model_parallel: + return "replicated" + + # Check partition dimension + partition_dim = getattr(module, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + return "row" + + # Fallback for normalization layers + if any(norm in module_type for norm in ["Norm", "Normalization"]): + return "replicated" + + # Check parallel_mode for TELinear + if module_type == "TELinear": + if module.parallel_mode == "column": + return "column" + elif module.parallel_mode == "row": + return "row" + else: + return "replicated" + + # Cannot determine - raise informative error + known_types = {p: sorted(list(t)) for p, t in self._MODULE_TYPE_REGISTRY.items()} + + raise ValueError( + f"Cannot determine parallelism type for module '{module_type}' " + f"at weight '{self.megatron_param}'.\n" + f"Please use an explicit mapping type (e.g., ColumnParallelMapping) " + f"or register the module type using:\n" + f" AutoMapping.register_module_type('{module_type}', 'column|row|replicated')\n\n" + f"Currently known module types:\n{json.dumps(known_types, indent=2)}" + ) + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Delegate to appropriate mapping based on module type.""" + # Detect type and create delegate on first use + if self._mapping is None: + self._detected_type = self._detect_parallelism_type(megatron_module) + self._mapping = self._get_or_create_mapping(self._detected_type) + + return self._mapping.hf_to_megatron(hf_weights, megatron_module) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Delegate to appropriate mapping based on module type.""" + # Need to determine type even if module is None (different PP rank) + assert self.megatron_param is not None, "`megatron_param` is required for AutoMapping." + + if self._mapping is None: + if megatron_module is not None: + self._detected_type = self._detect_parallelism_type(megatron_module) + # Broadcast to other ranks + self._detected_type = self.broadcast_obj_from_pp_rank( + self._detected_type, "detected_type" + ) + else: + # Receive from owning rank + self._detected_type = self.broadcast_obj_from_pp_rank(None, "detected_type") + self._mapping = self._get_or_create_mapping(self._detected_type) + + return self._mapping.megatron_to_hf(megatron_weights, megatron_module) + + +class QKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + """ + Mapping for interleaved Query/Key/Value attention projection weights. + + This mapping handles the conversion between separate Q, K, V matrices used in + standard transformers and Megatron's optimized interleaved format. The + interleaving pattern groups queries with their corresponding key-value pairs + to maximize GEMM efficiency during attention computation. + + **External format (HuggingFace)** + - Separate tensors: `q_proj`, `k_proj`, `v_proj` + - Each of shape `[hidden_size, hidden_size]` or `[hidden_size, head_dim * num_heads]` + + **Megatron format** + - Single interleaved tensor following grouped query attention (GQA) pattern + - Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]` + - Where `n = num_attention_heads / num_query_groups` + + **Key features** + 1. Format conversion: Handles merging/splitting with proper interleaving + 2. Grouped Query Attention: Supports different numbers of Q and KV heads + 3. Tensor parallelism: Delegates to AutoMapping for distribution + + Example: + .. code-block:: python + + # Create mapping for attention weights + mapping = QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight" + ) + + # Convert from HuggingFace to Megatron + qkv_weights = {"q": q_tensor, "k": k_tensor, "v": v_tensor} + megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) + + # Convert from Megatron to HuggingFace + hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) + # Returns: {"q_proj.weight": ..., "k_proj.weight": ..., "v_proj.weight": ...} + + Note: + This mapping automatically handles both regular multi-head attention + (same number of Q, K, V heads) and grouped query attention (fewer + KV heads than Q heads) based on the model configuration. + """ + + def __init__(self, megatron_param: str, q: str, k: str, v: str): + """Initialize QKV mapping. + + Args: + megatron_param (str): Megatron QKV parameter name pattern. + q (str): Query weight name pattern. + k (str): Key weight name pattern. + v (str): Value weight name pattern. + """ + super().__init__(megatron_param, {"q": q, "k": k, "v": v}) + # Delegate all tensor-parallel logic to the smart TP-aware mapping so we + # do not hard-code the assumption that QKV projections are column-parallel. + # This keeps the format-handling (merge/split) concerns separate from + # TP/PP distribution mechanics. + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + def hf_to_megatron( + self, hf_weights: Dict[str, torch.Tensor], megatron_module: nn.Module + ) -> torch.Tensor: + """Merge Q, K, V into interleaved format and distribute.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + + # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors) + if hf_weights["q"].ndim == 1: + # For biases, use the bias-specific merge function + merged = merge_qkv_biases(config, hf_weights["q"], hf_weights["k"], hf_weights["v"]) + else: + # For hf_weights, use the standard merge function + merged = merge_qkv_weights( + config, hf_weights["q"], hf_weights["k"], hf_weights["v"] + ) + else: + merged = None + + # Delegate the actual sharding/broadcasting to the TP-aware mapping. + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather QKV shards and split into Q, K, V.""" + # Dequantize if needed + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # ------------------------------------------------------------------ + # Broadcast / retrieve the transformer configuration so that every PP + # rank (also the ones that will early-return) participates in the + # collective communication. + # ------------------------------------------------------------------ + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None, "qkv_config") + else: + config = self._get_config(megatron_module) + # create shallow copy and remove non-picklable objects with max depth=2 + config = remove_non_pickleables(config, max_depth=2) + config = self.broadcast_obj_from_pp_rank(config, "qkv_config") + + # Delegate TP/PP gathering. + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + + if not packed_dict: + return {} + + packed_qkv = next(iter(packed_dict.values())) + + # Check if we're dealing with biases (1D) or weights (2D) + if packed_qkv.ndim == 1: + # Split biases + q, k, v = split_qkv_biases(config, packed_qkv) + else: + # Split weights + q, k, v = split_qkv_weights(config, packed_qkv) + + return {self.hf_param["q"]: q, self.hf_param["k"]: k, self.hf_param["v"]: v} + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Return a new *resolved* QKVMapping instance.""" + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + + return type(self)( + resolved_megatron_param, + resolved_hf_param["q"], + resolved_hf_param["k"], + resolved_hf_param["v"], + ) + + +class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + """ + Mapping for interleaved Query/Key/Value attention projection weights. + + This mapping handles the conversion between Concatenated Q, K, V matrices used in + some transformers models and Megatron's optimized interleaved format. The + interleaving pattern groups queries with their corresponding key-value pairs + to maximize GEMM efficiency during attention computation. + + **External format (HuggingFace)** + - One tensor with concatenated query, key, value: `qkv`, with shape + `[hidden_size, head_dim * num_heads + 2 * head_dim * num_query_groups]` + + **Megatron format** + - Single interleaved tensor following grouped query attention (GQA) pattern + - Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]` + - Where `n = num_attention_heads / num_query_groups` + + **Key features** + 1. Format conversion: Handles merging/splitting with proper interleaving + 2. Grouped Query Attention: Supports different numbers of Q and KV heads + 3. Tensor parallelism: Delegates to AutoMapping for distribution + + Example: + .. code-block:: python + + # Create mapping for attention weights + mapping = QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + qkv="model.layers.*.self_attn.qkv.weight", + ) + + # Convert from HuggingFace to Megatron + megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) + + # Convert from Megatron to HuggingFace + hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) + + Note: + This mapping automatically handles both regular multi-head attention + (same number of Q, K, V heads) and grouped query attention (fewer + KV heads than Q heads) based on the model configuration. + """ + + def __init__(self, megatron_param: str, hf_param: str): + """Initialize QKV mapping. + + Args: + megatron_param (str): Megatron interleaved QKV parameter name pattern. + hf_param (str): HF concatenated QKV parameter name pattern. + """ + super().__init__(megatron_param, hf_param) + # Delegate all tensor-parallel logic to the smart TP-aware mapping so we + # do not hard-code the assumption that QKV projections are column-parallel. + # This keeps the format-handling (merge/split) concerns separate from + # TP/PP distribution mechanics. + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Merge Q, K, V into interleaved format and distribute.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + head_num = config.num_attention_heads + head_size = config.kv_channels + num_query_groups = config.num_query_groups + q, k, v = hf_weights.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], + dim=0, + ) + # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors) + if q.ndim == 1: + # For biases, use the bias-specific merge function + merged = merge_qkv_biases(config, q, k, v) + else: + # For hf_weights, use the standard merge function + merged = merge_qkv_weights(config, q, k, v) + else: + merged = None + + # Delegate the actual sharding/broadcasting to the TP-aware mapping. + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather QKV shards and split into Q, K, V.""" + # Dequantize if needed + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # ------------------------------------------------------------------ + # Broadcast / retrieve the transformer configuration so that every PP + # rank (also the ones that will early-return) participates in the + # collective communication. + # ------------------------------------------------------------------ + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None, "qkv_config") + else: + config = self._get_config(megatron_module) + # create shallow copy and remove non-picklable objects with max depth=2 + config = remove_non_pickleables(config, max_depth=2) + config = self.broadcast_obj_from_pp_rank(config, "qkv_config") + + # Delegate TP/PP gathering. + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + + if not packed_dict: + return {} + + packed_qkv = next(iter(packed_dict.values())) + + # Check if we're dealing with biases (1D) or weights (2D) + if packed_qkv.ndim == 1: + # Split biases + q, k, v = split_qkv_biases(config, packed_qkv) + else: + # Split weights + q, k, v = split_qkv_weights(config, packed_qkv) + + return {str(self.hf_param): torch.cat((q, k, v), dim=0)} + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Return a new *resolved* QKVMapping instance.""" + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + + return type(self)(resolved_megatron_param, resolved_hf_param) + + +class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU). + + Checkpoint formats expose two independent matrices: + + - **G** – gate projection + - **U** – up projection + + Megatron concatenates them row-wise (`[G; U]`) so that a single GEMM can + produce both activations. + + **Responsibilities handled by this mapping** + 1. **Concatenate / split** – convert between `[G; U]` (Megatron) and the + separate `{G, U}` matrices (external). + 2. **Tensor-parallel distribution** – correctly splits gate and up + projections separately before concatenating corresponding shards, + ensuring each TP rank gets the proper [gate_shard; up_shard] format. + + **TP Distribution Strategy** + For tensor parallelism, this mapping: + - Splits gate and up matrices separately along output dimension (dim 0) + - Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i + - This ensures each rank's concatenated tensor matches the expected shape + """ + + def __init__(self, megatron_param: str, gate: str, up: str): + """Initialize gated MLP mapping. + + Args: + megatron_param (str): Megatron MLP parameter name pattern. + gate (str): Gate projection weight name pattern. + up (str): Up projection weight name pattern. + """ + super().__init__(megatron_param, {"gate": gate, "up": up}) + + def hf_to_megatron( + self, hf_weights: Dict[str, torch.Tensor], megatron_module: nn.Module + ) -> torch.Tensor: + """Split gate and up separately, then concatenate corresponding shards.""" + # For single TP, just concatenate and return + if self.tp_size == 1: + return torch.cat([hf_weights["gate"], hf_weights["up"]], dim=0) + + # Get target parameter info from megatron module + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + # On rank 0, split gate and up separately, then concatenate corresponding pieces + if self.tp_rank == 0: + gate = hf_weights["gate"] + up = hf_weights["up"] + + # Verify shapes match + assert gate.shape == up.shape, "Gate and up weights must have the same shape" + + # Check divisibility for TP splitting + gate_output_size = gate.shape[0] + if gate_output_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split gate dimension 0 size {gate_output_size} across {self.tp_size} TP ranks" + ) + + # Split gate and up separately along output dimension (dim 0) + # This works for both bias (1D) and weight (2D) tensors + gate_splits = torch.chunk(gate, self.tp_size, dim=0) + up_splits = torch.chunk(up, self.tp_size, dim=0) + + # Concatenate corresponding pieces: [gate_shard_i; up_shard_i] for each rank i + splits = [torch.cat([gate_splits[i], up_splits[i]], dim=0) for i in range(self.tp_size)] + else: + splits = None + + # Scatter the concatenated shards to each rank + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather concatenated shards and split into gate and up.""" + # Handle cross-PP broadcast first + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + # Handle TP gathering + if self.tp_size == 1: + # No TP, just split the concatenated tensor + fused_mlp = megatron_weights + gate, up = torch.chunk(fused_mlp, 2, dim=0) + + else: + # Gather shards from all TP ranks + gathered_shards = self.gather_from_tp_ranks(megatron_weights) + + # Split each shard back into gate and up parts + gate_parts = [] + up_parts = [] + for shard in gathered_shards: + # Each shard is [gate_shard; up_shard] concatenated along dim 0 + # This works for both bias (1D) and weight (2D) tensors + gate_shard, up_shard = torch.chunk(shard, 2, dim=0) + gate_parts.append(gate_shard) + up_parts.append(up_shard) + + # Concatenate all gate parts and all up parts separately + gate = torch.cat(gate_parts, dim=0) + up = torch.cat(up_parts, dim=0) + + if self.is_expert: + gathered_gate_weights_dict = self.gather_from_ep_ranks( + gate, megatron_module, self.hf_param["gate"] + ) + gathered_up_weights_dict = self.gather_from_ep_ranks( + up, megatron_module, self.hf_param["up"] + ) + return {**gathered_gate_weights_dict, **gathered_up_weights_dict} + + return {self.hf_param["gate"]: gate, self.hf_param["up"]: up} + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Return a new *resolved* GatedMLPMapping instance.""" + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + + return type(self)( + resolved_megatron_param, resolved_hf_param["gate"], resolved_hf_param["up"] + ) + + +def merge_qkv_biases( + config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + """Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format. + + Args: + config (TransformerConfig): Transformer configuration. + q (torch.Tensor): Query projection biases [hidden_size]. + k (torch.Tensor): Key projection biases [kv_hidden_size]. + v (torch.Tensor): Value projection biases [kv_hidden_size]. + + Returns: + torch.Tensor: Interleaved QKV biases in Megatron format as 1D tensor. + """ + head_num = config.num_attention_heads + num_query_groups = config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = config.kv_channels or (config.hidden_size // head_num) + + # Reshape biases to expose head dimension + q = q.view(head_num, head_size) + k = k.view(num_query_groups, head_size) + v = v.view(num_query_groups, head_size) + + # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...] + qkv_biases = [] + for i in range(num_query_groups): + qkv_biases.append(q[i * heads_per_group : (i + 1) * heads_per_group, :]) + qkv_biases.append(k[i : i + 1, :]) + qkv_biases.append(v[i : i + 1, :]) + + # Concatenate and flatten back to 1D + qkv = torch.cat(qkv_biases) + return qkv.flatten() + + +def split_qkv_biases( + config: TransformerConfig, qkv: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved QKV bias into separate Q, K, V biases. + + Args: + config (TransformerConfig): Transformer configuration. + qkv (torch.Tensor): Interleaved QKV biases in Megatron format (1D + tensor). + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of (Q, K, V) bias vectors. + """ + head_num = config.num_attention_heads + num_query_groups = config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = config.kv_channels or (config.hidden_size // head_num) + qkv_total_dim = head_num + 2 * num_query_groups + + # Reshape to expose interleaved structure + qkv = qkv.reshape(qkv_total_dim, head_size) + + # Extract Q, K, V from interleaved pattern + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2) + + q = qkv[q_slice].flatten() + k = qkv[k_slice].flatten() + v = qkv[v_slice].flatten() + + return q, k, v + + +def merge_qkv_weights( + provider: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + """Merge separate Q, K, V weight matrices into Megatron's interleaved QKV format. + + Args: + provider (TransformerConfig): Model configuration provider. + q (torch.Tensor): Query projection weights [hidden_size, hidden_size] or + bias [hidden_size]. + k (torch.Tensor): Key projection weights [kv_hidden_size, hidden_size] + or bias [kv_hidden_size]. + v (torch.Tensor): Value projection weights [kv_hidden_size, + hidden_size] or bias [kv_hidden_size]. + + Returns: + torch.Tensor: Interleaved QKV weights in Megatron format. + """ + head_num = provider.num_attention_heads + num_query_groups = provider.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // head_num) + hidden_size = provider.hidden_size + is_bias = q.ndim == 1 + + # Reshape to expose head dimension + if is_bias: + q_reshaped = q.view(head_num, head_size) + k_reshaped = k.view(num_query_groups, head_size) + v_reshaped = v.view(num_query_groups, head_size) + else: + q_reshaped = q.view(head_num, head_size, hidden_size) + k_reshaped = k.view(num_query_groups, head_size, hidden_size) + v_reshaped = v.view(num_query_groups, head_size, hidden_size) + + # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...] + qkv_weights = [] + for i in range(num_query_groups): + q_group = q_reshaped[i * heads_per_group : (i + 1) * heads_per_group] + k_group = k_reshaped[i : i + 1] + v_group = v_reshaped[i : i + 1] + qkv_weights.extend([q_group, k_group, v_group]) + + qkv = torch.cat(qkv_weights, dim=0) + + # Final reshape + if is_bias: + return qkv.reshape(-1) + else: + return qkv.reshape([-1, hidden_size]) + + +def split_qkv_weights( + provider: TransformerConfig, qkv: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved QKV tensor into separate Q, K, V matrices. + + Args: + provider (TransformerConfig): Model configuration provider. + qkv (torch.Tensor): Interleaved QKV weights in Megatron format. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of (Q, K, V) + weight matrices. + """ + head_num = provider.num_attention_heads + num_query_groups = provider.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // head_num) + qkv_total_dim = head_num + 2 * num_query_groups + is_bias = qkv.ndim == 1 + + if is_bias: + hidden_size = 1 + qkv_reshaped = qkv.view(qkv_total_dim, head_size) + else: + hidden_size = qkv.shape[-1] + qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size) + + # Extract Q, K, V from interleaved pattern + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2) + + q = qkv_reshaped[q_slice] + k = qkv_reshaped[k_slice] + v = qkv_reshaped[v_slice] + + if is_bias: + q = q.reshape(-1) + k = k.reshape(-1) + v = v.reshape(-1) + else: + q = q.reshape(-1, hidden_size) + k = k.reshape(-1, hidden_size) + v = v.reshape(-1, hidden_size) + + return q, k, v diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/utils.py b/flagscale/train/megatron/nemo_bridge/models/conversion/utils.py new file mode 100644 index 0000000000..66d68aee66 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/utils.py @@ -0,0 +1,287 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import copy +import functools +import re +import types + +from typing import Iterable, List, Optional, Tuple + +import torch + +from rich.table import Table +from transformers.configuration_utils import PretrainedConfig + +from megatron.core.transformer.module import MegatronModule +from megatron.core.utils import unwrap_model + + +def weights_verification_table(bridge, megatron_model) -> Table: + """ + Returns a table comparing weights between a Hugging Face model and a Megatron-LM model. + + Args: + bridge (AutoBridge): The bridge object containing model information. + megatron_model: The Megatron-LM model instance. + + Returns: + Table: A rich Table object with the comparison. + """ + table = Table(title="Hugging Face Weights Verification") + table.add_column("Weight Name", style="cyan") + table.add_column("Shape") + table.add_column("DType") + table.add_column("Device") + table.add_column("Matches Original", justify="center") + + # Check each weight against the original HF-model + for name, param in bridge.export_hf_weights(megatron_model, show_progress=True): + original_param = bridge.hf_pretrained.state[name] + table.add_row( + name, + str(tuple(param.shape)), + str(param.dtype).replace("torch.", ""), + str(param.device), + "✅" if torch.allclose(param, original_param.to(param.device), atol=1e-6) else "❌", + ) + + return table + + +def get_module_and_param_from_name( + models: MegatronModule | List[MegatronModule], param_name: str, vp_stage: Optional[int] = None +) -> Tuple[torch.nn.Module, torch.Tensor] | Tuple[torch.nn.Module, torch.Tensor, Tuple]: + """ + Get parameter from specific VP stage, ensuring that parameter + attributes are preserved. Supports both absolute and relative parameter names. + + Args: + models: List of Megatron model instances or a submodule + param_name: Dot-separated parameter name (can be absolute or relative to models) + vp_stage: Virtual pipeline stage index (None for single stage) + + Returns: + Tuple of (module, parameter) where module owns the parameter + + Raises: + ValueError: If vp_stage is out of range or parameter doesn't exist + + Examples: + Basic usage with full model: + >>> module, param = get_module_and_param_from_name( + ... models=full_model, + ... param_name="transformer.layers.0.attention.query.weight" + ... ) + + Usage with model list and VP stage: + >>> module, param = get_module_and_param_from_name( + ... models=[model1, model2, model3], + ... param_name="layers.0.mlp.dense.bias", + ... vp_stage=1 + ... ) + + Usage with submodule and relative path: + >>> linear_module = model.transformer.layers[0].mlp.dense + >>> module, param = get_module_and_param_from_name( + ... models=linear_module, + ... param_name="weight" + ... ) + + Usage with submodule and absolute path (automatic suffix matching): + >>> linear_module = model.transformer.layers[0].mlp.dense + >>> module, param = get_module_and_param_from_name( + ... models=linear_module, + ... param_name="transformer.layers.0.mlp.dense.weight" + ... ) + # Automatically matches "weight" suffix and returns the parameter + + Edge case with partial path matching: + >>> attention_module = model.transformer.layers[0].attention + >>> module, param = get_module_and_param_from_name( + ... models=attention_module, + ... param_name="layers.0.attention.query.weight" + ... ) + # Matches "query.weight" suffix within the attention module + """ + + if isinstance(models, list): + if vp_stage is None: + model = models[0] + else: + if vp_stage >= len(models): + raise ValueError(f"VP stage {vp_stage} out of range (max: {len(models) - 1})") + model = models[vp_stage] + else: + model = models + + module = unwrap_model(model) + splitted_name = param_name.split(".") + + # Try to find the parameter using the given parts + def try_get_param(parts): + param = module + temp_module = module + + for i, part in enumerate(parts): + if not hasattr(param, part): + return None + param = getattr(param, part) + if i < len(parts) - 1: + temp_module = getattr(temp_module, part) + + return temp_module, param + + # First try the full parameter name (current behavior) + result = try_get_param(splitted_name) + if result is not None: + return result + + # If full name doesn't work, try suffixes of the parameter name + # This handles cases where models is a submodule but param_name is absolute + for start_idx in range(1, len(splitted_name)): + suffix_parts = splitted_name[start_idx:] + result = try_get_param(suffix_parts) + if result is not None: + return result + + # If no approach works, raise an error + raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}") + + +def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0): + """Remove non-pickleable objects from a configuration object recursively. + + This utility function identifies and removes objects that cannot be pickled for + inter-process communication, including functions, bound methods, partial + functions, and other problematic callables. + + Args: + obj: The object to clean + max_depth: Maximum recursion depth (default: 2) + current_depth: Current recursion depth (internal use) + + Returns: + The cleaned object with non-pickleables removed + """ + + # Stop recursion if max depth reached + if current_depth >= max_depth: + return obj + + # Handle None + if obj is None: + return obj + + # Check if object is a problematic callable + if callable(obj): + # Allow classes/types but remove function objects, methods, partials + if isinstance(obj, type): + return obj + elif hasattr(obj, "__call__") and ( + isinstance(obj, (types.FunctionType, types.MethodType, functools.partial)) + or hasattr(obj, "__self__") + ): # bound methods + return None + + # Handle dataclass/object with attributes + if hasattr(obj, "__dict__"): + # Create a copy to avoid modifying the original + cleaned_obj = copy.copy(obj) + + for attr_name in list(vars(cleaned_obj).keys()): + attr_value = getattr(cleaned_obj, attr_name) + + # Recursively clean attribute + cleaned_value = remove_non_pickleables(attr_value, max_depth, current_depth + 1) + + # Set the cleaned value (or None if it was removed) + setattr(cleaned_obj, attr_name, cleaned_value) + + return cleaned_obj + + # Handle lists + elif isinstance(obj, list): + return [remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj] + + # Handle tuples + elif isinstance(obj, tuple): + return tuple(remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj) + + # Handle dictionaries + elif isinstance(obj, dict): + return { + key: remove_non_pickleables(value, max_depth, current_depth + 1) + for key, value in obj.items() + } + + # For primitive types and other safe objects, return as-is + return obj + + +def extract_sort_key(param_name: str): + """Extract sorting key based on layer and expert numbers.""" + + # Extract at most 2 numbers: layer number and expert number + # Pattern: *layers.d+.*d+ (layer number and potentially expert number) + numbers = [] + # Find layer number + layer_match = re.search(r"layers\.(\d+)", param_name) + if layer_match: + numbers.append(int(layer_match.group(1))) + # Find expert number after bias or weight + expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) + if expert_match: + numbers.append(int(expert_match.group(1))) + # Pad to ensure consistent comparison (max 2 numbers) + while len(numbers) < 2: + numbers.append(-1) + numbers = numbers[:2] # Keep at most 2 numbers + return numbers, param_name + + +def get_causal_lm_class_via_auto_map( + model_name_or_path: str, config: PretrainedConfig +) -> type | None: + """Return CausalLM class via config.auto_map if available; otherwise None. + + If auto_map["AutoModelForCausalLM"] is present in the config, returns the dynamically loaded class. + Returns None when auto_map is absent or loading fails. Does not download weights. + """ + auto_map = getattr(config, "auto_map", None) + if auto_map and "AutoModelForCausalLM" in auto_map: + auto_map_class = auto_map["AutoModelForCausalLM"] + repo_id = model_name_or_path or getattr(config, "_name_or_path", None) + if not repo_id: + return None + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + return get_class_from_dynamic_module( + class_reference=auto_map_class, + pretrained_model_name_or_path=repo_id, + cache_dir=None, + force_download=False, + resume_download=True, + proxies=None, + use_auth_token=None, + revision=None, + local_files_only=False, + repo_id=repo_id, + ) + except Exception: + return None + + return None + + +def persistent_buffers(model: torch.nn.Module) -> Iterable[Tuple[str, torch.Tensor]]: + """Return an iterator over persistent module buffers, yielding both the name of the buffer as well as the buffer itself.""" + + for mod_prefix, mod in model.named_modules(): + # only local buffers; we'll add the prefix ourselves + for local_name, buffer in mod.named_buffers(recurse=False): + if local_name not in getattr(mod, "_non_persistent_buffers_set", set()): + full_name = f"{mod_prefix + '.' if mod_prefix else ''}{local_name}" + yield full_name, buffer diff --git a/flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py b/flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py new file mode 100644 index 0000000000..744d700763 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from megatron.nemo_bridge.models.decorators.dispatch import dispatch +#from megatron.nemo_bridge.models.decorators.torchrun import torchrun_main + +#__all__ = ["dispatch", "torchrun_main"] +__all__ = ["dispatch"] diff --git a/flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py b/flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py new file mode 100644 index 0000000000..7e02855d66 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py @@ -0,0 +1,348 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +"""Simplified dispatch system for Python, based on classes' typeclass implementation. + +This module provides a dispatch-based polymorphism system allowing extensible +behavior for different types using the `impl` decorator. +""" + +from functools import _find_impl # type: ignore +from typing import Any, Callable, Dict, Optional, TypeVar + +_SignatureType = TypeVar("_SignatureType", bound=Callable) + + +class _Dispatch: + """Internal dispatch representation with type-based routing logic.""" + + __slots__ = ("_signature", "_name", "_exact_types", "_dispatch_cache", "_doc", "_module") + + def __init__(self, signature: Callable) -> None: + self._signature = signature + self._name = signature.__name__ + self._exact_types: Dict[Any, Callable] = {} + self._dispatch_cache: Dict[Any, Callable] = {} + + # Extract docstring and module info for rich repr + self._doc = signature.__doc__ + self._module = signature.__module__ + + def __call__(self, instance: Any, *args, **kwargs) -> Any: + """Dispatch to the appropriate implementation based on instance type.""" + # Special case for tuple-based keys. + if isinstance(instance, tuple): + key = tuple(v if isinstance(v, (type, str)) else type(v) for v in instance) + + # Direct match + impl = self._exact_types.get(key) + if impl is not None: + # NOTE: This path is not cached for simplicity + return impl(instance, *args, **kwargs) + + # Subclass match for tuples of types + for registered_key, callback in self._exact_types.items(): + if ( + not isinstance(registered_key, tuple) + or len(registered_key) != len(key) + or not all(isinstance(t, type) for t in registered_key) + ): + continue + + try: + # For subclass checks, operate on the instance types only + key_types = tuple(v if isinstance(v, type) else type(v) for v in instance) + if all(issubclass(k, rk) for k, rk in zip(key_types, registered_key)): + # NOTE: not caching tuple subclass matches for simplicity + return callback(instance, *args, **kwargs) + except TypeError: + continue # issubclass can fail + + # Normalize both sides to names so tuples of types and/or strings can match. + def _name(obj): + return obj if isinstance(obj, str) else getattr(obj, "__name__", None) or str(obj) + + key_names = tuple(_name(v) for v in key) + for registered_key, callback in self._exact_types.items(): + if not isinstance(registered_key, tuple) or len(registered_key) != len(key): + continue + reg_names = tuple(_name(rk) for rk in registered_key) + if reg_names == key_names: + return callback(instance, *args, **kwargs) + + # No implementation found for this tuple, raise a specific error. + error_msg = self._format_no_implementation_error(instance) + raise NotImplementedError(error_msg) + + # For class dispatch, we use the class (or string of class name) itself as the key + if isinstance(instance, type): + cache_key = instance + instance_type = instance + elif isinstance(instance, str): + cache_key = instance + instance_type = str + else: + cache_key = type(instance) + instance_type = cache_key + + # Try cache + impl = self._dispatch_cache.get(cache_key) + if impl is None: + impl = self._dispatch(instance, instance_type) + if impl is None: + error_msg = self._format_no_implementation_error(instance) + raise NotImplementedError(error_msg) + self._dispatch_cache[cache_key] = impl + + return impl(instance, *args, **kwargs) + + def impl(self, *target_types: Any) -> Callable[[Callable], Callable]: + """Register an implementation for one or more types. + + Usage: + @mydispatch.impl(int) # Register for a single type + @mydispatch.impl(int, str) # Register for multiple types + @mydispatch.impl((list, str)) # Register for a tuple of types as a key + """ + if not target_types: + raise ValueError( + "\n✗ Missing argument to .impl()\n\n" + "You must specify at least one target type.\n\n" + "Examples:\n" + f" @{self._name}.impl(str) # Single type\n" + f" @{self._name}.impl(int, float) # Multiple types\n" + f" @{self._name}.impl((list, str)) # Tuple key\n" + ) + + def decorator(func: Callable) -> Callable: + if len(target_types) == 1: + # This handles both `@impl(int)` and `@impl((int, str))` + self._exact_types[target_types[0]] = func + else: + # This handles `@impl(int, str)` + for typ in target_types: + self._exact_types[typ] = func + + self._dispatch_cache.clear() + return func + + return decorator + + def __repr__(self) -> str: + """Rich representation showing all implementations.""" + # Build signature string + import inspect + + sig = inspect.signature(self._signature) + sig_str = f"{self._name}{sig}" + + lines = [f"Dispatch({sig_str})("] + + # Add regular implementations + for typ, impl in self._exact_types.items(): + if isinstance(typ, tuple): + type_name = ( + f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})" + ) + else: + type_name = typ.__name__ if hasattr(typ, "__name__") else str(typ) + impl_loc = self._format_location(impl) + lines.append(f" ({type_name}): {impl.__name__} at {impl_loc}") + + lines.append(")") + return "\n".join(lines) + + def _dispatch(self, instance: Any, instance_type: type) -> Optional[Callable]: + """Find the implementation for a given type. + + Fallback order: + 1) Exact type match + 2) issubclass match (when instance is a type) + 3) MRO-based match via functools._find_impl + 4) Name-based fallback: match by class __name__ for dynamically generated + classes (e.g., HF transformers auto_map dynamic modules) + """ + # Direct type match + impl = self._exact_types.get(instance_type, None) + if impl is not None: + return impl + + # For class dispatch, check issubclass relationships + if isinstance(instance, type): + for registered_type, callback in self._exact_types.items(): + if not isinstance(registered_type, type): + continue + try: + if issubclass(instance, registered_type): + return callback + except TypeError: + # issubclass can fail for some types + pass + + # Use functools._find_impl for MRO-based dispatch, only for single types + single_type_impls = {k: v for k, v in self._exact_types.items() if isinstance(k, type)} + impl = _find_impl(instance_type, single_type_impls) + if impl is not None: + return impl + + # Name-based fallback for dynamic HF classes and string registrations. + def _name(obj): + return obj if isinstance(obj, str) else getattr(obj, "__name__", None) + + if isinstance(instance, str): + inst_name = instance + elif isinstance(instance, type): + inst_name = _name(instance) + else: + inst_name = _name(type(instance)) + + if inst_name: + for registered_type, callback in self._exact_types.items(): + reg_name = _name(registered_type) + if reg_name and str(reg_name) == inst_name: + return callback + + return None + + def _format_location(self, func: Callable) -> str: + """Format the location of a function for display.""" + try: + import inspect + + filename = inspect.getfile(func) + _, lineno = inspect.getsourcelines(func) + # Shorten the path to be more readable + import os + + filename = os.path.relpath(filename) + return f"{filename}:{lineno}" + except Exception: + return "" + + def _format_no_implementation_error(self, instance: Any) -> str: + """Format a helpful error message when no implementation is found.""" + type_name_for_header: str + type_name_for_suggestion: str + type_name_for_func: str + instance_type_hint: str + + if isinstance(instance, tuple): + instance_types = tuple(v if isinstance(v, type) else type(v) for v in instance) + type_names_str = ", ".join( + t.__qualname__ if hasattr(t, "__qualname__") else str(t) for t in instance_types + ) + type_name_for_header = f"tuple of types ({type_names_str})" + + suggestion_names = ", ".join( + t.__name__ if hasattr(t, "__name__") else str(t) for t in instance_types + ) + type_name_for_suggestion = f"({suggestion_names})" + type_name_for_func = "tuple" + instance_type_hint = f"Tuple[{', '.join(t.__name__ for t in instance_types)}]" + else: + instance_type = instance if isinstance(instance, type) else type(instance) + qualname = ( + instance_type.__qualname__ + if hasattr(instance_type, "__qualname__") + else str(instance_type) + ) + type_name_for_header = f"type '{qualname}'" + type_name_for_suggestion = ( + instance_type.__name__ if hasattr(instance_type, "__name__") else str(instance_type) + ) + type_name_for_func = type_name_for_suggestion.lower().replace(".", "_") + instance_type_hint = type_name_for_suggestion + + # Build error message + lines = [ + f"\n✗ No implementation found for {type_name_for_header}", + "", + f"The dispatch function '{self._name}' has no implementation for this type.", + "", + ] + + # Add available implementations + if self._exact_types: + lines.append("Available implementations:") + + # Add registered types + sorted_keys = sorted(self._exact_types.keys(), key=str) + for typ in sorted_keys: + if isinstance(typ, tuple): + type_display = f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})" + else: + type_display = typ.__name__ if hasattr(typ, "__name__") else str(typ) + lines.append(f" • {type_display}") + else: + lines.append("No implementations registered yet.") + + # Generate help based on existing implementations + if self._exact_types: + # Get a sample implementation to show the pattern + _, sample_impl = next(iter(self._exact_types.items())) + + lines.extend( + [ + "", + "To add support for this type, register an implementation:", + f" @{self._name}.impl({type_name_for_suggestion})", + f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}) -> ...:", + " # Your implementation here", + ] + ) + + # Try to extract parameter info from the sample implementation + import inspect + + try: + sig = inspect.signature(sample_impl) + params = list(sig.parameters.keys())[1:] # Skip first param (instance) + if params: + param_hints = ", ".join(params) + lines.append(f" # Expected parameters: {param_hints}") + except Exception: + pass + else: + lines.extend( + [ + "", + "To add support for this type:", + f" @{self._name}.impl({type_name_for_suggestion})", + f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}, ...) -> ...:", + " # Your implementation here", + ] + ) + + return "\n".join(lines) + + +def dispatch(func: _SignatureType) -> _Dispatch: + """Create a new dispatch function from a signature. + + Args: + func: Function defining the dispatch signature and default behavior + + Returns: + A dispatch object that can be extended with implementations + + Example: + >>> @dispatch + ... def to_string(instance) -> str: + ... '''Convert instance to string representation.''' + ... + >>> @to_string.impl(int) + ... def _to_string_int(instance: int) -> str: + ... return str(instance) + ... + >>> @to_string.impl(list, tuple) + ... def _to_string_sequence(instance) -> str: + ... return ', '.join(map(str, instance)) + ... + >>> assert to_string(42) == "42" + >>> assert to_string([1, 2, 3]) == "1, 2, 3" + """ + return _Dispatch(func) + + +__all__ = ["dispatch"] diff --git a/flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py b/flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py new file mode 100644 index 0000000000..80fa77dcc5 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import os +import traceback + +from functools import wraps + +import torch + +from torch.distributed.elastic.multiprocessing.errors import record + + +def torchrun_main(fn): + """ + A decorator that wraps the main function of a torchrun script. It uses + the `torch.distributed.elastic.multiprocessing.errors.record` decorator + to record any exceptions and ensures that the distributed process group + is properly destroyed on successful completion. In case of an exception, + it prints the traceback and performs a hard exit, allowing torchrun to + terminate all other processes. + """ + recorded_fn = record(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return_value = recorded_fn(*args, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + return return_value + except Exception: + # The 'record' decorator might only log the exception to a file. + # Print it to stderr as well to make sure it's visible. + traceback.print_exc() + # Use os._exit(1) for a hard exit. A regular sys.exit(1) might + # not be enough to terminate a process stuck in a bad C++ state + # (e.g., after a NCCL error), which can cause the job to hang. + os._exit(1) + + return wrapper diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py new file mode 100644 index 0000000000..f2b27048b5 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from megatron.nemo_bridge.models.deepseek.deepseek_provider import ( + DeepSeekModelProvider, + DeepSeekProvider, + DeepSeekV2LiteModelProvider, + DeepSeekV2LiteProvider, + DeepSeekV2ModelProvider, + DeepSeekV2Provider, + DeepSeekV3ModelProvider, + DeepSeekV3Provider, + MoonlightModelProvider16B, + MoonlightProvider, +) +from megatron.nemo_bridge.models.deepseek.deepseek_v2_bridge import DeepSeekV2Bridge # noqa: F401 +from megatron.nemo_bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge # noqa: F401 + +__all__ = [ + "DeepSeekModelProvider", + "DeepSeekV2LiteModelProvider", + "DeepSeekV2ModelProvider", + "DeepSeekV3ModelProvider", + "MoonlightModelProvider16B", + "DeepSeekProvider", + "DeepSeekV2LiteProvider", + "DeepSeekV2Provider", + "DeepSeekV3Provider", + "MoonlightProvider", +] diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py new file mode 100644 index 0000000000..b8a660c957 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from megatron.nemo_bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + +try: + import apex # noqa: F401 + + HAVE_APEX = True +except ImportError: + HAVE_APEX = False + + +def get_common_configs(hf_pretrained: PreTrainedCausalLM) -> dict: + """ + Returns a dictionary of common configurations for the DeepSeek family of models. + """ + hf_config = hf_pretrained.config + + configs = {} + + if not HAVE_APEX: + configs["gradient_accumulation_fusion"] = False + + if hasattr(hf_config, "rope_scaling") and hf_config.rope_scaling is not None: + configs["rotary_scaling_factor"] = hf_config.rope_scaling["factor"] + configs["mscale"] = hf_config.rope_scaling["mscale"] + configs["mscale_all_dim"] = hf_config.rope_scaling["mscale_all_dim"] + else: + configs["rotary_scaling_factor"] = 1.0 + configs["mscale"] = 1.0 + configs["mscale_all_dim"] = 1.0 + + configs["num_layers"] = hf_config.num_hidden_layers + configs["hidden_size"] = hf_config.hidden_size + configs["ffn_hidden_size"] = hf_config.intermediate_size + configs["num_attention_heads"] = hf_config.num_attention_heads + configs["kv_channels"] = hf_config.num_key_value_heads + configs["q_lora_rank"] = hf_config.q_lora_rank + configs["num_moe_experts"] = hf_config.n_routed_experts + configs["moe_ffn_hidden_size"] = hf_config.moe_intermediate_size + configs["moe_shared_expert_intermediate_size"] = ( + hf_config.moe_intermediate_size * hf_config.n_shared_experts + ) + configs["moe_layer_freq"] = [0] * hf_config.first_k_dense_replace + [1] * ( + hf_config.num_hidden_layers - hf_config.first_k_dense_replace + ) + configs["moe_router_topk"] = hf_config.num_experts_per_tok + configs["moe_router_num_groups"] = hf_config.n_group + configs["moe_router_group_topk"] = hf_config.topk_group + configs["moe_router_topk_scaling_factor"] = hf_config.routed_scaling_factor + configs["kv_lora_rank"] = hf_config.kv_lora_rank + configs["qk_head_dim"] = hf_config.qk_nope_head_dim + configs["qk_pos_emb_head_dim"] = hf_config.qk_rope_head_dim + configs["v_head_dim"] = hf_config.v_head_dim + + # Ensure MLA is enabled + configs["multi_latent_attention"] = True + configs["generation_config"] = hf_pretrained.generation_config + configs["vocab_size"] = hf_config.vocab_size + configs["rotary_base"] = hf_config.rope_theta + configs["init_method_std"] = hf_config.initializer_range + configs["layernorm_epsilon"] = hf_config.rms_norm_eps + + return configs + + +def get_common_mapping_list() -> list: + """ + Returns a list of common parameter mappings for the DeepSeek family of models. + """ + param_mappings = { + # Embed + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + # Attention + "decoder.layers.*.input_layernorm.weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + # Reference: https://github.com/NVIDIA/NeMo/blob/50cceb9c90ea1f440d1e14074fa13bd45f60a1c4/nemo/collections/llm/gpt/model/deepseek.py#L637-L650 + # In deepseek, HF weight `model.layers.*.post_attention_layernorm.weight` is mapped to the following mcore weights depending on the layer type: + # (a) `decoder.layers.*.pre_mlp_layernorm.weight`, if the layer is MoE + # (b) `decoder.layers.*.mlp.linear_fc1.layer_norm_weight`, if the layer is dense + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.linear_kv_down_proj.weight": "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", + "decoder.layers.*.self_attention.linear_kv_up_proj.weight": "model.layers.*.self_attn.kv_b_proj.weight", + "decoder.layers.*.self_attention.linear_kv_up_proj.layer_norm_weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + # Mcore local spec + "decoder.layers.*.self_attention.kv_layernorm.weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + # Dense MLP + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + # MoE + "decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight", + "decoder.layers.*.mlp.experts.linear_fc2.weight*": "model.layers.*.mlp.experts.*.down_proj.weight", + "decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "model.layers.*.mlp.shared_experts.down_proj.weight", + # LM Head + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + # MLA + "decoder.layers.*.self_attention.linear_q_down_proj.weight": "model.layers.*.self_attn.q_a_proj.weight", + "decoder.layers.*.self_attention.linear_q_up_proj.weight": "model.layers.*.self_attn.q_b_proj.weight", + "decoder.layers.*.self_attention.linear_q_up_proj.layer_norm_weight": "model.layers.*.self_attn.q_a_layernorm.weight", + # Mcore local spec + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_a_layernorm.weight", + # For models without MLA + "decoder.layers.*.self_attention.linear_q_proj.weight": "model.layers.*.self_attn.q_proj.weight", + } + + # TODO: mtp layers + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + mapping_list.extend( + [ + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="model.layers.*.mlp.experts.*.gate_proj.weight", + up="model.layers.*.mlp.experts.*.up_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.shared_experts.linear_fc1.weight", + gate="model.layers.*.mlp.shared_experts.gate_proj.weight", + up="model.layers.*.mlp.shared_experts.up_proj.weight", + ), + ] + ) + + return mapping_list diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py new file mode 100644 index 0000000000..f429df4279 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py @@ -0,0 +1,309 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import warnings + +from dataclasses import dataclass, field +from functools import partial +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +import torch +import torch.nn.functional as F + +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec + +from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider +from megatron.nemo_bridge.models.transformer_config import MLATransformerConfig +from megatron.nemo_bridge.utils.common_utils import get_rank_safe + +try: + import transformer_engine # type: ignore # noqa: F401 + + HAVE_TE = True +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + +if TYPE_CHECKING: + from megatron.core.transformer import ModuleSpec + +if HAVE_TE: + from megatron.core.utils import is_te_min_version + + +@dataclass +class DeepSeekModelProvider(MLATransformerConfig, GPTModelProvider): + """ + Base config for DeepSeek V2 and V3 models. + """ + + transformer_layer_spec: Union["ModuleSpec", Callable[["GPTModelProvider"], "ModuleSpec"]] = ( + partial(get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE) + ) + + # Model + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True # swiglu + position_embedding_type: str = "rope" + add_bias_linear: bool = False + share_embeddings_and_output_weights: bool = False + num_attention_heads: int = 128 + kv_channels: int = 128 + max_position_embeddings: int = 4096 + seq_length: int = 4096 + rotary_base: float = 10000.0 + make_vocab_size_divisible_by: int = 3200 + mtp_num_layers: Optional[int] = None + mtp_loss_scaling_factor: Optional[float] = None + + # Regularization + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + qk_layernorm: bool = True + + # MoE + moe_grouped_gemm: bool = True + moe_router_pre_softmax: bool = True + moe_token_dispatcher_type: str = "alltoall" + moe_router_load_balancing_type: str = "seq_aux_loss" + moe_shared_expert_overlap: bool = True + moe_router_dtype: Optional[str] = "fp32" + + # MLA + q_lora_rank: int = 1536 + kv_lora_rank: int = 512 + qk_head_dim: int = 128 + qk_pos_emb_head_dim: int = 64 + v_head_dim: int = 128 + rotary_scaling_factor: float = 40 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + + # Miscellaneous + init_method_std: float = 0.006 + layernorm_epsilon: float = 1e-6 + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 + async_tensor_model_parallel_allreduce: bool = True + attention_softmax_in_fp32: bool = False + persist_layer_norm: bool = True + num_layers_in_first_pipeline_stage: Optional[int] = None + num_layers_in_last_pipeline_stage: Optional[int] = None + account_for_embedding_in_pipeline_split: bool = False + account_for_loss_in_pipeline_split: bool = False + + # MLA specific + multi_latent_attention: bool = True + + # fusions + apply_rope_fusion: bool = False + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + masked_softmax_fusion: bool = True + cross_entropy_loss_fusion: bool = True + cross_entropy_fusion_impl: str = "te" + moe_permute_fusion: bool = is_te_min_version("2.1.0") if HAVE_TE else False + + +@dataclass +class DeepSeekV2ModelProvider(DeepSeekModelProvider): + """ + DeepSeek-V2 Model: https://github.com/deepseek-ai/DeepSeek-V2 + """ + + num_layers: int = 60 + hidden_size: int = 5120 + ffn_hidden_size: int = 12288 + num_moe_experts: int = 160 + moe_ffn_hidden_size: int = 1536 + moe_shared_expert_intermediate_size: int = 3072 # 1536 * 2 shared experts + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] + [1] * 59 + ) # first layer is dense + moe_router_topk: int = 6 + moe_router_num_groups: int = 8 + moe_router_group_topk: int = 3 + moe_router_topk_scaling_factor: float = 16.0 + moe_aux_loss_coeff: float = 1e-3 + mscale: float = 0.707 + mscale_all_dim: float = 0.707 + vocab_size: int = 102400 + + +@dataclass +class DeepSeekV2LiteModelProvider(DeepSeekV2ModelProvider): + """ + DeepSeek-V2-Lite Model: https://github.com/deepseek-ai/DeepSeek-V2 + HuggingFace: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite + """ + + num_layers: int = 27 + hidden_size: int = 2048 + ffn_hidden_size: int = 10944 + num_attention_heads: int = 16 + kv_channels: int = 16 + q_lora_rank: int = None + num_moe_experts: int = 64 + moe_ffn_hidden_size: int = 1408 + moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared experts + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] + [1] * 26 + ) # first layer is dense + moe_router_topk: int = 6 + moe_router_num_groups: int = 1 + moe_router_group_topk: int = 1 + moe_router_topk_scaling_factor: float = 1.0 + vocab_size: int = 102400 + + +@dataclass +class DeepSeekV3ModelProvider(DeepSeekModelProvider): + """ + DeepSeek-V3 Model: https://github.com/deepseek-ai/DeepSeek-V3 + """ + + num_layers: int = 61 + hidden_size: int = 7168 + ffn_hidden_size: int = 18432 + num_moe_experts: int = 256 + moe_ffn_hidden_size: int = 2048 + moe_shared_expert_intermediate_size: int = 2048 # 2048 * 1 shared expert + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] * 3 + [1] * 58 + ) # first three layers are dense + moe_router_topk: int = 8 + moe_router_num_groups: int = 8 + moe_router_group_topk: int = 4 + moe_router_topk_scaling_factor: float = 2.5 + make_vocab_size_divisible_by: int = 1280 + moe_router_score_function: str = "sigmoid" + moe_router_enable_expert_bias: bool = True + moe_router_bias_update_rate: float = 1e-3 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + vocab_size: int = 129280 + + +@dataclass +class MoonlightModelProvider16B(DeepSeekModelProvider): + """ + Moonlight-16B-A3B Model: https://github.com/moonshotai/Moonlight-16B-A3B + + Moonlight is based on DeepSeek-V3. + """ + + max_position_embeddings: int = 4096 + num_layers: int = 27 + hidden_size: int = 2048 + ffn_hidden_size: int = 11264 + num_attention_heads: int = 16 + kv_channels: int = 16 + num_moe_experts: int = 64 + moe_ffn_hidden_size: int = 1408 + moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared expert + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] * 1 + [1] * 26 + ) # first layer is dense + moe_router_topk: int = 6 + moe_router_num_groups: int = 1 + moe_router_group_topk: int = 1 + moe_router_topk_scaling_factor: float = 2.446 + moe_aux_loss_coeff: float = 0.001 + make_vocab_size_divisible_by: int = 1280 + moe_router_score_function: str = "sigmoid" + moe_router_enable_expert_bias: bool = True + rotary_scaling_factor: float = 1.0 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + rotary_base: float = 50000 + layernorm_epsilon: float = 1e-5 + q_lora_rank: int = None + init_method_std: float = 0.02 + moe_router_bias_update_rate: float = 1e-3 + rotary_percent: float = 1.0 + vocab_size: int = 163840 + + +# ----------------------------------------------------------------------------- +# Deprecated aliases (to be removed in a future release) +# ----------------------------------------------------------------------------- + + +def _warn_deprecated(old_cls: str, new_cls: str) -> None: + if get_rank_safe() == 0: + warnings.warn( + f"{old_cls} is deprecated and will be removed in a future release. Use {new_cls} instead.", + DeprecationWarning, + stacklevel=2, + ) + + +@dataclass +class DeepSeekProvider(DeepSeekModelProvider): + """Deprecated alias for ``DeepSeekModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekProvider", "DeepSeekModelProvider") + super().__post_init__() + + +@dataclass +class DeepSeekV2Provider(DeepSeekV2ModelProvider): + """Deprecated alias for ``DeepSeekV2ModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekV2ModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekV2Provider", "DeepSeekV2ModelProvider") + super().__post_init__() + + +@dataclass +class DeepSeekV2LiteProvider(DeepSeekV2LiteModelProvider): + """Deprecated alias for ``DeepSeekV2LiteModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekV2LiteModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekV2LiteProvider", "DeepSeekV2LiteModelProvider") + super().__post_init__() + + +@dataclass +class DeepSeekV3Provider(DeepSeekV3ModelProvider): + """Deprecated alias for ``DeepSeekV3ModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekV3ModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekV3Provider", "DeepSeekV3ModelProvider") + super().__post_init__() + + +@dataclass +class MoonlightProvider(MoonlightModelProvider16B): + """Deprecated alias for ``MoonlightModelProvider16B``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``MoonlightModelProvider16B`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("MoonlightProvider", "MoonlightModelProvider16B") + super().__post_init__() diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py new file mode 100644 index 0000000000..75b72847b7 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, +) +from megatron.nemo_bridge.models.deepseek.deepseek_provider import DeepSeekV2ModelProvider +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + + +@MegatronModelBridge.register_bridge(source="DeepseekV2ForCausalLM", target=GPTModel) +class DeepSeekV2Bridge(MegatronModelBridge): + """ + Megatron Bridge for DeepSeek-V2. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V2", trust_remote_code=True) + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV2ModelProvider: + hf_config = hf_pretrained.config + configs = get_common_configs(hf_pretrained) + + configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16 + configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16 + configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32) + + configs["make_vocab_size_divisible_by"] = 3200 + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + provider = DeepSeekV2ModelProvider(**configs) + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + mapping_list = get_common_mapping_list() + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py new file mode 100644 index 0000000000..7c19cfb0ab --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import AutoMapping +from megatron.nemo_bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, +) +from megatron.nemo_bridge.models.deepseek.deepseek_provider import DeepSeekV3ModelProvider +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + + +@MegatronModelBridge.register_bridge(source="DeepseekV3ForCausalLM", target=GPTModel) +class DeepSeekV3Bridge(MegatronModelBridge): + """ + Megatron Bridge for DeepSeek-V3. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V3-Base", trust_remote_code=True) + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV3ModelProvider: + hf_config = hf_pretrained.config + configs = get_common_configs(hf_pretrained) + + configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16 + configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16 + configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32) + + configs["make_vocab_size_divisible_by"] = 1280 + configs["moe_router_score_function"] = "sigmoid" + configs["moe_router_enable_expert_bias"] = True + # aux_loss_alpha is not set in all DSv3 HF configs + if hasattr(hf_config, "aux_loss_alpha"): + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + # TODO: mtp + + provider = DeepSeekV3ModelProvider(**configs) + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + mapping_list = get_common_mapping_list() + + param_mappings = { + # expert bias + "decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias" + } + + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py b/flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py new file mode 100644 index 0000000000..7409349ce5 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from importlib.metadata import version +from typing import Any, Callable, Optional, Union + +import packaging +import torch + +from transformer_engine.pytorch import TransformerLayer + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.transformer.cuda_graphs import CudaGraphManager +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_layer import BaseTransformerLayer +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + + +# Copied from nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +class AutocastTransformerLayer(TransformerLayer): + """ + Wrapper of te.pytorch.TransformerLayer: a single transformerlayer + that takes input with size [s, b, h] and returns an output of + the same size. + """ + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + layernorm_epsilon: float, + num_attention_heads: int, + init_method: Callable, + output_layer_init_method: Callable, + hidden_dropout: float, + attention_dropout: float, + layer_number: Optional[int] = None, + kv_channels: Optional[int] = None, + self_attn_mask_type: str = "causal", + tp_group: Optional[Any] = None, + tp_size: int = 1, + params_dtype: torch.dtype = torch.float32, + get_rng_state_tracker: Optional[Callable] = None, + fuse_wgrad_accumulation: bool = False, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + sequence_parallel: bool = False, + apply_residual_connection_post_layernorm: bool = False, + output_layernorm: bool = False, + layer_type: str = "encoder", + drop_path_rate: float = 0, + use_emha: bool = False, + ub_tp_comm_overlap: bool = False, + ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = True, + autocast_dtype: Any = 16, + zero_centered_gamma: bool = False, + device: str = "cuda", + **kwargs, + ) -> None: + transformer_layer_args = { + "hidden_size": hidden_size, + "ffn_hidden_size": ffn_hidden_size, + "layernorm_epsilon": layernorm_epsilon, + "num_attention_heads": num_attention_heads, + "init_method": init_method, + "output_layer_init_method": output_layer_init_method, + "hidden_dropout": hidden_dropout, + "attention_dropout": attention_dropout, + "layer_number": layer_number, + "kv_channels": kv_channels, + "self_attn_mask_type": self_attn_mask_type, + "tp_group": tp_group, + "tp_size": tp_size, + "params_dtype": params_dtype, + "get_rng_state_tracker": get_rng_state_tracker, + "fuse_wgrad_accumulation": fuse_wgrad_accumulation, + "seq_length": seq_length, + "micro_batch_size": micro_batch_size, + "sequence_parallel": sequence_parallel, + "apply_residual_connection_post_layernorm": apply_residual_connection_post_layernorm, + "output_layernorm": output_layernorm, + "layer_type": layer_type, + "drop_path_rate": drop_path_rate, + "set_parallel_mode": tp_size > 1, + "fuse_qkv_params": True, + "zero_centered_gamma": zero_centered_gamma, + "ub_tp_comm_overlap": ub_tp_comm_overlap, + "ub_bulk_wgrad": ub_bulk_wgrad, + "ub_bulk_dgrad": ub_bulk_dgrad, + "device": device, + } + te_version = packaging.version.Version(version("transformer-engine")) + if te_version > packaging.version.Version("1.5.0"): + for comm in ["ag", "rs"]: + ub_overlap_flag = "ub_overlap_" + comm + split_gemm_flag = "ub_split_" + comm + atomic_gemm_flag = "ub_atomic_gemm_" + comm + # Use old overlap flags if they were supplied instead + if ub_overlap_flag in kwargs: + transformer_layer_args[ub_overlap_flag] = kwargs[ub_overlap_flag] + else: + transformer_layer_args[ub_overlap_flag] = kwargs.get( + split_gemm_flag, True + ) or kwargs.get(atomic_gemm_flag, False) + if te_version > packaging.version.Version("1.6.0.dev0"): + transformer_layer_args["ub_overlap_rs_dgrad"] = kwargs.get( + "ub_overlap_rs_dgrad", False + ) + else: + transformer_layer_args["ub_split_ag"] = kwargs.get("ub_split_ag", True) + transformer_layer_args["ub_split_rs"] = kwargs.get("ub_split_rs", True) + transformer_layer_args["ub_atomic_gemm_ag"] = kwargs.get("ub_atomic_gemm_ag", False) + transformer_layer_args["ub_atomic_gemm_rs"] = kwargs.get("ub_atomic_gemm_rs", False) + super().__init__(**transformer_layer_args) + + # Dtype for forward pass + self.dtype = torch_dtype_from_precision(autocast_dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + encoder_output: Optional[torch.Tensor] = None, + enc_dec_attn_mask: Optional[torch.Tensor] = None, + inference_params: Optional[Any] = None, + is_first_microbatch: Optional[bool] = None, + checkpoint_core_attention: Optional[bool] = False, + ) -> torch.Tensor: + """ + Perform a forward pass through the transformer layer. + """ + if self.dtype == torch.float32: + return super().forward( + hidden_states, + attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + is_first_microbatch=is_first_microbatch, + checkpoint_core_attention=checkpoint_core_attention, + ) + with torch.autocast(device_type="cuda", dtype=self.dtype): + return super().forward( + hidden_states, + attention_mask=attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + is_first_microbatch=is_first_microbatch, + checkpoint_core_attention=checkpoint_core_attention, + ) + + +class TETransformerLayerAutocast(MegatronModule, BaseTransformerLayer): # type: ignore + """ + A MegatronModule that wraps the AutocastTransformerLayer. + """ + + def __init__(self, config, layer_number=1, hidden_dropout=None, **kwargs): + super().__init__(config=config) + self.layer_number = layer_number + self._get_layer_offset() + + self.config = config + self.is_first_microbatch = True + precision = "bf16" if config.bf16 else 16 + + transformer_layer_args = { + "hidden_size": config.hidden_size, + "ffn_hidden_size": config.ffn_hidden_size, + "layernorm_epsilon": config.layernorm_epsilon, + "num_attention_heads": config.num_attention_heads, + "init_method": config.init_method, + "output_layer_init_method": config.output_layer_init_method, + "hidden_dropout": config.hidden_dropout, + "attention_dropout": config.attention_dropout, + "layer_number": layer_number + self._get_layer_offset(), + "kv_channels": config.kv_channels, + "tp_size": parallel_state.get_tensor_model_parallel_world_size(), + "params_dtype": config.params_dtype, + "get_rng_state_tracker": tensor_parallel.random.get_cuda_rng_tracker, + "fuse_wgrad_accumulation": config.gradient_accumulation_fusion, + "seq_length": None, # used for jit warmup + "micro_batch_size": None, # used for jit warmup + "sequence_parallel": config.sequence_parallel, + "apply_residual_connection_post_layernorm": config.apply_residual_connection_post_layernorm, + "autocast_dtype": precision, + "ub_tp_comm_overlap": config.tp_comm_overlap, + "ub_bulk_wgrad": config.tp_comm_bulk_wgrad, + "ub_bulk_dgrad": config.tp_comm_bulk_dgrad, + "zero_centered_gamma": config.layernorm_zero_centered_gamma, + "device": "cpu" if config.use_cpu_initialization else "cuda", + } + te_version = packaging.version.Version(version("transformer-engine")) + if te_version > packaging.version.Version("1.5.0"): + # Use old overlap flags if they were supplied instead + transformer_layer_args["ub_overlap_ag"] = ( + config.tp_comm_overlap_ag + if hasattr(config, "tp_comm_overlap_ag") + else config.tp_comm_split_ag or config.tp_comm_atomic_ag + ) + transformer_layer_args["ub_overlap_rs"] = ( + config.tp_comm_overlap_rs + if hasattr(config, "tp_comm_overlap_rs") + else config.tp_comm_split_rs or config.tp_comm_atomic_rs + ) + if te_version > packaging.version.Version("1.6.0.dev0"): + transformer_layer_args["ub_overlap_rs_dgrad"] = ( + config.tp_comm_overlap_rs_dgrad + if hasattr(config, "tp_comm_overlap_rs_dgrad") + else False + ) + else: + transformer_layer_args["ub_split_ag"] = config.tp_comm_split_ag + transformer_layer_args["ub_split_rs"] = config.tp_comm_split_rs + transformer_layer_args["ub_atomic_gemm_ag"] = config.tp_comm_atomic_ag + transformer_layer_args["ub_atomic_gemm_rs"] = config.tp_comm_atomic_rs + self.transformer_layer = AutocastTransformerLayer(**transformer_layer_args) + + if self.config.enable_cuda_graph and self.training: + assert ( + not config.cpu_offloading and config.recompute_granularity is None + ), "Cudagraphs not supported" + self.add_module("cudagraph_manager", CudaGraphManager(config)) + + # Called by MCore's TransformerBlock.forward + # megatron/core/transformer/transformer_block.py + def forward( + self, + hidden_states, + is_first_microbatch=None, + attention_mask=None, + context=None, + context_mask=None, + inference_params=None, + **kwargs, + ): + """Forward function of TETransformerLayerAutocast. Called by MCore's TransformerBlock.forward.""" + # Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise. + hidden_states = self.transformer_layer.forward( + hidden_states, + attention_mask=attention_mask, + encoder_output=context, + enc_dec_attn_mask=context_mask, + inference_params=inference_params, + is_first_microbatch=( + is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch + ), + # checkpoint_core_attention, + ) + self.is_first_microbatch = False + context = None + + # External CUDA graph requires returned values to be Tensors + if ( + hasattr(self.config, "external_cuda_graph") + and self.config.external_cuda_graph + and self.training + ): + return hidden_states + return hidden_states, context + + def _get_layer_offset(self): + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + + num_layers_per_pipeline_rank = ( + self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() + ) + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + total_num_layers = self.config.num_layers + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size + total_virtual_chunks = total_num_layers // vp_size + offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) + + else: + # Each stage gets a contiguous set of layers. + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + offset = pipeline_rank * num_layers_per_pipeline_rank + else: + offset = 0 + + return offset + + def sharded_state_dict(self, prefix: str = "", sharded_offsets: tuple = (), metadata=None): + """Get the sharded state dict for the transformer layer.""" + TENSOR_PARALLEL_LAYERS_AXIS_MAP = { + "self_attention.layernorm_qkv.weight": 0, + "self_attention.layernorm_qkv.bias": 0, + "self_attention.proj.weight": 1, + "layernorm_mlp.fc1_weight": 0, + "layernorm_mlp.fc1_bias": 0, + "layernorm_mlp.fc2_weight": 1, + } + + state_dict = self.state_dict(prefix="", keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + state_dict, prefix, TENSOR_PARALLEL_LAYERS_AXIS_MAP, sharded_offsets + ) + + # TODO: we need to add sharded_state_dict_keys_map to the config. Like in TransformerLayer submodules config + # prefixed_map = { + # f'{prefix}{k}': f'{prefix}{v}' + # for k, v in self.config.sharded_state_dict_keys_map.items() + # } + + # if prefixed_map: + # apply_prefix_mapping(sharded_state_dict, prefixed_map) + + return sharded_state_dict + + def __call__(self, *args, **kwargs): + if hasattr(self, "cudagraph_manager"): + return self.cudagraph_manager(self, args, kwargs) + return super().__call__(*args, **kwargs) + + +# Use this spec to use the full Transformer layer from Transformer Engine +def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: + """Get the ModuleSpec for full Transformer layer from Transformer Engine.""" + num_layers = get_num_layers_to_build(transformer_config) + return TransformerBlockSubmodules( + layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, + layer_norm=FusedLayerNorm, + ) + + +def torch_dtype_from_precision(precision: Union[int, str]) -> torch.dtype: + """Mapping from precision types to corresponding PyTorch parameter datatype.""" + if precision in ("bf16", "bf16-mixed"): + return torch.bfloat16 + elif precision in (16, "16", "16-mixed"): + return torch.float16 + elif precision in (32, "32", "32-true"): + return torch.float32 + else: + raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype") diff --git a/flagscale/train/megatron/nemo_bridge/models/gpt_provider.py b/flagscale/train/megatron/nemo_bridge/models/gpt_provider.py new file mode 100644 index 0000000000..322c661097 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/gpt_provider.py @@ -0,0 +1,430 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import contextlib +import inspect +import logging + +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Literal, Optional, Union + +import torch + +from megatron.core import parallel_state +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec +from megatron.core.transformer import ModuleSpec + +from megatron.nemo_bridge.models.model_provider import ModelProviderMixin +from megatron.nemo_bridge.models.transformer_config import TransformerConfig +from megatron.nemo_bridge.utils import fusions +from megatron.nemo_bridge.utils.vocab_utils import calculate_padded_vocab_size + +logger = logging.getLogger(__name__) + + +def transformer_engine_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Create a Transformer Engine layer specification based on the provided config.""" + if ( + "use_te_op_fuser" + in inspect.signature(get_gpt_layer_with_transformer_engine_spec).parameters + ): + kwargs = {"use_te_op_fuser": config.use_transformer_engine_op_fuser} + else: + kwargs = {} + return get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + fp8=bool(config.num_moe_experts and (config.fp8 is not None)), + **kwargs, + ) + + +def transformer_engine_full_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Create a full Transformer Engine layer specification with autocast support. + + Args: + config: GPT configuration object + + Returns: + ModuleSpec: Module specification for full TE layers + """ + from megatron.nemo_bridge.models.gpt_full_te_layer_autocast_spec import ( + get_gpt_full_te_layer_autocast_spec, + ) + + return get_gpt_full_te_layer_autocast_spec(transformer_config=config) + + +def local_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Create a local layer specification without Transformer Engine. + + Args: + config: GPT configuration object + + Returns: + ModuleSpec: Module specification for local implementation layers + """ + return get_gpt_layer_local_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + normalization=config.normalization, + ) + + +def quantization_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Layer specification for quantization with ModelOpt.""" + return get_gpt_modelopt_spec( + config=config, + local_core_attention=False, + remap_te_layernorm=True, + real_quant_cfg="None", + use_arbitrary_attention_mask=True, + ) + + +def default_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Determine the most appropriate layer specification based on availability.""" + if config.restore_modelopt_state: + return quantization_layer_spec(config) + elif config.use_transformer_engine_full_layer_spec: + return transformer_engine_full_layer_spec(config) + else: + return transformer_engine_layer_spec(config) + + +@dataclass +class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]): + """Configuration and provider for Megatron Core GPT models. + + This class extends TransformerConfig with GPT-specific parameters and + provides a method to instantiate configured GPT models. + """ + + # Model configuration + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + share_embeddings_and_output_weights: bool = True + make_vocab_size_divisible_by: int = 128 + position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" + rotary_base: int = 10000 + rotary_percent: float = 1.0 + seq_len_interpolation_factor: Optional[float] = None + seq_length: int = 1024 + attention_softmax_in_fp32: bool = False + deallocate_pipeline_outputs: bool = True + scatter_embedding_sequence_parallel: bool = True + tp_only_amax_red: bool = False + tp_comm_overlap_cfg: Optional[Union[str, dict[str, Any]]] = None + """Config file when tp_comm_overlap is enabled.""" + + use_transformer_engine_full_layer_spec: bool = False + use_transformer_engine_op_fuser: bool = False + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTModelProvider"], ModuleSpec]] = ( + default_layer_spec + ) + + generation_config: Optional[Any] = None + + # This represents the unpadded vocab size + # The padded vocab size is automatically calculated in the provide() method. + vocab_size: Optional[int] = None + # Set if the tokenizer provides the vocab size. In this case, the vocab size will be padded + # Controls whether vocab size should be padded for tensor parallelism + should_pad_vocab: bool = False + + # MoE / FP8 + num_moe_experts: Optional[int] = None + moe_grouped_gemm: bool = False + qk_layernorm: bool = False + fp8: Optional[str] = None + normalization: str = "LayerNorm" + + # Multi-token prediction + mtp_enabled: bool = False + + # Additional parameters that might be needed + init_model_with_meta_device: bool = False + use_te_rng_tracker: bool = False + enable_cuda_graph: bool = False + virtual_pipeline_model_parallel_size: Optional[int] = None + account_for_embedding_in_pipeline_split: bool = False + account_for_loss_in_pipeline_split: bool = False + + # Fusions + masked_softmax_fusion: bool = field(default_factory=fusions.can_enable_masked_softmax_fusion) + cross_entropy_loss_fusion: bool = True # Generally beneficial, no specific dependencies + gradient_accumulation_fusion: bool = field( + default_factory=fusions.can_enable_gradient_accumulation_fusion + ) + bias_activation_fusion: bool = ( + False # Disabled by default as it can interfere with certain architectures + ) + persist_layer_norm: bool = False + bias_dropout_fusion: bool = field(default_factory=fusions.can_enable_bias_dropout_fusion) + apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion) + + # If True, restore the modelopt_state that contains quantization, sparsity, speculative decoding transformation state. + # When resuming modelopt_state, we also change the transformer_layer_spec to `megatron.core.post_training.modelopt.gpt.model_specs` which is a combination of local spec + TEDotProductAttention. + + restore_modelopt_state: bool = False + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a Megatron Core GPT model based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured Megatron Core GPT model instance + """ + # Validate fusion configurations + if not fusions.validate_rope_fusion_compatibility(self): + self.apply_rope_fusion = False + + if self.enable_cuda_graph: + assert getattr(self, "use_te_rng_tracker", False), ( + "Transformer engine's RNG tracker is required for cudagraphs, it can be " + "enabled with use_te_rng_tracker=True'." + ) + + vp_size = self.virtual_pipeline_model_parallel_size + is_pipeline_asymmetric = getattr( + self, "account_for_embedding_in_pipeline_split", False + ) or getattr(self, "account_for_loss_in_pipeline_split", False) + is_pipeline_asymmetric |= ( + getattr(self, "num_layers_in_first_pipeline_stage", None) + or getattr(self, "num_layers_in_last_pipeline_stage", None) + ) is not None + is_flexible_pp_layout = is_pipeline_asymmetric or ( + getattr(self, "pipeline_model_parallel_layout", None) is not None + ) + if vp_size and not is_flexible_pp_layout: + p_size = self.pipeline_model_parallel_size + assert ( + self.num_layers // p_size + ) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + transformer_layer_spec = self.transformer_layer_spec + if not isinstance(transformer_layer_spec, ModuleSpec): + # Check if the transformer_layer_spec function accepts vp_stage parameter + if "vp_stage" in inspect.signature(transformer_layer_spec).parameters: + transformer_layer_spec = transformer_layer_spec(self, vp_stage=vp_stage) + else: + transformer_layer_spec = transformer_layer_spec(self) + + assert self.vocab_size is not None, "vocab_size must be configured before calling provide()" + if self.should_pad_vocab: + padded_vocab_size = calculate_padded_vocab_size( + self.vocab_size, self.make_vocab_size_divisible_by, self.tensor_model_parallel_size + ) + else: + padded_vocab_size = self.vocab_size + + # Initialize model as meta data instead of allocating data on a device + model_init_device_context = contextlib.nullcontext + if self.init_model_with_meta_device: + model_init_device_context = partial(torch.device, device="meta") + + # Check if mtp_block_spec parameter is supported + kwargs = {} + if "mtp_block_spec" in inspect.signature(MCoreGPTModel.__init__).parameters: + kwargs["mtp_block_spec"] = mtp_block_spec(self, vp_stage=vp_stage) + + with model_init_device_context(): + model = MCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=padded_vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process + or parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage), + post_process=post_process + or parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + vp_stage=vp_stage, + **kwargs, + ) + + # If using full TE layer, need to set TP, CP group since the module call + # is not routed through megatron core, which normally handles passing the + # TP, CP group to the TE modules. + # Deep iterate but skip self to avoid infinite recursion. + if self.use_transformer_engine_full_layer_spec: + # Copied from: + # https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + if parallel_state.get_tensor_model_parallel_world_size() > 1: + for index, child in enumerate(model.modules()): + if index == 0: + continue + if hasattr(child, "set_tensor_parallel_group"): + tp_group = parallel_state.get_tensor_model_parallel_group() + child.set_tensor_parallel_group(tp_group) + + if parallel_state.get_context_parallel_world_size() > 1: + cp_stream = torch.cuda.Stream() + for index, child in enumerate(model.modules()): + if index == 0: + continue + if hasattr(child, "set_context_parallel_group"): + child.set_context_parallel_group( + parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream, + ) + + return model + + +def mtp_block_spec( + config: "GPTModelProvider", vp_stage: Optional[int] = None +) -> Optional[ModuleSpec]: + """Pass in the MTP block spec if model has MTP layers. + + Args: + config: GPT configuration object + + Returns: + ModuleSpec: The MTP module specification + """ + if getattr(config, "mtp_num_layers", None): + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + + if isinstance(config.transformer_layer_spec, Callable): + if "vp_stage" in inspect.signature(config.transformer_layer_spec).parameters: + spec = config.transformer_layer_spec(config, vp_stage=vp_stage) + else: + spec = config.transformer_layer_spec(config) + else: + spec = config.transformer_layer_spec + if hasattr(spec, "layer_specs") and len(spec.layer_specs) == 0: + # Get the decoder layer spec explicitly if no decoder layer in the last stage, + # Only happens with block spec (TransformerBlockSubmodules) when using MoE. + spec = default_layer_spec(config) + return get_gpt_mtp_block_spec(config, spec, use_transformer_engine=True, vp_stage=vp_stage) + else: + return None + + +@dataclass +class GPTProvider126M(GPTModelProvider): + """Configuration for a 126M parameter GPT model. + + Predefined configuration for a small GPT model with 12 layers, + 768 hidden size, and 12 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 12 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 12 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider5B(GPTModelProvider): + """Configuration for a 5B parameter GPT model. + + Predefined configuration for a medium-sized GPT model with 24 layers, + 4096 hidden size, and 32 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 24 + hidden_size: int = 4096 + ffn_hidden_size: int = 16384 + num_attention_heads: int = 32 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider7B(GPTModelProvider): + """Configuration for a 7B parameter GPT model. + + Predefined configuration for a medium-sized GPT model with 32 layers, + 4096 hidden size, and 32 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 32 + hidden_size: int = 4096 + ffn_hidden_size: int = 10880 + num_attention_heads: int = 32 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider20B(GPTModelProvider): + """Configuration for a 20B parameter GPT model. + + Predefined configuration for a large GPT model with 44 layers, + 6144 hidden size, and 48 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 44 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider40B(GPTModelProvider): + """Configuration for a 40B parameter GPT model. + + Predefined configuration for a large GPT model with 48 layers, + 8192 hidden size, and 64 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 48 + hidden_size: int = 8192 + ffn_hidden_size: int = 32768 + num_attention_heads: int = 64 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider175B(GPTModelProvider): + """Configuration for a 175B parameter GPT model. + + Predefined configuration for a massive GPT model with 96 layers, + 12288 hidden size, and 96 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 96 + hidden_size: int = 12288 + ffn_hidden_size: int = 49152 + num_attention_heads: int = 96 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + layernorm_zero_centered_gamma: bool = True diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py new file mode 100644 index 0000000000..81f80fd7ac --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.nemo_bridge.models.hf_pretrained.vlm import PreTrainedVLM + +__all__ = ["PreTrainedCausalLM", "PreTrainedVLM"] diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py new file mode 100644 index 0000000000..ad8b3e332d --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import shutil + +from abc import ABC, abstractmethod +from fnmatch import fnmatch +from pathlib import Path +from typing import ClassVar, Dict, List, Optional, Union + +import torch + +from transformers import AutoConfig, PreTrainedModel + +from megatron.nemo_bridge.models.hf_pretrained.state import ( + SafeTensorsStateSource, + StateDict, + StateSource, +) + + +class PreTrainedBase(ABC): + """ + Abstract base class for all pretrained models. + + This class provides a generic mechanism for managing model artifacts + (e.g., config, tokenizer) with lazy loading. Subclasses that are + decorated with `@dataclass` can define artifacts as fields with metadata + specifying a loader method. The `model` itself is handled via a + dedicated property that relies on the abstract `_load_model` method. + + Example: + @dataclass + class MyModel(PreTrainedBase): + config: AutoConfig = field( + init=False, + metadata=artifact(loader="_load_config") + ) + + def _load_model(self) -> "PreTrainedModel": + # Implementation for the loading logic + ... + """ + + model_name_or_path: Union[str, Path] + ARTIFACTS: ClassVar[List[str]] = [] + OPTIONAL_ARTIFACTS: ClassVar[List[str]] = [] + + def __init__(self, **kwargs): + self._state_dict_accessor: Optional[StateDict] = None + self.init_kwargs = kwargs + # Store the original source path for custom modeling file preservation + self._original_source_path: Optional[Union[str, Path]] = None + + def get_artifacts(self) -> Dict[str, str]: + """Get the artifacts dictionary mapping artifact names to their attribute names.""" + return {artifact: f"_{artifact}" for artifact in self.ARTIFACTS} + + def _copy_custom_modeling_files( + self, source_path: Union[str, Path], target_path: Union[str, Path] + ) -> None: + """Copy custom modeling files from source to target directory. + + This preserves custom modeling files that were used during model loading + with trust_remote_code=True, ensuring the saved model can be loaded properly. + + Args: + source_path: Source directory containing custom modeling files + target_path: Target directory to copy files to + """ + source_path = Path(source_path) + target_path = Path(target_path) + + # Common custom modeling file patterns + custom_file_patterns = ["*.py", "*.json", "*.jpeg", "*.png", "*.jpg", "*.mp4"] + copied_files = [] + + # First, try to copy from local directory if it exists + if source_path.exists() and source_path.is_dir(): + for pattern in custom_file_patterns: + for file_path in source_path.glob(pattern): + if file_path.is_file(): + target_file = target_path / file_path.name + try: + shutil.copy2(file_path, target_file) + copied_files.append(file_path.name) + except (OSError, IOError): + # Silently skip files that can't be copied + pass + + # If no files were copied and source_path looks like a HuggingFace Hub ID, + # try to download the custom modeling files directly from the Hub + if not copied_files and "/" in str(source_path) and not source_path.exists(): + try: + from huggingface_hub import hf_hub_download, list_repo_files + + # Get list of Python files in the repository + repo_files = list_repo_files(str(source_path)) + print("repo_files: ", repo_files) + for file in repo_files: + # Check if it matches our custom file patterns + if any(fnmatch(file, pattern) for pattern in custom_file_patterns): + try: + downloaded_file = hf_hub_download( + repo_id=str(source_path), + filename=file, + local_dir=target_path, + local_dir_use_symlinks=False, + ) + copied_files.append(file) + except Exception as e: + print("Error downloading file: ", e, "Skipping file...") + # Silently skip files that can't be downloaded + pass + + except Exception as e: + print( + "Error downloading custom modeling files: ", + e, + "Skipping custom modeling files...", + ) + # If HuggingFace Hub operations fail, silently continue + pass + + return copied_files + + def save_artifacts(self, save_directory: Union[str, Path]): + """ + Saves all loaded, generic artifacts that have a `save_pretrained` method + to the specified directory. Note: This does not save the `model` attribute. + + If the model was loaded with trust_remote_code=True, this method will also + attempt to preserve any custom modeling files to ensure the saved model + can be loaded properly. + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + _ = getattr(self, "config") # trigger lazy loading of config + if hasattr(self, "_config") and self._config is not None: + self._config.save_pretrained(save_path) + + # Iterate over required artifacts to save them in a predictable order + # for name in self.ARTIFACTS: + # # Access the public property to trigger lazy loading if needed + # artifact = getattr(self, name) + # attr_name = f"_{name}" + # if hasattr(self, attr_name): + # if artifact is not None and hasattr(artifact, "save_pretrained"): + # artifact.save_pretrained(save_path) + + # Iterate over optional artifacts - only save if they exist and have save_pretrained + for name in self.OPTIONAL_ARTIFACTS: + artifact = getattr(self, name, None) + if artifact is not None and hasattr(artifact, "save_pretrained"): + artifact.save_pretrained(save_path) + + # Preserve custom modeling files if trust_remote_code was used + if hasattr(self, 'trust_remote_code') and self.trust_remote_code: + # Try original source path first, then fallback to model_name_or_path + source_paths = [] + if hasattr(self, '_original_source_path') and self._original_source_path: + source_paths.append(self._original_source_path) + if hasattr(self, 'model_name_or_path') and self.model_name_or_path: + source_paths.append(self.model_name_or_path) + + for source_path in source_paths: + copied_files = self._copy_custom_modeling_files(source_path, save_path) + if copied_files: + # Successfully copied files, no need to try other paths + break + + @abstractmethod + def _load_model(self) -> PreTrainedModel: + """Subclasses must implement this to load the main model.""" + pass + + @abstractmethod + def _load_config(self) -> AutoConfig: + """Subclasses must implement this to load the model config.""" + pass + + @property + def model(self) -> PreTrainedModel: + """Lazily loads and returns the underlying model.""" + if not hasattr(self, "_model"): + self._model = self._load_model() + return self._model + + @model.setter + def model(self, value: PreTrainedModel): + """Manually set the model.""" + self._model = value + + @property + def config(self) -> AutoConfig: + """Lazy load and return the model config.""" + if not hasattr(self, "_config"): + self._config = self._load_config() + return self._config + + @config.setter + def config(self, value: AutoConfig): + """Set the config manually.""" + self._config = value + + @property + def state(self) -> StateDict: + """ + Get the state dict accessor for pandas-like querying. + + This accessor can be backed by either a fully loaded model in memory + or a ".safetensors" checkpoint on disk, enabling lazy loading of tensors. + + Examples: + model.state() # Get full state dict + model.state["key"] # Get single entry + model.state[["key1", "key2"]] # Get multiple entries + model.state["*.weight"] # Glob pattern + model.state.regex(r".*\\.bias$") # Regex pattern + """ + if self._state_dict_accessor is None: + source: Optional[Union[Dict[str, torch.Tensor], StateSource]] = None + # Prioritize the loaded model's state_dict if available + if hasattr(self, "_model") and self._model is not None: + source = self.model.state_dict() + elif hasattr(self, "model_name_or_path") and self.model_name_or_path: + source = SafeTensorsStateSource(self.model_name_or_path) + + if source is None: + raise ValueError( + "Cannot create StateDict accessor: model is not loaded and model_name_or_path is not set." + ) + self._state_dict_accessor = StateDict(source) + return self._state_dict_accessor diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py new file mode 100644 index 0000000000..e5383f11fc --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py @@ -0,0 +1,657 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import sys + +from pathlib import Path +from typing import Dict, Generic, List, Optional, TypeVar, Union + +import torch + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + PreTrainedTokenizer, +) +from transformers.generation.utils import GenerateOutput + +from megatron.nemo_bridge.models.hf_pretrained.base import PreTrainedBase +from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import ( + safe_load_config_with_retry, +) + +# Python 3.12+ supports PEP 692 (TypedDict Unpack) +if sys.version_info >= (3, 12): + from typing import TypedDict, Unpack +else: + from typing_extensions import TypedDict, Unpack + + +CausalLMType = TypeVar("CausalLMType", bound=AutoModelForCausalLM) + + +class PreTrainedCausalLM(PreTrainedBase, Generic[CausalLMType]): + """ + A generic class for Pretrained Causal Language Models with lazy loading. + + Allows type-safe access to specific model implementations like LlamaForCausalLM. + + Examples: + Basic usage with lazy loading: + >>> from mbridge.pretrained import PreTrainedCausalLM + >>> # Create instance - no model loading happens yet + >>> model = PreTrainedCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> # Components are loaded on first access + >>> config = model.config # Loads config + >>> tokenizer = model.tokenizer # Loads tokenizer + >>> # Generate text - model is loaded here + >>> inputs = model.encode("Hello, how are you?") + >>> outputs = model.generate(**inputs, max_length=50) + >>> print(model.decode(outputs[0], skip_special_tokens=True)) + + Using specific model types with type hints: + >>> from transformers import LlamaForCausalLM + >>> from mbridge.pretrained import PreTrainedCausalLM + >>> # Type-safe access to Llama-specific features + >>> llama_model: PreTrainedCausalLM[LlamaForCausalLM] = PreTrainedCausalLM.from_pretrained( + ... "meta-llama/Llama-2-7b-chat-hf", + ... torch_dtype=torch.float16, + ... device="cuda" + ... ) + >>> # Access Llama-specific attributes + >>> model_instance = llama_model.model # Type is LlamaForCausalLM + + Loading with custom configurations: + >>> # Load model with specific settings + >>> model = PreTrainedCausalLM.from_pretrained( + ... "gpt2", + ... device="cuda:0", + ... torch_dtype=torch.bfloat16, + ... attn_implementation="flash_attention_2", + ... load_in_8bit=True + ... ) + >>> # Override generation config + >>> from transformers import GenerationConfig + >>> model.generation_config = GenerationConfig( + ... max_length=100, + ... temperature=0.7, + ... top_p=0.9, + ... do_sample=True + ... ) + + Manual component management: + >>> # Create empty instance + >>> model = PreTrainedCausalLM() + >>> # Manually set components + >>> from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM + >>> model.config = AutoConfig.from_pretrained("microsoft/phi-2") + >>> model.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") + >>> model.model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") + >>> # Save all components + >>> model.save_artifacts("./my_model") + + Batch processing example: + >>> # Process multiple prompts + >>> prompts = [ + ... "The capital of France is", + ... "Machine learning is", + ... "Python programming language was created by" + ... ] + >>> # Encode all prompts + >>> inputs = model.encode(prompts, padding=True, truncation=True) + >>> # Generate completions + >>> outputs = model.generate(**inputs, max_new_tokens=20) + >>> # Decode results + >>> for i, output in enumerate(outputs): + ... print(f"Prompt {i+1}: {model.decode(output, skip_special_tokens=True)}") + """ + + ARTIFACTS = ["tokenizer"] + OPTIONAL_ARTIFACTS = ["generation_config"] + + def __init__( + self, + model_name_or_path: Optional[Union[str, Path]] = None, + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ): + """ + Initialize a Pretrained Causal LM with lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on (e.g., 'cuda', 'cpu') + torch_dtype: Data type to load model in (e.g., torch.float16) + trust_remote_code: Whether to trust remote code when loading + **kwargs: Additional arguments passed to from_pretrained methods + """ + self._model_name_or_path = model_name_or_path + # self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.device = "cpu" + self.torch_dtype = torch_dtype + self.trust_remote_code = trust_remote_code + super().__init__(**kwargs) + # Store the original source path for custom modeling file preservation + if model_name_or_path and trust_remote_code: + self._original_source_path = model_name_or_path + + def _load_model(self) -> CausalLMType: + """Load the model.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load model") + + model_kwargs = {"trust_remote_code": self.trust_remote_code, **self.init_kwargs} + if self.torch_dtype is not None: + model_kwargs["torch_dtype"] = self.torch_dtype + config = getattr(self, "_config", None) + if config is not None: + model_kwargs["config"] = config + + model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **model_kwargs) + model = model.to(self.device) + + generation_config = getattr(self, "_generation_config", None) + if generation_config is not None and hasattr(model, "generation_config"): + model.generation_config = generation_config + return model + + def _load_config(self) -> AutoConfig: + """Load the model config with thread-safety protection.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load config") + return safe_load_config_with_retry( + self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs + ) + + def _load_tokenizer(self) -> PreTrainedTokenizer: + """Load the tokenizer.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def _load_generation_config(self) -> Optional[GenerationConfig]: + """Load the generation config.""" + if self.model_name_or_path is not None: + try: + return GenerationConfig.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Not all models have generation configs + pass + return None + + @property + def generation_config(self) -> Optional[GenerationConfig]: + """Lazy load and return the generation config.""" + if not hasattr(self, "_generation_config"): + self._generation_config = self._load_generation_config() + return self._generation_config + + @generation_config.setter + def generation_config(self, value: GenerationConfig): + """Set the generation config manually.""" + self._generation_config = value + # Update model's generation config if model is already loaded + model = getattr(self, "_model", None) + if model is not None and hasattr(model, "generation_config"): + model.generation_config = value + + @property + def tokenizer(self) -> PreTrainedTokenizer: + """Lazy load and return the tokenizer.""" + if not hasattr(self, "_tokenizer"): + self._tokenizer = self._load_tokenizer() + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value: PreTrainedTokenizer): + """Set the tokenizer manually.""" + self._tokenizer = value + + @property + def model_name_or_path(self) -> Optional[Union[str, Path]]: + """Return the model name or path.""" + return self._model_name_or_path + + @property + def has_model(self) -> bool: + """Check if model has been loaded.""" + return hasattr(self, "_model") and self._model is not None + + @property + def model(self) -> CausalLMType: + """Lazy load and return the underlying model.""" + return super().model + + @model.setter + def model(self, value: CausalLMType): + """Set the model manually and move it to the appropriate device.""" + self._model = value + if self._model is not None: + self._model = self._model.to(self.device) + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "PreTrainedCausalLM[CausalLMType]": + """ + Create a PreTrainedCausalLM instance for lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on + torch_dtype: Data type to load model in + trust_remote_code: Whether to trust remote code + **kwargs: Additional arguments for from_pretrained methods + + Returns: + PreTrainedCausalLM instance configured for lazy loading + """ + return cls( + model_name_or_path=model_name_or_path, + device=device, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + def generate( + self, input_ids: Optional[torch.LongTensor] = None, **kwargs: Unpack["GenerateKwargs"] + ) -> Union[torch.LongTensor, GenerateOutput]: + """ + Generate text using the underlying language model. + + This method forwards all arguments to the model's generate method, + supporting all generation strategies provided by the transformers library. + + Common parameters include: + inputs (torch.LongTensor, optional): Input token IDs. If not provided, + will generate from the beginning of sequence token. + max_length (int, optional): Maximum length of generated sequence. + Defaults to model's max_length configuration. + min_length (int, optional): Minimum length of generated sequence. + max_new_tokens (int, optional): Maximum number of tokens to generate, + ignoring the number of tokens in the prompt. + do_sample (bool, optional): Whether to use sampling. Defaults to False + (greedy decoding). + temperature (float, optional): Temperature for sampling. Higher values + produce more random outputs. Typical range: 0.1-2.0. + top_p (float, optional): Nucleus sampling threshold. Only tokens with + cumulative probability up to top_p are considered. Range: 0.0-1.0. + top_k (int, optional): Only consider the top k tokens for sampling. + num_beams (int, optional): Number of beams for beam search. 1 means + no beam search. + repetition_penalty (float, optional): Penalty for repeating tokens. + Values > 1.0 discourage repetition. + pad_token_id (int, optional): ID of padding token. + eos_token_id (int or List[int], optional): ID(s) of end-of-sequence token(s). + use_cache (bool, optional): Whether to use past key values to speed up + generation. Defaults to True. + + Returns: + torch.LongTensor or transformers.generation.utils.GenerateOutput: + Generated token IDs. If return_dict_in_generate=True, returns a + GenerateOutput object containing generated sequences and additional + information like scores. + + Examples: + >>> # Basic generation + >>> model = PreTrainedCausalLM.from_pretrained("gpt2") + >>> inputs = model.encode("Hello, how are") + >>> outputs = model.generate(inputs["input_ids"], max_length=20) + >>> print(model.decode(outputs[0])) + + >>> # Generation with sampling + >>> outputs = model.generate( + ... inputs["input_ids"], + ... max_length=50, + ... do_sample=True, + ... temperature=0.8, + ... top_p=0.9 + ... ) + + >>> # Beam search + >>> outputs = model.generate( + ... inputs["input_ids"], + ... max_length=50, + ... num_beams=5, + ... early_stopping=True + ... ) + + Note: + For detailed documentation of all parameters, see the transformers + library documentation for generation methods. + """ + model = self.model # Ensures model is loaded + # Sync generation config if it has been set on the wrapper + generation_config = getattr(self, "_generation_config", None) + if generation_config is not None and hasattr(model, "generation_config"): + model.generation_config = generation_config + return model.generate(input_ids, **kwargs) + + def __call__(self, *args, **kwargs): + """Forward call to model.""" + return self.model(*args, **kwargs) + + def encode( + self, text: Union[str, List[str]], **kwargs: Unpack["EncodeKwargs"] + ) -> Dict[str, torch.Tensor]: + """ + Encode text into token IDs using the model's tokenizer. + + This method tokenizes input text and returns tensors ready for model input. + The output is automatically moved to the same device as the model. + + Args: + text (str or List[str]): Input text to encode. Can be a single string + or a list of strings for batch encoding. + **kwargs: Additional arguments passed to the tokenizer. Common options: + padding (bool or str, optional): Padding strategy. + - True or 'longest': Pad to longest sequence in batch + - 'max_length': Pad to max_length + - False or 'do_not_pad': No padding (default) + truncation (bool or str, optional): Truncation strategy. + - True or 'longest_first': Truncate to max_length + - 'only_first': Truncate only first sequence (for pairs) + - False: No truncation + max_length (int, optional): Maximum length of returned sequences. + Defaults to model's max_length. + add_special_tokens (bool, optional): Whether to add special tokens + (e.g., [CLS], [SEP]). Defaults to True. + return_attention_mask (bool, optional): Whether to return attention + mask. Defaults to True. + return_token_type_ids (bool, optional): Whether to return token + type IDs (for models like BERT). Defaults to True if model + expects them. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing: + - input_ids: Token IDs tensor of shape (batch_size, sequence_length) + - attention_mask: Attention mask tensor of same shape (if applicable) + - token_type_ids: Token type IDs tensor (if applicable) + Additional keys may be present depending on the tokenizer. + + Examples: + >>> model = PreTrainedCausalLM.from_pretrained("gpt2") + >>> # Single text encoding + >>> tokens = model.encode("Hello world!") + >>> print(tokens["input_ids"].shape) # torch.Size([1, 3]) + + >>> # Batch encoding with padding + >>> texts = ["Hello!", "How are you doing today?"] + >>> tokens = model.encode(texts, padding=True) + >>> print(tokens["input_ids"].shape) # torch.Size([2, 6]) + + >>> # Encoding with truncation + >>> tokens = model.encode( + ... "This is a very long text that might exceed the maximum length", + ... truncation=True, + ... max_length=10 + ... ) + + Note: + The returned tensors are on the same device as the model, ready + for immediate use in forward passes or generation. + """ + # Only set return_tensors default if not provided + if "return_tensors" not in kwargs: + kwargs["return_tensors"] = "pt" + + return self.tokenizer(text, **kwargs).to(self.device) + + def decode( + self, token_ids: Union[int, List[int], torch.Tensor], **kwargs: Unpack["DecodeKwargs"] + ) -> str: + """ + Decode token IDs back into text using the model's tokenizer. + + This method converts token IDs (from model output or encode method) + back into human-readable text. + + Args: + token_ids (int, List[int], or torch.Tensor): Token IDs to decode. + Can be: + - Single token ID (int) + - List of token IDs + - 1D tensor of token IDs + - 2D tensor (will decode the first sequence) + **kwargs: Additional arguments passed to the tokenizer's decode method: + skip_special_tokens (bool, optional): Whether to remove special + tokens (e.g., [PAD], [CLS], [SEP]) from output. Defaults to True. + clean_up_tokenization_spaces (bool, optional): Whether to clean up + tokenization artifacts (extra spaces, etc.). Defaults to True. + + Returns: + str: Decoded text string. + + Examples: + >>> model = PreTrainedCausalLM.from_pretrained("gpt2") + >>> # Encode and decode round-trip + >>> text = "Hello, world!" + >>> tokens = model.encode(text) + >>> decoded = model.decode(tokens["input_ids"][0]) + >>> print(decoded) # "Hello, world!" + + >>> # Decode generated tokens + >>> inputs = model.encode("The weather is") + >>> outputs = model.generate(inputs["input_ids"], max_length=10) + >>> decoded = model.decode(outputs[0]) + >>> print(decoded) # "The weather is nice today..." + + >>> # Decode without special tokens + >>> token_ids = [101, 7592, 1010, 2088, 999, 102] # BERT-style tokens + >>> decoded = model.decode(token_ids, skip_special_tokens=True) + >>> print(decoded) # "Hello, world!" + + >>> # Decode keeping special tokens + >>> decoded = model.decode(token_ids, skip_special_tokens=False) + >>> print(decoded) # "[CLS] Hello, world! [SEP]" + + Note: + If a 2D tensor is provided (batch of sequences), only the first + sequence is decoded. For batch decoding, use tokenizer.batch_decode() + directly or iterate over the sequences. + """ + return self.tokenizer.decode(token_ids, **kwargs) + + def to(self, device: Union[str, torch.device]): + """Move model to specified device.""" + self.device = device + if self.has_model: + self._model = self._model.to(device) + return self + + def half(self): + """Convert model to half precision (float16).""" + if self.has_model: + self._model = self._model.half() + return self + + def float(self): + """Convert model to full precision (float32).""" + if self.has_model: + self._model = self._model.float() + return self + + def save_pretrained(self, save_directory: Union[str, Path]): + """ + Save all components (model, tokenizer, config, generation_config) to a directory. + + This method saves: + - Model weights and config + - Tokenizer files + - Generation config (if available) + + Args: + save_directory: Path to directory where components will be saved + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Save model if loaded + if hasattr(self, "_model") and self._model is not None: + self._model.save_pretrained(save_path) + + # Use the base class save_artifacts to save config and all artifacts + self.save_artifacts(save_path) + + @property + def dtype(self) -> Optional[torch.dtype]: + """Get model's dtype if loaded.""" + if self.has_model: + try: + return next(self.model.parameters()).dtype + except StopIteration: + return None + return None + + @property + def num_parameters(self) -> Optional[int]: + """Get total number of parameters if model is loaded.""" + if self.has_model: + return sum(p.numel() for p in self.model.parameters()) + return None + + def __repr__(self) -> str: + """Return a string representation of the PreTrainedCausalLM instance.""" + try: + # Access config to trigger lazy loading for a richer repr + _ = self.config + except Exception: + # If loading fails, repr shouldn't crash. + pass + + lines = [f"{self.__class__.__name__}("] + for name, attr_name in sorted(self.get_artifacts().items()): + is_loaded = hasattr(self, attr_name) + artifact_instance = getattr(self, attr_name, None) if is_loaded else None + + type_name = "N/A" + details = "not loaded" + if is_loaded and artifact_instance is not None: + type_name = artifact_instance.__class__.__name__ + if name == "tokenizer": + vocab = getattr(artifact_instance, "vocab_size", "N/A") + details = f"vocab_size={vocab}" + elif name == "config": + m_type = getattr(artifact_instance, "model_type", "N/A") + details = f"model_type={m_type}" + else: + details = "loaded" + lines.append(f" ({name}): {type_name} [{details}]") + + # Manually add model repr + model_repr_content: str + if self.has_model: + model_class_name = self.model.__class__.__name__ + # Assuming self.config is loaded or available here due to earlier attempt + config = self.config + layers = getattr(config, "num_hidden_layers", "N/A") + hidden_size = getattr(config, "hidden_size", "N/A") + model_repr_content = ( + f"{model_class_name} [layers={layers}, hidden_size={hidden_size}, loaded]" + ) + elif "config" in self.__dict__: # Model not loaded, but config is + config = self.config + model_class_name_from_hf_config = "CausalLM" # Default + if hasattr(config, "architectures") and config.architectures: + model_class_name_from_hf_config = config.architectures[0] + elif getattr(config, "model_type", None): + mt = config.model_type + model_class_name_from_hf_config = f"{mt.capitalize()}Model" if mt else "CausalLM" + + details_parts = [] + if getattr(config, "num_hidden_layers", None) is not None: + details_parts.append(f"layers={config.num_hidden_layers}") + if getattr(config, "hidden_size", None) is not None: + details_parts.append(f"hidden_size={config.hidden_size}") + + details_str = ", ".join(details_parts) + status_suffix = "not loaded" + if details_str: + model_repr_content = ( + f"{model_class_name_from_hf_config}({details_str}) [{status_suffix}]" + ) + else: + model_repr_content = f"{model_class_name_from_hf_config} [{status_suffix}]" + else: # Model and Config also not loaded + model_repr_content = "AutoModelForCausalLM [not loaded]" + + lines.append(f" (model): {model_repr_content}") + + lines.sort() + + params_str = f"{self.num_parameters:,}" if self.num_parameters is not None else "N/A" + dtype_str = str(self.dtype).replace("torch.", "") if self.dtype is not None else "N/A" + lines.extend( + [ + f" (parameters): {params_str}", + f" (device): {str(self.device)}", + f" (dtype): {dtype_str}", + ")", + ] + ) + return "\n".join(lines) + + +# TypedDict definitions for method parameters +class GenerateKwargs(TypedDict, total=False): + """TypedDict for generate method parameters.""" + + attention_mask: Optional[torch.Tensor] + max_length: Optional[int] + max_new_tokens: Optional[int] + min_length: Optional[int] + do_sample: Optional[bool] + temperature: Optional[float] + top_k: Optional[int] + top_p: Optional[float] + repetition_penalty: Optional[float] + pad_token_id: Optional[int] + eos_token_id: Optional[Union[int, List[int]]] + bos_token_id: Optional[int] + num_beams: Optional[int] + num_return_sequences: Optional[int] + early_stopping: Optional[bool] + use_cache: Optional[bool] + return_dict_in_generate: Optional[bool] + output_scores: Optional[bool] + output_attentions: Optional[bool] + + +class EncodeKwargs(TypedDict, total=False): + """TypedDict for encode method parameters.""" + + padding: Union[bool, str] + truncation: Union[bool, str] + max_length: Optional[int] + add_special_tokens: bool + return_attention_mask: bool + return_token_type_ids: Optional[bool] + return_tensors: str + + +class DecodeKwargs(TypedDict, total=False): + """TypedDict for decode method parameters.""" + + skip_special_tokens: bool + clean_up_tokenization_spaces: bool diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py new file mode 100644 index 0000000000..9d5e9490aa --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +""" +Thread-safe configuration loading utilities. + +This module provides utilities for safely loading HuggingFace model configurations +in multi-threaded environments, preventing race conditions that can occur when +multiple threads try to download and cache the same model simultaneously. +""" + +import hashlib +import os +import time + +from pathlib import Path +from typing import Union + +import filelock + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig + + +def safe_load_config_with_retry( + path: Union[str, Path], + trust_remote_code: bool = False, + max_retries: int = 3, + base_delay: float = 1.0, + **kwargs, +) -> PretrainedConfig: + """ + Thread-safe and process-safe configuration loading with retry logic. + + This function prevents race conditions when multiple threads/processes + try to download and cache the same model configuration simultaneously. + Uses file locking (if filelock is available) to coordinate access across + processes. + + Args: + path: HuggingFace model ID or path to model directory + trust_remote_code: Whether to trust remote code when loading config + max_retries: Maximum number of retry attempts (default: 3) + base_delay: Base delay in seconds for exponential backoff (default: 1.0) + **kwargs: Additional arguments passed to AutoConfig.from_pretrained + + Returns: + PretrainedConfig: The loaded model configuration + + Raises: + ValueError: If config loading fails after all retries + + Environment Variables: + MEGATRON_CONFIG_LOCK_DIR: Override the directory where lock files are created. + Default: ~/.cache/huggingface/ + Useful for multi-node setups where a shared lock directory is needed. + + Example: + >>> config = safe_load_config_with_retry("meta-llama/Meta-Llama-3-8B") + >>> print(config.model_type) + + >>> # With custom retry settings + >>> config = safe_load_config_with_retry( + ... "gpt2", + ... max_retries=5, + ... base_delay=0.5, + ... trust_remote_code=True + ... ) + + >>> # Multi-node setup with shared lock directory + >>> import os + >>> os.environ["MEGATRON_CONFIG_LOCK_DIR"] = "/shared/locks" + >>> config = safe_load_config_with_retry("meta-llama/Meta-Llama-3-8B") + """ + last_exception = None + + for attempt in range(max_retries + 1): + try: + # Use file locking for process-safe access + # Create a lock file based on the path hash to avoid conflicts + path_hash = hashlib.md5(str(path).encode()).hexdigest() + + # Allow override of lock directory via environment variable + # This is useful for multi-node setups where a shared lock directory is needed + lock_dir = os.getenv("MEGATRON_CONFIG_LOCK_DIR") + if lock_dir: + lock_file = Path(lock_dir) / f".megatron_config_lock_{path_hash}" + else: + lock_file = ( + Path.home() / ".cache" / "huggingface" / f".megatron_config_lock_{path_hash}" + ) + + lock_file.parent.mkdir(parents=True, exist_ok=True) + + with filelock.FileLock(str(lock_file) + ".lock", timeout=60): + return AutoConfig.from_pretrained( + path, trust_remote_code=trust_remote_code, **kwargs + ) + + except Exception as e: + last_exception = e + + # Don't retry on certain types of errors + error_msg = str(e).lower() + if any( + phrase in error_msg + for phrase in [ + "does not appear to have a file named config.json", + "repository not found", + "entry not found", + "401 client error", + "403 client error", + ] + ): + # Model doesn't exist or access denied, no point retrying + raise ValueError( + f"Failed to load configuration from {path}. " + f"Ensure the path is valid and contains a config.json file. " + f"Error: {e}" + ) from e + + if attempt < max_retries: + # Exponential backoff with jitter + delay = base_delay * (2**attempt) + (time.time() % 1) * 0.1 + time.sleep(delay) + else: + # Final attempt failed + break + + # All retries exhausted + raise ValueError( + f"Failed to load configuration from {path} after {max_retries + 1} attempts. " + f"This might be due to network issues or concurrent access conflicts. " + f"Last error: {last_exception}" + ) from last_exception diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py new file mode 100644 index 0000000000..01c401ec52 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py @@ -0,0 +1,850 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import fnmatch +import json +import re + +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Mapping +from functools import lru_cache +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Pattern, Tuple, Union, overload + +import torch + + +class StateDict(Mapping[str, torch.Tensor]): + """ + A state dict accessor that provides a unified interface for querying model + checkpoints. + + `StateDict` allows for efficient and flexible access to tensor data from + various sources, such as in-memory dictionaries or directories of + `.safetensors` files. A key feature is its ability to query and load only + the required tensors without loading the entire checkpoint into memory, + making it highly memory-efficient for large models. + + It supports a flexible, pandas-like querying interface that allows for + accessing tensors by exact name, a list of names, glob patterns, or regular + expressions. This makes it easy to inspect and manipulate model + checkpoints. + + Examples: + >>> # Setup an example StateDict from an in-memory dictionary + >>> import torch + >>> import re + >>> d = { + ... "model.layer.0.weight": torch.randn(10, 10), + ... "model.layer.0.bias": torch.randn(10), + ... "model.layer.1.weight": torch.randn(10, 10), + ... "model.layer.1.bias": torch.randn(10), + ... } + >>> state = StateDict(d) + >>> + >>> # 1. Access a single tensor by exact key + >>> state["model.layer.0.weight"].shape + torch.Size([10, 10]) + >>> + >>> # 2. Access multiple tensors with a list of strings + >>> list(state[["model.layer.0.weight", "model.layer.1.weight"]].keys()) + ['model.layer.0.weight', 'model.layer.1.weight'] + >>> + >>> # 3. Access with a glob pattern + >>> sorted(list(state.glob("model.layer.*.bias").keys())) + ['model.layer.0.bias', 'model.layer.1.bias'] + >>> + >>> # 4. Access with a compiled regex pattern + >>> regex = re.compile(r"model\\\\.layer\\\\.0\\\\..*") + >>> sorted(list(state[regex].keys())) + ['model.layer.0.bias', 'model.layer.0.weight'] + + The same querying flexibility applies to checkpoints on disk. The following + is a conceptual example of using `StateDict` with a `SafetensorsStateSource` + to query a sharded checkpoint without loading all of it into memory. + + .. code-block:: python + + # Assume SafetensorsStateSource is available + # from megatron.nemo_bridge.models.state import SafetensorsStateSource + + # Imagine a directory 'my_model_checkpoint/' with sharded weights. + state_from_disk = StateDict(SafetensorsStateSource('my_model_checkpoint/')) + + # You can query it just like the in-memory dictionary. Only the required + # tensors (e.g., all weight tensors) will be loaded from disk. + weights = state_from_disk.glob("model.layer.*.weight") + """ + + source: "StateSource" + + def __init__(self, source: Dict[str, torch.Tensor] | "StateSource"): + """ + Initializes the StateDict query accessor. + + Args: + source: The source of the tensor data. This can be a standard + Python dictionary mapping tensor names to `torch.Tensor` objects, + or an instance of a `StateSource` subclass (e.g., + `SafetensorsStateSource`) for more advanced, out-of-memory + access. + """ + if isinstance(source, dict): + source = DictStateSource(source) + + if not isinstance(source, StateSource): + raise TypeError(f"StateDict source must be a dict or a StateSource, got {type(source)}") + + self.source = source + + def _get_all_keys(self) -> List[str]: + """ + Get all available tensor keys from the underlying source. + """ + return self.source.get_all_keys() + + def _load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: + """ + Load specified tensors from the underlying source. + """ + return self.source.load_tensors(keys_to_load) + + def _match_keys(self, pattern: Union[str, Pattern]) -> List[str]: + """Match keys against a glob pattern or regex.""" + all_keys = self._get_all_keys() + + if isinstance(pattern, Pattern): + # Regex pattern + return [k for k in all_keys if pattern.search(k)] + elif "*" in pattern or "?" in pattern or "[" in pattern: + # Glob pattern + return [k for k in all_keys if fnmatch.fnmatch(k, pattern)] + else: + # Exact match + return [pattern] if pattern in all_keys else [] + + @overload + def __getitem__(self, key: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... + + @overload + def __getitem__(self, key: List[str]) -> Dict[str, torch.Tensor]: ... + + @overload + def __getitem__(self, key: Pattern) -> Dict[str, torch.Tensor]: ... + + def __getitem__( + self, key: Union[str, List[str], Pattern] + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Accesses state dict entries using various key types. + + This method allows for retrieving tensors using: + - A single string for an exact key match. + - A list of strings for multiple exact key matches. + - A string with glob-style wildcards (`*`, `?`, `[]`). + - A compiled regular expression object. + + Args: + key: A single key string, a list of keys, a glob pattern string, or a + compiled regular expression. + + Returns: + - A single `torch.Tensor` if `key` is a string that matches exactly one key + and does not contain wildcards. + - A `Dict[str, torch.Tensor]` for all other cases (list of keys, glob + pattern, or regex), mapping the matched keys to their corresponding + tensors. + + Raises: + KeyError: If the key (or any key in a list) is not found, or if a + pattern matches no keys. + + Examples: + >>> d = { + ... "model.embed_tokens.weight": torch.randn(10, 1), + ... "model.layers.0.mlp.weight": torch.randn(10, 1), + ... "model.layers.0.self_attn.q_proj.weight": torch.randn(10, 1), + ... "lm_head.weight": torch.randn(10, 1), + ... } + >>> state = StateDict(d) + >>> + >>> # Exact match (returns a single tensor) + >>> tensor = state["model.embed_tokens.weight"] + >>> isinstance(tensor, torch.Tensor) + True + >>> + >>> # List of keys (returns a dict of tensors) + >>> tensors = state[["model.embed_tokens.weight", "lm_head.weight"]] + >>> sorted(tensors.keys()) + ['lm_head.weight', 'model.embed_tokens.weight'] + >>> + >>> # Glob pattern (returns a dict of tensors) + >>> layer_0_weights = state["model.layers.0.*.weight"] + >>> sorted(layer_0_weights.keys()) + ['model.layers.0.mlp.weight', 'model.layers.0.self_attn.q_proj.weight'] + >>> + >>> # Regex pattern (returns a dict of tensors) + >>> import re + >>> attn_weights = state[re.compile(r".*self_attn.*")] + >>> list(attn_weights.keys()) + ['model.layers.0.self_attn.q_proj.weight'] + """ + if isinstance(key, Pattern): + matched_keys = self._match_keys(key) + if not matched_keys: + raise KeyError(f"No keys match regex pattern: {key.pattern}") + return self._load_tensors(matched_keys) + elif isinstance(key, str): + if "*" in key or "?" in key or "[" in key: + matched_keys = self._match_keys(key) + if not matched_keys: + raise KeyError(f"No keys match pattern: {key}") + return self._load_tensors(matched_keys) + else: + if key not in self._get_all_keys(): + raise KeyError(f"Key not found: {key}") + return self._load_tensors([key])[key] + elif isinstance(key, list): + all_keys_set = set(self._get_all_keys()) + missing_keys = [k for k in key if k not in all_keys_set] + if missing_keys: + raise KeyError(f"Keys not found: {missing_keys}") + return self._load_tensors(key) + else: + raise TypeError(f"Key must be str, list of str, or compiled regex, got {type(key)}") + + def regex(self, pattern: str) -> Dict[str, torch.Tensor]: + """ + Queries the state dict with a regular expression pattern. + + This is a convenience method that compiles the pattern string and uses it + to retrieve all matching tensors. + + Args: + pattern: The regular expression string to match against tensor keys. + + Returns: + A dictionary mapping matching tensor names to their `torch.Tensor` objects. + + Examples: + >>> d = { + ... "model.layers.0.self_attn.weight": torch.randn(1, 1), + ... "model.layers.1.self_attn.weight": torch.randn(1, 1), + ... "model.layers.1.mlp.weight": torch.randn(1, 1) + ... } + >>> state = StateDict(d) + >>> # Get all attention-related weights + >>> attention_weights = state.regex(r"model\\.layers\\.\\d+\\.self_attn.*") + >>> sorted(attention_weights.keys()) + ['model.layers.0.self_attn.weight', 'model.layers.1.self_attn.weight'] + """ + return self[re.compile(pattern)] + + def glob(self, pattern: str) -> Dict[str, torch.Tensor]: + """ + Queries the state dict with a glob pattern. + + This is a convenience method for pattern matching using Unix shell-style + wildcards. + + Args: + pattern: The glob pattern string to match against tensor keys. + + Returns: + A dictionary mapping matching tensor names to their `torch.Tensor` objects. + + Examples: + >>> d = { + ... "model.layers.0.mlp.weight": torch.randn(1, 1), + ... "model.layers.0.mlp.bias": torch.randn(1, 1), + ... "model.layers.1.mlp.weight": torch.randn(1, 1) + ... } + >>> state = StateDict(d) + >>> # Get all mlp weights and biases from the first layer + >>> layer_0_mlp = state.glob("model.layers.0.mlp.*") + >>> sorted(layer_0_mlp.keys()) + ['model.layers.0.mlp.bias', 'model.layers.0.mlp.weight'] + """ + return self[pattern] + + def __call__(self) -> Dict[str, torch.Tensor]: + """ + Loads and returns the entire state dict as a dictionary. + + Note: + This method loads all tensors from the source into memory. For large + models, this can be memory-intensive. Prefer using pattern-based + or single-key lookups for more efficient access if you only need a + subset of the state dict. + + Returns: + A dictionary containing all tensor names and their corresponding + `torch.Tensor` objects. + """ + all_keys = self._get_all_keys() + return self._load_tensors(all_keys) + + def keys(self) -> List[str]: + """Get all state dict keys.""" + return self._get_all_keys() + + def items(self) -> List[tuple]: + """Get all state dict items.""" + return list(self().items()) + + def __contains__(self, key: str) -> bool: + """Check if a key exists in the state dict.""" + return key in self._get_all_keys() + + def __repr__(self) -> str: + """String representation.""" + try: + num_params = len(self) + return f"" + except Exception: + return "" + + def get(self, key: str, default=None) -> Optional[torch.Tensor]: + """ + Gets a tensor from the state dict. + Returns `default` if the key is not found. + Note: This method is for single key lookup and does not support patterns. + """ + if key in self._get_all_keys(): + return self._load_tensors([key])[key] + return default + + def __iter__(self) -> Iterable[str]: + """Iterate over state dict keys.""" + return iter(self.keys()) + + def __len__(self) -> int: + """Get number of entries in the state dict.""" + return len(self.keys()) + + def has_glob(self, pattern: str) -> bool: + """ + Efficiently checks if any tensor key matches the given glob pattern. + This is forwarded to the underlying StateSource which may have an + optimized implementation that avoids iterating over all keys. + + Args: + pattern: The glob pattern to match against tensor keys. + + Returns: + True if a matching key is found, False otherwise. + """ + return self.source.has_glob(pattern) + + +class StateSource(ABC, Mapping[str, torch.Tensor]): + """ + Abstract base class for a source of model state. + + This class defines a standard interface for `StateDict` to access tensor + data, abstracting away the details of how and where the data is stored. + Subclasses can implement loading from different storage backends, such as + in-memory dictionaries or files on disk. This allows `StateDict` to handle + various checkpoint formats in a uniform way. + """ + + @abstractmethod + def get_all_keys(self) -> List[str]: + """Returns a list of all available tensor keys in the source.""" + pass + + @abstractmethod + def load_tensors(self, keys: List[str]) -> Dict[str, torch.Tensor]: + """Loads the specified tensors from the source.""" + pass + + def __getitem__(self, key: str) -> torch.Tensor: + """Loads a single tensor by key.""" + tensors = self.load_tensors([key]) + if key not in tensors: + raise KeyError(f"Key not found in source: {key}") + return tensors[key] + + def __iter__(self) -> Iterable[str]: + """Iterates over all tensor keys.""" + return iter(self.get_all_keys()) + + def __len__(self) -> int: + """Returns the total number of tensors in the source.""" + return len(self.get_all_keys()) + + def has_glob(self, pattern: str) -> bool: + """ + Checks if any tensor key matches the given glob pattern. + This default implementation is not efficient for all sources, as it may + load all keys. Subclasses should override this method if a more + performant implementation is available. + """ + import fnmatch + + for key in self.get_all_keys(): + if fnmatch.fnmatch(key, pattern): + return True + return False + + +class DictStateSource(StateSource): + """ + A state source backed by an in-memory Python dictionary. + + This is the simplest `StateSource` implementation. It's used when the entire + model state dict is already loaded into a dictionary in memory. + + Args: + state_dict: A dictionary mapping tensor names (str) to `torch.Tensor` objects. + """ + + def __init__(self, state_dict: Dict[str, torch.Tensor]): + self._dict = state_dict + self._keys_cache: Optional[List[str]] = None + + def get_all_keys(self) -> List[str]: + if self._keys_cache is None: + self._keys_cache = sorted(list(self._dict.keys())) + return self._keys_cache + + def load_tensors(self, keys: List[str]) -> Dict[str, torch.Tensor]: + return {key: self._dict[key] for key in keys if key in self._dict} + + +class SafeTensorsStateSource(StateSource): + """ + A state source backed by a directory of .safetensors files. + + This source is designed for efficiently loading tensors from checkpoints saved + in the Safetensors format, which is common for large models that are often + "sharded" into multiple files. + + It can handle two common scenarios: + 1. A directory containing multiple `.safetensors` files. + 2. A directory containing a `model.safetensors.index.json` file, which maps + tensor names to the specific `.safetensors` file they reside in. This is + the standard format used by Hugging Face Transformers. + + Using this source allows `StateDict` to query for tensor keys and load only + the necessary files and tensors from disk, avoiding high memory usage. + + Args: + path: The path to the directory containing the `.safetensors` files + and/or the index file. Can also be a Hugging Face Hub model ID. + """ + + def __init__(self, path: Union[str, Path]): + self.model_name_or_path = path + self._resolved_path_cache: Optional[Path] = None + self._keys_cache: Optional[List[str]] = None + self._key_to_filename_map_cache: Optional[Dict[str, str]] = None + + @property + def path(self) -> Path: + """ + The local path to the checkpoint files. + If the initial path is a Hugging Face Hub model ID, this property + will handle downloading the necessary files and return the local + cache path. + """ + if self._resolved_path_cache is None: + self._resolved_path_cache = self._resolve_path(self.model_name_or_path) + return self._resolved_path_cache + + @property + def key_to_filename_map(self) -> Dict[str, str]: + """ + Provides a mapping from tensor keys to the safetensor filename they + are stored in. + + This map is constructed either from `model.safetensors.index.json` if + it exists, or by scanning all `.safetensors` files in the directory. + The result is cached for efficiency. + """ + if self._key_to_filename_map_cache is not None: + return self._key_to_filename_map_cache + + # First, try to load from the index file. + key_map = self._cached_get_key_to_filename_map(self.path) + if key_map: + self._key_to_filename_map_cache = key_map + return key_map + + # If no index, scan the directory. + import os + + from glob import glob as file_glob + + from safetensors import safe_open + + key_map = {} + safetensor_files = file_glob(str(self.path / "*.safetensors")) + for file_path in safetensor_files: + filename = os.path.basename(file_path) + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in key_map: + # This is an issue. Same key in multiple files, and no index. + # How to resolve ambiguity? Let's just warn and overwrite. Last one wins. + print( + f"Warning: duplicate key '{key}' found in '{filename}' and '{key_map[key]}'. Using '{filename}'." + ) + key_map[key] = filename + except Exception as e: + # Can be not a safetensor file, etc. + print(f"Warning: could not open {filename} as a safetensors file: {e}") + + self._key_to_filename_map_cache = key_map + return key_map + + @staticmethod + def _resolve_path(model_name_or_path: Union[str, Path]) -> Path: + """ + Resolves a model name or path to a local directory. + If the path is not a local directory, it is treated as a Hugging + Face Hub model ID, and the corresponding files are downloaded. + """ + local_path = Path(model_name_or_path) + if local_path.is_dir(): + return local_path + + try: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import HfHubHTTPError + + # Not a local directory, so we assume it's a model ID + # on the Hugging Face Hub. + return Path( + snapshot_download( + repo_id=str(model_name_or_path), + allow_patterns=["*.safetensors", "model.safetensors.index.json"], + # Ignore other large files. + ignore_patterns=["*.bin", "*.pt", "*.pth"], + ) + ) + except (ImportError, HfHubHTTPError, ValueError): + # If huggingface_hub is not installed, or if it's not a + # valid model ID, we return the original path and let the + # subsequent logic handle the file not found error. + return local_path + + def get_all_keys(self) -> List[str]: + if self._keys_cache is not None: + return self._keys_cache + + from glob import glob as file_glob + + from safetensors import safe_open + + all_keys = set() + key_to_filename_map = self.key_to_filename_map + if key_to_filename_map: + all_keys.update(key_to_filename_map.keys()) + + if not all_keys: + safetensor_files = file_glob(str(self.path / "*.safetensors")) + if not safetensor_files and not key_to_filename_map: + raise FileNotFoundError( + f"No .safetensors files or index found in {self.model_name_or_path}" + ) + for safetensor_file in safetensor_files: + with safe_open(safetensor_file, framework="pt", device="cpu") as f: + all_keys.update(f.keys()) + + self._keys_cache = sorted(list(all_keys)) + return self._keys_cache + + def load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: + if not keys_to_load: + return {} + + from glob import glob as file_glob + + from safetensors import safe_open + + loaded_tensors = {} + remaining_keys = set(keys_to_load) + key_to_filename_map = self.key_to_filename_map + + if key_to_filename_map: + file_to_keys_map = defaultdict(list) + for key in list(remaining_keys): + if key in key_to_filename_map: + filename = key_to_filename_map[key] + file_to_keys_map[filename].append(key) + + for filename, keys_in_file in file_to_keys_map.items(): + file_path = self.path / filename + if file_path.exists(): + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in keys_in_file: + if key in f.keys(): + loaded_tensors[key] = f.get_tensor(key) + remaining_keys.discard(key) + + if remaining_keys: + safetensor_files = file_glob(str(self.path / "*.safetensors")) + if not safetensor_files and not key_to_filename_map and not loaded_tensors: + raise FileNotFoundError( + f"No .safetensors files found in {self.model_name_or_path} to load keys: {remaining_keys}" + ) + for safetensor_file_path in safetensor_files: + if not remaining_keys: + break + with safe_open(safetensor_file_path, framework="pt", device="cpu") as f: + current_file_keys = f.keys() + for key in list(remaining_keys): + if key in current_file_keys: + loaded_tensors[key] = f.get_tensor(key) + remaining_keys.remove(key) + + if remaining_keys: + raise KeyError( + f"Keys not found in safetensors from {self.model_name_or_path}: {remaining_keys}" + ) + + return loaded_tensors + + def has_glob(self, pattern: str) -> bool: + """ + Efficiently checks if any tensor key matches the given glob pattern. + + This method avoids loading all tensor keys into memory at once. It scans + the checkpoint index or file headers and returns as soon as a match is + found. + + Args: + pattern: The glob pattern to match against tensor keys. + + Returns: + True if a matching key is found, False otherwise. + """ + import fnmatch + + from glob import glob as file_glob + + from safetensors import safe_open + + key_to_filename_map = self.key_to_filename_map + if key_to_filename_map: + for key in key_to_filename_map.keys(): + if fnmatch.fnmatch(key, pattern): + return True + return False + + # If no index map, scan the files directly. + safetensor_files = file_glob(str(self.path / "*.safetensors")) + if not safetensor_files: + return False + + for safetensor_file in safetensor_files: + try: + with safe_open(safetensor_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if fnmatch.fnmatch(key, pattern): + return True + except Exception: + # Ignore files that are not valid safetensors + continue + + return False + + def save_generator( + self, + generator: Iterable[Tuple[str, torch.Tensor]], + output_path: Union[str, Path], + strict: bool = True, + ): + """ + Saves tensors from a generator to `.safetensors` files, preserving the + original sharding structure in a memory-efficient, streaming fashion. + + This method reads the sharding information (which tensor belongs to which + file) from the source checkpoint. It then consumes a generator of tensors, + buffering them in memory only until a complete file shard can be written to + disk. This approach minimizes peak memory usage compared to collecting all + tensors first. + + If the original checkpoint had a `model.safetensors.index.json` file, a new + one will be created for the saved tensors. + + Args: + generator: An iterable of (tensor_name, tensor) tuples. + output_path: The directory where the new safetensor files and index + will be saved. + strict: If True (default), raises a KeyError if the generator + yields a tensor name not found in the original model's + sharding structure. If False, it prints a warning and + skips the tensor. + """ + # In a distributed environment, only rank 0 should write to disk. + # Other ranks must still exhaust the generator to participate in collectives. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + rank = torch.distributed.get_rank() if is_distributed else 0 + + if rank != 0: + # Other ranks must exhaust the generator to avoid hangs in collectives. + for _ in generator: + pass + return + + # Rank 0 proceeds with saving. + from safetensors.torch import save_file + + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + key_to_filename_map = self.key_to_filename_map + all_expected_keys = set(key_to_filename_map.keys()) + + if not key_to_filename_map: + buffered_tensors = dict(generator) + if buffered_tensors: + save_file(buffered_tensors, output_path / "model.safetensors") + return + + filename_to_keys_map = defaultdict(set) + for key, filename in key_to_filename_map.items(): + filename_to_keys_map[filename].add(key) + + files_to_save = dict(filename_to_keys_map) + buffered_tensors = {} + all_yielded_keys = set() + all_saved_keys = set() + + for name, tensor in generator: + all_yielded_keys.add(name) + if name not in all_expected_keys: + if strict: + raise KeyError( + f"Tensor '{name}' from generator not found in the original model structure. " + "To ignore, set strict=False." + ) + else: + print( + f"Warning: tensor '{name}' from generator not found in original model structure. Skipping." + ) + continue + + buffered_tensors[name] = tensor + + # Check if any file is complete and can be saved. + # Iterate over a copy of keys since we might modify the dict. + for filename in list(files_to_save.keys()): + keys_for_file = files_to_save[filename] + if keys_for_file.issubset(buffered_tensors.keys()): + # This shard is complete, save it. + tensors_to_save = {key: buffered_tensors[key] for key in keys_for_file} + + output_file_path = output_path / filename + save_file(tensors_to_save, output_file_path) + + # Free memory by removing saved tensors from the buffer. + for key in keys_for_file: + del buffered_tensors[key] + + all_saved_keys.update(keys_for_file) + del files_to_save[filename] + + # --- Final Reporting --- + if files_to_save: + if strict: + print( + "Warning: The following files could not be saved because the generator did not yield all of their tensors:" + ) + else: + print( + "Warning: The following files are different from the source because the generator did not yield all " + "of their tensors. However they are still saved because strict=False." + ) + for filename, keys_for_file in files_to_save.items(): + missing_for_file = keys_for_file - all_yielded_keys + if missing_for_file: + print(f" - {filename}: missing {len(missing_for_file)} tensors:") + for key in sorted(list(missing_for_file)): + print(f" - {key}") + if not strict: + for filename in list(files_to_save.keys()): + keys_for_file = files_to_save[filename] + tensors_to_save = { + key: buffered_tensors[key] + for key in keys_for_file + if key in buffered_tensors + } + # missing_keys = set(keys_for_file) - tensors_to_save.keys() + # if missing_keys: + # print(f" - {filename}: missing {len(missing_keys)} tensors:") + # for key in sorted(list(missing_keys)): + # print(f" - {key}") + output_file_path = output_path / filename + save_file(tensors_to_save, output_file_path) + + # Free memory by removing saved tensors from the buffer. + for key in tensors_to_save.keys(): + del buffered_tensors[key] + + all_saved_keys.update(keys_for_file) + del files_to_save[filename] + + if buffered_tensors: + print( + f"Warning: {len(buffered_tensors)} tensors were yielded but not saved because their corresponding file shards were incomplete." + ) + + # Final check on whether all original tensors were written. + unsaved_keys = all_expected_keys - all_saved_keys + if not unsaved_keys: + extra_keys = all_yielded_keys - all_expected_keys + if extra_keys: + print( + f"\nSuccess: All tensors from the original checkpoint were written. " + f"({len(extra_keys)} extra tensors from generator were ignored as per strict=False)." + ) + else: + print("\nSuccess: All tensors from the original checkpoint were written.") + else: + print( + f"\nError: {len(unsaved_keys)} tensors from the original checkpoint were not written. See warnings above for details." + ) + + # Create index file for the saved shards. + original_index_file = self.path / "model.safetensors.index.json" + if original_index_file.exists(): + with open(original_index_file, "r") as f: + original_index_data = json.load(f) + + new_weight_map = {key: key_to_filename_map[key] for key in all_saved_keys} + + new_index_data = { + "metadata": original_index_data.get("metadata", {}), + "weight_map": new_weight_map, + } + + output_index_file = output_path / "model.safetensors.index.json" + if new_weight_map: + with open(output_index_file, "w") as f: + json.dump(new_index_data, f, indent=4) + + def _get_key_to_filename_map(self) -> Optional[Dict[str, str]]: + return self._cached_get_key_to_filename_map(self.path) + + @staticmethod + @lru_cache(maxsize=None) + def _cached_get_key_to_filename_map( + model_name_or_path: Union[str, Path] + ) -> Optional[Dict[str, str]]: + """Static, cached method to get the key-to-filename map.""" + index_file = Path(model_name_or_path) / "model.safetensors.index.json" + if index_file.exists(): + with open(index_file, "r") as f: + try: + index_data = json.load(f) + if "weight_map" in index_data and isinstance(index_data["weight_map"], dict): + return index_data["weight_map"] + except json.JSONDecodeError: + return None + return None diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py new file mode 100644 index 0000000000..7ad431f6f9 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py @@ -0,0 +1,603 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from pathlib import Path +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union + +import torch + +from transformers import ( + AutoConfig, + AutoImageProcessor, + AutoModel, + AutoProcessor, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizer, + ProcessorMixin, +) +from transformers.generation.utils import GenerateOutput + +from megatron.nemo_bridge.models.hf_pretrained.base import PreTrainedBase +from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import ( + safe_load_config_with_retry, +) + +# Type variable for generic model type +VLMType = TypeVar("VLMType", bound=PreTrainedModel) + + +class PreTrainedVLM(PreTrainedBase, Generic[VLMType]): + """ + A generic class for Pretrained Vision-Language Models with lazy loading. + + Allows type-safe access to specific VLM implementations like LlavaForConditionalGeneration. + + Examples: + Basic usage with image and text: + >>> from megatron.nemo_bridge.models.hf_pretrained.vlm import PreTrainedVLM + >>> from PIL import Image + >>> + >>> # Create instance - no model loading happens yet + >>> vlm = PreTrainedVLM.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> + >>> # Load an image + >>> image = Image.open("cat.jpg") + >>> + >>> # Process image and text together - processor and model load here + >>> inputs = vlm.process_images_and_text( + ... images=image, + ... text="What do you see in this image?" + ... ) + >>> + >>> # Generate response + >>> outputs = vlm.generate(**inputs, max_new_tokens=100) + >>> print(vlm.decode(outputs[0], skip_special_tokens=True)) + + Batch processing with multiple images: + >>> # Process multiple images with questions + >>> images = [Image.open(f"image_{i}.jpg") for i in range(3)] + >>> questions = [ + ... "What is the main object in this image?", + ... "Describe the scene", + ... "What colors do you see?" + ... ] + >>> + >>> # Process batch + >>> inputs = vlm.process_images_and_text( + ... images=images, + ... text=questions, + ... padding=True + ... ) + >>> + >>> # Generate responses + >>> outputs = vlm.generate(**inputs, max_new_tokens=50) + >>> for i, output in enumerate(outputs): + ... print(f"Image {i+1}: {vlm.decode(output, skip_special_tokens=True)}") + + Using specific VLM types with type hints: + >>> from transformers import LlavaForConditionalGeneration + >>> from megatron.nemo_bridge.models.hf_pretrained.vlm import PreTrainedVLM + >>> + >>> # Type-safe access to Llava-specific features + >>> llava: PreTrainedVLM[LlavaForConditionalGeneration] = PreTrainedVLM.from_pretrained( + ... "llava-hf/llava-1.5-7b-hf", + ... torch_dtype=torch.float16, + ... device="cuda" + ... ) + >>> + >>> # Access model-specific attributes + >>> vision_tower = llava.model.vision_tower # Type-safe access + + Text-only generation (for multimodal models that support it): + >>> # Some VLMs can also work with text-only inputs + >>> text_inputs = vlm.encode_text("Explain what a neural network is.") + >>> outputs = vlm.generate(**text_inputs, max_length=100) + >>> print(vlm.decode(outputs[0], skip_special_tokens=True)) + + Custom preprocessing and generation: + >>> # Load with custom settings + >>> vlm = PreTrainedVLM.from_pretrained( + ... "Qwen/Qwen-VL-Chat", + ... trust_remote_code=True, + ... device_map="auto", + ... load_in_4bit=True + ... ) + >>> + >>> # Custom generation config + >>> from transformers import GenerationConfig + >>> vlm.generation_config = GenerationConfig( + ... max_new_tokens=200, + ... temperature=0.8, + ... top_p=0.95, + ... do_sample=True + ... ) + >>> + >>> # Process with custom parameters + >>> inputs = vlm.process_images_and_text( + ... images=image, + ... text="\\nDescribe this image in detail.", + ... max_length=512 + ... ) + + Manual component setup: + >>> # Create empty instance + >>> vlm = PreTrainedVLM() + >>> + >>> # Load components separately + >>> from transformers import AutoProcessor, AutoModel + >>> vlm.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base") + >>> vlm.model = AutoModel.from_pretrained("microsoft/Florence-2-base") + >>> + >>> # Use for various vision tasks + >>> task_prompt = "" # Object detection task + >>> inputs = vlm.process_images_and_text(images=image, text=task_prompt) + >>> outputs = vlm.generate(**inputs) + + Conversational VLM usage: + >>> # Multi-turn conversation with images + >>> conversation = [] + >>> + >>> # First turn + >>> image1 = Image.open("chart.png") + >>> inputs = vlm.process_images_and_text( + ... images=image1, + ... text="What type of chart is this?" + ... ) + >>> response = vlm.generate(**inputs) + >>> conversation.append(("user", "What type of chart is this?")) + >>> conversation.append(("assistant", vlm.decode(response[0]))) + >>> + >>> # Follow-up question + >>> follow_up = "What is the highest value shown?" + >>> # Format conversation history + new question + >>> full_prompt = format_conversation(conversation) + f"\\nUser: {follow_up}" + >>> inputs = vlm.process_images_and_text(images=image1, text=full_prompt) + >>> response = vlm.generate(**inputs) + """ + + ARTIFACTS = ["processor", "tokenizer", "image_processor"] + OPTIONAL_ARTIFACTS = ["generation_config"] + + def __init__( + self, + model_name_or_path: Optional[Union[str, Path]] = None, + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ): + """ + Initialize a Pretrained VLM with lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on (e.g., 'cuda', 'cpu') + torch_dtype: Data type to load model in (e.g., torch.float16) + trust_remote_code: Whether to trust remote code when loading + **kwargs: Additional arguments passed to component loaders + """ + self._model_name_or_path = model_name_or_path + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.torch_dtype = torch_dtype + self.trust_remote_code = trust_remote_code + super().__init__(**kwargs) + + def _load_model(self) -> VLMType: + """Lazy load and return the model.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load model") + + model_kwargs = {"trust_remote_code": self.trust_remote_code, **self.init_kwargs} + + if self.torch_dtype is not None: + model_kwargs["torch_dtype"] = self.torch_dtype + + # Use provided config if already loaded + config = getattr(self, "_config", None) + if config is not None: + model_kwargs["config"] = config + + # Try AutoModel first for VLMs + model = AutoModel.from_pretrained(self.model_name_or_path, **model_kwargs) + + # Move to device + model = model.to(self.device) + + # Set generation config if available + generation_config = getattr(self, "_generation_config", None) + if generation_config is not None and hasattr(model, "generation_config"): + model.generation_config = generation_config + return model + + def _load_config(self) -> AutoConfig: + """Lazy load and return the model config with thread-safety protection.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load config") + + return safe_load_config_with_retry( + self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs + ) + + def _load_processor(self) -> ProcessorMixin: + """Lazy load and return the processor.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load processor") + + try: + return AutoProcessor.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Some VLMs might not have a processor, fall back to manual loading + raise ValueError( + f"Could not load processor for {self.model_name_or_path}. " + "This model might require manual processor setup." + ) + + def _load_tokenizer(self) -> Optional[PreTrainedTokenizer]: + """ + Lazy load and return the tokenizer. + For VLMs, the tokenizer might be included in the processor. + """ + # Check if tokenizer is available through processor first + processor = getattr(self, "_processor", None) + if processor is not None and hasattr(processor, "tokenizer"): + return processor.tokenizer + + # Try to load tokenizer separately + if self.model_name_or_path is not None: + try: + tokenizer = AutoTokenizer.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + + # Set padding token if not present + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + except Exception: + # Some VLMs include tokenizer only in processor + pass + return None + + def _load_image_processor(self) -> Optional[Any]: + """ + Lazy load and return the image processor. + For VLMs, the image processor might be included in the processor. + """ + # Check if image processor is available through processor first + processor = getattr(self, "_processor", None) + if processor is not None and hasattr(processor, "image_processor"): + return processor.image_processor + + # Try to load image processor separately + if self.model_name_or_path is not None: + try: + return AutoImageProcessor.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Some VLMs include image processor only in processor + pass + return None + + def _load_generation_config(self) -> Optional[GenerationConfig]: + """Lazy load and return the generation config.""" + if self.model_name_or_path is not None: + try: + return GenerationConfig.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Not all models have generation configs + pass + return None + + @property + def model_name_or_path(self) -> Optional[Union[str, Path]]: + """Return the model name or path.""" + return self._model_name_or_path + + @property + def model(self) -> VLMType: + """Lazy load and return the underlying model.""" + if not hasattr(self, "_model"): + self._model = self._load_model() + else: + # Ensure model is on the right device when accessed + if hasattr(self._model, "device") and hasattr(self._model.device, "type"): + current_device = str(self._model.device) + target_device = str(self.device) + if current_device != target_device: + self._model = self._model.to(self.device) + return self._model + + @model.setter + def model(self, value: VLMType): + """Set the model manually.""" + self._model = value + + @property + def processor(self) -> ProcessorMixin: + """Lazy load and return the processor.""" + if not hasattr(self, "_processor"): + self._processor = self._load_processor() + return self._processor + + @processor.setter + def processor(self, value: ProcessorMixin): + """Set the processor manually.""" + self._processor = value + + @property + def tokenizer(self) -> Optional[PreTrainedTokenizer]: + """Lazy load and return the tokenizer.""" + if not hasattr(self, "_tokenizer"): + self._tokenizer = self._load_tokenizer() + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value: PreTrainedTokenizer): + """Set the tokenizer manually.""" + self._tokenizer = value + + @property + def image_processor(self) -> Optional[Any]: + """Lazy load and return the image processor.""" + if not hasattr(self, "_image_processor"): + self._image_processor = self._load_image_processor() + return self._image_processor + + @image_processor.setter + def image_processor(self, value: Any): + """Set the image processor manually.""" + self._image_processor = value + + @property + def generation_config(self) -> Optional[GenerationConfig]: + """Lazy load and return the generation config.""" + if not hasattr(self, "_generation_config"): + self._generation_config = self._load_generation_config() + return self._generation_config + + @generation_config.setter + def generation_config(self, value: GenerationConfig): + """Set the generation config manually.""" + self._generation_config = value + # Update model's generation config if model is loaded + if ( + hasattr(self, "_model") + and self._model is not None + and hasattr(self._model, "generation_config") + ): + self._model.generation_config = value + + @property + def kwargs(self) -> Dict[str, Any]: + """Additional initialization kwargs.""" + return self.init_kwargs + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "PreTrainedVLM[VLMType]": + """ + Create a PreTrainedVLM instance for lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on + torch_dtype: Data type to load model in + trust_remote_code: Whether to trust remote code + **kwargs: Additional arguments for from_pretrained methods + + Returns: + PreTrainedVLM instance configured for lazy loading + """ + return cls( + model_name_or_path=model_name_or_path, + device=device, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + def generate(self, **kwargs) -> Union[torch.LongTensor, GenerateOutput]: + """ + Generate sequences using the model. + + Args: + **kwargs: Arguments for the generate method + + Returns: + Generated sequences + """ + return self.model.generate(**kwargs) + + def __call__(self, *args, **kwargs): + """Forward pass through the model.""" + return self.model(*args, **kwargs) + + def encode_text(self, text: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: + """ + Encode text input using the tokenizer. + + Args: + text: Input text or list of texts + **kwargs: Additional tokenizer arguments + + Returns: + Encoded inputs ready for the model + """ + if self.tokenizer is None: + raise ValueError( + "No tokenizer available. Set tokenizer manually or ensure model has one." + ) + return self.tokenizer(text, return_tensors="pt", **kwargs).to(self.device) + + def decode(self, token_ids: torch.Tensor, **kwargs) -> str: + """ + Decode token IDs to text. + + Args: + token_ids: Token IDs to decode + **kwargs: Additional decoding arguments + + Returns: + Decoded text + """ + if self.tokenizer is None: + raise ValueError( + "No tokenizer available. Set tokenizer manually or ensure model has one." + ) + return self.tokenizer.decode(token_ids, **kwargs) + + def process_images_and_text( + self, images: Optional[Any] = None, text: Optional[Union[str, List[str]]] = None, **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Process images and text together using the processor. + + Args: + images: Input images + text: Input text + **kwargs: Additional processor arguments + + Returns: + Processed inputs ready for the model + """ + inputs = self.processor(images=images, text=text, return_tensors="pt", **kwargs) + # Move all tensors in the dict to the device + if isinstance(inputs, dict): + for key, value in inputs.items(): + if hasattr(value, "to"): + inputs[key] = value.to(self.device) + return inputs + + def save_pretrained(self, save_directory: Union[str, Path]): + """ + Save the model and all components to a directory. + + Args: + save_directory: Directory to save to + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Save model + if hasattr(self, "_model") and self._model is not None: + self._model.save_pretrained(save_path) + + # Save artifacts through base class + self.save_artifacts(save_path) + + def to(self, device: Union[str, torch.device]) -> "PreTrainedVLM[VLMType]": + """ + Move model to a device. + + Args: + device: Target device + + Returns: + Self for chaining + """ + self.device = device + if hasattr(self, "_model") and self._model is not None: + self._model = self._model.to(device) + return self + + def half(self) -> "PreTrainedVLM[VLMType]": + """ + Convert model to half precision. + + Returns: + Self for chaining + """ + if hasattr(self, "_model") and self._model is not None: + self._model = self._model.half() + self.torch_dtype = torch.float16 + return self + + def float(self) -> "PreTrainedVLM[VLMType]": + """ + Convert model to full precision. + + Returns: + Self for chaining + """ + if hasattr(self, "_model") and self._model is not None: + self._model = self._model.float() + self.torch_dtype = torch.float32 + return self + + @property + def dtype(self) -> Optional[torch.dtype]: + """Return the dtype of the model.""" + if hasattr(self, "_model") and self._model is not None: + return next(self._model.parameters()).dtype + return self.torch_dtype + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get the number of parameters in the model. + + Args: + only_trainable: Whether to count only trainable parameters + + Returns: + Number of parameters + """ + if not hasattr(self, "_model") or self._model is None: + return 0 + + if only_trainable: + return sum(p.numel() for p in self._model.parameters() if p.requires_grad) + return sum(p.numel() for p in self._model.parameters()) + + def __repr__(self) -> str: + """String representation.""" + parts = [f"{self.__class__.__name__}("] + + if self._model_name_or_path: + parts.append(f" model_name_or_path='{self._model_name_or_path}',") + + parts.append(f" device='{self.device}',") + + if self.torch_dtype: + parts.append(f" torch_dtype={self.torch_dtype},") + + if self.trust_remote_code: + parts.append(f" trust_remote_code={self.trust_remote_code},") + + # Show loaded components + loaded = [] + if hasattr(self, "_model") and self._model is not None: + loaded.append("model") + if hasattr(self, "_processor") and self._processor is not None: + loaded.append("processor") + if hasattr(self, "_tokenizer") and self._tokenizer is not None: + loaded.append("tokenizer") + if hasattr(self, "_config") and self._config is not None: + loaded.append("config") + + if loaded: + parts.append(f" loaded_components={loaded},") + + parts.append(")") + return "\n".join(parts) diff --git a/flagscale/train/megatron/nemo_bridge/models/model_provider.py b/flagscale/train/megatron/nemo_bridge/models/model_provider.py new file mode 100644 index 0000000000..d11f868488 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/model_provider.py @@ -0,0 +1,710 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import abc +import os + +from pathlib import Path +from typing import Callable, Generic, TypedDict, TypeVar, Union + +try: + from typing import Unpack +except ImportError: + try: + from typing_extensions import Unpack + except ImportError: + from unittest.mock import MagicMock + + Unpack = MagicMock() + + +from typing import Callable + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.distributed import ( + DistributedDataParallel, + DistributedDataParallelConfig, + FullyShardedDataParallel, + TorchFullyShardedDataParallel, +) +from megatron.core.enums import ModelType +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.module import Float16Module, MegatronModule +from megatron.core.utils import get_model_config + +from megatron.nemo_bridge.models.config import from_hf_pretrained, save_hf_pretrained +from megatron.nemo_bridge.utils.common_utils import get_local_rank_preinit +from megatron.nemo_bridge.utils.instantiate_utils import InstantiationMode + +try: + from megatron.core.fp8_utils import correct_amax_history_if_needed +except ImportError: + correct_amax_history_if_needed = None + + +ModelT = TypeVar("ModelT", bound=MegatronModule) + + +class ModelProviderMixin(abc.ABC, Generic[ModelT]): + """A mixin that implements the ModelProvider pattern for Megatron Bridge. + + The ModelProvider pattern solves ecosystem fragmentation by providing a standardized + way to instantiate models. This mixin provides a consistent `provide_distributed_model()` method + that handles the complexity of distributed training setup, along with HuggingFace-inspired + `.from_hf_pretrained()` and `.save_hf_pretrained()` for interoperability. + + For advanced customization, multiple hooks can be registered via `register_pre_wrap_hook` + and `register_post_wrap_hook`. These hooks allow modifying the model before and after + it's wrapped for distributed training (e.g., freezing layers, logging). The composed + hooks can be accessed via the `pre_wrap_hook` and `post_wrap_hook` properties. + + Subclasses must implement the `provide` method to define the model architecture. + """ + + CONFIG_NAME = "mhub_model.json" + DEFAULT_CONFIG_FORMAT = "json" + + @abc.abstractmethod + def provide( + self, + pre_process: bool | None = None, + post_process: bool | None = None, + vp_stage: int | None = None, + ) -> ModelT: + """Abstract method to provide the model instance. + + Subclasses must implement this method to return the specific Megatron model + (e.g., `GPTModel`) with its configuration. This method is called by `get_model` + to obtain the base model before it is wrapped for distributed training. + + Args: + pre_process (bool, optional): Whether to include the embedding layer (used with pipeline parallelism). + post_process (bool, optional): Whether to include the output layer (used with pipeline parallelism). + vp_stage (int, optional): The virtual pipeline stage of the model. + + Returns: + ModelT: The Megatron model instance. + """ + pass + + def provide_distributed_model( + self, + ddp_config: DistributedDataParallelConfig | None = None, + model_type=ModelType.encoder_or_decoder, + overlap_param_gather_with_optimizer_step: bool = False, + fp16: bool | None = None, + bf16: bool | None = None, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, + wrap_with_ddp: bool = True, + data_parallel_random_init: bool = True, + use_cpu_initialization: None | bool = False, + init_model_with_meta_device: bool | None = None, + pre_wrap_hook: ( + Union[ + Callable[[list[MegatronModule]], list[MegatronModule]], + list[Callable[[list[MegatronModule]], list[MegatronModule]]], + ] + | None + ) = None, + post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None, + ) -> list[ModelT]: + """Instantiate and wrap the model for distributed training. + + This method retrieves the model from `provide` and sets up the distributed + environment, including data-parallel and model-parallel configurations. + It's the primary entry point for creating a model that's ready for use + in the Megatron ecosystem. + + Args: + ddp_config: Configuration for distributed data parallel. + model_type: Type of model (encoder, decoder, or both). + overlap_param_gather_with_optimizer_step: Whether to overlap param gathering. + fp16: Override FP16 setting. + bf16: Override BF16 setting. + use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel + use_torch_fsdp2: Use PyTorch FSDP2 instead of custom DDP. + wrap_with_ddp: Whether to wrap model with DDP. + data_parallel_random_init: Initialize parameters randomly across data parallel ranks. + use_cpu_initialization: Initialize model on CPU. + init_model_with_meta_device: Initialize model on meta device. + pre_wrap_hook: A single callable or list of callables to modify the model before it's wrapped. + If provided, this will override all hooks registered via `register_pre_wrap_hook`. + If a list is provided, hooks will be executed in order. + post_wrap_hook: A single callable to modify the model after it's wrapped. If provided, + this will override all hooks registered via `register_post_wrap_hook`. + + Returns: + A list containing the wrapped model instance. + """ + if wrap_with_ddp and not ddp_config: + raise ValueError("ddp_config is required when wrap_with_ddp is True") + + if not torch.distributed.is_initialized(): + os.environ["RANK"] = os.environ.get("RANK", "0") + os.environ["WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1") + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12355") + torch.cuda.set_device(get_local_rank_preinit()) + torch.distributed.init_process_group("nccl") + + if not parallel_state.is_initialized(): + print("Model parallel not initialized, initializing...") + self.initialize_model_parallel(seed=0) + + # Convert list of hooks to a single composed callable + if isinstance(pre_wrap_hook, list): + + def composed_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]: + for hook in pre_wrap_hook: + model = hook(model) + return model + + final_pre_wrap_hook = composed_pre_wrap_hook + else: + final_pre_wrap_hook = pre_wrap_hook or self.pre_wrap_hook + final_post_wrap_hook = post_wrap_hook or self.post_wrap_hook + + model = get_model( + self, + ddp_config=ddp_config, + model_type=model_type, + overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, + fp16=fp16, + bf16=bf16, + use_megatron_fsdp=use_megatron_fsdp, + use_torch_fsdp2=use_torch_fsdp2, + wrap_with_ddp=wrap_with_ddp, + data_parallel_random_init=data_parallel_random_init, + use_cpu_initialization=use_cpu_initialization, + init_model_with_meta_device=init_model_with_meta_device, + pre_wrap_hook=final_pre_wrap_hook, + ) + + if final_post_wrap_hook: + _model = final_post_wrap_hook(model) + if _model is not None: + model = _model + + return model + + def initialize_model_parallel( + self, seed: int | None = None, seed_kwargs: dict | None = None, **model_parallel_kwargs + ) -> None: + """Initializes model parallelism and sets the random seed. + + This is a convenience method that sets up tensor, pipeline, and other + forms of model parallelism based on the attributes of the provider instance. + + Args: + seed: The random seed for model parallel RNG. + seed_kwargs: Additional arguments for `model_parallel_cuda_manual_seed`. + **model_parallel_kwargs: Additional arguments for `parallel_state.initialize_model_parallel`. + """ + if not torch.distributed.is_initialized(): + torch.cuda.set_device(get_local_rank_preinit()) + torch.distributed.init_process_group("nccl") + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=getattr(self, "tensor_model_parallel_size", 1), + pipeline_model_parallel_size=getattr(self, "pipeline_model_parallel_size", 1), + virtual_pipeline_model_parallel_size=getattr( + self, "virtual_pipeline_model_parallel_size", None + ), + context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, + expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, + expert_tensor_parallel_size=getattr(self, "expert_tensor_parallel_size", None), + **model_parallel_kwargs, + ) + if seed is not None: + model_parallel_cuda_manual_seed(seed, **(seed_kwargs or {})) + + @property + def meta_model(self) -> list[ModelT]: + """Returns the model instantiated on the meta device for inspection. + + This is useful for examining the model architecture without allocating + GPU memory. + """ + return self(wrap_with_ddp=False, init_model_with_meta_device=True) + + @property + def pre_wrap_hook(self) -> Callable[[list[MegatronModule]], list[MegatronModule]] | None: + """A composed callable of all registered pre-wrap hooks. + + This read-only property returns a single function that executes all registered + pre-wrap hooks in order. The hook is applied before the model is passed to the DDP + wrapper and can be used for tasks like freezing layers or altering model structure. + + Use `register_pre_wrap_hook` to add a hook to the execution chain. + + Returns: + A callable that executes all registered pre-wrap hooks in order, or None if no + hooks are registered. + """ + if not hasattr(self, "_pre_wrap_hooks") or not self._pre_wrap_hooks: + return None + + def composed_hook(model: list[MegatronModule]) -> list[MegatronModule]: + for hook in self._pre_wrap_hooks: + model = hook(model) + return model + + return composed_hook + + def register_pre_wrap_hook( + self, hook: Callable[[list[MegatronModule]], list[MegatronModule]], prepend: bool = False + ) -> None: + """Registers a hook to be executed before the model is wrapped. + + The hook should be a callable that accepts a list of `MegatronModule` instances + and returns a (potentially modified) list of `MegatronModule` instances. + + Args: + hook: The hook to register. + prepend: If True, the hook is inserted at the beginning of the execution + chain. Otherwise, it is appended to the end. + """ + if not hasattr(self, "_pre_wrap_hooks"): + self._pre_wrap_hooks = [] + if prepend: + self._pre_wrap_hooks.insert(0, hook) + else: + self._pre_wrap_hooks.append(hook) + + @property + def post_wrap_hook(self) -> Callable[[list[MegatronModule]], list[MegatronModule]] | None: + """A composed callable of all registered post-wrap hooks. + + This read-only property returns a single function that executes all registered + post-wrap hooks in order. The hook is applied after the model has been wrapped by + DDP and is useful for tasks like logging or attaching custom attributes. + + Use `register_post_wrap_hook` to add a hook to the execution chain. + + Returns: + A callable that executes all registered post-wrap hooks in order, or None if no + hooks are registered. + """ + if not hasattr(self, "_post_wrap_hooks") or not self._post_wrap_hooks: + return None + + def composed_hook(model: list[MegatronModule]) -> list[MegatronModule]: + for hook in self._post_wrap_hooks: + model = hook(model) + return model + + return composed_hook + + def register_post_wrap_hook( + self, hook: Callable[[list[MegatronModule]], list[MegatronModule]], prepend: bool = False + ) -> None: + """Registers a hook to be executed after the model is wrapped. + + The hook should be a callable that accepts a list of `MegatronModule` instances + and returns a (potentially modified) list of `MegatronModule` instances. + + Args: + hook: The hook to register. + prepend: If True, the hook is inserted at the beginning of the execution + chain. Otherwise, it is appended to the end. + """ + if not hasattr(self, "_post_wrap_hooks"): + self._post_wrap_hooks = [] + if prepend: + self._post_wrap_hooks.insert(0, hook) + else: + self._post_wrap_hooks.append(hook) + + @classmethod + def from_hf_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + trust_remote_code: bool = False, + mode: InstantiationMode | None = None, + config_name: str | None = None, + **kwargs, + ): + """Load a pretrained model configuration from a directory or HuggingFace Hub. + + This method provides a HuggingFace-inspired interface for loading model + configurations, enabling interoperability. + + Args: + pretrained_model_name_or_path: The path to a local directory or a + HuggingFace model identifier. + trust_remote_code: Whether to trust remote code when loading. + mode: The instantiation mode (e.g., `LENIENT`). + config_name: The name of the configuration file (without extension). + **kwargs: Additional keyword arguments for `from_hf_pretrained`. + + Returns: + An instance of the model provider with the loaded configuration. + """ + if config_name is None: + config_name = cls.CONFIG_NAME.rsplit(".", 1)[0] + if mode is None: + mode = InstantiationMode.LENIENT + return from_hf_pretrained( + cls, + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + mode=mode, + config_name=config_name, + **kwargs, + ) + + def save_hf_pretrained( + self, + save_directory: str | Path, + config_format: str | None = None, + config_name: str | None = None, + **kwargs, + ): + """Save the model configuration to a directory. + + This method provides a HuggingFace-inspired interface for saving model + configurations, enabling interoperability. + + Args: + save_directory: The directory where the configuration will be saved. + config_format: The format for the configuration file (e.g., `json` or `yaml`). + config_name: The name of the configuration file (without extension). + **kwargs: Additional keyword arguments for `save_hf_pretrained`. + """ + if config_name is None: + config_name = self.CONFIG_NAME.rsplit(".", 1)[0] + if config_format is None: + config_format = self.DEFAULT_CONFIG_FORMAT + return save_hf_pretrained( + self, save_directory, config_format=config_format, config_name=config_name, **kwargs + ) + + +class GetModelKwargs(TypedDict, total=False): + """Keyword arguments for the `provide_distributed_model` method. + + Attributes: + ddp_config: Configuration for distributed data parallel. + model_type: Type of model (encoder, decoder, or both). + overlap_param_gather_with_optimizer_step: Whether to overlap param gathering. + fp16: Override FP16 setting. + bf16: Override BF16 setting. + use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel + use_torch_fsdp2: Use PyTorch FSDP2 instead of custom DDP. + wrap_with_ddp: Whether to wrap model with DDP. + data_parallel_random_init: Initialize parameters randomly across data parallel ranks. + use_cpu_initialization: Initialize model on CPU. + init_model_with_meta_device: Initialize model on meta device. + pre_wrap_hook: A single callable or list of callables that overrides all registered pre-wrap hooks. + post_wrap_hook: A single callable that overrides all registered post-wrap hooks. + """ + + ddp_config: DistributedDataParallelConfig | None + model_type: ModelType + overlap_param_gather_with_optimizer_step: bool + fp16: bool | None + bf16: bool | None + use_megatron_fsdp: bool + use_torch_fsdp2: bool + wrap_with_ddp: bool + data_parallel_random_init: bool + use_cpu_initialization: bool | None + init_model_with_meta_device: bool | None + pre_wrap_hook: ( + Union[ + Callable[[list[MegatronModule]], list[MegatronModule]], + list[Callable[[list[MegatronModule]], list[MegatronModule]]], + ] + | None + ) + post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None + + +class ModelParallelKwargs(TypedDict, total=False): + """Model-parallel override kwargs. + + Attributes map to `TransformerConfig`/provider fields that control parallelism. + Only provided values are applied as overrides. + """ + + tensor_model_parallel_size: int + pipeline_model_parallel_size: int + context_parallel_size: int + expert_model_parallel_size: int + expert_tensor_parallel_size: int + moe_extended_tp: bool + sequence_parallel: bool + virtual_pipeline_model_parallel_size: int | None + hierarchical_context_parallel_sizes: list[int] | None + + +def get_model( + model_provider: ModelProviderMixin, + ddp_config: DistributedDataParallelConfig, + model_type=ModelType.encoder_or_decoder, + overlap_param_gather_with_optimizer_step: bool = False, + fp16: bool | None = None, + bf16: bool | None = None, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, + wrap_with_ddp: bool = True, + data_parallel_random_init: bool = True, + use_cpu_initialization: None | bool = False, + init_model_with_meta_device: bool | None = None, + pre_wrap_hook: ( + Union[ + Callable[[list[MegatronModule]], list[MegatronModule]], + list[Callable[[list[MegatronModule]], list[MegatronModule]]], + ] + | None + ) = None, +) -> list[MegatronModule]: + """Create and configure a model for distributed training. + + This function handles the complete model creation pipeline including: + - Model instantiation with proper pipeline parallel configuration + - GPU memory allocation + - Mixed precision (FP16/BF16) wrapping + - Float8 tensor correction + - Distributed Data Parallel (DDP) wrapping + + Args: + model_provider: ModelProviderMixin instance that creates the model. + Uses the provide() method with optional pre_process(bool), post_process(bool), + vp_stage(int) arguments for pipeline parallelism + ddp_config: Configuration for distributed data parallel training + model_type: Type of model (encoder, decoder, or encoder_and_decoder) + overlap_param_gather_with_optimizer_step: Whether to overlap parameter + gathering with optimizer step for performance optimization + fp16: Enable FP16 mixed precision training. If None, uses model config + bf16: Enable BF16 mixed precision training. If None, uses model config + use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel + use_torch_fsdp2: Use PyTorch's Fully Sharded Data Parallel v2 + wrap_with_ddp: Whether to wrap the model with DDP + data_parallel_random_init: Whether to use random initialization for + data parallel ranks (vs broadcasting from rank 0) + use_cpu_initialization: Whether to initialize model on CPU to save GPU memory + init_model_with_meta_device: Whether to initialize the model on the meta device + pre_wrap_hook: A callable or list of callables that takes a list of `MegatronModule` + and returns a modified list, or `None` to clear the hook. If a list is provided, + hooks will be executed in order. + + Returns: + list[MegatronModule]: List of model modules. Contains multiple modules + when using virtual pipeline parallelism, otherwise a single module + """ + if fp16: + model_provider.fp16 = fp16 + if bf16: + model_provider.bf16 = bf16 + + model_provider.use_cpu_initialization = ( + use_cpu_initialization if use_cpu_initialization else False + ) + if init_model_with_meta_device: + model_provider.init_model_with_meta_device = True + with torch.device("meta"): + model = _create_model(model_provider, model_type) + else: + model = _create_model(model_provider, model_type) + + if pre_wrap_hook: + if isinstance(pre_wrap_hook, list): + # Execute hooks in order + for hook in pre_wrap_hook: + if not callable(hook): + raise RuntimeError("All elements in pre_wrap_hook list must be callable") + _model = hook(model) + if _model is not None: + model = _model + else: + if not callable(pre_wrap_hook): + raise RuntimeError("pre_wrap_hook must be a callable or a list of callables") + _model = pre_wrap_hook(model) + if _model is not None: + model = _model + + # Set tensor model parallel attributes if not set + # In case pre_wrap_hook augmented the model (e.g. adding PEFT adapters) + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + _print_num_params(model) + + model_config = get_model_config(model[0]) + + # GPU allocation. + # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory + # in the fully_shard function of FSDP2 instead. + if ( + not use_torch_fsdp2 + and not model_config.use_cpu_initialization + and not model_config.init_model_with_meta_device + ): + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + if model_config.fp16 or model_config.bf16: + model = [Float16Module(model_config, model_module) for model_module in model] + + if correct_amax_history_if_needed is not None: + correct_amax_history_if_needed(model) + + if wrap_with_ddp: + model = _ddp_wrap( + model, + data_parallel_random_init, + ddp_config, + overlap_param_gather_with_optimizer_step, + use_megatron_fsdp=use_megatron_fsdp, + use_torch_fsdp2=use_torch_fsdp2, + ) + + return model + + +def _create_model( + model_provider: ModelProviderMixin, model_type: ModelType +) -> list[MegatronModule]: + """Create model instances with appropriate pipeline parallel configuration. + + Handles virtual pipeline parallelism (VPP) by creating multiple model + instances when needed. Sets pre_process and post_process flags based on + pipeline parallel rank. + + Args: + model_provider: ModelProviderMixin instance that creates the model + model_type: ModelType enum indicating encoder, decoder, or both + + Returns: + list: List of model instances. Multiple instances for VPP, otherwise single + """ + + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + assert ( + model_type != ModelType.encoder_and_decoder + ), "Interleaved schedule not supported for model with both encoder and decoder" + model = [] + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + pre_process = parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) + post_process = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) + this_model = model_provider.provide( + pre_process=pre_process, post_process=post_process, vp_stage=i + ) + this_model.model_type = model_type + model.append(this_model) + else: + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + if model_type == ModelType.encoder_and_decoder: + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + rank = parallel_state.get_pipeline_model_parallel_rank() + first_decoder_rank = parallel_state.get_pipeline_model_parallel_decoder_start() + world_size = parallel_state.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == first_decoder_rank + post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1)) + model = model_provider.provide() + else: + model = model_provider.provide(pre_process=pre_process, post_process=post_process) + model.model_type = model_type + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + return model + + +def _ddp_wrap( + model: list[MegatronModule], + data_parallel_random_init: bool, + ddp_config: DistributedDataParallelConfig, + overlap_param_gather_with_optimizer_step: bool, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, +) -> list[MegatronModule]: + """Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP). + + Args: + model: List of model modules to wrap + use_torch_fsdp2: Whether to use PyTorch FSDP v2 instead of DDP + data_parallel_random_init: Whether to broadcast parameters from rank 0 + ddp_config: Configuration for distributed data parallel + overlap_param_gather_with_optimizer_step: Whether to disable bucketing + for overlapping parameter gathering with optimizer step + + Returns: + list[MegatronModule]: List of DDP/FSDP wrapped model modules + """ + if use_megatron_fsdp: + DP = FullyShardedDataParallel + if use_torch_fsdp2: + raise ValueError( + "Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported." + ) + elif use_torch_fsdp2: + DP = TorchFullyShardedDataParallel + else: + DP = DistributedDataParallel + + model = [ + DP( + config=get_model_config(model_chunk), + ddp_config=ddp_config, + module=model_chunk, + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step, + ) + for (model_chunk_idx, model_chunk) in enumerate(model) + ] + + # Broadcast params from data parallel src rank to other data parallel ranks. + if data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + + return model + + +def _print_num_params(model: list[MegatronModule]) -> None: + """Print the number of parameters in the model on rank 0. + + Only prints on data parallel rank 0 to avoid duplicate output. + Shows parameter count per (tensor parallel, pipeline parallel) rank. + + Args: + model: List of model modules to count parameters from + """ + if ( + parallel_state.get_data_parallel_rank() == 0 + and parallel_state.get_context_parallel_rank() == 0 + ): + print( + " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_pipeline_model_parallel_rank(), + sum( + [ + sum([p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ), + ), + flush=True, + ) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py b/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py new file mode 100644 index 0000000000..34cefc11d9 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from megatron.nemo_bridge.models.qwen.qwen2_bridge import Qwen2Bridge # noqa: F401 +from megatron.nemo_bridge.models.qwen.qwen3_bridge import Qwen3Bridge # noqa: F401 +from megatron.nemo_bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge # noqa: F401 +from megatron.nemo_bridge.models.qwen.qwen_provider import ( + Qwen2ModelProvider, + Qwen2ModelProvider1P5B, + Qwen2ModelProvider7B, + Qwen2ModelProvider72B, + Qwen2ModelProvider500M, + Qwen3ModelProvider, + Qwen3ModelProvider1P7B, + Qwen3ModelProvider4B, + Qwen3ModelProvider8B, + Qwen3ModelProvider14B, + Qwen3ModelProvider32B, + Qwen3ModelProvider600M, + Qwen3MoEModelProvider, + Qwen3MoEModelProvider30B_A3B, + Qwen3MoEModelProvider235B_A22B, + Qwen25ModelProvider1P5B, + Qwen25ModelProvider3B, + Qwen25ModelProvider7B, + Qwen25ModelProvider14B, + Qwen25ModelProvider32B, + Qwen25ModelProvider72B, + Qwen25ModelProvider500M, +) + +__all__ = [ + "Qwen2ModelProvider", + "Qwen2ModelProvider500M", + "Qwen2ModelProvider1P5B", + "Qwen2ModelProvider7B", + "Qwen2ModelProvider72B", + "Qwen25ModelProvider500M", + "Qwen25ModelProvider1P5B", + "Qwen25ModelProvider3B", + "Qwen25ModelProvider7B", + "Qwen25ModelProvider14B", + "Qwen25ModelProvider32B", + "Qwen25ModelProvider72B", + "Qwen3ModelProvider", + "Qwen3ModelProvider600M", + "Qwen3ModelProvider1P7B", + "Qwen3ModelProvider4B", + "Qwen3ModelProvider8B", + "Qwen3ModelProvider14B", + "Qwen3ModelProvider32B", + "Qwen3MoEModelProvider", + "Qwen3MoEModelProvider30B_A3B", + "Qwen3MoEModelProvider235B_A22B", +] diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py new file mode 100644 index 0000000000..84d6890ce5 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen2ForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.nemo_bridge.models.qwen.qwen_provider import Qwen2ModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen2ForCausalLM, target=GPTModel) +class Qwen2Bridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen2 Causal LM. + + This bridge handles the conversion between HuggingFace Qwen2ForCausalLM + and Megatron-Core GPTModel formats, including weight mappings and + configuration translation. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen2-7B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen2ModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen2ModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + add_qkv_bias=True, # Qwen2 has bias in QKV projections + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + # QKV bias: Combine separate Q, K, V biases into single QKV bias (Qwen2 specific) + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.bias", + q="model.layers.*.self_attn.q_proj.bias", + k="model.layers.*.self_attn.k_proj.bias", + v="model.layers.*.self_attn.v_proj.bias", + ), + # Gated MLP: Combine gate and up projection matrices into single FC1 matrix + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py new file mode 100644 index 0000000000..263fe26a32 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen3ForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.nemo_bridge.models.qwen.qwen_provider import Qwen3ModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class Qwen3Bridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3 Causal LM. + + This bridge handles the conversion between HuggingFace Qwen2ForCausalLM + (used for Qwen3 models) and Megatron-Core GPTModel formats. Qwen3 differs + from Qwen2 by using QK layernorm. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-1.7B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3ModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen3ModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + qk_layernorm=True, # Qwen3 uses QK layernorm + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", # Qwen3 specific + "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", # Qwen3 specific + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + # Note: Qwen3 does NOT have bias in QKV projections (unlike Qwen2) + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + # Gated MLP: Combine gate and up projection matrices into single FC1 matrix + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py new file mode 100755 index 0000000000..f9cf6fabde --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen3MoeForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.nemo_bridge.models.qwen.qwen_provider import Qwen3MoEModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3MoeForCausalLM, target=GPTModel) +class Qwen3MoEBridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3 MoE Causal LM. + + This bridge handles the conversion between HuggingFace Qwen3MoeForCausalLM + and Megatron-Core GPTModel formats. Qwen3 MoE models use mixture of experts + architecture with QK layernorm. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-235B-A22B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3MoEModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen3MoEModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + moe_ffn_hidden_size=hf_config.moe_intermediate_size, # Maps to moe_intermediate_size in HF + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + num_moe_experts=hf_config.num_experts, + moe_router_topk=hf_config.num_experts_per_tok, # Maps to num_experts_per_tok in HF + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + qk_layernorm=True, # Qwen3 MoE uses QK layernorm + moe_grouped_gemm=True, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight", + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", + "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + # Note: Qwen3 MoE does NOT have bias in QKV projections + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="model.layers.*.mlp.experts.*.gate_proj.weight", + up="model.layers.*.mlp.experts.*.up_proj.weight", + ), + AutoMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc2.weight*", + hf_param="model.layers.*.mlp.experts.*.down_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py new file mode 100644 index 0000000000..efc7b6ee0c --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py @@ -0,0 +1,393 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import logging + +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.nn.functional as F + +from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider + +logger = logging.getLogger(__name__) + + +@dataclass +class Qwen2ModelProvider(GPTModelProvider): + """Base model provider for Qwen 2 Models.""" + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = True + seq_length: int = 4096 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + autocast_dtype: torch.dtype = torch.bfloat16 + params_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + + +# ============================================================================= +# Qwen 2 Model Providers +# ============================================================================= + + +@dataclass +class Qwen2ModelProvider500M(Qwen2ModelProvider): + """ + Config for Qwen 2 0.5B: https://huggingface.co/Qwen/Qwen2-0.5B + """ + + num_layers: int = 24 + hidden_size: int = 896 + num_attention_heads: int = 14 + num_query_groups: int = 2 + ffn_hidden_size: int = 4864 + share_embeddings_and_output_weights: bool = True + seq_length: int = 32768 + + +@dataclass +class Qwen2ModelProvider1P5B(Qwen2ModelProvider): + """ + Config for Qwen 2 1.5B: https://huggingface.co/Qwen/Qwen2-1.5B + """ + + num_layers: int = 28 + hidden_size: int = 1536 + num_attention_heads: int = 12 + num_query_groups: int = 2 + ffn_hidden_size: int = 8960 + seq_length: int = 32768 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen2ModelProvider7B(Qwen2ModelProvider): + """ + Config for Qwen 2 7B: https://huggingface.co/Qwen/Qwen2-7B + """ + + num_layers: int = 28 + hidden_size: int = 3584 + num_attention_heads: int = 28 + num_query_groups: int = 4 + ffn_hidden_size: int = 18944 + vocab_size: int = 152064 + seq_length: int = 32768 + + +@dataclass +class Qwen2ModelProvider72B(Qwen2ModelProvider): + """ + Config for Qwen 2 72B: https://huggingface.co/Qwen/Qwen2-72B + """ + + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 29568 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +# ============================================================================= +# Qwen 2.5 Model Providers +# ============================================================================= + + +@dataclass +class Qwen25ModelProvider500M(Qwen2ModelProvider): + """ + Config for Qwen 2.5 0.5B: https://huggingface.co/Qwen/Qwen2.5-0.5B + """ + + num_layers: int = 24 + hidden_size: int = 896 + num_attention_heads: int = 14 + num_query_groups: int = 2 + ffn_hidden_size: int = 4864 + share_embeddings_and_output_weights: bool = True + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider1P5B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 1.5B: https://huggingface.co/Qwen/Qwen2.5-1.5B + """ + + num_layers: int = 28 + hidden_size: int = 1536 + num_attention_heads: int = 12 + num_query_groups: int = 2 + ffn_hidden_size: int = 8960 + seq_length: int = 32768 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen25ModelProvider3B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 3B: https://huggingface.co/Qwen/Qwen2.5-3B + """ + + num_layers: int = 36 + hidden_size: int = 2048 + num_attention_heads: int = 16 + num_query_groups: int = 2 + ffn_hidden_size: int = 11008 + vocab_size: int = 151936 + share_embeddings_and_output_weights: bool = True + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider7B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 7B: https://huggingface.co/Qwen/Qwen2.5-7B + """ + + num_layers: int = 28 + hidden_size: int = 3584 + num_attention_heads: int = 28 + num_query_groups: int = 4 + ffn_hidden_size: int = 18944 + vocab_size: int = 152064 + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider14B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 14B: https://huggingface.co/Qwen/Qwen2.5-14B + """ + + num_layers: int = 48 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 8 + ffn_hidden_size: int = 13824 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider32B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 32B: https://huggingface.co/Qwen/Qwen2.5-32B + """ + + num_layers: int = 64 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 8 + ffn_hidden_size: int = 27648 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider72B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 72B: https://huggingface.co/Qwen/Qwen2.5-72B + """ + + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 29568 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +# ============================================================================= +# Qwen 3 Model Provider (based on GPTProvider) +# ============================================================================= + + +@dataclass +class Qwen3ModelProvider(GPTModelProvider): + """Base model provider for Qwen 3 Models.""" + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + kv_channels: Optional[int] = 128 + num_query_groups: int = 8 + seq_length: int = 40960 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + autocast_dtype: torch.dtype = torch.bfloat16 + params_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + + +@dataclass +class Qwen3ModelProvider600M(Qwen3ModelProvider): + """ + Config for Qwen 3 0.6B: https://huggingface.co/Qwen/Qwen3-0.6B + """ + + num_layers: int = 28 + hidden_size: int = 1024 + num_attention_heads: int = 16 + ffn_hidden_size: int = 3072 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen3ModelProvider1P7B(Qwen3ModelProvider): + """ + Config for Qwen 3 1.7B: https://huggingface.co/Qwen/Qwen3-1.7B + """ + + num_layers: int = 28 + hidden_size: int = 2048 + num_attention_heads: int = 16 + ffn_hidden_size: int = 6144 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen3ModelProvider4B(Qwen3ModelProvider): + """ + Config for Qwen 3 4B: https://huggingface.co/Qwen/Qwen3-4B + """ + + num_layers: int = 36 + hidden_size: int = 2560 + num_attention_heads: int = 32 + ffn_hidden_size: int = 9728 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen3ModelProvider8B(Qwen3ModelProvider): + """ + Config for Qwen 3 8B: https://huggingface.co/Qwen/Qwen3-8B + """ + + num_layers: int = 36 + hidden_size: int = 4096 + num_attention_heads: int = 32 + ffn_hidden_size: int = 12288 + + +@dataclass +class Qwen3ModelProvider14B(Qwen3ModelProvider): + """ + Config for Qwen 3 14B: https://huggingface.co/Qwen/Qwen3-14B + """ + + num_layers: int = 40 + hidden_size: int = 5120 + num_attention_heads: int = 40 + ffn_hidden_size: int = 17408 + + +@dataclass +class Qwen3ModelProvider32B(Qwen3ModelProvider): + """ + Config for Qwen 3 32B: https://huggingface.co/Qwen/Qwen3-32B + """ + + num_layers: int = 64 + hidden_size: int = 5120 + num_attention_heads: int = 64 + ffn_hidden_size: int = 25600 + + +# ============================================================================= +# Qwen 3 MoE Model Provider (based on GPTProvider) +# ============================================================================= + + +@dataclass +class Qwen3MoEModelProvider(GPTModelProvider): + """Base provider for Qwen 3 MoE Models.""" + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + kv_channels: Optional[int] = 128 + num_query_groups: int = 8 + seq_length: int = 40960 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + autocast_dtype: torch.dtype = torch.bfloat16 + params_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + + # MoE specific parameters + num_moe_experts: int = 128 + moe_router_load_balancing_type: str = "aux_loss" + moe_aux_loss_coeff: float = 1e-3 + moe_router_topk: int = 8 + moe_router_pre_softmax: bool = False + moe_grouped_gemm: bool = True + moe_token_dispatcher_type: str = "alltoall" + moe_permute_fusion: bool = True + + +@dataclass +class Qwen3MoEModelProvider30B_A3B(Qwen3MoEModelProvider): + """ + Provider for Qwen 3 30B-A3B: https://huggingface.co/Qwen/Qwen3-30B-A3B + """ + + num_layers: int = 48 + hidden_size: int = 2048 + num_attention_heads: int = 32 + num_query_groups: int = 4 + ffn_hidden_size: int = 6144 + moe_ffn_hidden_size: int = 768 + + +@dataclass +class Qwen3MoEModelProvider235B_A22B(Qwen3MoEModelProvider): + """ + Provider for Qwen 3 235B-A22B: https://huggingface.co/Qwen/Qwen3-235B-A22B + """ + + num_layers: int = 94 + hidden_size: int = 4096 + num_attention_heads: int = 64 + num_query_groups: int = 4 + ffn_hidden_size: int = 12288 + moe_ffn_hidden_size: int = 1536 diff --git a/flagscale/train/megatron/nemo_bridge/models/transformer_config.py b/flagscale/train/megatron/nemo_bridge/models/transformer_config.py new file mode 100644 index 0000000000..4a3daf77fc --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/transformer_config.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +"""Bridge wrapper classes for Megatron Core transformer configurations. + +These classes provide deferred post-initialization to support the Bridge configuration +override system while maintaining compatibility with Megatron Core's post_init behavior. +""" + +from dataclasses import dataclass + +from megatron.core.transformer.transformer_config import ( + MLATransformerConfig as MCoreMLATransformerConfig, + TransformerConfig as MCoreTransformerConfig, +) + + +@dataclass +class TransformerConfig(MCoreTransformerConfig): + """Megatron Core TransformerConfig with deferred post-init. + + This class inherits from Megatron Core's TransformerConfig but defers the + execution of post_init() until finalize() is explicitly called. This allows + for field modifications after construction but before computed fields are + calculated. + + Usage: + # Create config with deferred post-init + config = TransformerConfig(num_layers=32, hidden_size=4096) + + # Modify fields as needed + config.seq_length = 8192 + config.tensor_model_parallel_size = 2 + + # Finalize to compute derived fields + config.finalize() + """ + + def __post_init__(self) -> None: + """Skip MCore post_init during initial construction. + + The original post_init logic is deferred until finalize() is called. + This allows for field modifications after construction without + invalidating computed fields. + """ + pass + + def finalize(self) -> None: + """Execute the deferred MCore post-init logic. + + This method calls the original Megatron Core TransformerConfig.__post_init__() + to compute derived fields based on the current field values. It can be + called multiple times safely. + """ + MCoreTransformerConfig.__post_init__(self) + + +@dataclass +class MLATransformerConfig(TransformerConfig, MCoreMLATransformerConfig): + """Megatron Core MLATransformerConfig with deferred post-init. + + This class inherits from Megatron Core's MLATransformerConfig but defers the + execution of post_init() until finalize() is explicitly called. This allows + for field modifications after construction but before computed fields are + calculated. + + Usage: + # Create config with deferred post-init + config = MLATransformerConfig(num_layers=32, hidden_size=4096) + + # Modify fields as needed + config.q_lora_rank = 1536 + config.kv_lora_rank = 512 + + # Finalize to compute derived fields + config.finalize() + """ + + def __post_init__(self) -> None: + """Skip MCore post_init during initial construction. + + The original post_init logic is deferred until finalize() is called. + This allows for field modifications after construction without + invalidating computed fields. + """ + pass + + def finalize(self) -> None: + """Execute the deferred MCore post-init logic. + + This method calls the original Megatron Core MLATransformerConfig.__post_init__() + to compute derived fields based on the current field values. It can be + called multiple times safely. + """ + MCoreMLATransformerConfig.__post_init__(self) diff --git a/flagscale/train/megatron/nemo_bridge/utils/__init__.py b/flagscale/train/megatron/nemo_bridge/utils/__init__.py new file mode 100644 index 0000000000..3bfe2ab7d3 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge diff --git a/flagscale/train/megatron/nemo_bridge/utils/common_utils.py b/flagscale/train/megatron/nemo_bridge/utils/common_utils.py new file mode 100644 index 0000000000..de4e4e17e4 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/common_utils.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import os +import types +import warnings + +import torch +import torch.distributed + +from megatron.core import DistributedDataParallel as DDP +from megatron.core.transformer.module import Float16Module + +try: + from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP + + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module) +except ImportError: + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) + + +def get_rank_safe() -> int: + """Get the distributed rank safely, even if torch.distributed is not initialized. + + Returns: + The current process rank. + """ + # In megatron init, args.rank comes from the torchrun env var. + # Once init has been done, args.rank is updated to value of torch get_rank() + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return int(os.getenv("RANK", "0")) + + +def get_world_size_safe() -> int: + """Get the distributed world size safely, even if torch.distributed is not initialized. + + Returns: + The total number of processes in the distributed job. + """ + # In megatron init, args.world_size comes from the torchrun env var. + # Once init has been done, args.world_size is updated to value of torch get_world_size() + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return int(os.getenv("WORLD_SIZE", "1")) + + +def get_last_rank() -> int: + """Get the last rank in the distributed group""" + if not torch.distributed.is_initialized(): + return 0 + return torch.distributed.get_world_size() - 1 + + +def get_local_rank_preinit() -> int: + """Get the local rank from the environment variable, intended for use before full init. + + Returns: + The local rank of the current process. + """ + return int(os.getenv("LOCAL_RANK", "0")) + + +def print_rank_0(message: str) -> None: + """Print a message only on global rank 0. + + Args: + message: The message string to print. + """ + rank = get_rank_safe() + if rank == 0: + print(message, flush=True) + + +def warn_rank_0(message): + """Warn only on rank 0.""" + rank = get_rank_safe() + if rank == 0: + warnings.warn(message) + + +def is_last_rank() -> bool: + """Check if the current rank is the last rank in the default process group. + + Returns: + True if the current rank is the last one, False otherwise. + """ + return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) + + +def print_rank_last(message: str) -> None: + """Print a message only on the last rank of the default process group. + + Args: + message: The message string to print. + """ + if torch.distributed.is_initialized(): + if is_last_rank(): + print(message, flush=True) + else: + print(message, flush=True) + + +def hook_hf_module_setattr_for_tp_grad_sync(module: torch.nn.Module) -> torch.nn.Module: + """Mark params for TP grad sync and hook __setattr__ on a module and its children. + + This ensures that all existing parameters under the provided module have the + attribute ``average_gradients_across_tp_domain=True`` and that any future + submodules assigned onto this module (or any of its current children) will + also have their parameters marked automatically. + + Args: + module: The root module (typically a Hugging Face module instance). + + Returns: + The same module instance for convenience. + """ + if module is None: + return module + + # Mark all existing parameters recursively + for param in module.parameters(recurse=True): + setattr(param, "average_gradients_across_tp_domain", True) + + def _wrap_setattr(original_setattr): + def _wrapped(self, name, value): + original_setattr(name, value) + if isinstance(value, torch.nn.Module): + for p in value.parameters(recurse=True): + setattr(p, "average_gradients_across_tp_domain", True) + + return _wrapped + + # Hook __setattr__ on the module and all existing submodules to catch + # future dynamic assignments anywhere in the hierarchy. + for submodule in module.modules(): + if getattr(submodule, "_tp_grad_sync_setattr_wrapped", False): + continue + original_setattr = submodule.__setattr__ + wrapped = _wrap_setattr(original_setattr) + submodule.__setattr__ = types.MethodType(wrapped, submodule) + setattr(submodule, "_tp_grad_sync_setattr_wrapped", True) + + return module diff --git a/flagscale/train/megatron/nemo_bridge/utils/decorators.py b/flagscale/train/megatron/nemo_bridge/utils/decorators.py new file mode 100644 index 0000000000..437db3b4f6 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/decorators.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import functools +import logging +import warnings + +from typing import Any, Callable, TypeVar + +logger = logging.getLogger(__name__) + +# Define a TypeVar for generic return types +R = TypeVar("R") + + +def experimental_fn(func: Callable[..., R]) -> Callable[..., R]: + """Decorator to mark a function as experimental and issue a warning upon its call.""" + warning_message = f"Function '{func.__name__}' is experimental. APIs in this module are subject to change without notice." + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> R: + warnings.warn(warning_message, stacklevel=2) + return func(*args, **kwargs) + + return wrapper diff --git a/flagscale/train/megatron/nemo_bridge/utils/fusions.py b/flagscale/train/megatron/nemo_bridge/utils/fusions.py new file mode 100644 index 0000000000..1f7d6f52a6 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/fusions.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +"""Fusion capability checks for Megatron models. + +This module provides functions to check if various fusion optimizations +can be enabled based on the current environment and dependencies. +""" + +import logging +import os + +from megatron.core.transformer.transformer_config import TransformerConfig + +logger = logging.getLogger(__name__) + +# Control whether to log warnings when fusions are disabled +# Set environment variable MEGATRON_SUPPRESS_FUSION_WARNINGS=1 to disable warnings +LOG_FUSION_DISABLE = os.environ.get("MEGATRON_SUPPRESS_FUSION_WARNINGS", "0") != "1" + + +def can_enable_apply_rope_fusion() -> bool: + """Check if RoPE (Rotary Position Embedding) fusion can be enabled. + + Returns: + bool: True if RoPE fusion is available and compatible. + """ + # Check for Transformer Engine availability + try: + import transformer_engine # noqa: F401 + + from megatron.core.utils import get_te_version, is_te_min_version + + if not is_te_min_version("2.2.0.dev0"): + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion requires Transformer Engine >= 2.2.0.dev0. " + f"Current version: {get_te_version()}. Fusion disabled." + ) + return False + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion requires Transformer Engine but it is not installed. Fusion disabled." + ) + return False + + # Check for RoPE fusion kernel availability + try: + from megatron.core.models.common.embeddings.rope_utils import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, + ) + + if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion kernels are not available in megatron.core. Fusion disabled." + ) + return False + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion requires RoPE fusion kernels from megatron.core but they are not available. " + "Fusion disabled." + ) + return False + + +def can_enable_gradient_accumulation_fusion() -> bool: + """Check if gradient accumulation fusion can be enabled. + + Returns: + bool: True if gradient accumulation fusion is available. + """ + try: + import fused_weight_gradient_mlp_cuda # noqa: F401 + + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "gradient_accumulation_fusion requires FusedLayerNorm from megatron.core.fusions " + "but it is not available. Fusion disabled." + ) + return False + + +def can_enable_bias_dropout_fusion() -> bool: + """Check if bias dropout fusion can be enabled. + + Returns: + bool: True if bias dropout fusion is available. + """ + try: + from megatron.core.fusions.fused_bias_dropout import ( # noqa: F401 + bias_dropout_add_fused_train, + ) + + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "bias_dropout_fusion requires fused_bias_dropout from megatron.core.fusions " + "but it is not available. Fusion disabled." + ) + return False + + +def can_enable_masked_softmax_fusion() -> bool: + """Check if masked softmax fusion can be enabled. + + Returns: + bool: True if masked softmax fusion kernels are available. + """ + try: + # Try to import the CUDA kernels that are required for masked softmax fusion + import scaled_masked_softmax_cuda # noqa: F401 + import scaled_upper_triang_masked_softmax_cuda # noqa: F401 + + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "masked_softmax_fusion requires CUDA kernels (scaled_masked_softmax_cuda) " + "but they are not available. This typically happens when Megatron-Core is not " + "built with CUDA extensions. Fusion disabled." + ) + return False + + +def validate_rope_fusion_compatibility(config: TransformerConfig) -> bool: + """Validate if RoPE fusion is compatible with the current model configuration. + + Args: + model_provider: The GPTModelProvider instance to validate. + + Returns: + bool: True if RoPE fusion is compatible, False otherwise. + """ + if not config.apply_rope_fusion: + return True + + # Check for multi_latent_attention incompatibility + if getattr(config, "multi_latent_attention", False): + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion for multi-latent attention only supports training. " + "It is experimental and may change in future versions." + ) + return True + + # Check TE version for rotary_interleaved + if getattr(config, "rotary_interleaved", False): + try: + from megatron.core.utils import get_te_version, is_te_min_version + + if not is_te_min_version("2.2.0.dev0"): + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion with rotary_interleaved requires TE >= 2.2.0.dev0. " + f"Current TE version: {get_te_version()}. Consider disabling apply_rope_fusion." + ) + return False + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion with rotary_interleaved requires Transformer Engine. " + "Consider disabling apply_rope_fusion." + ) + return False + + return True diff --git a/flagscale/train/megatron/nemo_bridge/utils/import_utils.py b/flagscale/train/megatron/nemo_bridge/utils/import_utils.py new file mode 100644 index 0000000000..33d1dd4edf --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/import_utils.py @@ -0,0 +1,409 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import importlib +import logging +import traceback + +from contextlib import contextmanager +from typing import Tuple + +import torch + +from packaging.version import Version as PkgVersion + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + +GPU_INSTALL_STRING = ( + """Install GPU packages via `pip install --extra-index-url """ + """https://pypi.nvidia.com nemo-curator[cuda12x]` +or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source""" +) +MISSING_NEMO_EXPORT_DEPLOY_MSG = ( + "nemo-export-deploy is not available. Please install it with `pip install nemo-export-deploy`." +) +MISSING_NVRX_MSG = "nvidia-resiliency-ext is not available. Please install it with `pip install nvidia-resiliency-ext`." +MISSING_NEMO_RUN_MSG = "nemo-run is not available. Please install it with `pip install nemo-run`." + + +class UnavailableError(Exception): + """Error thrown if a symbol is unavailable due to an issue importing it""" + + +@contextmanager +def null_decorator(*args, **kwargs): + """null_decorator""" + if len(kwargs) == 0 and len(args) == 1 and callable(args[0]): + return args[0] + else: + + def inner(func): + return func + + return inner + + +class UnavailableMeta(type): + """A metaclass for generating placeholder objects for unavailable symbols + + This metaclass allows errors to be deferred from import time to the time + that a symbol is actually used in order to streamline the usage of optional + dependencies. This is particularly useful for attempted imports of GPU-only + modules which will only be invoked if GPU-only functionality is + specifically used. + + If an attempt to import a symbol fails, this metaclass is used to generate + a class which stands in for that symbol. Any attempt to call the symbol + (instantiate the class) or access its attributes will throw an + UnavailableError exception. Furthermore, this class can be used in + e.g. isinstance checks, since it will (correctly) fail to match any + instance it is compared against. + + In addition to calls and attribute access, a number of dunder methods are + implemented so that other common usages of imported symbols (e.g. + arithmetic) throw an UnavailableError, but this is not guaranteed for + all possible uses. In such cases, other exception types (typically + TypeErrors) will be thrown instead. + """ + + def __new__(meta, name, bases, dct): + if dct.get("_msg", None) is None: + dct["_msg"] = f"{name} could not be imported" + name = f"MISSING{name}" + return super(UnavailableMeta, meta).__new__(meta, name, bases, dct) + + def __call__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __getattr__(cls, name): + # Special handling for unittest.mock which tries to access __func_ + # and other attributes during its operations + if name in ("__func__", "__wrapped__", "__name__", "__qualname__"): + raise AttributeError(f"'{cls.__name__}' has no attribute '{name}'") + raise UnavailableError(cls._msg) + + def __eq__(cls, other): + raise UnavailableError(cls._msg) + + def __lt__(cls, other): + raise UnavailableError(cls._msg) + + def __gt__(cls, other): + raise UnavailableError(cls._msg) + + def __le__(cls, other): + raise UnavailableError(cls._msg) + + def __ge__(cls, other): + raise UnavailableError(cls._msg) + + def __ne__(cls, other): + raise UnavailableError(cls._msg) + + def __abs__(cls): + raise UnavailableError(cls._msg) + + def __add__(cls, other): + raise UnavailableError(cls._msg) + + def __radd__(cls, other): + raise UnavailableError(cls._msg) + + def __iadd__(cls, other): + raise UnavailableError(cls._msg) + + def __floordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __rfloordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __ifloordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __lshift__(cls, other): + raise UnavailableError(cls._msg) + + def __rlshift__(cls, other): + raise UnavailableError(cls._msg) + + def __mul__(cls, other): + raise UnavailableError(cls._msg) + + def __rmul__(cls, other): + raise UnavailableError(cls._msg) + + def __imul__(cls, other): + raise UnavailableError(cls._msg) + + def __ilshift__(cls, other): + raise UnavailableError(cls._msg) + + def __pow__(cls, other): + raise UnavailableError(cls._msg) + + def __rpow__(cls, other): + raise UnavailableError(cls._msg) + + def __ipow__(cls, other): + raise UnavailableError(cls._msg) + + def __rshift__(cls, other): + raise UnavailableError(cls._msg) + + def __rrshift__(cls, other): + raise UnavailableError(cls._msg) + + def __irshift__(cls, other): + raise UnavailableError(cls._msg) + + def __sub__(cls, other): + raise UnavailableError(cls._msg) + + def __rsub__(cls, other): + raise UnavailableError(cls._msg) + + def __isub__(cls, other): + raise UnavailableError(cls._msg) + + def __truediv__(cls, other): + raise UnavailableError(cls._msg) + + def __rtruediv__(cls, other): + raise UnavailableError(cls._msg) + + def __itruediv__(cls, other): + raise UnavailableError(cls._msg) + + def __divmod__(cls, other): + raise UnavailableError(cls._msg) + + def __rdivmod__(cls, other): + raise UnavailableError(cls._msg) + + def __neg__(cls): + raise UnavailableError(cls._msg) + + def __invert__(cls): + raise UnavailableError(cls._msg) + + def __hash__(cls): + raise UnavailableError(cls._msg) + + def __index__(cls): + raise UnavailableError(cls._msg) + + def __iter__(cls): + raise UnavailableError(cls._msg) + + def __delitem__(cls, name): + raise UnavailableError(cls._msg) + + def __setitem__(cls, name, value): + raise UnavailableError(cls._msg) + + def __enter__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __get__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __delete__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __len__(cls): + raise UnavailableError(cls._msg) + + +def is_unavailable(obj): + """Helper to check if given symbol is actually a placeholder""" + return type(obj) is UnavailableMeta + + +class UnavailableNullContext: + """A placeholder class for unavailable context managers + + This context manager will return a value which will throw an + UnavailableError if used in any way, but the context manager itself can be + safely invoked. + """ + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return UnavailableMeta( + "MissingContextValue", + (), + {"_msg": "Attempted to make use of placeholder context return value."}, + ) + + def __exit__(self, *args, **kwargs): + pass + + +def safe_import(module, *, msg=None, alt=None) -> Tuple[object, bool]: + """A function used to import modules that may not be available. + + This function will attempt to import a module with the given name, but it + will not throw an ImportError if the module is not found. Instead, it will + return a placeholder object which will raise an exception only if used. + + Args: + module (str): The name of the module to import. + msg (str, optional): An error message to be displayed if this module is used + after a failed import. Defaults to None. + alt (object, optional): A module to be used in place of the given module if it + fails to import. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported module, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + try: + return importlib.import_module(module), True + except ImportError: + exception_text = traceback.format_exc() + logger.debug(f"Import of {module} failed with: {exception_text}") + except Exception: + exception_text = traceback.format_exc() + raise + if msg is None: + msg = f"{module} could not be imported" + if alt is None: + return UnavailableMeta(module.rsplit(".")[-1], (), {"_msg": msg}), False + else: + return alt, False + + +def safe_import_from( + module, symbol, *, msg=None, alt=None, fallback_module=None +) -> Tuple[object, bool]: + """A function used to import symbols from modules that may not be available. + + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used. + + Args: + module (str): The name of the module in which the symbol is defined. + symbol (str): The name of the symbol to import. + msg (str, optional): An error message to be displayed if this symbol is used + after a failed import. Defaults to None. + alt (object, optional): An object to be used in place of the given symbol if it fails + to import. Defaults to None. + fallback_module (str, optional): Alternative name of the model in which the symbol is defined. + The function will first try to import using the `module` value and if that fails + will also try the `fallback_module`. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported symbol, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + try: + imported_module = importlib.import_module(module) + return getattr(imported_module, symbol), True + except ImportError: + exception_text = traceback.format_exc() + logger.debug(f"Import of {module} failed with: {exception_text}") + except AttributeError: + # if there is a fallback module try it. + if fallback_module is not None: + return safe_import_from(fallback_module, symbol, msg=msg, alt=alt, fallback_module=None) + exception_text = traceback.format_exc() + logger.info(f"Import of {symbol} from {module} failed with: {exception_text}") + except Exception: + exception_text = traceback.format_exc() + raise + if msg is None: + msg = f"{module}.{symbol} could not be imported" + if alt is None: + return UnavailableMeta(symbol, (), {"_msg": msg}), False + else: + return alt, False + + +def gpu_only_import(module, *, alt=None) -> Tuple[object, bool]: + """A function used to import modules required only in GPU installs. + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing a GPU build. + + Args: + module (str): The name of the module to import. + alt (object, optional): A module to be used in place of the given module if it + fails to import in a non-GPU-enabled install. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported module, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + + return safe_import( + module, + msg=f"{module} is not enabled in non GPU-enabled installations or environemnts. {GPU_INSTALL_STRING}", + alt=alt, + ) + + +def gpu_only_import_from(module, symbol, *, alt=None) -> Tuple[object, bool]: + """A function used to import symbols required only in GPU installs. + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing a GPU build. + + Args: + module (str): The name of the module to import. + symbol (str): The name of the symbol to import. + alt (object, optional): An object to be used in place of the given symbol if it fails + to import in a non-GPU-enabled install. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported symbol, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + return safe_import_from( + module, + symbol, + msg=f"{module}.{symbol} is not enabled in non GPU-enabled installations or environments. {GPU_INSTALL_STRING}", + alt=alt, + ) + + +def get_torch_version(): + """Returns the installed PyTorch version as a packaging.version.Version object. + + Handles potential exceptions during version parsing, returning a dummy version + ("0.0.0") if parsing fails (e.g., during documentation builds where torch + might not be fully imported or available). + + Returns: + packaging.version.Version: The parsed PyTorch version, or Version("0.0.0") on error. + """ + try: + _torch_version = PkgVersion(torch.__version__) + except Exception: + # This is a WAR for building docs, where torch is not actually imported + _torch_version = PkgVersion("0.0.0") + return _torch_version + + +def is_torch_min_version(version, check_equality=True): + """Check if minimum version of `torch` is installed.""" + if check_equality: + return get_torch_version() >= PkgVersion(version) + return get_torch_version() > PkgVersion(version) diff --git a/flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py b/flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py new file mode 100644 index 0000000000..2bbf9a7eb3 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py @@ -0,0 +1,418 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import copy +import functools +import logging + +from enum import Enum +from textwrap import dedent +from typing import Any, Callable, Sequence, Union + +from omegaconf import OmegaConf +from omegaconf._utils import is_structured_config + + +class InstantiationException(Exception): + """Custom exception type for instantiation errors.""" + + ... + + +class InstantiationMode(Enum): + """Enum for instantiation modes.""" + + STRICT = "strict" + LENIENT = "lenient" + + +class _Keys(str, Enum): + """Special keys in configs used by instantiate.""" + + TARGET = "_target_" + PARTIAL = "_partial_" + CALL = "_call_" + ARGS = "_args_" + + +def instantiate( + config: Any, *args: Any, mode: InstantiationMode = InstantiationMode.LENIENT, **kwargs: Any +) -> Any: + """Instantiate an object or callable from a config object. + + This function takes a configuration object (dictionary, list, OmegaConf config, + or Structured Config instance) and instantiates the target specified within it. + + The config object must contain: + _target_ (str): The fully qualified name of the class or callable to instantiate. + + The config object may also contain: + _args_ (list): Positional arguments for the target. + _partial_ (bool): If True, return a functools.partial object instead of calling + the target. Defaults to False. + _call_ (bool): If False, simply resolves and returns the target without calling it. + Defaults to True. + Additional keyword arguments to pass to the target. + + Args: + config: The configuration object describing the target and its parameters. + *args: Optional positional arguments that will override _args_ in the config + if provided. + mode: Instantiation mode (STRICT or LENIENT). In LENIENT mode (default), + errors during instantiation of parameters are logged as warnings, + and None is used instead. In STRICT mode, errors are raised. + **kwargs: Optional keyword arguments that will override parameters in the config. + Note: Dataclass instances in kwargs are treated as nested configs. + + Returns: + The instantiated object or the return value of the callable. + If config._partial_ is True, returns a functools.partial object. + If config._call_ is False, returns the resolved target callable/class itself. + Returns None if the input config is None. + + Raises: + InstantiationException: If the config is invalid, the target cannot be resolved, + or instantiation fails in STRICT mode. + TypeError: If the _partial_ flag is not a boolean. + """ + + # Return None if config is None + if config is None: + return None + + if isinstance(config, (dict, list)): + config = _prepare_input_dict_or_list(config) + + kwargs = _prepare_input_dict_or_list(kwargs) + + # Structured Config always converted first to OmegaConf + if is_structured_config(config) or isinstance(config, (dict, list)): + config = OmegaConf.structured(config, flags={"allow_objects": True}) + + if OmegaConf.is_dict(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + if kwargs: + config = OmegaConf.merge(config, kwargs) + + OmegaConf.resolve(config) + + _partial_ = config.pop(_Keys.PARTIAL, False) + + return instantiate_node(config, *args, partial=_partial_, mode=mode) + elif OmegaConf.is_list(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + OmegaConf.resolve(config) + + _partial_ = kwargs.pop(_Keys.PARTIAL, False) + + if _partial_: + raise InstantiationException( + "The _partial_ keyword is not compatible with top-level list instantiation" + ) + + return instantiate_node(config, *args, partial=_partial_, mode=mode) + else: + raise InstantiationException( + dedent( + f"""\ + Cannot instantiate config of type {type(config).__name__}. + Top level config must be an OmegaConf DictConfig/ListConfig object, + a plain dict/list, or a Structured Config class or instance.""" + ) + ) + + +def instantiate_node( + node: Any, + *args: Any, + partial: bool = False, + mode: InstantiationMode = InstantiationMode.LENIENT, +) -> Any: + """Recursively instantiates a node within a configuration structure. + + This function handles the instantiation of individual nodes (dictionaries, + lists, or primitive values) within a larger configuration tree, typically + managed by OmegaConf. + + If the node is a dictionary containing a `_target_` key, it resolves and + instantiates the target callable/class using the other items in the + dictionary as keyword arguments. Nested nodes are recursively instantiated. + + If the node is a list, it recursively instantiates each item in the list. + + If the node is not an OmegaConf config node (e.g., a primitive type), it's + returned directly. + + Args: + node: The configuration node to instantiate (can be DictConfig, ListConfig, + or a primitive type). + *args: Positional arguments passed down from the top-level `instantiate` call, + used primarily for the final target call if the node is a dictionary + with `_target_`. + partial: Boolean flag indicating whether to return a `functools.partial` object + instead of calling the target. This can be overridden by a + `_partial_` key within the node itself. + mode: Instantiation mode (STRICT or LENIENT). Determines error handling + behavior for nested instantiations. + + Returns: + The instantiated object, list, or the original node if it wasn't a config. + Returns None if the input node is None or represents a None value in OmegaConf. + + Raises: + InstantiationException: If instantiation fails in STRICT mode, or if there are + issues like incompatible arguments or non-callable targets. + TypeError: If a `_partial_` flag within the config is not a boolean. + """ + # Return None if config is None + if node is None or (OmegaConf.is_config(node) and node._is_none()): + return None + + if not OmegaConf.is_config(node): + return node + + # Override parent modes from config if specified + if OmegaConf.is_dict(node): + # using getitem instead of get(key, default) because OmegaConf will raise an exception + # if the key type is incompatible on get. + partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial + + full_key = node._get_full_key(None) + + if not isinstance(partial, bool): + msg = f"Instantiation: _partial_ flag must be a bool, got {type(partial)}" + if node and full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) + + if OmegaConf.is_list(node): + items = [instantiate_node(item, mode=mode) for item in node._iter_ex(resolve=True)] + + return items + elif OmegaConf.is_dict(node): + exclude_keys = set(item.value for item in _Keys if item != _Keys.ARGS) + if _is_target(node): + should_call_target = node.get("_call_", True) + _target_ = _resolve_target( + node.get(_Keys.TARGET), full_key, check_callable=should_call_target + ) + kwargs = {} + is_partial = node.get("_partial_", False) or partial + + if not should_call_target: + if len(set(node.keys()) - {"_target_", "_call_"}) != 0: + extra_keys = set(node.keys()) - {"_target_", "_call_"} + raise InstantiationException( + f"_call_ was set to False for target {_convert_target_to_string(_target_)}," + f" but extra keys were found: {extra_keys}" + ) + else: + return _target_ + + for key in node.keys(): + if key not in exclude_keys: + if OmegaConf.is_missing(node, key) and is_partial: + continue + value = node[key] + try: + value = instantiate_node(value, mode=mode) + except (ImportError, InstantiationException) as e: + if mode == InstantiationMode.STRICT: + raise InstantiationException( + f"Error instantiating {value} for key {full_key}.{key}: {e}" + ) from e + else: + value = None + logging.warning( + f"Error instantiating {value} for key {full_key}.{key}. " + f"Using None instead in lenient mode." + ) + kwargs[key] = _convert_node(value) + + assert callable(_target_) + return _call_target(_target_, partial, args, kwargs, full_key) + else: + dict_items = {} + for key, value in node.items(): + dict_items[key] = instantiate_node(value, mode=mode) + return dict_items + + else: + assert False, f"Unexpected config type : {type(node).__name__}" + + +def _locate(path: str) -> Any: + """ + Locate an object by name or dotted path, importing as necessary. + This function attempts to import modules starting from the most specific path + (back to front), making it possible to import objects where the final component + could be either a module or an attribute of the previous module. + """ + if path == "": + raise ImportError("Empty path") + from importlib import import_module + + parts = [part for part in path.split(".")] + for part in parts: + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + assert len(parts) > 0 + + # Try importing from the most specific path first (back to front) + for i in range(len(parts), 0, -1): + module_path = ".".join(parts[:i]) + try: + obj = import_module(module_path) + + # If this isn't the full path, get the remaining attributes + remaining_parts = parts[i:] + for part in remaining_parts: + try: + obj = getattr(obj, part) + except AttributeError as exc_attr: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{module_path}'?" + ) from exc_attr + + # Successfully found the object + return obj + + except ModuleNotFoundError: + # Module not found, try a less specific path + continue + except Exception as exc_import: + # If we hit a different exception, it's likely an issue with the module itself + raise ImportError(f"Error loading '{path}':\n{repr(exc_import)}") from exc_import + + # If we've tried all paths and nothing worked, report failure with the base module + raise ImportError( + f"Error loading '{path}': Unable to import any module in the path. " + f"Are you sure that module '{parts[0]}' is installed?" + ) + + +def _is_target(x: Any) -> bool: + if isinstance(x, dict): + return "_target_" in x + if OmegaConf.is_dict(x): + return "_target_" in x + return False + + +def _call_target( + _target_: Callable[..., Any], + _partial_: bool, + args: tuple[Any, ...], + kwargs: dict[str, Any], + full_key: str, +) -> Any: + """Call target (type) with args and kwargs.""" + args, kwargs = _extract_pos_args(args, kwargs) + if _partial_: + try: + return functools.partial(_target_, *args, **kwargs) + except Exception as e: + msg = ( + f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + else: + try: + return _target_(*args, **kwargs) + except Exception as e: + msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + + +def _convert_target_to_string(t: Any) -> Any: + if callable(t): + return f"{t.__module__}.{t.__qualname__}" + else: + return t + + +def _prepare_input_dict_or_list(d: Union[dict[Any, Any], list[Any]]) -> Any: + res: Any + if isinstance(d, dict): + res = {} + for k, v in d.items(): + if k == "_target_": + v = _convert_target_to_string(d["_target_"]) + elif isinstance(v, (dict, list)): + v = _prepare_input_dict_or_list(v) + res[k] = v + elif isinstance(d, list): + res = [] + for v in d: + if isinstance(v, (list, dict)): + v = _prepare_input_dict_or_list(v) + res.append(v) + else: + assert False + return res + + +def _resolve_target( + target: Union[str, type, Callable[..., Any]], full_key: str, check_callable: bool = True +) -> Union[type, Callable[..., Any], object]: + """Resolve target string, type or callable into type or callable.""" + if isinstance(target, str): + try: + target = _locate(target) + except Exception as e: + msg = f"Error locating target '{target}'." + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + if check_callable and not callable(target): + msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + return target + + +def _extract_pos_args(input_args: Any, kwargs: Any) -> tuple[Any, Any]: + config_args = kwargs.pop(_Keys.ARGS, ()) + output_args = config_args + + if isinstance(config_args, Sequence): + if len(input_args) > 0: + output_args = input_args + else: + raise InstantiationException( + f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'" + ) + + return output_args, kwargs + + +def _convert_node(node: Any) -> Any: + if OmegaConf.is_config(node): + node = OmegaConf.to_container(node, resolve=True) + + return node diff --git a/flagscale/train/megatron/nemo_bridge/utils/path_utils.py b/flagscale/train/megatron/nemo_bridge/utils/path_utils.py new file mode 100644 index 0000000000..0fe9c30ee8 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/path_utils.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from pathlib import Path + + +def resolve_path(path: str) -> Path: + """Resolve a path to an absolute path.""" + return Path(path).expanduser().absolute().resolve() diff --git a/flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py b/flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py new file mode 100644 index 0000000000..85b68e1683 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import math + +from functools import lru_cache + +from megatron.nemo_bridge.utils.common_utils import print_rank_0 + + +def calculate_padded_vocab_size( + vocab_size: int, + make_vocab_size_divisible_by: int, + tensor_model_parallel_size: int, + logging_enabled: bool = True, +) -> int: + """Calculate padded vocab size for tensor parallelism. + + This function pads the vocabulary size to ensure it's divisible by the required + multiple for efficient tensor parallel operations. + + Args: + vocab_size: The original (unpadded) vocabulary size + make_vocab_size_divisible_by: Base divisibility requirement (e.g., 128) + tensor_model_parallel_size: Number of tensor parallel ranks + logging_enabled: Whether to log the padding information + + Returns: + int: The padded vocabulary size + """ + padded_size = _calculate_padded_vocab_size_cached( + vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size + ) + + # Handle logging separately to avoid affecting cache behavior + if logging_enabled: + print_rank_0( + " > padded vocab (size: {}) with {} dummy tokens (new size: {})".format( + vocab_size, padded_size - vocab_size, padded_size + ) + ) + + return padded_size + + +@lru_cache(maxsize=128) +def _calculate_padded_vocab_size_cached( + vocab_size: int, make_vocab_size_divisible_by: int, tensor_model_parallel_size: int +) -> int: + """Cached computation of padded vocab size.""" + if vocab_size <= 0: + raise ValueError(f"vocab_size must be positive, got {vocab_size}") + if make_vocab_size_divisible_by <= 0: + raise ValueError( + f"make_vocab_size_divisible_by must be positive, got {make_vocab_size_divisible_by}" + ) + if tensor_model_parallel_size <= 0: + raise ValueError( + f"tensor_model_parallel_size must be positive, got {tensor_model_parallel_size}" + ) + + multiple = make_vocab_size_divisible_by * tensor_model_parallel_size + return int(math.ceil(vocab_size / multiple) * multiple) diff --git a/flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py b/flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py new file mode 100644 index 0000000000..f38553d6a7 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import enum +import functools +import inspect + +from contextlib import contextmanager +from typing import Any, Generator, Optional + +import yaml + + +@contextmanager +def safe_yaml_representers() -> Generator[None, None, None]: + """ + Context manager for safely adding and removing custom YAML representers. + + Temporarily adds custom representers for functions, classes, and other objects + to the YAML SafeDumper, and restores the original representers when exiting + the context. + + Usage: + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(my_complex_object) + """ + # Save original representers + original_representers = yaml.SafeDumper.yaml_representers.copy() + original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() + + try: + # Register custom representers + + # Partial representer + yaml.SafeDumper.add_representer(functools.partial, _partial_representer) + + # Enum representer + yaml.SafeDumper.add_multi_representer(enum.Enum, _enum_representer) + + # Function representer + yaml.SafeDumper.add_representer(type(lambda: ...), _function_representer) + yaml.SafeDumper.add_representer(type(object), _function_representer) + + # Try to add torch dtype representer if available + try: + import torch + + yaml.SafeDumper.add_representer(torch.dtype, _torch_dtype_representer) + except ModuleNotFoundError: + pass + + # Try to add GenerationConfig representer if available + try: + from transformers import GenerationConfig + + yaml.SafeDumper.add_representer(GenerationConfig, _generation_config_representer) + except ModuleNotFoundError: + pass + + # Try to add PretrainedConfig representer if available (generic for HF configs) + try: + from transformers import PretrainedConfig + + # Use multi-representer so subclasses of PretrainedConfig are also handled + yaml.SafeDumper.add_multi_representer(PretrainedConfig, _pretrained_config_representer) + except ModuleNotFoundError: + pass + + # General object representer + yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) + + yield + finally: + # Restore original representers + yaml.SafeDumper.yaml_representers = original_representers + yaml.SafeDumper.yaml_multi_representers = original_multi_representers + + +def dump_dataclass_to_yaml(obj: Any, filename: Optional[str] = None) -> Optional[str]: + """Dump a dataclass object or other Python object to a YAML file or string. + + Uses safe representers to handle common types. + + Args: + obj: The object to dump. + filename: If provided, the path to the file where YAML should be written. + If None, returns the YAML string directly. + + Returns: + If filename is None, returns the YAML string representation of the object. + Otherwise, returns None. + """ + with safe_yaml_representers(): + if filename is not None: + with open(filename, "w+") as f: + yaml.safe_dump(obj, f) + else: + return yaml.safe_dump(obj) + + +def _function_representer(dumper, data): + """Represent functions in YAML.""" + value = { + "_target_": f"{inspect.getmodule(data).__name__}.{data.__qualname__}", # type: ignore + "_call_": False, + } + return dumper.represent_data(value) + + +def _torch_dtype_representer(dumper, data): + """Represent torch dtypes in YAML.""" + value = {"_target_": str(data), "_call_": False} + return dumper.represent_data(value) + + +def _safe_object_representer(dumper, data): + """ + General object representer for YAML. + + This function is a fallback for objects that don't have specific representers. + If the object has __qualname__ attr, + the _target_ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}". + If the object does not have a __qualname__ attr, the _target_ is set from its __class__ attr. + The _call_ key is used to indicate whether the target should be called to create an instance. + + Args: + dumper (yaml.Dumper): The YAML dumper to use for serialization. + data (Any): The data to serialize. + + Returns: + The YAML representation of the data. + """ + try: + obj = data + target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" + call = False + except AttributeError: + obj = data.__class__ + target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" + call = True + + value = {"_target_": target, "_call_": call} # type: ignore + return dumper.represent_data(value) + + +def _partial_representer(dumper, data): + """Represent functools.partial objects in YAML.""" + # Get the underlying function + func = data.func + + # Create a dictionary representation + value = { + "_target_": f"{inspect.getmodule(func).__name__}.{func.__qualname__}", + "_partial_": True, + "_args_": list(data.args) if data.args else [], + } + + # Add keyword arguments if any exist + if data.keywords: + for k, v in data.keywords.items(): + value[k] = v + + return dumper.represent_data(value) + + +def _enum_representer(dumper, data): + """Represent enums in YAML.""" + # Create a dictionary representation + enum_class = data.__class__ + value = { + "_target_": f"{inspect.getmodule(enum_class).__name__}.{enum_class.__qualname__}", + "_call_": True, + "_args_": [data.value], + } + + return dumper.represent_data(value) + + +def _generation_config_representer(dumper, data): + """Represent transformers GenerationConfig objects in YAML.""" + cls = data.__class__ + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", + "_call_": True, + "config_dict": data.to_dict(), + } + + return dumper.represent_data(value) + + +def _pretrained_config_representer(dumper, data): + """Represent transformers PretrainedConfig objects in YAML generically. + + Uses the class's from_dict/to_dict methods to ensure full round-trip of all fields. + """ + cls = data.__class__ + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", + "_call_": True, + "config_dict": data.to_dict(), + } + return dumper.represent_data(value) diff --git a/flagscale/train/megatron/training/arguments.py b/flagscale/train/megatron/training/arguments.py index 9d4ffc2131..f29ede5fe5 100644 --- a/flagscale/train/megatron/training/arguments.py +++ b/flagscale/train/megatron/training/arguments.py @@ -889,6 +889,17 @@ def validate_args(args, defaults={}): if args.save_retain_interval is not None: assert args.save_retain_interval > 0 assert args.save_retain_interval % args.save_interval == 0 + + if args.save_hf is not None: + assert args.save is not None + assert args.save_interval is not None + assert args.save_interval > 0 + assert args.save_hf_interval is not None + assert args.save_hf_interval > 0 + assert args.save_hf_interval > args.save_interval and args.save_hf_interval % args.save_interval == 0, \ + "save_hf_interval must be greater than save_interval and be an integer multiple of it" + assert args.hf_config_path is not None + # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' @@ -2573,6 +2584,15 @@ def _add_checkpointing_args(parser): ' rank for saving. Turn on only if experiencing host or device memory' ' issues. Has affect only with `--dist-ckpt-optim-fully-reshardable`' ' flag.') + group.add_argument('--load-hf', action='store_true',default=None, + help='Use the HF format for warm start, and save it in the torch_dict' + 'format while also saving it in the HF format.') + group.add_argument('--save-hf', action='store_true',default=None, + help='Save as Hugging Face format checkpoint.') + group.add_argument('--hf-config-path', default=None, + help='Load the HF model from config.') + group.add_argument('--save-hf-interval', type=int, default=None, + help='Number of iterations between hf checkpoint saves.') return parser diff --git a/flagscale/train/megatron/training/checkpointing.py b/flagscale/train/megatron/training/checkpointing.py index 9489b47632..c14032dfcc 100644 --- a/flagscale/train/megatron/training/checkpointing.py +++ b/flagscale/train/megatron/training/checkpointing.py @@ -471,6 +471,24 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati # Save dataloader state if the dataloader supports it (currently only Megatron Energon). maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) + # save hf format model weight + hf_checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel, + tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=True) + if args.save_hf and hasattr(args,'hf_config_path') and args.save_hf_interval : + assert args.hf_config_path is not None, "hf_config_path should not be None" + if iteration % args.save_hf_interval == 0 or iteration == args.train_iters: + #use megatron bridge + from megatron.nemo_bridge.models import AutoBridge + from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry + from transformers import AutoConfig + #Load the HF model from config + config_load = args.hf_config_path + config = safe_load_config_with_retry(config_load, trust_remote_code=False) + bridge = AutoBridge.from_hf_config(config) + #Save the HF model weights in the corresponding iteration's safetensor folder. + safe_save = os.path.join(hf_checkpoint_name, 'safetensor') + bridge.save_hf_pretrained(model=model,path=safe_save) + # Save distributed optimizer's custom parameter state. if ( args.use_distributed_optimizer @@ -1408,6 +1426,17 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', args = get_args() load_dir = getattr(args, load_arg) + # load hf format + if args.load_hf: + # use megatron bridge + from megatron.nemo_bridge.models import AutoBridge + bridge=AutoBridge.from_hf_pretrained(load_dir) + bridge.load_hf_weights(ddp_model) + # no optimizer weight + iteration=0 + num_floating_point_operations_so_far=0 + return iteration, num_floating_point_operations_so_far + # Finetuning directories pretrained_dir = getattr(args, 'pretrained_checkpoint', None) if pretrained_dir is not None and not checkpoint_exists(load_dir): diff --git a/flagscale/train/megatron/training/yaml_arguments.py b/flagscale/train/megatron/training/yaml_arguments.py index 405d7b70fa..c8ad21e255 100644 --- a/flagscale/train/megatron/training/yaml_arguments.py +++ b/flagscale/train/megatron/training/yaml_arguments.py @@ -409,7 +409,9 @@ def core_transformer_config_from_yaml(args, transfomer_key = "language_model"): # Hardcoded kw_args['deallocate_pipeline_outputs'] = True kw_args['pipeline_dtype'] = kw_args['params_dtype'] - kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm + kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm + + kw_args['untie_embeddings_and_output_weights'] = args.untie_embeddings_and_output_weights assert args.activation_func in ["swiglu","squaredrelu","gelu"], f"{args.activation_func} is not a supported activation function" if args.activation_func == "swiglu": From 90c2ed0128a2275d575e4f84788fd33af2d449d6 Mon Sep 17 00:00:00 2001 From: chai-xiaonan <3072824838@qq.com> Date: Thu, 15 Jan 2026 17:09:55 +0800 Subject: [PATCH 2/3] Reconstruct the code, and some functions use pip megatron-bridge --- .../train/megatron/nemo_bridge/README.md | 9 + .../train/megatron/nemo_bridge/__init__.py | 3 +- .../megatron/nemo_bridge/models/__init__.py | 90 +- .../megatron/nemo_bridge/models/config.py | 340 ---- .../nemo_bridge/models/conversion/__init__.py | 25 - .../models/conversion/auto_bridge.py | 556 +----- .../models/conversion/mapping_registry.py | 266 --- .../models/conversion/model_bridge.py | 787 +------- .../models/conversion/param_mapping.py | 1706 +---------------- .../nemo_bridge/models/conversion/utils.py | 287 --- .../nemo_bridge/models/decorators/__init__.py | 9 - .../nemo_bridge/models/decorators/dispatch.py | 348 ---- .../nemo_bridge/models/decorators/torchrun.py | 42 - .../nemo_bridge/models/deepseek/__init__.py | 27 - .../nemo_bridge/models/deepseek/common.py | 2 +- .../models/deepseek/deepseek_provider.py | 309 --- .../models/deepseek/deepseek_v2_bridge.py | 48 - .../models/deepseek/deepseek_v3_bridge.py | 6 +- .../models/gpt_full_te_layer_autocast_spec.py | 347 ---- .../nemo_bridge/models/gpt_provider.py | 430 ----- .../models/hf_pretrained/__init__.py | 5 +- .../nemo_bridge/models/hf_pretrained/base.py | 237 --- .../models/hf_pretrained/causal_lm.py | 662 +------ .../hf_pretrained/safe_config_loader.py | 136 -- .../nemo_bridge/models/hf_pretrained/state.py | 850 -------- .../nemo_bridge/models/hf_pretrained/vlm.py | 603 ------ .../nemo_bridge/models/model_provider.py | 710 ------- .../nemo_bridge/models/qwen/__init__.py | 52 - .../nemo_bridge/models/qwen/qwen2_bridge.py | 110 -- .../nemo_bridge/models/qwen/qwen3_bridge.py | 8 +- .../models/qwen/qwen3_moe_bridge.py | 113 -- .../nemo_bridge/models/qwen/qwen_provider.py | 393 ---- .../nemo_bridge/models/transformer_config.py | 96 - .../megatron/nemo_bridge/utils/__init__.py | 3 - .../nemo_bridge/utils/common_utils.py | 147 -- .../megatron/nemo_bridge/utils/decorators.py | 26 - .../megatron/nemo_bridge/utils/fusions.py | 175 -- .../nemo_bridge/utils/import_utils.py | 409 ---- .../nemo_bridge/utils/instantiate_utils.py | 418 ---- .../megatron/nemo_bridge/utils/path_utils.py | 10 - .../megatron/nemo_bridge/utils/vocab_utils.py | 64 - .../megatron/nemo_bridge/utils/yaml_utils.py | 203 -- .../train/megatron/training/checkpointing.py | 2 +- .../train/megatron/training/yaml_arguments.py | 2 - 44 files changed, 225 insertions(+), 10846 deletions(-) create mode 100644 flagscale/train/megatron/nemo_bridge/README.md delete mode 100644 flagscale/train/megatron/nemo_bridge/models/config.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/conversion/utils.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/gpt_provider.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/model_provider.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py delete mode 100755 flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py delete mode 100644 flagscale/train/megatron/nemo_bridge/models/transformer_config.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/__init__.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/common_utils.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/decorators.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/fusions.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/import_utils.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/path_utils.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py delete mode 100644 flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py diff --git a/flagscale/train/megatron/nemo_bridge/README.md b/flagscale/train/megatron/nemo_bridge/README.md new file mode 100644 index 0000000000..d2a66b98dd --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/README.md @@ -0,0 +1,9 @@ +#Before using this function, you need to install megatron-bridge +git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git +cd Megatron-Bridge +pip install --no-build-isolation megatron-bridge + +#You must install Megatron-Bridge first, and then install the Megatron-LM-FL version of Megatron-Core. +git clone https://github.com/flagos-ai/Megatron-LM-FL.git +cd Megatron-LM-FL +pip install --no-build-isolation . diff --git a/flagscale/train/megatron/nemo_bridge/__init__.py b/flagscale/train/megatron/nemo_bridge/__init__.py index 95db0a4890..713df8c977 100644 --- a/flagscale/train/megatron/nemo_bridge/__init__.py +++ b/flagscale/train/megatron/nemo_bridge/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + """Megatron Bridge - A component of the Megatron ecosystem.""" from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge diff --git a/flagscale/train/megatron/nemo_bridge/models/__init__.py b/flagscale/train/megatron/nemo_bridge/models/__init__.py index 14986c4bef..3d2aa52d7e 100644 --- a/flagscale/train/megatron/nemo_bridge/models/__init__.py +++ b/flagscale/train/megatron/nemo_bridge/models/__init__.py @@ -1,99 +1,21 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge from megatron.nemo_bridge.models.conversion.param_mapping import ( AutoMapping, - ColumnParallelMapping, - GatedMLPMapping, - MegatronParamMapping, QKVMapping, - ReplicatedMapping, - RowParallelMapping, -) -from megatron.nemo_bridge.models.deepseek import ( - DeepSeekModelProvider, - DeepSeekProvider, - DeepSeekV2LiteModelProvider, - DeepSeekV2LiteProvider, - DeepSeekV2ModelProvider, - DeepSeekV2Provider, - DeepSeekV3ModelProvider, - DeepSeekV3Provider, - MoonlightModelProvider16B, - MoonlightProvider, -) -from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider -from megatron.nemo_bridge.models.qwen import ( - Qwen2ModelProvider, - Qwen2ModelProvider1P5B, - Qwen2ModelProvider7B, - Qwen2ModelProvider72B, - Qwen2ModelProvider500M, - Qwen3ModelProvider, - Qwen3ModelProvider1P7B, - Qwen3ModelProvider4B, - Qwen3ModelProvider8B, - Qwen3ModelProvider14B, - Qwen3ModelProvider32B, - Qwen3ModelProvider600M, - Qwen3MoEModelProvider, - Qwen3MoEModelProvider30B_A3B, - Qwen3MoEModelProvider235B_A22B, - Qwen25ModelProvider1P5B, - Qwen25ModelProvider3B, - Qwen25ModelProvider7B, - Qwen25ModelProvider14B, - Qwen25ModelProvider32B, - Qwen25ModelProvider72B, - Qwen25ModelProvider500M, ) +from megatron.nemo_bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge +from megatron.nemo_bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM __all__ = [ "AutoBridge", - "MegatronMappingRegistry", "MegatronModelBridge", - "ColumnParallelMapping", - "GatedMLPMapping", - "MegatronParamMapping", "QKVMapping", - "ReplicatedMapping", - "RowParallelMapping", "AutoMapping", - "GPTModelProvider", - "Qwen2ModelProvider", - "Qwen2ModelProvider500M", - "Qwen2ModelProvider1P5B", - "Qwen2ModelProvider7B", - "Qwen2ModelProvider72B", - "Qwen25ModelProvider500M", - "Qwen25ModelProvider1P5B", - "Qwen25ModelProvider3B", - "Qwen25ModelProvider7B", - "Qwen25ModelProvider14B", - "Qwen25ModelProvider32B", - "Qwen25ModelProvider72B", - "Qwen3ModelProvider", - "Qwen3ModelProvider600M", - "Qwen3ModelProvider1P7B", - "Qwen3ModelProvider4B", - "Qwen3ModelProvider8B", - "Qwen3ModelProvider14B", - "Qwen3ModelProvider32B", - "Qwen3MoEModelProvider", - "Qwen3MoEModelProvider30B_A3B", - "Qwen3MoEModelProvider235B_A22B", - "DeepSeekModelProvider", - "DeepSeekProvider", - "DeepSeekV2LiteModelProvider", - "DeepSeekV2LiteProvider", - "DeepSeekV2ModelProvider", - "DeepSeekV2Provider", - "DeepSeekV3ModelProvider", - "DeepSeekV3Provider", - "MoonlightModelProvider16B", - "MoonlightProvider", + "DeepSeekV3Bridge", + "Qwen3Bridge", + "PreTrainedCausalLM", ] diff --git a/flagscale/train/megatron/nemo_bridge/models/config.py b/flagscale/train/megatron/nemo_bridge/models/config.py deleted file mode 100644 index 6e421ee5b6..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/config.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import json - -from dataclasses import fields as dataclass_fields, is_dataclass -from pathlib import Path -from typing import Any, Dict, Literal, Optional, Protocol, Type, TypeVar, Union, runtime_checkable - -import yaml - -from omegaconf import OmegaConf - -from megatron.nemo_bridge.utils.instantiate_utils import InstantiationMode, instantiate -from megatron.nemo_bridge.utils.yaml_utils import safe_yaml_representers - -# For TOML support -try: - import toml - - HAS_TOML = True -except ImportError: - HAS_TOML = False - - -T = TypeVar("T") -ConfigFormat = Literal["yaml", "json", "toml"] - - -@runtime_checkable -class ConfigProtocol(Protocol): - """Protocol defining the configuration interface for model providers.""" - - @classmethod - def from_hf_pretrained( - cls: Type[T], - pretrained_model_name_or_path: Union[str, Path], - trust_remote_code: bool = False, - mode: InstantiationMode = InstantiationMode.LENIENT, - **kwargs, - ) -> T: - """Load a pretrained model configuration from a directory or file.""" - ... - - def save_hf_pretrained( - self, - save_directory: Union[str, Path], - config_format: ConfigFormat | None = None, - config_name: Optional[str] = None, - **kwargs, - ) -> None: - """Save the model configuration to a directory.""" - ... - - -def from_hf_pretrained( - cls: Type[T], - pretrained_model_name_or_path: Union[str, Path], - trust_remote_code: bool = False, - mode: InstantiationMode = InstantiationMode.LENIENT, - config_name: str = "config", - **kwargs, -) -> T: - """ - Load a pretrained model configuration from a directory or file. - - Args: - cls: The class to instantiate - pretrained_model_name_or_path: Path to a directory containing a config file, - or direct path to a config file (yaml/json/toml) - trust_remote_code: Whether to trust and execute code references (classes/functions) - found in the configuration. Required to be True if the config - contains any class or function references. Default: False - mode: Instantiation mode (STRICT or LENIENT) for the instantiate function - config_name: Base name of the config file (without extension) - **kwargs: Additional keyword arguments to override loaded configuration - - Returns: - Instance of the class with loaded configuration - - Example: - ```python - # Load from directory (looks for config.yaml, config.json, or config.toml) - model = from_hf_pretrained(MyModel, "./saved_model/") - - # Load from specific file - model = from_hf_pretrained(MyModel, "./saved_model/config.yaml") - - # With code references - model = from_pretrained(MyModel, "./saved_model/", trust_remote_code=True) - - # Override configuration values - model = from_pretrained(MyModel, "./saved_model/", temperature=0.8) - ``` - """ - path = Path(pretrained_model_name_or_path) - - # Determine the config file path - if path.is_dir(): - # Look for config files in order of preference - config_file = None - for ext in [".yaml", ".yml", ".json", ".toml"]: - candidate = path / f"{config_name}{ext}" - if candidate.exists(): - config_file = candidate - break - - if config_file is None: - raise FileNotFoundError( - f"No configuration file found in {path}. " - f"Expected {config_name}.yaml, {config_name}.json, or {config_name}.toml" - ) - else: - config_file = path - - if not config_file.exists(): - raise FileNotFoundError(f"Configuration file not found at {config_file}") - - # Load the configuration based on file extension - file_ext = config_file.suffix.lower() - - if file_ext in [".yaml", ".yml"]: - with open(config_file, "r", encoding="utf-8") as f: - config_dict = yaml.safe_load(f) - elif file_ext == ".json": - with open(config_file, "r", encoding="utf-8") as f: - config_dict = json.load(f) - elif file_ext == ".toml": - if not HAS_TOML: - raise ImportError( - "TOML support requires the 'toml' package. Install it with: pip install toml" - ) - with open(config_file, "r", encoding="utf-8") as f: - config_dict = toml.load(f) - else: - raise ValueError( - f"Unsupported file format: {file_ext}. Supported formats: .yaml, .yml, .json, .toml" - ) - - # Check for trust_remote_code requirement - if not trust_remote_code and _contains_code_references(config_dict): - raise ValueError( - "This configuration contains class or function references. " - "Loading it requires trust_remote_code=True to prevent arbitrary code execution." - ) - - # Convert to OmegaConf for compatibility with instantiate - omega_conf = OmegaConf.create(config_dict) - - # Merge with kwargs - if kwargs: - override_conf = OmegaConf.create(kwargs) - omega_conf = OmegaConf.merge(omega_conf, override_conf) - - # Add _target_ if not present - if "_target_" not in omega_conf: - omega_conf["_target_"] = f"{cls.__module__}.{cls.__qualname__}" - - # Convert back to container for instantiate - final_config = OmegaConf.to_container(omega_conf, resolve=True) - - # Use instantiate to create the object - return instantiate(final_config, mode=mode) - - -def save_hf_pretrained( - obj: Any, - save_directory: Union[str, Path], - config_format: ConfigFormat = "json", - config_name: str = "config", - **kwargs, -) -> None: - """ - Save the model configuration to a directory. - - Args: - obj: The object to save - save_directory: Directory where to save the configuration - config_format: Format to save in ("yaml", "json", or "toml"). Default: "json" - config_name: Name for the config file (without extension) - **kwargs: Additional metadata to save alongside the configuration - - Example: - ```python - # Save as JSON (default) - save_hf_pretrained(model, "./saved_model/") - - # Save as YAML - save_hf_pretrained(model, "./saved_model/", config_format="yaml") - - # Save with custom name - save_hf_pretrained(model, "./saved_model/", config_name="my_config") - ``` - """ - save_path = Path(save_directory) - save_path.mkdir(parents=True, exist_ok=True) - - # Determine file extension - format_to_ext = {"yaml": ".yaml", "yml": ".yaml", "json": ".json", "toml": ".toml"} - - config_format = config_format.lower() - if config_format not in format_to_ext: - raise ValueError( - f"Unsupported format: {config_format}. Supported formats: {list(format_to_ext.keys())}" - ) - - if config_format == "toml" and not HAS_TOML: - raise ImportError( - "TOML support requires the 'toml' package. Install it with: pip install toml" - ) - - config_file = save_path / f"{config_name}{format_to_ext[config_format]}" - - # Get the configuration dictionary - config_dict = _to_dict(obj) - - # Add any additional metadata - if kwargs: - config_dict.update(kwargs) - - # Save based on format - if config_format in ["yaml", "yml"]: - with safe_yaml_representers(): - with open(config_file, "w", encoding="utf-8") as f: - yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) - elif config_format == "json": - # First convert to YAML string to use the custom representers - with safe_yaml_representers(): - yaml_str = yaml.safe_dump(config_dict, default_flow_style=False) - # Then parse and save as JSON - yaml_dict = yaml.safe_load(yaml_str) - with open(config_file, "w", encoding="utf-8") as f: - json.dump(yaml_dict, f, indent=2, ensure_ascii=False) - elif config_format == "toml": - # First convert to YAML string to use the custom representers - with safe_yaml_representers(): - yaml_str = yaml.safe_dump(config_dict, default_flow_style=False) - # Then parse and save as TOML - yaml_dict = yaml.safe_load(yaml_str) - with open(config_file, "w", encoding="utf-8") as f: - toml.dump(yaml_dict, f) - - print(f"Configuration saved to {config_file}") - - -def _to_dict(obj: Any) -> Dict[str, Any]: - """ - Convert an object to a dictionary representation. - - Args: - obj: The object to convert - - Returns: - Dictionary representation of the object - """ - # Check if this is a ConfigContainer (has to_dict method) - if hasattr(obj, "to_dict") and callable(obj.to_dict): - return obj.to_dict() - - # Otherwise, build dict from dataclass fields or attributes - result = {} - result["_target_"] = f"{obj.__class__.__module__}.{obj.__class__.__qualname__}" - - if is_dataclass(obj): - # Handle dataclass - for field in dataclass_fields(obj): - if field.name.startswith("_"): - continue - value = getattr(obj, field.name) - result[field.name] = _convert_value_to_dict(value) - else: - # Handle regular class - for key, value in obj.__dict__.items(): - if not key.startswith("_"): - result[key] = _convert_value_to_dict(value) - - return result - - -def _convert_value_to_dict(value: Any) -> Any: - """ - Recursively convert a value to a dictionary representation. - - Args: - value: The value to convert - - Returns: - The converted value - """ - if hasattr(value, "_to_dict"): - return value._to_dict() - elif hasattr(value, "to_dict") and callable(value.to_dict): - return value.to_dict() - elif is_dataclass(value) and not isinstance(value, type): - # Handle regular dataclasses - result = {"_target_": f"{value.__class__.__module__}.{value.__class__.__qualname__}"} - for field in dataclass_fields(value): - if not field.name.startswith("_"): - result[field.name] = _convert_value_to_dict(getattr(value, field.name)) - return result - elif isinstance(value, (list, tuple)): - return [_convert_value_to_dict(item) for item in value] - elif isinstance(value, dict): - return {k: _convert_value_to_dict(v) for k, v in value.items()} - else: - return value - - -def _contains_code_references(config_dict: Dict[str, Any]) -> bool: - """ - Check if a configuration dictionary contains code references. - - Args: - config_dict: The configuration dictionary to check - - Returns: - True if code references are found, False otherwise - """ - if isinstance(config_dict, dict): - for key, value in config_dict.items(): - # Check for _target_ that's not a built-in type - if key == "_target_" and isinstance(value, str): - # Consider it a code reference if it's not a basic type - if not value.startswith( - ("builtins.", "str", "int", "float", "bool", "list", "dict", "tuple") - ): - return True - # Check for _call_ = False which indicates a code reference - if key == "_call_" and value is False: - return True - # Recursively check nested structures - if _contains_code_references(value): - return True - elif isinstance(config_dict, (list, tuple)): - for item in config_dict: - if _contains_code_references(item): - return True - - return False diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py b/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py index e7a20e1f97..b8a5672155 100644 --- a/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py @@ -1,32 +1,7 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge -# Import model providers for easy access from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge -from megatron.nemo_bridge.models.conversion.param_mapping import ( - AutoMapping, - ColumnParallelMapping, - GatedMLPMapping, - MegatronParamMapping, - QKVMapping, - ReplicatedMapping, - RowParallelMapping, -) -from megatron.nemo_bridge.models.conversion.utils import weights_verification_table __all__ = [ "AutoBridge", - "MegatronMappingRegistry", - "MegatronModelBridge", - "ColumnParallelMapping", - "GatedMLPMapping", - "MegatronParamMapping", - "QKVMapping", - "ReplicatedMapping", - "RowParallelMapping", - "AutoMapping", - "weights_verification_table", ] diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py b/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py index 88a9a7f9b9..8112c58950 100644 --- a/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py @@ -1,247 +1,41 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge -import dataclasses - -from functools import cached_property, partial -from pathlib import Path -from typing import Any, Generic, Iterable, List, Optional, Type, TypeVar, Union - -import torch.distributed as dist +from megatron.bridge import AutoBridge as OriginalAutoBridge import transformers - +import torch.distributed as dist from transformers import AutoModelForCausalLM from transformers.configuration_utils import PretrainedConfig -from typing_extensions import Unpack from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig - from megatron.nemo_bridge.models.conversion import model_bridge -from megatron.nemo_bridge.models.conversion.model_bridge import ( - HFWeightTuple, - MegatronModelBridge, - WeightConversionTask, -) -from megatron.nemo_bridge.models.conversion.utils import get_causal_lm_class_via_auto_map +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge -# from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import ( - safe_load_config_with_retry, -) -from megatron.nemo_bridge.models.hf_pretrained.state import SafeTensorsStateSource - -# from megatron.nemo_bridge.models.model_provider import GetModelKwargs, ModelParallelKwargs, ModelProviderMixin - +from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource +from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry +from megatron.bridge.models.conversion.utils import get_causal_lm_class_via_auto_map +from typing import TypeVar, Union +from pathlib import Path MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule) -DataclassT = TypeVar("DataclassT") - - -class AutoBridge(Generic[MegatronModelT]): - """ - Automatically select and instantiate the appropriate bridge for a model. - - This unified bridge class combines automatic model detection with full bridge - functionality for converting models between HuggingFace and Megatron formats. - It handles the conversion of causal language models (e.g., GPT, Llama, Phi) - between HuggingFace's transformers library format and Megatron-Core's distributed - training format. It manages weight mapping, tensor parallelism distribution, and - configuration translation. - - The bridge supports both directions of conversion: - - HuggingFace → Megatron: For training or inference with Megatron - - Megatron → HuggingFace: For saving trained models in HF format - - Args: - hf_pretrained: Either a PreTrainedCausalLM instance with loaded model, - or a PretrainedConfig for configuration-only operations - - Example: - >>> # Load and convert a model to Megatron format - >>> bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B") - >>> provider = bridge.to_megatron_provider() - >>> megatron_model = provider.provide_distributed_model(wrap_with_ddp=False) - - >>> # Export a Megatron model back to HuggingFace format - >>> bridge.save_hf_pretrained(megatron_model, "./exported_model") - >>> # Convert weights with custom settings - >>> for name, weight in bridge.export_hf_weights( - ... megatron_model, - ... cpu=True - ... ): - ... print(f"Exported {name}: {weight.shape}") - >>> # Check if a model is supported before loading - >>> if AutoBridge.can_handle("microsoft/phi-2"): - ... bridge = AutoBridge.from_hf_pretrained("microsoft/phi-2") - - Note: - The bridge automatically detects the model architecture and applies - the appropriate weight mappings. Custom architectures require implementing - a MegatronModelBridge subclass. - """ +class AutoBridge(OriginalAutoBridge): def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): - raise ValueError( - "hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance" - ) + raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance") self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained - - @classmethod - def list_supported_models(cls) -> list[str]: - """ - List all model architectures currently supported by the bridge system. - - Returns: - List of supported HuggingFace model architecture names - """ - # Get all registered implementations from the dispatch system - supported = [] - - # Access the dispatch registry to find all registered types - - if hasattr(model_bridge.get_model_bridge, "_exact_types"): - for arch_type in model_bridge.get_model_bridge._exact_types.keys(): - # Support both type and string registrations - if isinstance(arch_type, str): - supported.append(arch_type) - elif hasattr(arch_type, "__name__"): - supported.append(arch_type.__name__) - - return sorted(supported) - - @classmethod - def supports(cls, config: Any) -> bool: - """ - Check if this bridge supports the given model configuration. - - A model is supported if it has at least one architecture ending with 'ForCausalLM' or 'ForConditionalGeneration' - or 'NemotronH_Nano_VL_V2'. - - Args: - config: HuggingFace model config object - - Returns: - True if this bridge can handle the model, False otherwise - """ - architectures = getattr(config, "architectures", []) - if not architectures: - return False - return any( - arch.endswith(("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2")) - for arch in architectures - ) - - @classmethod - def from_hf_config(cls, config: PretrainedConfig) -> "AutoBridge": - """ - Create an AutoBridge from a HuggingFace configuration. - - This method creates a bridge instance from just a model configuration, - without loading any weights. This is useful for: - - Creating Megatron models with random initialization - - Working with model architectures without downloading weights - - Testing and development scenarios - - Args: - config: HuggingFace PretrainedConfig instance containing model - architecture information - - Returns: - AutoBridge: Bridge instance configured for the architecture - - Raises: - ValueError: If the configuration is not for a supported CausalLM model - - Example: - >>> from transformers import AutoConfig - >>> - >>> # Load just the configuration - >>> config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B") - >>> - >>> # Create bridge from config (no weights) - >>> bridge = AutoBridge.from_hf_config(config) - >>> - >>> # Create Megatron model with random initialization - >>> provider = bridge.to_megatron_provider(load_weights=False) - >>> model = provider.provide_distributed_model(wrap_with_ddp=False) - - >>> # Or use for architecture exploration - >>> transformer_config = bridge.transformer_config - >>> print(f"Hidden size: {transformer_config.hidden_size}") - >>> print(f"Num layers: {transformer_config.num_layers}") - - See Also: - from_hf_pretrained: Create bridge with loaded weights - transformer_config: Access the Megatron TransformerConfig - """ - cls._validate_config(config) - model = PreTrainedCausalLM() - model.config = config - import torch - - from accelerate import init_empty_weights - from accelerate.utils import set_module_tensor_to_device - - with init_empty_weights(): - hf_model = AutoModelForCausalLM.from_config(model.config) - - for name, param in hf_model.named_parameters(): - set_module_tensor_to_device( - hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype) - ) - model.model = hf_model - return cls(model) + super().__init__(hf_pretrained) @classmethod def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge": """ Load an AutoBridge from a pretrained model, automatically detecting the model type. - - This method loads a model from HuggingFace Hub or a local directory and - creates a bridge instance ready for conversion operations. The model - architecture is validated to ensure compatibility. - - Args: - path: HuggingFace model ID or path to model directory - Examples: "meta-llama/Meta-Llama-3-8B", "./my_model" - **kwargs: Additional arguments passed to HuggingFace from_hf_pretrained - Common options include: - - torch_dtype: Model precision (torch.float16, torch.bfloat16) - - device_map: Device placement strategy ("auto", "cuda:0", etc.) - - trust_remote_code: Allow custom model code execution - - attn_implementation: Attention implementation ("flash_attention_2", etc.) - - Returns: - AutoBridge: Bridge instance with loaded model - - Raises: - ValueError: If the model architecture is not supported - - Example: - >>> # Basic loading - >>> bridge = AutoBridge.from_hf_pretrained("gpt2") - - >>> # Load with specific settings - >>> bridge = AutoBridge.from_hf_pretrained( - ... "meta-llama/Meta-Llama-3-8B", - ... torch_dtype=torch.float16, - ... device_map="auto" - ... ) - - >>> # Works with local paths too - >>> bridge = AutoBridge.from_hf_pretrained("/path/to/model") """ # First load just the config to check architecture support # Use thread-safe config loading to prevent race conditions - config = safe_load_config_with_retry( - path, trust_remote_code=kwargs.get("trust_remote_code", False) - ) + config = safe_load_config_with_retry(path, trust_remote_code=kwargs.get("trust_remote_code", False)) cls._validate_config(config, str(path)) @@ -250,243 +44,13 @@ def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge": except Exception as e: raise ValueError(f"Failed to load model with AutoBridge: {e}") from e - def load_hf_weights( - self, model: list[MegatronModelT], hf_path: str | Path | None = None - ) -> None: - """ - Load HuggingFace weights into a Megatron model. - - This method handles the conversion and distribution of weights from - HuggingFace format to Megatron's distributed format, including proper - tensor parallel and pipeline parallel distribution. - - Args: - model: List of Megatron model instances (one per virtual pipeline stage) - hf_path: Optional path to load weights from. If None, uses weights - from the bridge's hf_pretrained instance - - Returns: - The input model with loaded weights - - Raises: - ValueError: If hf_path is None and bridge was created without weights - - Example: - >>> # Load weights from bridge's pretrained model - >>> bridge = AutoBridge.from_hf_pretrained("gpt2") - >>> megatron_model = create_megatron_model() # Your model creation - >>> bridge.load_hf_weights(megatron_model) - - >>> # Load weights from a different checkpoint - >>> bridge.load_hf_weights(megatron_model, "./finetuned_model") - """ - if hf_path is None: - if not isinstance(self.hf_pretrained, PreTrainedCausalLM): - raise ValueError( - "hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance" - ) - pre_trained = self.hf_pretrained - else: - pre_trained = PreTrainedCausalLM.from_pretrained(hf_path) - # Preserve trust_remote_code setting from the original bridge instance - trust_remote_code = getattr(self.hf_pretrained, 'trust_remote_code', False) - pre_trained = PreTrainedCausalLM.from_pretrained( - hf_path, trust_remote_code=trust_remote_code - ) - # self._model_bridge.load_weights_hf_to_megatron(model, pre_trained) - self._model_bridge.load_weights_hf_to_megatron(pre_trained, model) - - return model - - def save_hf_pretrained( - self, - model: list[MegatronModelT], - path: str | Path, - show_progress: bool = True, - strict: bool = True, - ) -> None: - """ - Save a Megatron model in HuggingFace format. - - This method exports the complete model including configuration, tokenizer, - and weights to a directory that can be loaded with HuggingFace's - from_pretrained methods. - - If the original model was loaded with trust_remote_code=True, any custom - modeling files (e.g., modeling_*.py, configuration_*.py) will be preserved - to ensure the saved model can be loaded properly. - - Args: - model: Megatron model instance or list of instances - path: Directory path to save the model - show_progress: Display progress bar during weight export - - Example: - >>> # Save model after training - >>> bridge.save_hf_pretrained(megatron_model, "./my_finetuned_model") - - >>> # Load the saved model with HuggingFace - >>> from transformers import AutoModelForCausalLM - >>> hf_model = AutoModelForCausalLM.from_pretrained("./my_finetuned_model") - - Note: - This method is collective - all ranks must call it. Only rank 0 - saves the configuration files, while weight saving is coordinated - across all ranks. - """ - if dist.is_available() and dist.is_initialized(): - # Distributed training, only rank 0 saves artifacts - if dist.get_rank() == 0: - self.hf_pretrained.save_artifacts(path) - else: - # No distributed training, save artifacts - self.hf_pretrained.save_artifacts(path) - self.save_hf_weights(model, path, show_progress, strict) - - def save_hf_weights( - self, - model: list[MegatronModelT], - path: str | Path, - show_progress: bool = True, - strict: bool = True, - ) -> None: - """ - Save Megatron model weights in HuggingFace safetensors format. - - This method exports only the model weights (not configuration or tokenizer) - to safetensors files compatible with HuggingFace. It uses streaming save - to handle large models efficiently without requiring all weights in memory - at once. - - The weights are gathered from distributed ranks and saved in the standard - HuggingFace sharded format when the model is large. - - Args: - model: Megatron model instance or list of instances - path: Directory path where weight files will be saved - show_progress: Display progress bar during export - - Raises: - ValueError: If the state source doesn't support streaming save - - Example: - >>> # Save just the weights - >>> bridge.save_hf_weights(megatron_model, "./model_weights") - - >>> # Save without progress bar (useful in scripts) - >>> bridge.save_hf_weights(megatron_model, "./weights", show_progress=False) - - Note: - - This method is collective and must be called by all ranks - - Uses safetensors format for efficient loading and security - - Automatically handles model sharding for large models - - The saved weights can be loaded with HuggingFace's from_pretrained - """ - if dist.is_available() and dist.is_initialized(): - dist.barrier() - dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) - generator = model_bridge.stream_weights_megatron_to_hf( - dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress - ) - source = SafeTensorsStateSource(path) - # Check if the state source is SafeTensorsStateSource for streaming save. - if ( - hasattr(self.hf_pretrained, "state") - and hasattr(self.hf_pretrained.state, "source") - # and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource) - ): - # self.hf_pretrained.state.source.save_generator(generator, path, strict=strict) - source.save_generator(generator, path, strict=strict) - else: - raise ValueError( - "The state source is not a SafeTensorsStateSource, cannot save in streaming mode." - ) - - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - @property - def _model_bridge(self) -> "MegatronModelBridge": - return model_bridge.get_model_bridge(self._causal_lm_architecture) - - @cached_property - def _causal_lm_architecture(self): - """Resolve the model's CausalLM architecture for dispatch. - - Behavior: - - If the model can be imported from transformers directly, return the actual transformers class object. - - Otherwise, if the model uses HuggingFace auto_map, return the architecture's class name as a string (e.g., - "DeepseekV2ForCausalLM"). - - Returns: - str | type: The transformers class for the CausalLM architecture or the architecture's class name as a - string for auto_map models. - - Raises: - ValueError: If no CausalLM architecture is found or cannot be resolved. - """ - if isinstance(self.hf_pretrained, PreTrainedCausalLM): - config = self.hf_pretrained.config - model_name_or_path = getattr(config, "_name_or_path", None) or getattr( - self.hf_pretrained, "model_name_or_path", None - ) - else: - config = self.hf_pretrained - model_name_or_path = getattr(config, "_name_or_path", None) - - architectures = getattr(config, "architectures", []) - - if not architectures: - raise ValueError( - "\n✗ No architectures found in model config\n\n" - "The model configuration does not specify any architectures.\n" - "This is required for determining the model type." - ) - - causal_lm_arch = None - for architecture_name in architectures: - # TODO: Can we improve this? - if architecture_name.endswith( - ("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2") - ): - causal_lm_arch = architecture_name - break - - if not causal_lm_arch: - raise ValueError( - f"\n✗ No CausalLM architecture found\n\n" - f"Model architectures: {architectures}\n\n" - f"None of the architectures end with 'ForCausalLM' or 'ForConditionalGeneration' or" - f"'NemotronH_Nano_VL_V2'.\n" - f"This bridge only supports causal language models.\n" - f"For other model types, use a different bridge class." - ) - - # Try auto_map first - cls = get_causal_lm_class_via_auto_map(model_name_or_path=model_name_or_path, config=config) - if cls is not None: - # For auto_map models, return the class name as a string - return getattr(cls, "__name__", str(cls)) - - try: - return getattr(transformers, causal_lm_arch) - except AttributeError: - raise ValueError( - f"\n✗ Architecture class '{causal_lm_arch}' not found in transformers\n\n" - f"This could mean:\n" - f"1. The model requires a newer version of transformers\n" - f"2. The model uses a custom modeling file not in the standard library\n" - f"3. There's a typo in the architecture name\n\n" - f"Please verify your transformers installation and the model requirements." - ) - @classmethod def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> None: # Check if this is a causal LM model if not cls.supports(config): architectures = getattr(config, "architectures", []) raise ValueError( - f"\n✗ Model architecture not supported by AutoBridge\n\n" + f"\n�~\~W Model architecture not supported by AutoBridge\n\n" f"Model: {path}\n" f"Architectures: {architectures}\n\n" f"AutoBridge only supports models with architectures ending in 'ForCausalLM' or" @@ -522,7 +86,7 @@ def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> except AttributeError: # Fall back to name-based registration arch_key = architecture - + # Test if we have a registered implementation (type or class-name string) has_implementation = False if hasattr(model_bridge.get_model_bridge, "_exact_types"): @@ -535,15 +99,10 @@ def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> ) if not has_implementation: - # Get list of supported models - supported_models = cls.list_supported_models() - raise ValueError( - f"\n✗ Model architecture '{architecture}' is not yet supported\n\n" + f"\n�~\~W Model architecture '{architecture}' is not yet supported\n\n" f"Model: {path}\n" f"Architecture: {architecture}\n\n" - f"Currently supported architectures:\n" - + "\n".join(f" • {model}" for model in supported_models) + f"\n\nTo add support for {architecture}, you need to:\n" f"1. Create a new bridge class that inherits from MegatronModelBridge\n" f"2. Implement the required methods (provider_bridge, mapping_registry)\n" @@ -561,12 +120,83 @@ def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> f" # Return a MegatronMappingRegistry with weight mappings\n" f" ...\n\n" f"For reference implementations, see:\n" - f" • src/megatron/bridge/models/llama/llama_bridge.py\n" - f" • src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py" + f" �~@� src/megatron/bridge/models/llama/llama_bridge.py\n" + f" �~@� src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py" ) from None - def _get_model_instance(self, model: list[MegatronModelT]) -> MegatronModelT: - model_instance = model[0] - while hasattr(model_instance, "module"): - model_instance = model_instance.module - return model_instance + + @classmethod + def from_hf_config(cls, config: PretrainedConfig) -> "AutoBridge": + cls._validate_config(config) + model = PreTrainedCausalLM() + model.config = config + import torch + + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + + with init_empty_weights(): + hf_model = AutoModelForCausalLM.from_config(model.config) + + for name, param in hf_model.named_parameters(): + set_module_tensor_to_device( + hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype) + ) + model.model = hf_model + return cls(model) + + def load_hf_weights( + self, model: list[MegatronModelT], hf_path: str | Path | None = None + ) -> None: + if hf_path is None: + if not isinstance(self.hf_pretrained, PreTrainedCausalLM): + raise ValueError( + "hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance" + ) + pre_trained = self.hf_pretrained + else: + pre_trained = PreTrainedCausalLM.from_pretrained(hf_path) + # Preserve trust_remote_code setting from the original bridge instance + trust_remote_code = getattr(self.hf_pretrained, 'trust_remote_code', False) + pre_trained = PreTrainedCausalLM.from_pretrained( + hf_path, trust_remote_code=trust_remote_code + ) + # self._model_bridge.load_weights_hf_to_megatron(model, pre_trained) + self._model_bridge.load_weights_hf_to_megatron(pre_trained, model) + + return model + + def save_hf_weights( + self, + model: list[MegatronModelT], + path: str | Path, + show_progress: bool = True, + strict: bool = True, + ) -> None: + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) + generator = model_bridge.stream_weights_megatron_to_hf( + dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress + ) + source = SafeTensorsStateSource(path) + # Check if the state source is SafeTensorsStateSource for streaming save. + if ( + hasattr(self.hf_pretrained, "state") + and hasattr(self.hf_pretrained.state, "source") + # and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource) + ): + # self.hf_pretrained.state.source.save_generator(generator, path, strict=strict) + source.save_generator(generator, path, strict=strict) + else: + raise ValueError( + "The state source is not a SafeTensorsStateSource, cannot save in streaming mode." + ) + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + @property + def _model_bridge(self) -> "MegatronModelBridge": + return model_bridge.get_model_bridge(self._causal_lm_architecture) diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py b/flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py deleted file mode 100644 index 58e154eb3c..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/conversion/mapping_registry.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import re - -from typing import List, Optional - -from megatron.nemo_bridge.models.conversion.param_mapping import MegatronParamMapping - - -class MegatronMappingRegistry: - """ - Registry for weight mappings between model formats with pattern matching support. - - This class serves as a registry of weight mappings between Megatron and external - (typically HuggingFace) model formats. It provides efficient pattern matching - for parameter names using glob-like wildcards (*) and supports both forward - (Megatron → HF) and reverse (HF → Megatron) lookups. - - The registry pre-compiles regex patterns for efficient repeated lookups and - handles the resolution of wildcards in parameter names. - - Args: - *mappings: Variable number of MegatronParamMapping objects defining - the individual weight mappings - - Example: - >>> # Create a mapping registry with various mappings - >>> mapping_registry = MegatronMappingRegistry( - ... AutoMapping( - ... megatron_param="embedding.word_embeddings.weight", - ... hf_param="model.embed_tokens.weight", - ... ), - ... QKVMapping( - ... megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", - ... q="model.layers.*.self_attn.q_proj.weight", - ... k="model.layers.*.self_attn.k_proj.weight", - ... v="model.layers.*.self_attn.v_proj.weight", - ... ), - ... ) - - >>> # Query for a specific layer (wildcards are resolved) - >>> mapping = mapping_registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.weight") - >>> print(mapping.hf_param) # Will show resolved HF names for layer 0 - - >>> # Reverse lookup from HF name - >>> mapping = mapping_registry.hf_to_megatron_lookup("model.layers.5.self_attn.q_proj.weight") - >>> print(mapping.megatron_param) # Shows corresponding Megatron name - - >>> # Build from a list - >>> mappings = [bridge1, bridge2, bridge3] - >>> mapping_registry = MegatronMappingRegistry(*mappings) - - Note: - Wildcard patterns support: - - '*' matches any sequence of digits (0-9) - designed for layer indices - - '**' matches any sequence of characters - designed for nested paths - """ - - def _convert_pattern_to_regex(self, pattern: str) -> str: - """Convert a pattern with wildcards to regex pattern. - - Args: - pattern: Pattern string with * and ** wildcards - - Returns: - Regex pattern string - - Note: - ** must be processed before * to avoid conflicts. - ** becomes (.*) - matches any characters including dots - * becomes (\\d+) - matches digits only for layer indices - """ - # Escape the pattern first - regex_pattern = re.escape(pattern) - - # Process ** before * to avoid conflicts - # Replace \*\* with (.*) - regex_pattern = regex_pattern.replace(r"\*\*", r"(.*)") - - # Replace remaining \* with (\d+) - regex_pattern = regex_pattern.replace(r"\*", r"(\d+)") - - return regex_pattern - - def __init__(self, *mappings: MegatronParamMapping): - """ - Initialize MegatronMappingRegistry with weight mappings. - - Args: - *mappings: MegatronParamMapping objects - """ - self.mappings = list(mappings) - - # Pre-compile patterns for efficiency - self._compiled_patterns = [] - self._reverse_patterns = [] # For hf_param -> megatron lookups - - for mapping in mappings: - # Compile source patterns - if "*" in mapping.megatron_param: - # Convert glob pattern to regex with support for * and ** - pattern = self._convert_pattern_to_regex(mapping.megatron_param) - self._compiled_patterns.append((re.compile(f"^{pattern}$"), mapping)) - else: - self._compiled_patterns.append((None, mapping)) - - # Compile destination patterns for reverse lookups - if isinstance(mapping.hf_param, str): - if "*" in mapping.hf_param: - pattern = self._convert_pattern_to_regex(mapping.hf_param) - self._reverse_patterns.append((re.compile(f"^{pattern}$"), mapping)) - else: - self._reverse_patterns.append((None, mapping)) - else: - # For dict destinations, compile patterns for each value - reverse_dict_patterns = {} - for key, hf_pattern in mapping.hf_param.items(): - if "*" in hf_pattern: - pattern = self._convert_pattern_to_regex(hf_pattern) - reverse_dict_patterns[key] = re.compile(f"^{pattern}$") - else: - reverse_dict_patterns[key] = None - self._reverse_patterns.append((reverse_dict_patterns, mapping)) - - def megatron_to_hf_lookup(self, megatron_param_name: str) -> Optional[MegatronParamMapping]: - """ - Get mapping for a Megatron parameter name. - - This method performs efficient lookups by first checking for exact matches, - then falling back to pattern matching using pre-compiled regex patterns. - When a pattern match is found, wildcards are automatically resolved. - - Args: - megatron_param_name: Megatron parameter name to look up - Example: "decoder.layers.0.self_attention.linear_qkv.weight" - - Returns: - MegatronParamMapping: Bridge instance with resolved wildcards, or None - if no matching mapping is found. The returned bridge will have - all wildcards replaced with actual values. - - Example: - >>> # Query with exact layer number - >>> bridge = state_map.megatron_to_hf_lookup("decoder.layers.5.mlp.linear_fc1.weight") - >>> if bridge: - ... print(f"Maps to: {bridge.hf_param}") # Shows HF name for layer 5 - """ - for pattern, mapping in self._compiled_patterns: - if pattern is None: - # Direct match - if mapping.megatron_param == megatron_param_name: - return mapping - else: - # Pattern match - match = pattern.match(megatron_param_name) - if match: - # Return resolved mapping with wildcards replaced - return mapping.resolve(match.groups()) - return None - - def hf_to_megatron_lookup(self, hf_param_name: str) -> Optional[MegatronParamMapping]: - """ - Get mapping for a destination parameter name (reverse lookup). - - This is useful when you have a destination name and want to find - the corresponding megatron name. - - Args: - hf_param_name: Destination parameter name to look up - - Returns: - MegatronParamMapping with resolved wildcards, or None if no match found - """ - for pattern_info, mapping in self._reverse_patterns: - if isinstance(mapping.hf_param, str): - # Simple string destination - pattern = pattern_info - if pattern is None: - # Direct match - if mapping.hf_param == hf_param_name: - return mapping - else: - # Pattern match - match = pattern.match(hf_param_name) - if match: - return mapping.resolve(match.groups()) - else: - # Dict destination - check each pattern - patterns_dict = pattern_info - for key, pattern in patterns_dict.items(): - if pattern is None: - # Direct match - if mapping.hf_param[key] == hf_param_name: - # Create a simplified mapping for this specific key - return mapping.resolve(()) - else: - # Pattern match - match = pattern.match(hf_param_name) - if match: - return mapping.resolve(match.groups()) - return None - - def get_all_mappings(self) -> List[MegatronParamMapping]: - """Get all mappings in this MegatronMappingRegistry.""" - return self.mappings.copy() - - def get_mappings_by_pattern(self, pattern: str) -> List[MegatronParamMapping]: - """ - Get all mappings that match a given pattern. - - Args: - pattern: Pattern to match (supports * and ** wildcards) - - Returns: - List of matching MegatronParamMapping objects - """ - # Convert pattern to regex using the same logic as _convert_pattern_to_regex - # but for this method we want both * and ** to match anything for search purposes - regex_pattern = re.escape(pattern) - regex_pattern = regex_pattern.replace(r"\*\*", r".*") - regex_pattern = regex_pattern.replace(r"\*", r".*") - compiled_pattern = re.compile(f"^{regex_pattern}$") - - matches = [] - for mapping in self.mappings: - if compiled_pattern.match(mapping.megatron_param): - matches.append(mapping) - - return matches - - def __len__(self) -> int: - """Return number of mappings.""" - return len(self.mappings) - - def __iter__(self): - """Iterate over mappings.""" - return iter(self.mappings) - - def __repr__(self) -> str: - """String representation of MegatronMappingRegistry.""" - return f"MegatronMappingRegistry({len(self.mappings)} mappings)" - - def describe(self) -> str: - """ - Get a human-readable description of all mappings. - - Returns: - Formatted string describing all weight mappings - """ - lines = [f"MegatronMappingRegistry with {len(self.mappings)} mappings:"] - for i, mapping in enumerate(self.mappings): - lines.append(f"\n{i + 1}. {mapping.megatron_param}") - if isinstance(mapping.hf_param, str): - lines.append(f" → {mapping.hf_param}") - else: - lines.append(" → {") - for key, value in mapping.hf_param.items(): - lines.append(f" {key}: {value}") - lines.append(" }") - - # Show bridge type - lines.append(f" bridge: {type(mapping).__name__}") - - return "\n".join(lines) diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py b/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py index 6858337ef4..c9554ffab6 100644 --- a/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py @@ -2,19 +2,13 @@ # # Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge -import abc import itertools import logging -import re -from dataclasses import dataclass from typing import ( Callable, - Generic, Iterable, List, - Mapping, - NamedTuple, Optional, Type, TypeVar, @@ -22,33 +16,34 @@ ) import torch - -from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn from transformers.modeling_utils import PreTrainedModel from megatron.core import parallel_state from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import get_pg_size, unwrap_model +from megatron.core.utils import unwrap_model -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.nemo_bridge.models.conversion.param_mapping import MegatronParamMapping -from megatron.nemo_bridge.models.conversion.utils import ( - extract_sort_key, +from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping +from megatron.bridge.models.conversion.utils import ( get_module_and_param_from_name, persistent_buffers, ) -from megatron.nemo_bridge.models.decorators.dispatch import dispatch -from megatron.nemo_bridge.utils.common_utils import print_rank_0 +from megatron.bridge.utils.common_utils import print_rank_0 +from megatron.bridge.models.decorators.dispatch import dispatch logger = logging.getLogger(__name__) +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge as OriginalMegatronModelBridge +from megatron.bridge.models.conversion.model_bridge import ( + _megatron_local_name_to_global, + stream_weights_megatron_to_hf, + HFWeightTuple, + WeightConversionTask, +) MappingT = TypeVar("MappingT", bound=MegatronParamMapping) HFPreTrained = TypeVar("HFPreTrained") MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) _BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") - def padding_embedd_size(mcore_weight: torch.Tensor, hf_vocab_size: int): hf_size = hf_vocab_size mcore_size = mcore_weight.shape[0] @@ -74,438 +69,53 @@ def padding_embedd_size(mcore_weight: torch.Tensor, hf_vocab_size: int): full_word = mcore_weight return full_word +class MegatronModelBridge(OriginalMegatronModelBridge): -class MegatronWeightTuple(NamedTuple): - """Tuple representing a Megatron model weight with its metadata.""" - - param_name: str - weight: torch.Tensor - vp_stage: int - - -class HFWeightTuple(NamedTuple): - """Tuple representing a HuggingFace model weight with its metadata.""" - - param_name: str - weight: torch.Tensor - - -@dataclass(frozen=True) -class WeightConversionTask(Generic[MappingT]): - """A unified task for converting weights between HuggingFace and Megatron formats. - - This class combines both HF->Megatron and Megatron->HF conversion tasks since they - have different method names (hf_to_megatron vs megatron_to_hf) and can coexist safely. - - The task encapsulates all information needed for weight conversion in either direction, - with different fields being relevant depending on the conversion type. - - Attributes: - param_name (str): *unwrapped, local* parameter name (no ``module.`` prefixes). - mapping (MappingT): Concrete :pyclass:`MegatronParamMapping` instance responsible - for weight transformation and distribution. - - pp_rank (Optional[int]): Pipeline-parallel rank that owns the parameter (required for saves). - vp_stage (Optional[int]): Virtual-pipeline stage index (required for loads). - megatron_module (Optional[torch.nn.Module]): Reference to the Megatron model or - sub-module that owns the parameter (required for loads). - param_weight (Optional[torch.Tensor]): The actual parameter tensor that will - receive the converted weight (required for loads). - - """ - - param_name: str - mapping: MappingT - pp_rank: Optional[int] = None - vp_stage: Optional[int] = None - megatron_module: Optional[torch.nn.Module] = None - param_weight: Optional[torch.Tensor] = None - - -def _megatron_local_name_to_global( - models: MegatronModule | List[MegatronModule], - config: TransformerConfig, - param_name: str, - vp_stage: Optional[int] = None, -) -> str: - """Adjust layer number and expert number from local to global numbering.""" - # PP - pp_group = parallel_state.get_pipeline_model_parallel_group() - if "layers." in param_name and get_pg_size(pp_group) > 1: - match = re.match(r"^(.+?\.layers\.\d+)", param_name) - assert match is not None - layer_prefix = match.group(1) - _, layer_module = get_module_and_param_from_name( - models=models, param_name=layer_prefix, vp_stage=vp_stage - ) - - local_layer_number = int(param_name.split("layers.")[1].split(".")[0]) - global_layer_number = layer_module.layer_number - 1 - param_name = param_name.replace( - f"layers.{local_layer_number}.", f"layers.{global_layer_number}." - ) - - # EP - ep_group = parallel_state.get_expert_model_parallel_group() - if ".mlp.experts.linear_fc" in param_name and get_pg_size(ep_group) > 1: - num_experts = config.num_moe_experts - num_experts_per_rank = num_experts // ep_group.size() - - def _update_expert_number(param_name: str, param_type: str) -> str: - """Update expert number from local to global for weight or bias parameters.""" - local_expert_number = int(param_name.split(f".{param_type}")[-1]) - global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number - return param_name.replace( - f".{param_type}{local_expert_number}", f".{param_type}{global_expert_number}" - ) - - # Handle weight and bias parameters - if ".weight" in param_name: - param_name = _update_expert_number(param_name, "weight") - elif ".bias" in param_name: - param_name = _update_expert_number(param_name, "bias") - return param_name - - -# class MegatronModelBridge(Generic[HFPreTrained, ModelProviderTarget, MegatronModel]): -class MegatronModelBridge(Generic[HFPreTrained, MegatronModel]): - """ - High-level orchestrator for HuggingFace ↔ Megatron model conversions. - - This abstract base class provides the framework for converting models between - HuggingFace and Megatron formats. It acts as an orchestrator that coordinates - the conversion process without directly handling the complex details of - tensor parallelism or weight transformations. - - The bridge pattern separates concerns: - - MegatronModelBridge: Orchestrates the overall conversion process - - MegatronMappingRegistry: Manages parameter name mappings - - MegatronParamMapping: Handles actual weight transformations and distribution - - Key responsibilities: - 1. Build conversion tasks that map each parameter to its appropriate bridge - 2. Execute tasks with proper error handling and progress tracking - 3. Provide utilities for configuration translation - 4. Handle virtual pipeline parallelism (VP) complexities - - To implement a bridge for a new model architecture: - - 1. Create a subclass decorated with @MegatronModelBridge.register_bridge: - - .. code-block:: python - - @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) - class MegatronCausalLlamaBridge(MegatronModelBridge): - pass - - 2. Implement provider_bridge to create Megatron configurations: - - .. code-block:: python - - def provider_bridge(self, hf_pretrained) -> LlamaModelProvider: - return LlamaModelProvider( - num_layers=hf_pretrained.config.num_hidden_layers, - hidden_size=hf_pretrained.config.hidden_size, - ... - ) - - 3. Implement mapping_registry to define weight mappings: - - .. code-block:: python - - def mapping_registry(self) -> MegatronMappingRegistry: - return MegatronMappingRegistry( - AutoMapping( - megatron_param="embedding.word_embeddings.weight", - hf_param="model.embed_tokens.weight" - ), - ... - ) - - Example: - .. code-block:: python - - # The bridge is typically not instantiated directly - # Instead, use AutoBridge or AutoBridge which handle this - bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B") - provider = bridge.to_megatron_provider() - - Note: - This class uses generic type parameters to ensure type safety: - - HFPreTrained: The HuggingFace model type - - ModelProviderTarget: The Megatron model provider type - - MegatronModel: The Megatron model type - """ - - @abc.abstractmethod - def mapping_registry(self) -> MegatronMappingRegistry: - """Define weight mappings between HuggingFace and Megatron formats. - - This abstract method must be implemented by subclasses to specify how - parameters map between the two formats. The returned MegatronMappingRegistry - contains all param mappings needed for the model architecture. - - Returns: - MegatronMappingRegistry: MegatronMappingRegistry containing all weight - mapping definitions. - - Example: - .. code-block:: python - - def mapping_registry(self): - return MegatronMappingRegistry( - AutoMapping( - megatron_param="embedding.word_embeddings.weight", - hf_param="model.embed_tokens.weight" - ), - QKVMapping( - megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", - q="model.layers.*.self_attn.q_proj.weight", - k="model.layers.*.self_attn.k_proj.weight", - v="model.layers.*.self_attn.v_proj.weight" - ), - # ... more param mappings - ) - """ - raise NotImplementedError("Subclass must implement mapping_registry method") - - def _megatron_global_param_names_all_pp_ranks( + def _broadcast_shared_embeddings( self, megatron_model: Union[MegatronModel, List[MegatronModel]] - ) -> List[str]: - """Get all parameter names across all pipeline parallel ranks.""" - # Cache the result after first call - if hasattr(self, "_cached_param_names"): - return self._cached_param_names - - # Compute the result - pp_group = parallel_state.get_pipeline_model_parallel_group() - model_config = unwrap_model(megatron_model)[0].config - global_param_names = [] - - # Ensure megatron_model is a list for consistent handling - models_list = megatron_model if isinstance(megatron_model, list) else [megatron_model] - - for vp_stage, model in enumerate(models_list): - # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately - for local_param_name, _ in itertools.chain( - model.named_parameters(), persistent_buffers(model) - ): - if "_extra_state" in local_param_name: - continue - local_param_name = self._unwrap_name(local_param_name) - global_param_name = _megatron_local_name_to_global( - models_list, model_config, local_param_name, vp_stage - ) - global_param_names.append(global_param_name) - - gathered_global_param_names = [None] * pp_group.size() - torch.distributed.all_gather_object( - gathered_global_param_names, global_param_names, group=pp_group - ) - - # flatten the list, sort it and remove duplicates - # the order matters here, casually re-order will cause a hang. - # e.g. decoder.layers.0.mlp.experts.linear_fc1.weight100 - flattened_names = list(set(sum(gathered_global_param_names, []))) - - # the order cannot be changed, this sync for all ranks for conversion - # change this might cause a hang - gathered_global_param_names = sorted(flattened_names, key=extract_sort_key) - - # Cache the result - self._cached_param_names = gathered_global_param_names - - return self._cached_param_names - - def _with_progress_tracking(self, tasks, description: str, show_progress: bool = True): - """Helper method to wrap an iterable with progress tracking. - - Args: - tasks: Iterable of tasks to process - description: Description for the progress bar - show_progress: Whether to show progress (defaults to True) - - Yields: - Items from the tasks iterable while updating progress - """ - is_main_rank = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 - bridge_name = self.__class__.__name__ - - if show_progress: - with Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeRemainingColumn(), - TextColumn("({task.completed}/{task.total})"), - TextColumn("{task.fields[bridge]}"), - disable=not is_main_rank, - ) as progress: - task_id = progress.add_task(description, total=len(tasks), bridge=bridge_name) - - for task in tasks: - yield task - progress.update(task_id, advance=1) - else: - # not using disable above because we notice it will dump some empty progress bar, - # even when disable is set to True - for task in tasks: - yield task - - def load_weights_hf_to_megatron( - self, hf_pretrained: HFPreTrained, megatron_model: Union[MegatronModel, List[MegatronModel]] - ) -> List[MegatronModel]: - """Load HuggingFace weights into Megatron models. - - This method orchestrates the complete weight loading process from HuggingFace - format to Megatron's distributed format. It builds a conversion task and - executes it with proper progress tracking and error handling. - - The actual weight transformations and distribution are delegated to the - appropriate MegatronParamMapping instances based on the state mappings. - - Args: - hf_pretrained (HFPreTrained): HuggingFace model or state source containing the - weights to load. - megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance - or list of model instances (one per virtual pipeline stage). - - Returns: - List[MegatronModel]: The input megatron_model as a list with loaded weights. - - Process: - 1. Build a task mapping each Megatron parameter to its source - 2. For each parameter in the task: - - Fetch source weights from HuggingFace state - - Apply format transformation via the param mapping - - Distribute to appropriate TP/PP ranks - - Copy into the Megatron parameter - - Example: - .. code-block:: python - - hf_model = PreTrainedCausalLM.from_pretrained("gpt2") - megatron_model = create_megatron_model() # Single model or list - bridge.load_weights_hf_to_megatron(hf_model, megatron_model) - - Note: - Progress is shown only on rank 0 to avoid cluttered output in - distributed environments. - - Raises: - ValueError: If hf_pretrained doesn't have state attribute or if weight shapes don't match. - AttributeError: If required HF weights are missing. - """ - if not isinstance(megatron_model, list): - megatron_model = [megatron_model] - - hf_to_megatron_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) - hf_state_dict: Mapping[str, torch.Tensor] = ( - hf_pretrained.state if hasattr(hf_pretrained, "state") else {} - ) - - description = f"Loading from {hf_pretrained.model_name_or_path}" - for task in self._with_progress_tracking(hf_to_megatron_tasks, description): - # None means megatron module not on current rank, skip if this task is not going to happen - if task.megatron_module is None: - continue - # 1) Fetch source tensor(s) from HF state dict - if isinstance(task.mapping.hf_param, str): - hf_weights = hf_state_dict[task.mapping.hf_param] - else: - hf_weights = {k: hf_state_dict[v] for k, v in task.mapping.hf_param.items()} - - # 2) Delegate conversion & distribution to the bridge - converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) - - # 3) Copy into Megatron param if this rank received a shard - if converted_weights is not None: - # Assert that param_weight is not None for HF->Megatron tasks - assert ( - task.param_weight is not None - ), "param_weight is required for HF->Megatron conversion" - - # Check shape compatibility before copying - if converted_weights.shape != task.param_weight.shape: - raise ValueError( - f"Shape mismatch for megatron param {task.mapping.megatron_param}:\n" - f" Expected shape: {task.param_weight.shape}\n" - f" Got shape: {converted_weights.shape}\n" - f" Bridge type: {type(task.mapping).__name__}\n" - f" HF mapping: {task.mapping.hf_param}" - ) - task.param_weight.data.copy_(converted_weights) - - self._broadcast_shared_embeddings(megatron_model) - return megatron_model - - def stream_weights_hf_to_megatron( - self, - hf_pretrained: HFPreTrained, - megatron_model: Union[MegatronModel, List[MegatronModel]], - conversion_tasks: Optional[List[WeightConversionTask]] = None, - ) -> Iterable[MegatronWeightTuple]: - """Generator variant of load_weights_hf_to_megatron for streaming weight conversion. + ) -> None: + """Broadcast shared embeddings and output weights across embedding group. - This method provides a memory-efficient way to convert weights by yielding - them one at a time instead of loading all at once. Useful for processing - very large models or when implementing custom weight handling logic. + When embeddings and output weights are shared and pipeline parallelism is enabled, + this method ensures all ranks in the embedding group have the same weights by + broadcasting from rank 0. Args: - hf_pretrained (HFPreTrained): HuggingFace model or state source containing - the weights. - megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance - or list of model instances to extract configuration from. - conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks. - If not provided, tasks will be built automatically from the models. - - Yields: - MegatronWeightTuple: Named tuples containing: - - vp_stage: Index of the model in megatron_model list - - param_name: Name of the parameter - - weight: Transformed weight tensor for this rank - - Example: - .. code-block:: python - - # Process weights one by one - for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model): - print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") - # Custom processing logic here - - # Or use pre-built conversion tasks - tasks = bridge.build_conversion_tasks(hf_model, megatron_model) - for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model, tasks): - print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") - - Note: - Only yields weights that belong to the current rank after TP/PP distribution. - - Raises: - ValueError: If input parameters are invalid. + megatron_model: Megatron model instance or list of model instances. """ + unwrapped_model = unwrap_model(megatron_model)[0] + # hack for vlm to work properly + if ( + hasattr(unwrapped_model, "language_model") + and unwrapped_model.language_model is not None + ): + unwrapped_model = unwrapped_model.language_model + model_config = unwrapped_model.config + if ( + not model_config.untie_embeddings_and_output_weights + and model_config.pipeline_model_parallel_size > 1 + ): + # Broadcast embeddings and output weights from rank 0 to embedding group + embd_group = parallel_state.get_embedding_group() + embd_group_ranks = torch.distributed.get_process_group_ranks(embd_group) + if embd_group is not None and torch.distributed.get_rank() in embd_group_ranks: + # Get embeddings and output weights from rank 0 + if hasattr(unwrapped_model, "embedding") and hasattr( + unwrapped_model.embedding, "word_embeddings" + ): + embd_weights = unwrapped_model.embedding.word_embeddings.weight.data + else: + assert hasattr(unwrapped_model, "output_layer"), "Output layer not found" + embd_weights = torch.empty_like(unwrapped_model.output_layer.weight.data) + torch.distributed.broadcast(embd_weights, src=embd_group_ranks[0], group=embd_group) + if hasattr(unwrapped_model, "output_layer"): + unwrapped_model.output_layer.weight.data.copy_(embd_weights) - if not isinstance(megatron_model, list): - megatron_model = [megatron_model] - - # Use provided conversion tasks or build them - if conversion_tasks is None: - conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) - - for task in conversion_tasks: - # None means megatron module not on current rank, skip if this task is not going to happen - if task.megatron_module is None: - continue - hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state - if isinstance(task.mapping.hf_param, str): - hf_weights = hf_state_dict[task.mapping.hf_param] - else: - hf_weights = {k: hf_state_dict[v] for k, v in task.mapping.hf_param.items()} - - converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) - if converted_weights is not None: - # Assert that vp_stage is not None for HF->Megatron tasks - yield MegatronWeightTuple(task.param_name, converted_weights, task.vp_stage) + @classmethod + def register_bridge( + cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] + ) -> Callable[[_BridgeImplClass], _BridgeImplClass]: + return create_bridge_decorator(source=source, target=target) def stream_weights_megatron_to_hf( self, @@ -516,50 +126,6 @@ def stream_weights_megatron_to_hf( conversion_tasks: Optional[List[WeightConversionTask]] = None, ) -> Iterable[HFWeightTuple]: """Export Megatron weights to HuggingFace format. - - This method orchestrates the conversion of weights from Megatron's distributed - format back to HuggingFace format. It handles gathering from tensor parallel - ranks, broadcasting across pipeline parallel ranks, and format conversions. - All ranks receive the full tensors. - - The export order is determined automatically: - - First tries safetensors order (if key_to_filename_map is available) - - Falls back to HuggingFace state dict order - - Args: - megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance - or list of model instances (one per virtual pipeline stage). - hf_pretrained (HFPreTrained): HuggingFace model/config for metadata - and mapping info. - cpu (bool, optional): Whether to move tensors to CPU before yielding. - Defaults to True. - show_progress (bool, optional): Display progress bar during export. - Defaults to True. - conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks. - If not provided, tasks will be built automatically from the models. - - Yields: - HFWeightTuple: Named tuples of (param_name, weight_tensor) in HF format. - - Example: - .. code-block:: python - - # Export weights - for name, weight in bridge.stream_weights_megatron_to_hf(megatron_model, hf_config): - print(f"Exported {name}: {weight.shape}") - - # Or use pre-built conversion tasks - tasks = bridge.build_conversion_tasks(hf_config, megatron_model) - for name, weight in bridge.stream_weights_megatron_to_hf( - megatron_model, hf_config, conversion_tasks=tasks - ): - print(f"Exported {name}: {weight.shape}") - - Raises: - ValueError: If input parameters are invalid. - - Note: - All ranks yield the full tensors after gathering from distributed format. """ if not isinstance(megatron_model, list): @@ -609,174 +175,6 @@ def stream_weights_megatron_to_hf( # Regular case - yield the tensor normally yield HFWeightTuple(hf_name, final_tensor) - def dtype_from_hf(self, config, default=None): - """Extract torch dtype from a HuggingFace config. - - This utility method handles the conversion of dtype specifications in - HuggingFace configs to PyTorch dtype objects. Supports both direct - torch.dtype objects and string representations. - - Args: - config: HuggingFace configuration object with a torch_dtype attribute. - default (Any, optional): Default value to return if torch_dtype is - not str or torch.dtype. Defaults to None. - - Returns: - torch.dtype: The corresponding PyTorch dtype. - - Raises: - AssertionError: If config doesn't have torch_dtype attribute. - ValueError: If torch_dtype is neither a string nor torch.dtype. - - Example: - .. code-block:: python - - dtype = bridge.dtype_from_hf(hf_config) - print(dtype) # torch.float16 - """ - assert hasattr(config, "torch_dtype"), "Expected config to have attr `torch_dtype`" - torch_dtype = config.torch_dtype - if isinstance(torch_dtype, torch.dtype): - return torch_dtype - elif isinstance(torch_dtype, str): - return self.dtype_from_str(torch_dtype) - elif default is not None: - return default - - raise ValueError("torch_dtype is not of type str/torch.dtype") - - def dtype_from_str(self, dtype: str) -> torch.dtype: - """Convert a string precision identifier to equivalent torch dtype. - - This utility method handles various string representations of PyTorch - data types, including common abbreviations and mixed precision formats. - - Args: - dtype (str): String representation of dtype (e.g., "float16", "fp16", - "bf16-mixed"). - - Returns: - torch.dtype: Corresponding PyTorch dtype (defaults to float32 if unknown). - - Supported formats: - - float16/fp16/16/16-mixed → torch.float16 - - bfloat16/bf16-mixed → torch.bfloat16 - - Others → torch.float32 (default) - - Example: - .. code-block:: python - - dtype = bridge.dtype_from_str("fp16") - print(dtype) # torch.float16 - - dtype = bridge.dtype_from_str("bf16-mixed") - print(dtype) # torch.bfloat16 - """ - assert isinstance(dtype, str) - if dtype in ["float16", "fp16", "16", "16-mixed"]: - return torch.float16 - elif dtype in ["bfloat16", "bf16-mixed"]: - return torch.bfloat16 - else: - return torch.float32 - - def make_vocab_size_divisible_by(self, vocab_size: int) -> int: - """Calculate an appropriate divisor for vocabulary size padding. - - Megatron requires vocabulary sizes to be divisible by certain values for - efficient tensor parallelism. This method finds the largest power of 2 - (up to 128) that evenly divides the vocabulary size. - - Args: - vocab_size (int): Original vocabulary size from the model. - - Returns: - int: Largest power of 2 (≤ 128) that divides vocab_size. - - Example: - .. code-block:: python - - # For vocab_size=50257 (GPT-2) - divisor = bridge.make_vocab_size_divisible_by(50257) - print(divisor) # 1 (50257 is prime) - - # For vocab_size=32000 (Llama) - divisor = bridge.make_vocab_size_divisible_by(32000) - print(divisor) # 128 - - Note: - The returned value is used by Megatron to potentially pad the - vocabulary to ensure efficient parallelization. - """ - base = 128 - while vocab_size % base != 0: - base //= 2 - return base - - # def _get_provider_from_model(self, model: MegatronModule) -> ModelProviderTarget: - # """Extract provider/config from model.""" - # model = unwrap_model(model) - # return model.config - - def _unwrap_name(self, name: str) -> str: - """Unwrap name from DDP or other wrappers. - - Args: - name: Parameter name that may have 'module.' prefixes - - Returns: - Unwrapped parameter name with 'module.' prefixes removed - - Example: - 'module.module.decoder.weight' -> 'decoder.weight' - """ - if not isinstance(name, str): - raise ValueError(f"name must be a string, got {type(name)}") - - while name.startswith("module."): - name = name[len("module.") :] - return name - - def _broadcast_shared_embeddings( - self, megatron_model: Union[MegatronModel, List[MegatronModel]] - ) -> None: - """Broadcast shared embeddings and output weights across embedding group. - - When embeddings and output weights are shared and pipeline parallelism is enabled, - this method ensures all ranks in the embedding group have the same weights by - broadcasting from rank 0. - - Args: - megatron_model: Megatron model instance or list of model instances. - """ - unwrapped_model = unwrap_model(megatron_model)[0] - # hack for vlm to work properly - if ( - hasattr(unwrapped_model, "language_model") - and unwrapped_model.language_model is not None - ): - unwrapped_model = unwrapped_model.language_model - model_config = unwrapped_model.config - if ( - not model_config.untie_embeddings_and_output_weights - and model_config.pipeline_model_parallel_size > 1 - ): - # Broadcast embeddings and output weights from rank 0 to embedding group - embd_group = parallel_state.get_embedding_group() - embd_group_ranks = torch.distributed.get_process_group_ranks(embd_group) - if embd_group is not None and torch.distributed.get_rank() in embd_group_ranks: - # Get embeddings and output weights from rank 0 - if hasattr(unwrapped_model, "embedding") and hasattr( - unwrapped_model.embedding, "word_embeddings" - ): - embd_weights = unwrapped_model.embedding.word_embeddings.weight.data - else: - assert hasattr(unwrapped_model, "output_layer"), "Output layer not found" - embd_weights = torch.empty_like(unwrapped_model.output_layer.weight.data) - torch.distributed.broadcast(embd_weights, src=embd_group_ranks[0], group=embd_group) - if hasattr(unwrapped_model, "output_layer"): - unwrapped_model.output_layer.weight.data.copy_(embd_weights) - def build_conversion_tasks( self, hf_pretrained: HFPreTrained, megatron_model: List[MegatronModel] ) -> List[None | WeightConversionTask]: @@ -890,82 +288,11 @@ def build_conversion_tasks( return tasks - @classmethod - def register_bridge( - cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] - ) -> Callable[[_BridgeImplClass], _BridgeImplClass]: - """Class decorator for registering bridge implementations. - - This decorator registers a MegatronModelBridge subclass with the dispatch - system, enabling automatic routing of conversions based on the source - HuggingFace model type and target Megatron model type. - - Args: - source (Type[PreTrainedModel] | str): HuggingFace PreTrainedModel class - (e.g., LlamaForCausalLM) or the class name as a string. Using a - string allows registering bridges for architectures that are only - available via auto_map. - target (Type[MegatronModel]): Megatron model class (e.g., GPTModel). - - Returns: - Callable[[_BridgeImplClass], _BridgeImplClass]: Decorator function - that registers the bridge implementation. - - Example: - .. code-block:: python - - @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) - class MegatronCausalLlamaBridge(MegatronModelBridge): - def provider_bridge(self, hf_pretrained): - # Implementation - pass - - def mapping_registry(self): - # Implementation - pass - - String-based registration is also supported: - - .. code-block:: python - - @MegatronModelBridge.register_bridge(source="DeepseekV3ForCausalLM", target=GPTModel) - class MegatronDeepseekV3Bridge(MegatronModelBridge): - ... - - Note: - The decorated class is registered with multiple dispatchers to handle - different conversion scenarios. The registration is automatic when the - class is defined. - """ - - return create_bridge_decorator(source=source, target=target) - - -def is_tensor_parallel(param) -> bool: - """Check if a parameter is tensor parallel distributed.""" - return hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel - - -# Core dispatch functions @dispatch def get_model_bridge(hf_architecture) -> "MegatronModelBridge": """Get the appropriate model bridge for a given HuggingFace architecture.""" ... - -@dispatch -def stream_weights_megatron_to_hf( - dispatch_instance: MegatronModel, - megatron_model: Union[MegatronModel, List[MegatronModel]], - hf_pretrained: HFPreTrained, - cpu: bool = True, - show_progress: bool = True, - conversion_tasks: Optional[List[WeightConversionTask]] = None, -) -> Iterable[HFWeightTuple]: - """Bridge Megatron model state to HuggingFace format.""" - ... - - def register_bridge_implementation( *, source: Type["PreTrainedModel"] | str, @@ -998,12 +325,12 @@ def _megatron_to_hf_registered_impl( conversion_tasks: Optional[List[WeightConversionTask]] = None, ) -> Iterable[HFWeightTuple]: bridge = bridge_class() + + # allow bridge to access model config + bridge.hf_config = hf_pretrained.config + return bridge.stream_weights_megatron_to_hf( - megatron_model, - hf_pretrained, - cpu=cpu, - show_progress=show_progress, - conversion_tasks=conversion_tasks, + megatron_model, hf_pretrained, cpu=cpu, show_progress=show_progress, conversion_tasks=conversion_tasks ) # Set meaningful names for debugging diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py b/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py index b5bb8f2a2b..e402a0d958 100644 --- a/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py @@ -1,35 +1,22 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import json -import re - -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union import torch -import torch.distributed import torch.nn as nn - -from megatron.core import mpu -from megatron.core.fp8_utils import FP8_TENSOR_CLASS, HAVE_TE_FP8_TENSOR_CLASS -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import get_pg_rank, get_pg_size - -from megatron.nemo_bridge.models.conversion.utils import ( - get_module_and_param_from_name, - remove_non_pickleables, +from megatron.bridge.models.conversion.utils import get_module_and_param_from_name +from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping as OriginalColumnParallelMapping +from megatron.bridge.models.conversion.param_mapping import AutoMapping as OriginalAutoMapping +from megatron.bridge.models.conversion.param_mapping import QKVMapping as OriginalQKVMapping +from megatron.bridge.models.conversion.param_mapping import ( + MegatronParamMapping, + RowParallelMapping, + ReplicatedMapping, + GatedMLPMapping, ) -WeightType = TypeVar("WeightType", torch.Tensor, Dict[str, torch.Tensor]) import logging - logger = logging.getLogger(__name__) - def col_padding_size(hf_weight: torch.Tensor, mcore_weight: torch.Tensor, tp_size: int): hf_size = hf_weight.shape[0] mcore_size = mcore_weight.shape[0] * tp_size @@ -54,735 +41,13 @@ def col_padding_size(hf_weight: torch.Tensor, mcore_weight: torch.Tensor, tp_siz return full_word -class MegatronParamMapping(ABC, Generic[WeightType]): - """ - Abstract base class for weight conversion between Megatron and external formats. - - This class provides the foundation for all weight mappings, handling the complex - conversions between Megatron-Core's distributed tensor formats and standard - (typically HuggingFace) formats. Each concrete mapping implements specific - transformation logic while inheriting common parallel communication patterns. - - Key responsibilities: - - Format transformation (e.g., QKV merging/splitting, gated MLP handling) - - Tensor parallel (TP) distribution and gathering across GPUs - - Pipeline parallel (PP) broadcasting between pipeline stages - - Wildcard pattern resolution for layer-wise mappings - - The mapping abstraction ensures that higher-level code doesn't need to know - about the parallel topology or format differences - it just requests a - conversion and the mapping handles all the complexity. - - Public helper methods for subclasses: - - broadcast_from_pp_rank: Broadcast tensors across pipeline stages - - broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks - - broadcast_tensor_to_tp_ranks: Broadcast within TP group - - scatter_to_tp_ranks: Distribute tensor shards to TP ranks - - gather_from_tp_ranks: Collect tensor shards from TP ranks - - Example: - .. code-block:: python - - class MyCustomMapping(MegatronParamMapping[torch.Tensor]): - def hf_to_megatron(self, hf_weights, megatron_module): - # Custom transformation logic - transformed = hf_weights.t() # Example: transpose - # Use helpers for distribution - return self.scatter_to_tp_ranks(...) - - def megatron_to_hf(self, megatron_weights, megatron_module): - # Broadcast from owning PP rank - weight = self.broadcast_from_pp_rank(megatron_weights) - # Gather from TP ranks and transform - gathered = self.gather_from_tp_ranks(weight) - return {"custom_weight": gathered[0].t()} - """ - - def __init__(self, megatron_param: str, hf_param: Union[str, Dict[str, str]]): - """Initialize the weight mapping. - - Args: - megatron_param (str): Megatron parameter name pattern (supports * - wildcards). - hf_param (Union[str, Dict[str, str]]): External format name pattern(s). - """ - self.megatron_param = megatron_param - self.hf_param = hf_param - self._validate_patterns() - - # Cache for metadata and tensor_spec_output - self._broadcast_obj_cache = {} - self._tensor_spec_output_cache = {} - - if mpu.is_initialized(): - self.pp_group = mpu.get_pipeline_model_parallel_group() - self.ep_group = mpu.get_expert_model_parallel_group() - self._tp_group = mpu.get_tensor_model_parallel_group() - self._etp_group = mpu.get_expert_tensor_parallel_group() - else: - self.pp_group = None - self.ep_group = None - self._tp_group = None - self._etp_group = None - - @property - def tp_group(self): - """Get the tensor model parallel group.""" - if self.is_expert: - return self._etp_group - return self._tp_group - - @property - def tp_rank(self) -> int: - """Get the tensor model parallel rank.""" - return get_pg_rank(self.tp_group) - - @property - def tp_size(self) -> int: - """Get the tensor model parallel size.""" - return get_pg_size(self.tp_group) - - @property - def pp_rank(self) -> int: - """Get the pipeline model parallel rank.""" - return get_pg_rank(self.pp_group) - - @property - def pp_size(self) -> int: - """Get the pipeline model parallel size.""" - return get_pg_size(self.pp_group) - - @property - def ep_rank(self) -> int: - """Get the expert model parallel rank.""" - return get_pg_rank(self.ep_group) - - @property - def ep_size(self) -> int: - """Get the expert model parallel size.""" - return get_pg_size(self.ep_group) - - @property - def etp_rank(self) -> int: - """Get the expert tensor parallel rank.""" - return get_pg_rank(self.etp_group) - - @property - def etp_size(self) -> int: - """Get the expert tensor parallel size.""" - return get_pg_size(self.etp_group) - - @property - def is_expert(self) -> bool: - """Check if this mapping is for an expert parameter.""" - return ".mlp.experts.linear_fc" in self.megatron_param - - def _resolve_names(self, captures: Tuple[str, ...]) -> Tuple[str, Union[str, Dict[str, str]]]: - """Resolve wildcard patterns with captured values. - - Handles both ** (any characters) and * (digits) wildcards in order. - ** patterns are processed before * patterns to avoid conflicts. - """ - resolved_megatron_param = self.megatron_param - capture_index = 0 - - # First pass: resolve ** wildcards - while "**" in resolved_megatron_param and capture_index < len(captures): - resolved_megatron_param = resolved_megatron_param.replace( - "**", captures[capture_index], 1 - ) - capture_index += 1 - - # Second pass: resolve * wildcards - while "*" in resolved_megatron_param and capture_index < len(captures): - resolved_megatron_param = resolved_megatron_param.replace( - "*", captures[capture_index], 1 - ) - capture_index += 1 - - if isinstance(self.hf_param, str): - resolved_hf_param = self.hf_param - capture_index = 0 - - # First pass: resolve ** wildcards - while "**" in resolved_hf_param and capture_index < len(captures): - resolved_hf_param = resolved_hf_param.replace("**", captures[capture_index], 1) - capture_index += 1 - - # Second pass: resolve * wildcards - while "*" in resolved_hf_param and capture_index < len(captures): - resolved_hf_param = resolved_hf_param.replace("*", captures[capture_index], 1) - capture_index += 1 - else: - resolved_hf_param = {} - for k, v in self.hf_param.items(): - resolved_v = v - capture_index = 0 - - # First pass: resolve ** wildcards - while "**" in resolved_v and capture_index < len(captures): - resolved_v = resolved_v.replace("**", captures[capture_index], 1) - capture_index += 1 - - # Second pass: resolve * wildcards - while "*" in resolved_v and capture_index < len(captures): - resolved_v = resolved_v.replace("*", captures[capture_index], 1) - capture_index += 1 - - resolved_hf_param[k] = resolved_v - - return resolved_megatron_param, resolved_hf_param - - def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": - """Create a new mapping with resolved wildcards. - - This default implementation works for mappings with a - (megatron_param, hf_param) constructor. - - Args: - captures (Tuple[str, ...]): Captured wildcard values. - - Returns: - MegatronParamMapping: A new mapping instance with resolved names. - """ - resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) - return type(self)(resolved_megatron_param, resolved_hf_param) - - @abstractmethod - def hf_to_megatron(self, hf_weights: WeightType, megatron_module: nn.Module) -> torch.Tensor: - """Convert hf_weights TO Megatron format. - - This method handles: - 1. Format transformation (if needed) - 2. Tensor parallel distribution (if self.tp_size > 1) - - Args: - hf_weights (WeightType): Source hf_weights in external format. - megatron_module (nn.Module): Target Megatron module (for config - access). - - Returns: - torch.Tensor: Weight tensor ready for the current TP rank. - """ - ... - - @abstractmethod - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Convert weights FROM Megatron format. - - This method handles: - 1. Pipeline parallel broadcasting (if weight is on different PP rank) - 2. Tensor parallel gathering (if needed) - 3. Format transformation - - Args: - megatron_weights (Optional[torch.Tensor]): Weight tensor from current - rank (None if on different PP rank). - megatron_module (Optional[nn.Module]): Module for config access - (None if on different PP rank). - - Returns: - Dict[str, torch.Tensor]: Converted weights (empty dict if not on - TP rank 0). - """ - ... - - def broadcast_from_pp_rank( - self, tensor: Optional[torch.Tensor], cache_key: Optional[str] = None - ) -> Optional[torch.Tensor]: - """Broadcast a tensor from the pipeline-parallel rank that owns it. - - Broadcasts to **all** PP ranks. This mirrors the behaviour of - `broadcast_from_megatron_pp` in the original MMapping implementation and - additionally keeps the tensor-parallel metadata (`tensor_model_parallel`, - `partition_dim`) consistent on every rank. - - Args: - tensor (Optional[torch.Tensor]): The local tensor if the current PP - rank owns it. ``None`` otherwise. - - Returns: - Optional[torch.Tensor]: The broadcasted tensor on every PP rank, or - ``None`` if *no* PP rank owned the tensor (which indicates a bug - in the calling code). - """ - - # Fast-path when we are not using pipeline parallelism. - if self.pp_size == 1: - return tensor - - # ------------------------------------------------------------------ - # 1. Gather (shape, dtype, tensor_parallel flag, partition_dim) from - # every PP rank so that we can find the source rank. - # ------------------------------------------------------------------ - if cache_key is not None and cache_key in self._tensor_spec_output_cache: - tensor_spec_output = self._tensor_spec_output_cache[cache_key] - else: - if tensor is not None: - shape = tensor.shape - dtype = tensor.dtype - tensor_parallel = getattr(tensor, "tensor_model_parallel", None) - partition_dim = getattr(tensor, "partition_dim", None) - tensor_spec = (shape, dtype, tensor_parallel, partition_dim) - else: - tensor_spec = None - - tensor_spec_output: list[Optional[tuple]] = [None] * self.pp_size - torch.distributed.all_gather_object( - tensor_spec_output, tensor_spec, group=self.pp_group - ) - self._tensor_spec_output_cache[cache_key] = tensor_spec_output - - # ------------------------------------------------------------------ - # 2. Identify the owning rank (the only rank with a non-None spec). - # ------------------------------------------------------------------ - target_tensor_spec = None - src_rank = None # Rank *inside* the PP group. - for rank, spec in enumerate(tensor_spec_output): - if spec is not None: - if target_tensor_spec is not None: - raise ValueError( - f"Tensor exists on more than one PP rank. Found on ranks {src_rank} and {rank}." - ) - target_tensor_spec = spec - src_rank = rank - - if target_tensor_spec is None: - # No rank had the tensor – this is an error in the caller. - raise ValueError("Object must exist on at least one PP rank") - - # ------------------------------------------------------------------ - # 3. Ensure every rank has an allocated tensor with the right shape - # and dtype before the broadcast. - # ------------------------------------------------------------------ - if tensor is None: - shape, dtype, tensor_parallel, partition_dim = target_tensor_spec - # Use CPU by default, unless CUDA is available - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - tensor = torch.empty(shape, dtype=dtype, device=device) - if tensor_parallel is not None: - tensor.tensor_model_parallel = tensor_parallel - if partition_dim is not None: - tensor.partition_dim = partition_dim - - # ------------------------------------------------------------------ - # 4. Broadcast from the source PP rank to all other PP ranks. - # ------------------------------------------------------------------ - global_src = torch.distributed.get_global_rank(group=self.pp_group, group_rank=src_rank) - torch.distributed.broadcast(tensor, src=global_src, group=self.pp_group) - - return tensor - - def broadcast_obj_from_pp_rank( - self, obj: Optional[Any], cache_key: Optional[str] = None - ) -> Any: - """Broadcast any Python object from the PP rank that owns it. - - This method is useful for broadcasting configuration objects or - other metadata across pipeline parallel ranks. Results are cached - after the first call to avoid redundant broadcasts. - - Args: - obj (Optional[Any]): Object to broadcast (None on non-owning ranks). - cache_key (Optional[str]): Optional cache key. If not provided, - no caching will be performed. - - Returns: - Any: Broadcasted object on all ranks. - - Raises: - ValueError: If object exists on multiple ranks or no ranks. - """ - if self.pp_size == 1: - return obj - - # Check if we already have a cached result (only if cache_key is provided) - if cache_key is not None and cache_key in self._broadcast_obj_cache: - return self._broadcast_obj_cache[cache_key] - - # ------------------------------------------------------------------ - # 1. Gather presence flags from all PP ranks to find the source rank - # ------------------------------------------------------------------ - has_obj = obj is not None - obj_flags = [None] * self.pp_size - torch.distributed.all_gather_object(obj_flags, has_obj, group=self.pp_group) - - # ------------------------------------------------------------------ - # 2. Identify the owning rank (the only rank with True flag) - # ------------------------------------------------------------------ - src_rank = None # Rank *inside* the PP group - for rank, flag in enumerate(obj_flags): - if flag: - src_rank = rank - - if src_rank is None: - raise ValueError("Object must exist on at least one PP rank") - - # ------------------------------------------------------------------ - # 3. Broadcast the object from the source rank to all ranks - # ------------------------------------------------------------------ - if src_rank is None: - raise ValueError("Could not determine source rank") - - # Use broadcast_object_list which is more robust than all_gather_object - obj_list = [obj] - pp_ranks = torch.distributed.get_process_group_ranks(self.pp_group) - global_src = pp_ranks[src_rank] - torch.distributed.broadcast_object_list(obj_list, src=global_src, group=self.pp_group) - - result = obj_list[0] - - # Cache the result for future calls (only if cache_key is provided) - if cache_key is not None: - self._broadcast_obj_cache[cache_key] = result - - return result - - def clear_broadcast_cache(self): - """Clear the broadcast object cache. - - This can be useful for testing or if the objects being broadcast - might change during the lifetime of the mapping. - """ - self._broadcast_obj_cache.clear() - - def clear_tensor_spec_output_cache(self): - """Clear the tensor spec output cache. - - This can be useful for testing or if the tensor spec output - might change during the lifetime of the mapping. - """ - self._tensor_spec_output_cache.clear() - - def broadcast_tensor_to_tp_ranks(self, tensor: torch.Tensor, src_rank: int = 0) -> torch.Tensor: - """Broadcast a tensor to all TP ranks. - - Args: - tensor (torch.Tensor): The tensor to broadcast. - src_rank (int, optional): The source rank within the TP group. - Defaults to 0. - - Returns: - torch.Tensor: The broadcasted tensor. - """ - if self.tp_size == 1: - return tensor - - global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) - torch.distributed.broadcast(tensor, src=global_src, group=self.tp_group) - return tensor - - def scatter_to_tp_ranks( - self, - splits: Optional[List[torch.Tensor]], - output_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - src_rank: int = 0, - ) -> torch.Tensor: - """Scatter tensor splits to TP ranks. - - Args: - splits (Optional[List[torch.Tensor]]): A list of tensor shards to - scatter. Only rank `src_rank` needs this. - output_shape (torch.Size): The shape of the output tensor on each rank. - dtype (torch.dtype): The data type of the output tensor. - device (torch.device): The device for the output tensor. - src_rank (int, optional): The source rank for the scatter operation. - Defaults to 0. - - Returns: - torch.Tensor: The scattered tensor shard on the current rank. - """ - if self.tp_size == 1: - return splits[0].to(device=device) if splits else None - - output = torch.empty(output_shape, dtype=dtype, device=device) - global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) - - scatter_list = None - if self.tp_rank == src_rank and splits: - scatter_list = [s.to(device=device) for s in splits] - - torch.distributed.scatter(output, scatter_list, src=global_src, group=self.tp_group) - return output - - def gather_from_tp_ranks(self, tensor: torch.Tensor) -> List[torch.Tensor]: - """Gather tensors from all TP ranks. - - Args: - tensor (torch.Tensor): The tensor shard to be gathered from the - current rank. - - Returns: - List[torch.Tensor]: A list of tensor shards from all TP ranks. - """ - if self.tp_size == 1: - return [tensor] - - gathered = [torch.empty_like(tensor) for _ in range(self.tp_size)] - torch.distributed.all_gather(gathered, tensor, group=self.tp_group) - return gathered - - def _count_wildcard_groups(self, pattern: str) -> int: - """Count the number of wildcard capture groups in a pattern. - - Args: - pattern: Pattern string with * and ** wildcards - - Returns: - Number of capture groups that will be generated - - Note: - ** counts as 1 group, * counts as 1 group - ** must be counted before * to avoid double-counting - """ - count = 0 - remaining = pattern - - # Count ** patterns first - while "**" in remaining: - count += 1 - remaining = remaining.replace("**", "", 1) - - # Count remaining * patterns - count += remaining.count("*") - - return count - - def _validate_patterns(self): - """Validate wildcard consistency between patterns.""" - megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param) - if isinstance(self.hf_param, str): - hf_param_wildcards = self._count_wildcard_groups(self.hf_param) - if megatron_param_wildcards != hf_param_wildcards: - raise ValueError( - f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has " - f"{megatron_param_wildcards} wildcards, hf_param='{self.hf_param}' has {hf_param_wildcards}" - ) - else: - for key, pattern in self.hf_param.items(): - hf_param_wildcards = self._count_wildcard_groups(pattern) - if megatron_param_wildcards != hf_param_wildcards: - raise ValueError( - f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has " - f"{megatron_param_wildcards} wildcards, hf_param['{key}']='{pattern}' has {hf_param_wildcards}" - ) - - def _normalize_expert_param_name(self, param_name: str) -> str: - """Normalize expert parameter name by replacing trailing numbers with 0. - e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0 - - Args: - param_name (str): Parameter name that may end with a number. - - Returns: - str: Parameter name with trailing number replaced by 0. - """ - # Use regex to replace any trailing number with 0 - return re.sub(r"\d+$", "0", param_name) - - def _get_config(self, module: nn.Module) -> Any: - """Extract configuration from module hierarchy.""" - current = module - while current is not None: - if hasattr(current, "config"): - return current.config - # Try parent module - if hasattr(current, "_parent"): - current = current._parent - else: - # Walk up the module tree - for parent_module in module.modules(): - for child_name, child_module in parent_module.named_children(): - if child_module is current: - current = parent_module - break - else: - continue - break - else: - current = None - - raise ValueError( - f"Could not find config in module hierarchy for {module.__class__.__name__}. " - f"Ensure the module or its parent has a 'config' attribute." - ) - - def gather_from_ep_ranks( - self, - megatron_weights: Optional[torch.Tensor], - megatron_module: Optional[MegatronModule], - hf_param_name: Optional[str], - ) -> Dict[str, torch.Tensor]: - """Handle expert parallel weight gathering for MoE models. - - This method gathers expert weights across expert-parallel (EP) ranks and - returns a mapping from HF parameter names to the corresponding tensors - from each EP rank. Call this only for confirmed expert parameters - (self.is_expert is True), typically after TP gathering/concatenation in - the export path (Megatron → HF). - - Behavior and notation: - - Let E be the total number of experts (e.g., config.num_moe_experts) and - S be the expert-parallel size (ep_size). We assume E % S == 0. - - Each EP rank owns E/S experts. For a given parameter name, we infer a - local expert index L (0 ≤ L < E/S) on the current EP rank from the - global expert id embedded in the name (works for both .weight and .bias). - - The set of global expert ids that correspond to this local index L - across all EP ranks is: {L + k * (E/S) | k ∈ [0, S-1]}. - - Communication and outputs: - - We perform an all_gather over the EP group to collect the tensor from - every EP rank into a list ordered by EP rank id. - - For each EP rank k, we construct the HF parameter name by replacing the - expert id in `hf_param_name` with (L + k * (E/S)), preserving the rest - of the path, and map that name to the gathered tensor from rank k. - - Example: - - E = 8, S = 2 → E/S = 4. Experts are distributed as: - Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7]. - If the local index L = 0 (derived from the param name), this returns: - {"...experts.0.weight": tensor_from_rank0, "...experts.4.weight": tensor_from_rank1} - - Args: - megatron_weights (Optional[torch.Tensor]): The local expert weight tensor - (after any TP handling) on this EP rank. - megatron_module (Optional[MegatronModule]): The Megatron module containing - configuration (used to determine E and E/S). Can be None on non-owning PP - ranks; values will be broadcast across PP. - hf_param_name (Optional[str]): HF parameter name template for the current - (local) expert on this rank. The expert id within this string is replaced - with the appropriate global expert ids for each EP rank. - - Returns: - Dict[str, torch.Tensor]: Mapping from HF parameter names (one per EP rank) - to the corresponding expert tensors gathered from each EP rank. - """ - if megatron_module is None: - num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") - else: - model_config = self._get_config(megatron_module) - num_experts = model_config.num_moe_experts - num_experts_per_rank = num_experts // self.ep_size - num_experts_per_rank = self.broadcast_obj_from_pp_rank( - num_experts_per_rank, "num_experts_per_rank" - ) - - # Extract local expert number from parameter name - # Handle both .weight and .bias suffixes - local_expert_number = None - for key in (".weight", ".bias"): - if key in self.megatron_param: - global_expert_number = int(self.megatron_param.split(key)[-1]) - local_expert_number = global_expert_number % num_experts_per_rank - - # Compute global expert numbers for all EP ranks - # use regex to replace the local expert number with the global expert number - gathered_expert_param_names = [ - re.sub( - r"experts\.(\d+)", - f"experts.{int(local_expert_number) + num_experts_per_rank * i}", - str(hf_param_name), - ) - for i in range(self.ep_size) - ] - assert ( - hf_param_name in gathered_expert_param_names - ), f"hf_param_name {hf_param_name} not in gathered_expert_param_names {gathered_expert_param_names}" - - # Gather weights from all EP ranks - gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] - torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) - - # Return dictionary mapping HF parameter names to weights - return { - param_name: gathered_weights[i] - for i, param_name in enumerate(gathered_expert_param_names) - } - - def maybe_dequantize(self, tensor: torch.Tensor) -> torch.Tensor: - """Dequantize FP8 tensor if needed.""" - if HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS): - return tensor.dequantize(dtype=tensor.dtype) - return tensor - - -class DirectMapping(MegatronParamMapping[torch.Tensor]): - """Direct 1:1 weight mapping with no transformation or tensor parallelism.""" - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: - """Direct copy - no transformation or distribution.""" - return hf_weights - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Direct copy with PP broadcast.""" - # Handle cross-PP broadcast - megatron_weights = self.broadcast_from_pp_rank( - megatron_weights, cache_key=str(self.hf_param) - ) - - if megatron_weights is None: - return {} - - # Dequantize if needed - megatron_weights = self.maybe_dequantize(megatron_weights) - - return {str(self.hf_param): megatron_weights} - - -class ColumnParallelMapping(MegatronParamMapping[torch.Tensor]): +class ColumnParallelMapping(OriginalColumnParallelMapping): """ Mapping for column-parallel linear and embedding weights. - Column-parallel layers in Megatron split the output dimension across tensor - parallel ranks. This is used for layers where each rank computes a portion - of the output features independently, such as: - - Embedding layers (split vocabulary) - - Linear layers producing hidden states (e.g., QKV projections, MLP up projections) - - The weight matrix is partitioned along dimension 0 (rows), so each TP rank - holds a subset of output features while maintaining all input features. - - **Sharding pattern** - - Original weight: `[output_features, input_features]` - - Rank 0: `[output_features/tp_size, input_features]` - - Rank 1: `[output_features/tp_size, input_features]` - - ... - - **Forward path (HuggingFace → Megatron)** - 1. Validate divisibility: output dimension must be divisible by tp_size - 2. Split: Chunk tensor along dim 0 into tp_size equal parts - 3. Scatter: Distribute chunks to respective TP ranks - - **Reverse path (Megatron → HuggingFace)** - 1. Broadcast: Ensure all PP ranks have the tensor - 2. Gather: Collect chunks from all TP ranks - 3. Concatenate: Reassemble along dim 0 on rank 0 - - Example: - .. code-block:: python - - # For a weight of shape [4096, 1024] with tp_size=4: - # Each rank gets [1024, 1024] after column-parallel split - mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight") - megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module) - # megatron_weights.shape = [1024, 1024] on each rank - - Note: - This mapping also handles bias terms, which are 1D tensors split - along their only dimension following the same pattern. """ - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: """Split weight along dim 0 and distribute to TP ranks.""" - # if self.tp_size == 1: - # return hf_weights - # Some parameters are named with global expert number, e.g. experts.weight15, # normalize it to experts.weight0, note we are only use the shape, dtype, device info, # not the actual value, so it is safe to do this. @@ -798,10 +63,6 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - if hf_weights is None: raise ValueError("hf_weights should not be None on rank 0") - # For MCore MambaMixer, A_log is initialized in FP32 but cast to BF16 when - # saving ckpts, including the ckpt uploaded to HF. Without this cast, - # self.scatter_to_tp_ranks will try to scatter the HF A_log weights in BF16 to - # the Megatron tensor which is in FP32. This will error. So we cast before the scatter. if hf_weights.dtype != target_param.dtype: logger.warning( f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. " @@ -810,14 +71,6 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - ) hf_weights = hf_weights.to(target_param.dtype) - # For bias (1D), we still split along dim 0 - # For weight (2D), we split along dim 0 (output dimension) - # full_size = hf_weights.shape[0] - # if full_size % self.tp_size != 0: - # raise ValueError( - # f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" - # ) - # splits = torch.chunk(hf_weights, self.tp_size, dim=0) full_weight = col_padding_size(hf_weights, target_param, self.tp_size) full_size = full_weight.shape[0] if full_size % self.tp_size != 0: @@ -833,274 +86,7 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - splits, target_param.shape, target_param.dtype, target_param.device ) - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Gather from all TP ranks and concatenate.""" - # Handle cross-PP broadcast - megatron_weights = self.broadcast_from_pp_rank( - megatron_weights, cache_key=str(self.hf_param) - ) - - if megatron_weights is None: - return {} - - # Dequantize if needed - megatron_weights = self.maybe_dequantize(megatron_weights) - - if self.tp_size == 1: - full_weights = megatron_weights - else: - # Gather from all TP ranks - gathered = self.gather_from_tp_ranks(megatron_weights) - full_weights = torch.cat(gathered, dim=0) - - if self.is_expert: - return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) - - return {str(self.hf_param): full_weights} - - -class RowParallelMapping(MegatronParamMapping[torch.Tensor]): - """Mapping for **row-parallel** linear weights. - - Megatron shards row-parallel tensors along **dimension 1** (the *input* - dimension of a linear layer). - - **Forward path (external → Megatron)** - 1. Rank 0 validates that the *second* dimension is divisible by `tp_size`. - 2. Rank 0 splits the tensor with `torch.chunk(..., dim=1)` producing - `tp_size` equally-sized shards. - 3. The shards are **scattered** so that every TP rank receives exactly one - shard matching the shape of its local Megatron parameter. - - **Reverse path (Megatron → external)** - 1. The local Megatron parameter (which may live on any PP rank) is - broadcast to all PP ranks so that the gather step can be collective. - 2. All TP ranks **gather** their shard. - 3. Rank 0 concatenates the gathered list along dim 1 to reconstruct the - original unsharded weight and emits it under the external (HF) name. - """ - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: - """Split weight along dim 1 and distribute to TP ranks.""" - if self.tp_size == 1: - return hf_weights - - # Some parameters are named with global expert number, e.g. experts.weight15, - # normalize it to experts.weight0, note we are only use the shape, dtype, device info, - # not the actual value, so it is safe to do this. - normalized_param = self._normalize_expert_param_name(self.megatron_param) - _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) - - # On rank 0, check for divisibility and split - if self.tp_rank == 0: - if hf_weights is None: - raise ValueError("hf_weights should not be None on rank 0") - - # For bias (1D), we still split along dim 0 - # For weight (2D), we split along dim 1 - if hf_weights.ndim == 1: - full_size = hf_weights.shape[0] - if full_size % self.tp_size != 0: - raise ValueError( - f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" - ) - splits = torch.chunk(hf_weights, self.tp_size, dim=0) - else: - assert hf_weights.ndim == 2 - full_size = hf_weights.shape[1] - if full_size % self.tp_size != 0: - raise ValueError( - f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" - ) - splits = torch.chunk(hf_weights, self.tp_size, dim=1) - - else: - splits = None - - # Scatter to all ranks. Each rank gets its sharded shape from its module. - return self.scatter_to_tp_ranks( - splits, target_param.shape, target_param.dtype, target_param.device - ) - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Gather from all TP ranks and concatenate.""" - # Handle cross-PP broadcast - megatron_weights = self.broadcast_from_pp_rank( - megatron_weights, cache_key=str(self.hf_param) - ) - - if megatron_weights is None: - return {} - - # Dequantize if needed - megatron_weights = self.maybe_dequantize(megatron_weights) - - if self.tp_size == 1: - full_weights = megatron_weights - else: - gathered = self.gather_from_tp_ranks(megatron_weights) - full_weights = torch.cat(gathered, dim=1) - - if self.is_expert: - return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) - - return {str(self.hf_param): full_weights} - - -class ReplicatedMapping(MegatronParamMapping[torch.Tensor]): - """Mapping for weights that are **fully replicated** across TP ranks. - - Examples: layer-norm scales, biases, router weights in MoE, etc. - - These tensors exist in exactly the same form on *every* TP rank, so the - mapping logic is trivial – but we still need to broadcast across TP ranks - during *load* (HF → Megatron) and ensure we do **not** emit duplicates - during *export* (Megatron → HF). - """ - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: - """Replicate weight to all TP ranks.""" - try: - target_device = megatron_module.weight.device - except AttributeError: - # the parameter may not be called "weight" - target_device = next(megatron_module.parameters()).device - hf_weights = hf_weights.to(device=target_device) - if self.tp_size == 1: - return hf_weights - - # TODO(yuya): router.weight is on device cpu, need to check. - if target_device.index != torch.cuda.current_device(): - hf_weights = hf_weights.to(torch.cuda.current_device()) - - # All ranks need the full weight - if self.tp_rank > 0: - # Create empty tensor of correct shape - hf_weights = torch.empty_like(hf_weights) - - # Broadcast from rank 0 to all TP ranks - return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Return weight only from rank 0 to avoid duplication.""" - # Handle cross-PP broadcast - megatron_weights = self.broadcast_from_pp_rank( - megatron_weights, cache_key=str(self.hf_param) - ) - - if megatron_weights is None: - return {} - - # Dequantize if needed - megatron_weights = self.maybe_dequantize(megatron_weights) - - if self.is_expert: - return self.gather_from_ep_ranks(megatron_weights, megatron_module, self.hf_param) - - return {str(self.hf_param): megatron_weights} - - -class AutoMapping(MegatronParamMapping[torch.Tensor]): - """ - Smart mapping that automatically detects and applies the correct parallelism strategy. - - This mapping eliminates the need to manually specify whether a layer is - column-parallel, row-parallel, or replicated. It examines the Megatron - module at runtime and delegates to the appropriate specialized mapping. - - **Detection strategy** - 1. Check module class name against a registry of known types - 2. If unknown, examine module attributes (tensor_model_parallel, partition_dim) - 3. Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated - - This abstraction is particularly useful for model-agnostic code where you - don't know the parallelism type ahead of time, or when working with models - that mix different parallelism strategies. - - **Built-in module recognition** - - Column-parallel: `ColumnParallelLinear`, `VocabParallelEmbedding`, etc. - - Row-parallel: `RowParallelLinear`, `TERowParallelLinear` - - Replicated: `LayerNorm`, `RMSNorm`, and other normalization layers - - Example: - .. code-block:: python - - # Automatically handles any weight type - mapping = AutoMapping( - megatron_param="decoder.layers.*.mlp.linear_fc1.weight", - hf_param="model.layers.*.mlp.gate_proj.weight" - ) - - # Works with column-parallel layers - megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module) - - # Also works with normalization layers - norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module) - - # Register custom module types - AutoMapping.register_module_type("MyCustomLinear", "column") - - Note: - If the parallelism type cannot be determined, the mapping will raise - a descriptive error suggesting how to fix the issue. - """ - - # Module type registry - _MODULE_TYPE_REGISTRY: Dict[str, set] = { - "column": { - "ColumnParallelLinear", - "TEColumnParallelLinear", - "TELayerNormColumnParallelLinear", - "TEColumnParallelGroupedLinear", - "VocabParallelEmbedding", - }, - "row": {"RowParallelLinear", "TERowParallelLinear", "TERowParallelGroupedLinear"}, - "replicated": { - # Normalization layers - "TENorm", - "FusedLayerNorm", - "WrappedTorchNorm", - "LayerNorm", - "RMSNorm", - "L2Norm", - # Other non-parallel modules - "IdentityOp", - "DotProductAttention", - "TEDotProductAttention", - "TopKRouter", - }, - } - - @classmethod - def register_module_type(cls, module_name: str, parallelism_type: str): - """Register a new module type for automatic parallelism detection. - - Args: - module_name (str): The name of the module class (e.g., - 'MyColumnLinear'). - parallelism_type (str): One of 'column', 'row', or 'replicated'. - """ - if parallelism_type not in cls._MODULE_TYPE_REGISTRY: - raise ValueError( - f"Invalid parallelism_type '{parallelism_type}'. " - f"Must be one of {list(cls._MODULE_TYPE_REGISTRY.keys())}" - ) - cls._MODULE_TYPE_REGISTRY[parallelism_type].add(module_name) - - def __init__(self, megatron_param: str, hf_param: str): - """Initialize TP-aware mapping.""" - super().__init__(megatron_param, hf_param) - - # Cache for detected parallelism type and delegate mapping - self._detected_type: Optional[str] = None - self._mapping: Optional[MegatronParamMapping[torch.Tensor]] = None - +class AutoMapping(OriginalAutoMapping): def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[torch.Tensor]: """Get or create the appropriate mapping for the given type.""" if parallelism_type == "column": @@ -1112,674 +98,8 @@ def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[ else: raise ValueError(f"Unknown parallelism type: {parallelism_type}") - def _detect_parallelism_type(self, module: nn.Module) -> str: - """Detect parallelism type from module.""" - module_type = type(module).__name__ - - # Handle fused modules like TELayerNormColumnParallelLinear - # These modules have both column-parallel weights (weight, bias) - # and replicated layer norm weights (layer_norm_weight, layer_norm_bias) - if module_type == "TELayerNormColumnParallelLinear": - # Check the actual parameter name to determine the correct parallelism type - if self.megatron_param and ( - self.megatron_param.endswith("layer_norm_weight") - or self.megatron_param.endswith("layer_norm_bias") - ): - return "replicated" - # All other parameters (weight, bias) are column-parallel - return "column" - - # Check registry first - for parallelism, types in self._MODULE_TYPE_REGISTRY.items(): - if module_type in types: - return parallelism - - # Fallback to inspecting module attributes - if hasattr(module, "tensor_model_parallel"): - if not module.tensor_model_parallel: - return "replicated" - - # Check partition dimension - partition_dim = getattr(module, "partition_dim", None) - if partition_dim == 0: - return "column" - elif partition_dim == 1: - return "row" - - # Fallback for normalization layers - if any(norm in module_type for norm in ["Norm", "Normalization"]): - return "replicated" - - # Check parallel_mode for TELinear - if module_type == "TELinear": - if module.parallel_mode == "column": - return "column" - elif module.parallel_mode == "row": - return "row" - else: - return "replicated" - - # Cannot determine - raise informative error - known_types = {p: sorted(list(t)) for p, t in self._MODULE_TYPE_REGISTRY.items()} - - raise ValueError( - f"Cannot determine parallelism type for module '{module_type}' " - f"at weight '{self.megatron_param}'.\n" - f"Please use an explicit mapping type (e.g., ColumnParallelMapping) " - f"or register the module type using:\n" - f" AutoMapping.register_module_type('{module_type}', 'column|row|replicated')\n\n" - f"Currently known module types:\n{json.dumps(known_types, indent=2)}" - ) - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: - """Delegate to appropriate mapping based on module type.""" - # Detect type and create delegate on first use - if self._mapping is None: - self._detected_type = self._detect_parallelism_type(megatron_module) - self._mapping = self._get_or_create_mapping(self._detected_type) - - return self._mapping.hf_to_megatron(hf_weights, megatron_module) - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Delegate to appropriate mapping based on module type.""" - # Need to determine type even if module is None (different PP rank) - assert self.megatron_param is not None, "`megatron_param` is required for AutoMapping." - - if self._mapping is None: - if megatron_module is not None: - self._detected_type = self._detect_parallelism_type(megatron_module) - # Broadcast to other ranks - self._detected_type = self.broadcast_obj_from_pp_rank( - self._detected_type, "detected_type" - ) - else: - # Receive from owning rank - self._detected_type = self.broadcast_obj_from_pp_rank(None, "detected_type") - self._mapping = self._get_or_create_mapping(self._detected_type) - - return self._mapping.megatron_to_hf(megatron_weights, megatron_module) - - -class QKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): - """ - Mapping for interleaved Query/Key/Value attention projection weights. - - This mapping handles the conversion between separate Q, K, V matrices used in - standard transformers and Megatron's optimized interleaved format. The - interleaving pattern groups queries with their corresponding key-value pairs - to maximize GEMM efficiency during attention computation. - - **External format (HuggingFace)** - - Separate tensors: `q_proj`, `k_proj`, `v_proj` - - Each of shape `[hidden_size, hidden_size]` or `[hidden_size, head_dim * num_heads]` - - **Megatron format** - - Single interleaved tensor following grouped query attention (GQA) pattern - - Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]` - - Where `n = num_attention_heads / num_query_groups` - - **Key features** - 1. Format conversion: Handles merging/splitting with proper interleaving - 2. Grouped Query Attention: Supports different numbers of Q and KV heads - 3. Tensor parallelism: Delegates to AutoMapping for distribution - - Example: - .. code-block:: python - - # Create mapping for attention weights - mapping = QKVMapping( - megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", - q="model.layers.*.self_attn.q_proj.weight", - k="model.layers.*.self_attn.k_proj.weight", - v="model.layers.*.self_attn.v_proj.weight" - ) - - # Convert from HuggingFace to Megatron - qkv_weights = {"q": q_tensor, "k": k_tensor, "v": v_tensor} - megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) - - # Convert from Megatron to HuggingFace - hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) - # Returns: {"q_proj.weight": ..., "k_proj.weight": ..., "v_proj.weight": ...} - - Note: - This mapping automatically handles both regular multi-head attention - (same number of Q, K, V heads) and grouped query attention (fewer - KV heads than Q heads) based on the model configuration. - """ - +class QKVMapping(OriginalQKVMapping): def __init__(self, megatron_param: str, q: str, k: str, v: str): - """Initialize QKV mapping. - - Args: - megatron_param (str): Megatron QKV parameter name pattern. - q (str): Query weight name pattern. - k (str): Key weight name pattern. - v (str): Value weight name pattern. - """ - super().__init__(megatron_param, {"q": q, "k": k, "v": v}) - # Delegate all tensor-parallel logic to the smart TP-aware mapping so we - # do not hard-code the assumption that QKV projections are column-parallel. - # This keeps the format-handling (merge/split) concerns separate from - # TP/PP distribution mechanics. - self._tp_mapping = AutoMapping(megatron_param, megatron_param) - - def hf_to_megatron( - self, hf_weights: Dict[str, torch.Tensor], megatron_module: nn.Module - ) -> torch.Tensor: - """Merge Q, K, V into interleaved format and distribute.""" - if self.tp_rank == 0: - config = self._get_config(megatron_module) - - # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors) - if hf_weights["q"].ndim == 1: - # For biases, use the bias-specific merge function - merged = merge_qkv_biases(config, hf_weights["q"], hf_weights["k"], hf_weights["v"]) - else: - # For hf_weights, use the standard merge function - merged = merge_qkv_weights( - config, hf_weights["q"], hf_weights["k"], hf_weights["v"] - ) - else: - merged = None - - # Delegate the actual sharding/broadcasting to the TP-aware mapping. - return self._tp_mapping.hf_to_megatron(merged, megatron_module) - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Gather QKV shards and split into Q, K, V.""" - # Dequantize if needed - if megatron_weights is not None: - megatron_weights = self.maybe_dequantize(megatron_weights) - - # ------------------------------------------------------------------ - # Broadcast / retrieve the transformer configuration so that every PP - # rank (also the ones that will early-return) participates in the - # collective communication. - # ------------------------------------------------------------------ - if megatron_module is None: - config = self.broadcast_obj_from_pp_rank(None, "qkv_config") - else: - config = self._get_config(megatron_module) - # create shallow copy and remove non-picklable objects with max depth=2 - config = remove_non_pickleables(config, max_depth=2) - config = self.broadcast_obj_from_pp_rank(config, "qkv_config") - - # Delegate TP/PP gathering. - packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) - - if not packed_dict: - return {} - - packed_qkv = next(iter(packed_dict.values())) - - # Check if we're dealing with biases (1D) or weights (2D) - if packed_qkv.ndim == 1: - # Split biases - q, k, v = split_qkv_biases(config, packed_qkv) - else: - # Split weights - q, k, v = split_qkv_weights(config, packed_qkv) - - return {self.hf_param["q"]: q, self.hf_param["k"]: k, self.hf_param["v"]: v} - - def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": - """Return a new *resolved* QKVMapping instance.""" - resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) - - return type(self)( - resolved_megatron_param, - resolved_hf_param["q"], - resolved_hf_param["k"], - resolved_hf_param["v"], - ) - - -class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): - """ - Mapping for interleaved Query/Key/Value attention projection weights. - - This mapping handles the conversion between Concatenated Q, K, V matrices used in - some transformers models and Megatron's optimized interleaved format. The - interleaving pattern groups queries with their corresponding key-value pairs - to maximize GEMM efficiency during attention computation. - - **External format (HuggingFace)** - - One tensor with concatenated query, key, value: `qkv`, with shape - `[hidden_size, head_dim * num_heads + 2 * head_dim * num_query_groups]` - - **Megatron format** - - Single interleaved tensor following grouped query attention (GQA) pattern - - Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]` - - Where `n = num_attention_heads / num_query_groups` - - **Key features** - 1. Format conversion: Handles merging/splitting with proper interleaving - 2. Grouped Query Attention: Supports different numbers of Q and KV heads - 3. Tensor parallelism: Delegates to AutoMapping for distribution - - Example: - .. code-block:: python - - # Create mapping for attention weights - mapping = QKVMapping( - megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", - qkv="model.layers.*.self_attn.qkv.weight", - ) - - # Convert from HuggingFace to Megatron - megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) - - # Convert from Megatron to HuggingFace - hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) - - Note: - This mapping automatically handles both regular multi-head attention - (same number of Q, K, V heads) and grouped query attention (fewer - KV heads than Q heads) based on the model configuration. - """ - - def __init__(self, megatron_param: str, hf_param: str): - """Initialize QKV mapping. - - Args: - megatron_param (str): Megatron interleaved QKV parameter name pattern. - hf_param (str): HF concatenated QKV parameter name pattern. - """ - super().__init__(megatron_param, hf_param) - # Delegate all tensor-parallel logic to the smart TP-aware mapping so we - # do not hard-code the assumption that QKV projections are column-parallel. - # This keeps the format-handling (merge/split) concerns separate from - # TP/PP distribution mechanics. + super().__init__(megatron_param, q, k, v) self._tp_mapping = AutoMapping(megatron_param, megatron_param) - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: - """Merge Q, K, V into interleaved format and distribute.""" - if self.tp_rank == 0: - config = self._get_config(megatron_module) - head_num = config.num_attention_heads - head_size = config.kv_channels - num_query_groups = config.num_query_groups - q, k, v = hf_weights.split( - [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], - dim=0, - ) - # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors) - if q.ndim == 1: - # For biases, use the bias-specific merge function - merged = merge_qkv_biases(config, q, k, v) - else: - # For hf_weights, use the standard merge function - merged = merge_qkv_weights(config, q, k, v) - else: - merged = None - - # Delegate the actual sharding/broadcasting to the TP-aware mapping. - return self._tp_mapping.hf_to_megatron(merged, megatron_module) - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Gather QKV shards and split into Q, K, V.""" - # Dequantize if needed - if megatron_weights is not None: - megatron_weights = self.maybe_dequantize(megatron_weights) - - # ------------------------------------------------------------------ - # Broadcast / retrieve the transformer configuration so that every PP - # rank (also the ones that will early-return) participates in the - # collective communication. - # ------------------------------------------------------------------ - if megatron_module is None: - config = self.broadcast_obj_from_pp_rank(None, "qkv_config") - else: - config = self._get_config(megatron_module) - # create shallow copy and remove non-picklable objects with max depth=2 - config = remove_non_pickleables(config, max_depth=2) - config = self.broadcast_obj_from_pp_rank(config, "qkv_config") - - # Delegate TP/PP gathering. - packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) - - if not packed_dict: - return {} - - packed_qkv = next(iter(packed_dict.values())) - - # Check if we're dealing with biases (1D) or weights (2D) - if packed_qkv.ndim == 1: - # Split biases - q, k, v = split_qkv_biases(config, packed_qkv) - else: - # Split weights - q, k, v = split_qkv_weights(config, packed_qkv) - - return {str(self.hf_param): torch.cat((q, k, v), dim=0)} - - def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": - """Return a new *resolved* QKVMapping instance.""" - resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) - - return type(self)(resolved_megatron_param, resolved_hf_param) - - -class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): - r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU). - - Checkpoint formats expose two independent matrices: - - - **G** – gate projection - - **U** – up projection - - Megatron concatenates them row-wise (`[G; U]`) so that a single GEMM can - produce both activations. - - **Responsibilities handled by this mapping** - 1. **Concatenate / split** – convert between `[G; U]` (Megatron) and the - separate `{G, U}` matrices (external). - 2. **Tensor-parallel distribution** – correctly splits gate and up - projections separately before concatenating corresponding shards, - ensuring each TP rank gets the proper [gate_shard; up_shard] format. - - **TP Distribution Strategy** - For tensor parallelism, this mapping: - - Splits gate and up matrices separately along output dimension (dim 0) - - Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i - - This ensures each rank's concatenated tensor matches the expected shape - """ - - def __init__(self, megatron_param: str, gate: str, up: str): - """Initialize gated MLP mapping. - - Args: - megatron_param (str): Megatron MLP parameter name pattern. - gate (str): Gate projection weight name pattern. - up (str): Up projection weight name pattern. - """ - super().__init__(megatron_param, {"gate": gate, "up": up}) - - def hf_to_megatron( - self, hf_weights: Dict[str, torch.Tensor], megatron_module: nn.Module - ) -> torch.Tensor: - """Split gate and up separately, then concatenate corresponding shards.""" - # For single TP, just concatenate and return - if self.tp_size == 1: - return torch.cat([hf_weights["gate"], hf_weights["up"]], dim=0) - - # Get target parameter info from megatron module - # Some parameters are named with global expert number, e.g. experts.weight15, - # normalize it to experts.weight0, note we are only use the shape, dtype, device info, - # not the actual value, so it is safe to do this. - normalized_param = self._normalize_expert_param_name(self.megatron_param) - _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) - - # On rank 0, split gate and up separately, then concatenate corresponding pieces - if self.tp_rank == 0: - gate = hf_weights["gate"] - up = hf_weights["up"] - - # Verify shapes match - assert gate.shape == up.shape, "Gate and up weights must have the same shape" - - # Check divisibility for TP splitting - gate_output_size = gate.shape[0] - if gate_output_size % self.tp_size != 0: - raise ValueError( - f"Cannot evenly split gate dimension 0 size {gate_output_size} across {self.tp_size} TP ranks" - ) - - # Split gate and up separately along output dimension (dim 0) - # This works for both bias (1D) and weight (2D) tensors - gate_splits = torch.chunk(gate, self.tp_size, dim=0) - up_splits = torch.chunk(up, self.tp_size, dim=0) - - # Concatenate corresponding pieces: [gate_shard_i; up_shard_i] for each rank i - splits = [torch.cat([gate_splits[i], up_splits[i]], dim=0) for i in range(self.tp_size)] - else: - splits = None - - # Scatter the concatenated shards to each rank - return self.scatter_to_tp_ranks( - splits, target_param.shape, target_param.dtype, target_param.device - ) - - def megatron_to_hf( - self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] - ) -> Dict[str, torch.Tensor]: - """Gather concatenated shards and split into gate and up.""" - # Handle cross-PP broadcast first - megatron_weights = self.broadcast_from_pp_rank( - megatron_weights, cache_key=str(self.hf_param) - ) - - if megatron_weights is None: - return {} - - # Dequantize if needed - megatron_weights = self.maybe_dequantize(megatron_weights) - - # Handle TP gathering - if self.tp_size == 1: - # No TP, just split the concatenated tensor - fused_mlp = megatron_weights - gate, up = torch.chunk(fused_mlp, 2, dim=0) - - else: - # Gather shards from all TP ranks - gathered_shards = self.gather_from_tp_ranks(megatron_weights) - - # Split each shard back into gate and up parts - gate_parts = [] - up_parts = [] - for shard in gathered_shards: - # Each shard is [gate_shard; up_shard] concatenated along dim 0 - # This works for both bias (1D) and weight (2D) tensors - gate_shard, up_shard = torch.chunk(shard, 2, dim=0) - gate_parts.append(gate_shard) - up_parts.append(up_shard) - - # Concatenate all gate parts and all up parts separately - gate = torch.cat(gate_parts, dim=0) - up = torch.cat(up_parts, dim=0) - - if self.is_expert: - gathered_gate_weights_dict = self.gather_from_ep_ranks( - gate, megatron_module, self.hf_param["gate"] - ) - gathered_up_weights_dict = self.gather_from_ep_ranks( - up, megatron_module, self.hf_param["up"] - ) - return {**gathered_gate_weights_dict, **gathered_up_weights_dict} - - return {self.hf_param["gate"]: gate, self.hf_param["up"]: up} - - def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": - """Return a new *resolved* GatedMLPMapping instance.""" - resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) - - return type(self)( - resolved_megatron_param, resolved_hf_param["gate"], resolved_hf_param["up"] - ) - - -def merge_qkv_biases( - config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor -) -> torch.Tensor: - """Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format. - - Args: - config (TransformerConfig): Transformer configuration. - q (torch.Tensor): Query projection biases [hidden_size]. - k (torch.Tensor): Key projection biases [kv_hidden_size]. - v (torch.Tensor): Value projection biases [kv_hidden_size]. - - Returns: - torch.Tensor: Interleaved QKV biases in Megatron format as 1D tensor. - """ - head_num = config.num_attention_heads - num_query_groups = config.num_query_groups - heads_per_group = head_num // num_query_groups - head_size = config.kv_channels or (config.hidden_size // head_num) - - # Reshape biases to expose head dimension - q = q.view(head_num, head_size) - k = k.view(num_query_groups, head_size) - v = v.view(num_query_groups, head_size) - - # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...] - qkv_biases = [] - for i in range(num_query_groups): - qkv_biases.append(q[i * heads_per_group : (i + 1) * heads_per_group, :]) - qkv_biases.append(k[i : i + 1, :]) - qkv_biases.append(v[i : i + 1, :]) - - # Concatenate and flatten back to 1D - qkv = torch.cat(qkv_biases) - return qkv.flatten() - - -def split_qkv_biases( - config: TransformerConfig, qkv: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Split Megatron's interleaved QKV bias into separate Q, K, V biases. - - Args: - config (TransformerConfig): Transformer configuration. - qkv (torch.Tensor): Interleaved QKV biases in Megatron format (1D - tensor). - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of (Q, K, V) bias vectors. - """ - head_num = config.num_attention_heads - num_query_groups = config.num_query_groups - heads_per_group = head_num // num_query_groups - head_size = config.kv_channels or (config.hidden_size // head_num) - qkv_total_dim = head_num + 2 * num_query_groups - - # Reshape to expose interleaved structure - qkv = qkv.reshape(qkv_total_dim, head_size) - - # Extract Q, K, V from interleaved pattern - q_slice = torch.cat( - [ - torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) - for i in range(num_query_groups) - ] - ) - k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2) - v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2) - - q = qkv[q_slice].flatten() - k = qkv[k_slice].flatten() - v = qkv[v_slice].flatten() - - return q, k, v - - -def merge_qkv_weights( - provider: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor -) -> torch.Tensor: - """Merge separate Q, K, V weight matrices into Megatron's interleaved QKV format. - - Args: - provider (TransformerConfig): Model configuration provider. - q (torch.Tensor): Query projection weights [hidden_size, hidden_size] or - bias [hidden_size]. - k (torch.Tensor): Key projection weights [kv_hidden_size, hidden_size] - or bias [kv_hidden_size]. - v (torch.Tensor): Value projection weights [kv_hidden_size, - hidden_size] or bias [kv_hidden_size]. - - Returns: - torch.Tensor: Interleaved QKV weights in Megatron format. - """ - head_num = provider.num_attention_heads - num_query_groups = provider.num_query_groups - heads_per_group = head_num // num_query_groups - head_size = provider.kv_channels or (provider.hidden_size // head_num) - hidden_size = provider.hidden_size - is_bias = q.ndim == 1 - - # Reshape to expose head dimension - if is_bias: - q_reshaped = q.view(head_num, head_size) - k_reshaped = k.view(num_query_groups, head_size) - v_reshaped = v.view(num_query_groups, head_size) - else: - q_reshaped = q.view(head_num, head_size, hidden_size) - k_reshaped = k.view(num_query_groups, head_size, hidden_size) - v_reshaped = v.view(num_query_groups, head_size, hidden_size) - - # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...] - qkv_weights = [] - for i in range(num_query_groups): - q_group = q_reshaped[i * heads_per_group : (i + 1) * heads_per_group] - k_group = k_reshaped[i : i + 1] - v_group = v_reshaped[i : i + 1] - qkv_weights.extend([q_group, k_group, v_group]) - - qkv = torch.cat(qkv_weights, dim=0) - - # Final reshape - if is_bias: - return qkv.reshape(-1) - else: - return qkv.reshape([-1, hidden_size]) - - -def split_qkv_weights( - provider: TransformerConfig, qkv: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Split Megatron's interleaved QKV tensor into separate Q, K, V matrices. - - Args: - provider (TransformerConfig): Model configuration provider. - qkv (torch.Tensor): Interleaved QKV weights in Megatron format. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of (Q, K, V) - weight matrices. - """ - head_num = provider.num_attention_heads - num_query_groups = provider.num_query_groups - heads_per_group = head_num // num_query_groups - head_size = provider.kv_channels or (provider.hidden_size // head_num) - qkv_total_dim = head_num + 2 * num_query_groups - is_bias = qkv.ndim == 1 - - if is_bias: - hidden_size = 1 - qkv_reshaped = qkv.view(qkv_total_dim, head_size) - else: - hidden_size = qkv.shape[-1] - qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size) - - # Extract Q, K, V from interleaved pattern - q_slice = torch.cat( - [ - torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) - for i in range(num_query_groups) - ] - ) - k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2) - v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2) - - q = qkv_reshaped[q_slice] - k = qkv_reshaped[k_slice] - v = qkv_reshaped[v_slice] - - if is_bias: - q = q.reshape(-1) - k = k.reshape(-1) - v = v.reshape(-1) - else: - q = q.reshape(-1, hidden_size) - k = k.reshape(-1, hidden_size) - v = v.reshape(-1, hidden_size) - - return q, k, v diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/utils.py b/flagscale/train/megatron/nemo_bridge/models/conversion/utils.py deleted file mode 100644 index 66d68aee66..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/conversion/utils.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import copy -import functools -import re -import types - -from typing import Iterable, List, Optional, Tuple - -import torch - -from rich.table import Table -from transformers.configuration_utils import PretrainedConfig - -from megatron.core.transformer.module import MegatronModule -from megatron.core.utils import unwrap_model - - -def weights_verification_table(bridge, megatron_model) -> Table: - """ - Returns a table comparing weights between a Hugging Face model and a Megatron-LM model. - - Args: - bridge (AutoBridge): The bridge object containing model information. - megatron_model: The Megatron-LM model instance. - - Returns: - Table: A rich Table object with the comparison. - """ - table = Table(title="Hugging Face Weights Verification") - table.add_column("Weight Name", style="cyan") - table.add_column("Shape") - table.add_column("DType") - table.add_column("Device") - table.add_column("Matches Original", justify="center") - - # Check each weight against the original HF-model - for name, param in bridge.export_hf_weights(megatron_model, show_progress=True): - original_param = bridge.hf_pretrained.state[name] - table.add_row( - name, - str(tuple(param.shape)), - str(param.dtype).replace("torch.", ""), - str(param.device), - "✅" if torch.allclose(param, original_param.to(param.device), atol=1e-6) else "❌", - ) - - return table - - -def get_module_and_param_from_name( - models: MegatronModule | List[MegatronModule], param_name: str, vp_stage: Optional[int] = None -) -> Tuple[torch.nn.Module, torch.Tensor] | Tuple[torch.nn.Module, torch.Tensor, Tuple]: - """ - Get parameter from specific VP stage, ensuring that parameter - attributes are preserved. Supports both absolute and relative parameter names. - - Args: - models: List of Megatron model instances or a submodule - param_name: Dot-separated parameter name (can be absolute or relative to models) - vp_stage: Virtual pipeline stage index (None for single stage) - - Returns: - Tuple of (module, parameter) where module owns the parameter - - Raises: - ValueError: If vp_stage is out of range or parameter doesn't exist - - Examples: - Basic usage with full model: - >>> module, param = get_module_and_param_from_name( - ... models=full_model, - ... param_name="transformer.layers.0.attention.query.weight" - ... ) - - Usage with model list and VP stage: - >>> module, param = get_module_and_param_from_name( - ... models=[model1, model2, model3], - ... param_name="layers.0.mlp.dense.bias", - ... vp_stage=1 - ... ) - - Usage with submodule and relative path: - >>> linear_module = model.transformer.layers[0].mlp.dense - >>> module, param = get_module_and_param_from_name( - ... models=linear_module, - ... param_name="weight" - ... ) - - Usage with submodule and absolute path (automatic suffix matching): - >>> linear_module = model.transformer.layers[0].mlp.dense - >>> module, param = get_module_and_param_from_name( - ... models=linear_module, - ... param_name="transformer.layers.0.mlp.dense.weight" - ... ) - # Automatically matches "weight" suffix and returns the parameter - - Edge case with partial path matching: - >>> attention_module = model.transformer.layers[0].attention - >>> module, param = get_module_and_param_from_name( - ... models=attention_module, - ... param_name="layers.0.attention.query.weight" - ... ) - # Matches "query.weight" suffix within the attention module - """ - - if isinstance(models, list): - if vp_stage is None: - model = models[0] - else: - if vp_stage >= len(models): - raise ValueError(f"VP stage {vp_stage} out of range (max: {len(models) - 1})") - model = models[vp_stage] - else: - model = models - - module = unwrap_model(model) - splitted_name = param_name.split(".") - - # Try to find the parameter using the given parts - def try_get_param(parts): - param = module - temp_module = module - - for i, part in enumerate(parts): - if not hasattr(param, part): - return None - param = getattr(param, part) - if i < len(parts) - 1: - temp_module = getattr(temp_module, part) - - return temp_module, param - - # First try the full parameter name (current behavior) - result = try_get_param(splitted_name) - if result is not None: - return result - - # If full name doesn't work, try suffixes of the parameter name - # This handles cases where models is a submodule but param_name is absolute - for start_idx in range(1, len(splitted_name)): - suffix_parts = splitted_name[start_idx:] - result = try_get_param(suffix_parts) - if result is not None: - return result - - # If no approach works, raise an error - raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}") - - -def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0): - """Remove non-pickleable objects from a configuration object recursively. - - This utility function identifies and removes objects that cannot be pickled for - inter-process communication, including functions, bound methods, partial - functions, and other problematic callables. - - Args: - obj: The object to clean - max_depth: Maximum recursion depth (default: 2) - current_depth: Current recursion depth (internal use) - - Returns: - The cleaned object with non-pickleables removed - """ - - # Stop recursion if max depth reached - if current_depth >= max_depth: - return obj - - # Handle None - if obj is None: - return obj - - # Check if object is a problematic callable - if callable(obj): - # Allow classes/types but remove function objects, methods, partials - if isinstance(obj, type): - return obj - elif hasattr(obj, "__call__") and ( - isinstance(obj, (types.FunctionType, types.MethodType, functools.partial)) - or hasattr(obj, "__self__") - ): # bound methods - return None - - # Handle dataclass/object with attributes - if hasattr(obj, "__dict__"): - # Create a copy to avoid modifying the original - cleaned_obj = copy.copy(obj) - - for attr_name in list(vars(cleaned_obj).keys()): - attr_value = getattr(cleaned_obj, attr_name) - - # Recursively clean attribute - cleaned_value = remove_non_pickleables(attr_value, max_depth, current_depth + 1) - - # Set the cleaned value (or None if it was removed) - setattr(cleaned_obj, attr_name, cleaned_value) - - return cleaned_obj - - # Handle lists - elif isinstance(obj, list): - return [remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj] - - # Handle tuples - elif isinstance(obj, tuple): - return tuple(remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj) - - # Handle dictionaries - elif isinstance(obj, dict): - return { - key: remove_non_pickleables(value, max_depth, current_depth + 1) - for key, value in obj.items() - } - - # For primitive types and other safe objects, return as-is - return obj - - -def extract_sort_key(param_name: str): - """Extract sorting key based on layer and expert numbers.""" - - # Extract at most 2 numbers: layer number and expert number - # Pattern: *layers.d+.*d+ (layer number and potentially expert number) - numbers = [] - # Find layer number - layer_match = re.search(r"layers\.(\d+)", param_name) - if layer_match: - numbers.append(int(layer_match.group(1))) - # Find expert number after bias or weight - expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) - if expert_match: - numbers.append(int(expert_match.group(1))) - # Pad to ensure consistent comparison (max 2 numbers) - while len(numbers) < 2: - numbers.append(-1) - numbers = numbers[:2] # Keep at most 2 numbers - return numbers, param_name - - -def get_causal_lm_class_via_auto_map( - model_name_or_path: str, config: PretrainedConfig -) -> type | None: - """Return CausalLM class via config.auto_map if available; otherwise None. - - If auto_map["AutoModelForCausalLM"] is present in the config, returns the dynamically loaded class. - Returns None when auto_map is absent or loading fails. Does not download weights. - """ - auto_map = getattr(config, "auto_map", None) - if auto_map and "AutoModelForCausalLM" in auto_map: - auto_map_class = auto_map["AutoModelForCausalLM"] - repo_id = model_name_or_path or getattr(config, "_name_or_path", None) - if not repo_id: - return None - try: - from transformers.dynamic_module_utils import get_class_from_dynamic_module - - return get_class_from_dynamic_module( - class_reference=auto_map_class, - pretrained_model_name_or_path=repo_id, - cache_dir=None, - force_download=False, - resume_download=True, - proxies=None, - use_auth_token=None, - revision=None, - local_files_only=False, - repo_id=repo_id, - ) - except Exception: - return None - - return None - - -def persistent_buffers(model: torch.nn.Module) -> Iterable[Tuple[str, torch.Tensor]]: - """Return an iterator over persistent module buffers, yielding both the name of the buffer as well as the buffer itself.""" - - for mod_prefix, mod in model.named_modules(): - # only local buffers; we'll add the prefix ourselves - for local_name, buffer in mod.named_buffers(recurse=False): - if local_name not in getattr(mod, "_non_persistent_buffers_set", set()): - full_name = f"{mod_prefix + '.' if mod_prefix else ''}{local_name}" - yield full_name, buffer diff --git a/flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py b/flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py deleted file mode 100644 index 744d700763..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/decorators/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -from megatron.nemo_bridge.models.decorators.dispatch import dispatch -#from megatron.nemo_bridge.models.decorators.torchrun import torchrun_main - -#__all__ = ["dispatch", "torchrun_main"] -__all__ = ["dispatch"] diff --git a/flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py b/flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py deleted file mode 100644 index 7e02855d66..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/decorators/dispatch.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -"""Simplified dispatch system for Python, based on classes' typeclass implementation. - -This module provides a dispatch-based polymorphism system allowing extensible -behavior for different types using the `impl` decorator. -""" - -from functools import _find_impl # type: ignore -from typing import Any, Callable, Dict, Optional, TypeVar - -_SignatureType = TypeVar("_SignatureType", bound=Callable) - - -class _Dispatch: - """Internal dispatch representation with type-based routing logic.""" - - __slots__ = ("_signature", "_name", "_exact_types", "_dispatch_cache", "_doc", "_module") - - def __init__(self, signature: Callable) -> None: - self._signature = signature - self._name = signature.__name__ - self._exact_types: Dict[Any, Callable] = {} - self._dispatch_cache: Dict[Any, Callable] = {} - - # Extract docstring and module info for rich repr - self._doc = signature.__doc__ - self._module = signature.__module__ - - def __call__(self, instance: Any, *args, **kwargs) -> Any: - """Dispatch to the appropriate implementation based on instance type.""" - # Special case for tuple-based keys. - if isinstance(instance, tuple): - key = tuple(v if isinstance(v, (type, str)) else type(v) for v in instance) - - # Direct match - impl = self._exact_types.get(key) - if impl is not None: - # NOTE: This path is not cached for simplicity - return impl(instance, *args, **kwargs) - - # Subclass match for tuples of types - for registered_key, callback in self._exact_types.items(): - if ( - not isinstance(registered_key, tuple) - or len(registered_key) != len(key) - or not all(isinstance(t, type) for t in registered_key) - ): - continue - - try: - # For subclass checks, operate on the instance types only - key_types = tuple(v if isinstance(v, type) else type(v) for v in instance) - if all(issubclass(k, rk) for k, rk in zip(key_types, registered_key)): - # NOTE: not caching tuple subclass matches for simplicity - return callback(instance, *args, **kwargs) - except TypeError: - continue # issubclass can fail - - # Normalize both sides to names so tuples of types and/or strings can match. - def _name(obj): - return obj if isinstance(obj, str) else getattr(obj, "__name__", None) or str(obj) - - key_names = tuple(_name(v) for v in key) - for registered_key, callback in self._exact_types.items(): - if not isinstance(registered_key, tuple) or len(registered_key) != len(key): - continue - reg_names = tuple(_name(rk) for rk in registered_key) - if reg_names == key_names: - return callback(instance, *args, **kwargs) - - # No implementation found for this tuple, raise a specific error. - error_msg = self._format_no_implementation_error(instance) - raise NotImplementedError(error_msg) - - # For class dispatch, we use the class (or string of class name) itself as the key - if isinstance(instance, type): - cache_key = instance - instance_type = instance - elif isinstance(instance, str): - cache_key = instance - instance_type = str - else: - cache_key = type(instance) - instance_type = cache_key - - # Try cache - impl = self._dispatch_cache.get(cache_key) - if impl is None: - impl = self._dispatch(instance, instance_type) - if impl is None: - error_msg = self._format_no_implementation_error(instance) - raise NotImplementedError(error_msg) - self._dispatch_cache[cache_key] = impl - - return impl(instance, *args, **kwargs) - - def impl(self, *target_types: Any) -> Callable[[Callable], Callable]: - """Register an implementation for one or more types. - - Usage: - @mydispatch.impl(int) # Register for a single type - @mydispatch.impl(int, str) # Register for multiple types - @mydispatch.impl((list, str)) # Register for a tuple of types as a key - """ - if not target_types: - raise ValueError( - "\n✗ Missing argument to .impl()\n\n" - "You must specify at least one target type.\n\n" - "Examples:\n" - f" @{self._name}.impl(str) # Single type\n" - f" @{self._name}.impl(int, float) # Multiple types\n" - f" @{self._name}.impl((list, str)) # Tuple key\n" - ) - - def decorator(func: Callable) -> Callable: - if len(target_types) == 1: - # This handles both `@impl(int)` and `@impl((int, str))` - self._exact_types[target_types[0]] = func - else: - # This handles `@impl(int, str)` - for typ in target_types: - self._exact_types[typ] = func - - self._dispatch_cache.clear() - return func - - return decorator - - def __repr__(self) -> str: - """Rich representation showing all implementations.""" - # Build signature string - import inspect - - sig = inspect.signature(self._signature) - sig_str = f"{self._name}{sig}" - - lines = [f"Dispatch({sig_str})("] - - # Add regular implementations - for typ, impl in self._exact_types.items(): - if isinstance(typ, tuple): - type_name = ( - f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})" - ) - else: - type_name = typ.__name__ if hasattr(typ, "__name__") else str(typ) - impl_loc = self._format_location(impl) - lines.append(f" ({type_name}): {impl.__name__} at {impl_loc}") - - lines.append(")") - return "\n".join(lines) - - def _dispatch(self, instance: Any, instance_type: type) -> Optional[Callable]: - """Find the implementation for a given type. - - Fallback order: - 1) Exact type match - 2) issubclass match (when instance is a type) - 3) MRO-based match via functools._find_impl - 4) Name-based fallback: match by class __name__ for dynamically generated - classes (e.g., HF transformers auto_map dynamic modules) - """ - # Direct type match - impl = self._exact_types.get(instance_type, None) - if impl is not None: - return impl - - # For class dispatch, check issubclass relationships - if isinstance(instance, type): - for registered_type, callback in self._exact_types.items(): - if not isinstance(registered_type, type): - continue - try: - if issubclass(instance, registered_type): - return callback - except TypeError: - # issubclass can fail for some types - pass - - # Use functools._find_impl for MRO-based dispatch, only for single types - single_type_impls = {k: v for k, v in self._exact_types.items() if isinstance(k, type)} - impl = _find_impl(instance_type, single_type_impls) - if impl is not None: - return impl - - # Name-based fallback for dynamic HF classes and string registrations. - def _name(obj): - return obj if isinstance(obj, str) else getattr(obj, "__name__", None) - - if isinstance(instance, str): - inst_name = instance - elif isinstance(instance, type): - inst_name = _name(instance) - else: - inst_name = _name(type(instance)) - - if inst_name: - for registered_type, callback in self._exact_types.items(): - reg_name = _name(registered_type) - if reg_name and str(reg_name) == inst_name: - return callback - - return None - - def _format_location(self, func: Callable) -> str: - """Format the location of a function for display.""" - try: - import inspect - - filename = inspect.getfile(func) - _, lineno = inspect.getsourcelines(func) - # Shorten the path to be more readable - import os - - filename = os.path.relpath(filename) - return f"{filename}:{lineno}" - except Exception: - return "" - - def _format_no_implementation_error(self, instance: Any) -> str: - """Format a helpful error message when no implementation is found.""" - type_name_for_header: str - type_name_for_suggestion: str - type_name_for_func: str - instance_type_hint: str - - if isinstance(instance, tuple): - instance_types = tuple(v if isinstance(v, type) else type(v) for v in instance) - type_names_str = ", ".join( - t.__qualname__ if hasattr(t, "__qualname__") else str(t) for t in instance_types - ) - type_name_for_header = f"tuple of types ({type_names_str})" - - suggestion_names = ", ".join( - t.__name__ if hasattr(t, "__name__") else str(t) for t in instance_types - ) - type_name_for_suggestion = f"({suggestion_names})" - type_name_for_func = "tuple" - instance_type_hint = f"Tuple[{', '.join(t.__name__ for t in instance_types)}]" - else: - instance_type = instance if isinstance(instance, type) else type(instance) - qualname = ( - instance_type.__qualname__ - if hasattr(instance_type, "__qualname__") - else str(instance_type) - ) - type_name_for_header = f"type '{qualname}'" - type_name_for_suggestion = ( - instance_type.__name__ if hasattr(instance_type, "__name__") else str(instance_type) - ) - type_name_for_func = type_name_for_suggestion.lower().replace(".", "_") - instance_type_hint = type_name_for_suggestion - - # Build error message - lines = [ - f"\n✗ No implementation found for {type_name_for_header}", - "", - f"The dispatch function '{self._name}' has no implementation for this type.", - "", - ] - - # Add available implementations - if self._exact_types: - lines.append("Available implementations:") - - # Add registered types - sorted_keys = sorted(self._exact_types.keys(), key=str) - for typ in sorted_keys: - if isinstance(typ, tuple): - type_display = f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})" - else: - type_display = typ.__name__ if hasattr(typ, "__name__") else str(typ) - lines.append(f" • {type_display}") - else: - lines.append("No implementations registered yet.") - - # Generate help based on existing implementations - if self._exact_types: - # Get a sample implementation to show the pattern - _, sample_impl = next(iter(self._exact_types.items())) - - lines.extend( - [ - "", - "To add support for this type, register an implementation:", - f" @{self._name}.impl({type_name_for_suggestion})", - f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}) -> ...:", - " # Your implementation here", - ] - ) - - # Try to extract parameter info from the sample implementation - import inspect - - try: - sig = inspect.signature(sample_impl) - params = list(sig.parameters.keys())[1:] # Skip first param (instance) - if params: - param_hints = ", ".join(params) - lines.append(f" # Expected parameters: {param_hints}") - except Exception: - pass - else: - lines.extend( - [ - "", - "To add support for this type:", - f" @{self._name}.impl({type_name_for_suggestion})", - f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}, ...) -> ...:", - " # Your implementation here", - ] - ) - - return "\n".join(lines) - - -def dispatch(func: _SignatureType) -> _Dispatch: - """Create a new dispatch function from a signature. - - Args: - func: Function defining the dispatch signature and default behavior - - Returns: - A dispatch object that can be extended with implementations - - Example: - >>> @dispatch - ... def to_string(instance) -> str: - ... '''Convert instance to string representation.''' - ... - >>> @to_string.impl(int) - ... def _to_string_int(instance: int) -> str: - ... return str(instance) - ... - >>> @to_string.impl(list, tuple) - ... def _to_string_sequence(instance) -> str: - ... return ', '.join(map(str, instance)) - ... - >>> assert to_string(42) == "42" - >>> assert to_string([1, 2, 3]) == "1, 2, 3" - """ - return _Dispatch(func) - - -__all__ = ["dispatch"] diff --git a/flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py b/flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py deleted file mode 100644 index 80fa77dcc5..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/decorators/torchrun.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import os -import traceback - -from functools import wraps - -import torch - -from torch.distributed.elastic.multiprocessing.errors import record - - -def torchrun_main(fn): - """ - A decorator that wraps the main function of a torchrun script. It uses - the `torch.distributed.elastic.multiprocessing.errors.record` decorator - to record any exceptions and ensures that the distributed process group - is properly destroyed on successful completion. In case of an exception, - it prints the traceback and performs a hard exit, allowing torchrun to - terminate all other processes. - """ - recorded_fn = record(fn) - - @wraps(fn) - def wrapper(*args, **kwargs): - try: - return_value = recorded_fn(*args, **kwargs) - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - return return_value - except Exception: - # The 'record' decorator might only log the exception to a file. - # Print it to stderr as well to make sure it's visible. - traceback.print_exc() - # Use os._exit(1) for a hard exit. A regular sys.exit(1) might - # not be enough to terminate a process stuck in a bad C++ state - # (e.g., after a NCCL error), which can cause the job to hang. - os._exit(1) - - return wrapper diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py index f2b27048b5..bee2b1aee3 100644 --- a/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py @@ -1,31 +1,4 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge -from megatron.nemo_bridge.models.deepseek.deepseek_provider import ( - DeepSeekModelProvider, - DeepSeekProvider, - DeepSeekV2LiteModelProvider, - DeepSeekV2LiteProvider, - DeepSeekV2ModelProvider, - DeepSeekV2Provider, - DeepSeekV3ModelProvider, - DeepSeekV3Provider, - MoonlightModelProvider16B, - MoonlightProvider, -) -from megatron.nemo_bridge.models.deepseek.deepseek_v2_bridge import DeepSeekV2Bridge # noqa: F401 from megatron.nemo_bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge # noqa: F401 -__all__ = [ - "DeepSeekModelProvider", - "DeepSeekV2LiteModelProvider", - "DeepSeekV2ModelProvider", - "DeepSeekV3ModelProvider", - "MoonlightModelProvider16B", - "DeepSeekProvider", - "DeepSeekV2LiteProvider", - "DeepSeekV2Provider", - "DeepSeekV3Provider", - "MoonlightProvider", -] diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py index b8a660c957..ee257f4a0c 100644 --- a/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, BAAI. All rights reserved. # -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge from megatron.nemo_bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py deleted file mode 100644 index f429df4279..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_provider.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import warnings - -from dataclasses import dataclass, field -from functools import partial -from typing import TYPE_CHECKING, Callable, List, Optional, Union - -import torch -import torch.nn.functional as F - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec - -from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider -from megatron.nemo_bridge.models.transformer_config import MLATransformerConfig -from megatron.nemo_bridge.utils.common_utils import get_rank_safe - -try: - import transformer_engine # type: ignore # noqa: F401 - - HAVE_TE = True -except (ImportError, ModuleNotFoundError): - HAVE_TE = False - -if TYPE_CHECKING: - from megatron.core.transformer import ModuleSpec - -if HAVE_TE: - from megatron.core.utils import is_te_min_version - - -@dataclass -class DeepSeekModelProvider(MLATransformerConfig, GPTModelProvider): - """ - Base config for DeepSeek V2 and V3 models. - """ - - transformer_layer_spec: Union["ModuleSpec", Callable[["GPTModelProvider"], "ModuleSpec"]] = ( - partial(get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE) - ) - - # Model - normalization: str = "RMSNorm" - activation_func: Callable = F.silu - gated_linear_unit: bool = True # swiglu - position_embedding_type: str = "rope" - add_bias_linear: bool = False - share_embeddings_and_output_weights: bool = False - num_attention_heads: int = 128 - kv_channels: int = 128 - max_position_embeddings: int = 4096 - seq_length: int = 4096 - rotary_base: float = 10000.0 - make_vocab_size_divisible_by: int = 3200 - mtp_num_layers: Optional[int] = None - mtp_loss_scaling_factor: Optional[float] = None - - # Regularization - attention_dropout: float = 0.0 - hidden_dropout: float = 0.0 - qk_layernorm: bool = True - - # MoE - moe_grouped_gemm: bool = True - moe_router_pre_softmax: bool = True - moe_token_dispatcher_type: str = "alltoall" - moe_router_load_balancing_type: str = "seq_aux_loss" - moe_shared_expert_overlap: bool = True - moe_router_dtype: Optional[str] = "fp32" - - # MLA - q_lora_rank: int = 1536 - kv_lora_rank: int = 512 - qk_head_dim: int = 128 - qk_pos_emb_head_dim: int = 64 - v_head_dim: int = 128 - rotary_scaling_factor: float = 40 - mscale: float = 1.0 - mscale_all_dim: float = 1.0 - - # Miscellaneous - init_method_std: float = 0.006 - layernorm_epsilon: float = 1e-6 - bf16: bool = True - params_dtype: torch.dtype = torch.bfloat16 - async_tensor_model_parallel_allreduce: bool = True - attention_softmax_in_fp32: bool = False - persist_layer_norm: bool = True - num_layers_in_first_pipeline_stage: Optional[int] = None - num_layers_in_last_pipeline_stage: Optional[int] = None - account_for_embedding_in_pipeline_split: bool = False - account_for_loss_in_pipeline_split: bool = False - - # MLA specific - multi_latent_attention: bool = True - - # fusions - apply_rope_fusion: bool = False - bias_activation_fusion: bool = True - bias_dropout_fusion: bool = True - masked_softmax_fusion: bool = True - cross_entropy_loss_fusion: bool = True - cross_entropy_fusion_impl: str = "te" - moe_permute_fusion: bool = is_te_min_version("2.1.0") if HAVE_TE else False - - -@dataclass -class DeepSeekV2ModelProvider(DeepSeekModelProvider): - """ - DeepSeek-V2 Model: https://github.com/deepseek-ai/DeepSeek-V2 - """ - - num_layers: int = 60 - hidden_size: int = 5120 - ffn_hidden_size: int = 12288 - num_moe_experts: int = 160 - moe_ffn_hidden_size: int = 1536 - moe_shared_expert_intermediate_size: int = 3072 # 1536 * 2 shared experts - moe_layer_freq: Union[int, List[int]] = field( - default_factory=lambda: [0] + [1] * 59 - ) # first layer is dense - moe_router_topk: int = 6 - moe_router_num_groups: int = 8 - moe_router_group_topk: int = 3 - moe_router_topk_scaling_factor: float = 16.0 - moe_aux_loss_coeff: float = 1e-3 - mscale: float = 0.707 - mscale_all_dim: float = 0.707 - vocab_size: int = 102400 - - -@dataclass -class DeepSeekV2LiteModelProvider(DeepSeekV2ModelProvider): - """ - DeepSeek-V2-Lite Model: https://github.com/deepseek-ai/DeepSeek-V2 - HuggingFace: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite - """ - - num_layers: int = 27 - hidden_size: int = 2048 - ffn_hidden_size: int = 10944 - num_attention_heads: int = 16 - kv_channels: int = 16 - q_lora_rank: int = None - num_moe_experts: int = 64 - moe_ffn_hidden_size: int = 1408 - moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared experts - moe_layer_freq: Union[int, List[int]] = field( - default_factory=lambda: [0] + [1] * 26 - ) # first layer is dense - moe_router_topk: int = 6 - moe_router_num_groups: int = 1 - moe_router_group_topk: int = 1 - moe_router_topk_scaling_factor: float = 1.0 - vocab_size: int = 102400 - - -@dataclass -class DeepSeekV3ModelProvider(DeepSeekModelProvider): - """ - DeepSeek-V3 Model: https://github.com/deepseek-ai/DeepSeek-V3 - """ - - num_layers: int = 61 - hidden_size: int = 7168 - ffn_hidden_size: int = 18432 - num_moe_experts: int = 256 - moe_ffn_hidden_size: int = 2048 - moe_shared_expert_intermediate_size: int = 2048 # 2048 * 1 shared expert - moe_layer_freq: Union[int, List[int]] = field( - default_factory=lambda: [0] * 3 + [1] * 58 - ) # first three layers are dense - moe_router_topk: int = 8 - moe_router_num_groups: int = 8 - moe_router_group_topk: int = 4 - moe_router_topk_scaling_factor: float = 2.5 - make_vocab_size_divisible_by: int = 1280 - moe_router_score_function: str = "sigmoid" - moe_router_enable_expert_bias: bool = True - moe_router_bias_update_rate: float = 1e-3 - mscale: float = 1.0 - mscale_all_dim: float = 1.0 - vocab_size: int = 129280 - - -@dataclass -class MoonlightModelProvider16B(DeepSeekModelProvider): - """ - Moonlight-16B-A3B Model: https://github.com/moonshotai/Moonlight-16B-A3B - - Moonlight is based on DeepSeek-V3. - """ - - max_position_embeddings: int = 4096 - num_layers: int = 27 - hidden_size: int = 2048 - ffn_hidden_size: int = 11264 - num_attention_heads: int = 16 - kv_channels: int = 16 - num_moe_experts: int = 64 - moe_ffn_hidden_size: int = 1408 - moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared expert - moe_layer_freq: Union[int, List[int]] = field( - default_factory=lambda: [0] * 1 + [1] * 26 - ) # first layer is dense - moe_router_topk: int = 6 - moe_router_num_groups: int = 1 - moe_router_group_topk: int = 1 - moe_router_topk_scaling_factor: float = 2.446 - moe_aux_loss_coeff: float = 0.001 - make_vocab_size_divisible_by: int = 1280 - moe_router_score_function: str = "sigmoid" - moe_router_enable_expert_bias: bool = True - rotary_scaling_factor: float = 1.0 - mscale: float = 1.0 - mscale_all_dim: float = 1.0 - rotary_base: float = 50000 - layernorm_epsilon: float = 1e-5 - q_lora_rank: int = None - init_method_std: float = 0.02 - moe_router_bias_update_rate: float = 1e-3 - rotary_percent: float = 1.0 - vocab_size: int = 163840 - - -# ----------------------------------------------------------------------------- -# Deprecated aliases (to be removed in a future release) -# ----------------------------------------------------------------------------- - - -def _warn_deprecated(old_cls: str, new_cls: str) -> None: - if get_rank_safe() == 0: - warnings.warn( - f"{old_cls} is deprecated and will be removed in a future release. Use {new_cls} instead.", - DeprecationWarning, - stacklevel=2, - ) - - -@dataclass -class DeepSeekProvider(DeepSeekModelProvider): - """Deprecated alias for ``DeepSeekModelProvider``. - - Deprecated: - This alias remains for backward compatibility and will be removed in a - future release. Import and use ``DeepSeekModelProvider`` instead. - """ - - def __post_init__(self) -> None: - _warn_deprecated("DeepSeekProvider", "DeepSeekModelProvider") - super().__post_init__() - - -@dataclass -class DeepSeekV2Provider(DeepSeekV2ModelProvider): - """Deprecated alias for ``DeepSeekV2ModelProvider``. - - Deprecated: - This alias remains for backward compatibility and will be removed in a - future release. Import and use ``DeepSeekV2ModelProvider`` instead. - """ - - def __post_init__(self) -> None: - _warn_deprecated("DeepSeekV2Provider", "DeepSeekV2ModelProvider") - super().__post_init__() - - -@dataclass -class DeepSeekV2LiteProvider(DeepSeekV2LiteModelProvider): - """Deprecated alias for ``DeepSeekV2LiteModelProvider``. - - Deprecated: - This alias remains for backward compatibility and will be removed in a - future release. Import and use ``DeepSeekV2LiteModelProvider`` instead. - """ - - def __post_init__(self) -> None: - _warn_deprecated("DeepSeekV2LiteProvider", "DeepSeekV2LiteModelProvider") - super().__post_init__() - - -@dataclass -class DeepSeekV3Provider(DeepSeekV3ModelProvider): - """Deprecated alias for ``DeepSeekV3ModelProvider``. - - Deprecated: - This alias remains for backward compatibility and will be removed in a - future release. Import and use ``DeepSeekV3ModelProvider`` instead. - """ - - def __post_init__(self) -> None: - _warn_deprecated("DeepSeekV3Provider", "DeepSeekV3ModelProvider") - super().__post_init__() - - -@dataclass -class MoonlightProvider(MoonlightModelProvider16B): - """Deprecated alias for ``MoonlightModelProvider16B``. - - Deprecated: - This alias remains for backward compatibility and will be removed in a - future release. Import and use ``MoonlightModelProvider16B`` instead. - """ - - def __post_init__(self) -> None: - _warn_deprecated("MoonlightProvider", "MoonlightModelProvider16B") - super().__post_init__() diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py deleted file mode 100644 index 75b72847b7..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v2_bridge.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import torch - -from megatron.core.models.gpt.gpt_model import GPTModel - -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge -from megatron.nemo_bridge.models.deepseek.common import ( - get_common_configs, - get_common_mapping_list, -) -from megatron.nemo_bridge.models.deepseek.deepseek_provider import DeepSeekV2ModelProvider -from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM - - -@MegatronModelBridge.register_bridge(source="DeepseekV2ForCausalLM", target=GPTModel) -class DeepSeekV2Bridge(MegatronModelBridge): - """ - Megatron Bridge for DeepSeek-V2. - - As a user you would not use this bridge directly, but through `AutoBridge`. - - Example: - >>> from megatron.nemo_bridge import AutoBridge - >>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V2", trust_remote_code=True) - >>> provider = bridge.to_megatron_provider() - """ - - def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV2ModelProvider: - hf_config = hf_pretrained.config - configs = get_common_configs(hf_pretrained) - - configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16 - configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16 - configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32) - - configs["make_vocab_size_divisible_by"] = 3200 - configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha - - provider = DeepSeekV2ModelProvider(**configs) - return provider - - def mapping_registry(self) -> MegatronMappingRegistry: - mapping_list = get_common_mapping_list() - return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py index 7c19cfb0ab..b83b90b11d 100644 --- a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py @@ -1,19 +1,19 @@ # Copyright (c) 2025, BAAI. All rights reserved. # -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge import torch from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge from megatron.nemo_bridge.models.conversion.param_mapping import AutoMapping from megatron.nemo_bridge.models.deepseek.common import ( get_common_configs, get_common_mapping_list, ) -from megatron.nemo_bridge.models.deepseek.deepseek_provider import DeepSeekV3ModelProvider +from megatron.bridge.models.deepseek.deepseek_provider import DeepSeekV3ModelProvider from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM diff --git a/flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py b/flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py deleted file mode 100644 index 7409349ce5..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/gpt_full_te_layer_autocast_spec.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -from importlib.metadata import version -from typing import Any, Callable, Optional, Union - -import packaging -import torch - -from transformer_engine.pytorch import TransformerLayer - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.transformer.cuda_graphs import CudaGraphManager -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import ( - TransformerBlockSubmodules, - get_num_layers_to_build, -) -from megatron.core.transformer.transformer_layer import BaseTransformerLayer -from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint - - -# Copied from nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py -class AutocastTransformerLayer(TransformerLayer): - """ - Wrapper of te.pytorch.TransformerLayer: a single transformerlayer - that takes input with size [s, b, h] and returns an output of - the same size. - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - layernorm_epsilon: float, - num_attention_heads: int, - init_method: Callable, - output_layer_init_method: Callable, - hidden_dropout: float, - attention_dropout: float, - layer_number: Optional[int] = None, - kv_channels: Optional[int] = None, - self_attn_mask_type: str = "causal", - tp_group: Optional[Any] = None, - tp_size: int = 1, - params_dtype: torch.dtype = torch.float32, - get_rng_state_tracker: Optional[Callable] = None, - fuse_wgrad_accumulation: bool = False, - seq_length: Optional[int] = None, - micro_batch_size: Optional[int] = None, - sequence_parallel: bool = False, - apply_residual_connection_post_layernorm: bool = False, - output_layernorm: bool = False, - layer_type: str = "encoder", - drop_path_rate: float = 0, - use_emha: bool = False, - ub_tp_comm_overlap: bool = False, - ub_bulk_wgrad: bool = True, - ub_bulk_dgrad: bool = True, - autocast_dtype: Any = 16, - zero_centered_gamma: bool = False, - device: str = "cuda", - **kwargs, - ) -> None: - transformer_layer_args = { - "hidden_size": hidden_size, - "ffn_hidden_size": ffn_hidden_size, - "layernorm_epsilon": layernorm_epsilon, - "num_attention_heads": num_attention_heads, - "init_method": init_method, - "output_layer_init_method": output_layer_init_method, - "hidden_dropout": hidden_dropout, - "attention_dropout": attention_dropout, - "layer_number": layer_number, - "kv_channels": kv_channels, - "self_attn_mask_type": self_attn_mask_type, - "tp_group": tp_group, - "tp_size": tp_size, - "params_dtype": params_dtype, - "get_rng_state_tracker": get_rng_state_tracker, - "fuse_wgrad_accumulation": fuse_wgrad_accumulation, - "seq_length": seq_length, - "micro_batch_size": micro_batch_size, - "sequence_parallel": sequence_parallel, - "apply_residual_connection_post_layernorm": apply_residual_connection_post_layernorm, - "output_layernorm": output_layernorm, - "layer_type": layer_type, - "drop_path_rate": drop_path_rate, - "set_parallel_mode": tp_size > 1, - "fuse_qkv_params": True, - "zero_centered_gamma": zero_centered_gamma, - "ub_tp_comm_overlap": ub_tp_comm_overlap, - "ub_bulk_wgrad": ub_bulk_wgrad, - "ub_bulk_dgrad": ub_bulk_dgrad, - "device": device, - } - te_version = packaging.version.Version(version("transformer-engine")) - if te_version > packaging.version.Version("1.5.0"): - for comm in ["ag", "rs"]: - ub_overlap_flag = "ub_overlap_" + comm - split_gemm_flag = "ub_split_" + comm - atomic_gemm_flag = "ub_atomic_gemm_" + comm - # Use old overlap flags if they were supplied instead - if ub_overlap_flag in kwargs: - transformer_layer_args[ub_overlap_flag] = kwargs[ub_overlap_flag] - else: - transformer_layer_args[ub_overlap_flag] = kwargs.get( - split_gemm_flag, True - ) or kwargs.get(atomic_gemm_flag, False) - if te_version > packaging.version.Version("1.6.0.dev0"): - transformer_layer_args["ub_overlap_rs_dgrad"] = kwargs.get( - "ub_overlap_rs_dgrad", False - ) - else: - transformer_layer_args["ub_split_ag"] = kwargs.get("ub_split_ag", True) - transformer_layer_args["ub_split_rs"] = kwargs.get("ub_split_rs", True) - transformer_layer_args["ub_atomic_gemm_ag"] = kwargs.get("ub_atomic_gemm_ag", False) - transformer_layer_args["ub_atomic_gemm_rs"] = kwargs.get("ub_atomic_gemm_rs", False) - super().__init__(**transformer_layer_args) - - # Dtype for forward pass - self.dtype = torch_dtype_from_precision(autocast_dtype) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor = None, - encoder_output: Optional[torch.Tensor] = None, - enc_dec_attn_mask: Optional[torch.Tensor] = None, - inference_params: Optional[Any] = None, - is_first_microbatch: Optional[bool] = None, - checkpoint_core_attention: Optional[bool] = False, - ) -> torch.Tensor: - """ - Perform a forward pass through the transformer layer. - """ - if self.dtype == torch.float32: - return super().forward( - hidden_states, - attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - is_first_microbatch=is_first_microbatch, - checkpoint_core_attention=checkpoint_core_attention, - ) - with torch.autocast(device_type="cuda", dtype=self.dtype): - return super().forward( - hidden_states, - attention_mask=attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - is_first_microbatch=is_first_microbatch, - checkpoint_core_attention=checkpoint_core_attention, - ) - - -class TETransformerLayerAutocast(MegatronModule, BaseTransformerLayer): # type: ignore - """ - A MegatronModule that wraps the AutocastTransformerLayer. - """ - - def __init__(self, config, layer_number=1, hidden_dropout=None, **kwargs): - super().__init__(config=config) - self.layer_number = layer_number + self._get_layer_offset() - - self.config = config - self.is_first_microbatch = True - precision = "bf16" if config.bf16 else 16 - - transformer_layer_args = { - "hidden_size": config.hidden_size, - "ffn_hidden_size": config.ffn_hidden_size, - "layernorm_epsilon": config.layernorm_epsilon, - "num_attention_heads": config.num_attention_heads, - "init_method": config.init_method, - "output_layer_init_method": config.output_layer_init_method, - "hidden_dropout": config.hidden_dropout, - "attention_dropout": config.attention_dropout, - "layer_number": layer_number + self._get_layer_offset(), - "kv_channels": config.kv_channels, - "tp_size": parallel_state.get_tensor_model_parallel_world_size(), - "params_dtype": config.params_dtype, - "get_rng_state_tracker": tensor_parallel.random.get_cuda_rng_tracker, - "fuse_wgrad_accumulation": config.gradient_accumulation_fusion, - "seq_length": None, # used for jit warmup - "micro_batch_size": None, # used for jit warmup - "sequence_parallel": config.sequence_parallel, - "apply_residual_connection_post_layernorm": config.apply_residual_connection_post_layernorm, - "autocast_dtype": precision, - "ub_tp_comm_overlap": config.tp_comm_overlap, - "ub_bulk_wgrad": config.tp_comm_bulk_wgrad, - "ub_bulk_dgrad": config.tp_comm_bulk_dgrad, - "zero_centered_gamma": config.layernorm_zero_centered_gamma, - "device": "cpu" if config.use_cpu_initialization else "cuda", - } - te_version = packaging.version.Version(version("transformer-engine")) - if te_version > packaging.version.Version("1.5.0"): - # Use old overlap flags if they were supplied instead - transformer_layer_args["ub_overlap_ag"] = ( - config.tp_comm_overlap_ag - if hasattr(config, "tp_comm_overlap_ag") - else config.tp_comm_split_ag or config.tp_comm_atomic_ag - ) - transformer_layer_args["ub_overlap_rs"] = ( - config.tp_comm_overlap_rs - if hasattr(config, "tp_comm_overlap_rs") - else config.tp_comm_split_rs or config.tp_comm_atomic_rs - ) - if te_version > packaging.version.Version("1.6.0.dev0"): - transformer_layer_args["ub_overlap_rs_dgrad"] = ( - config.tp_comm_overlap_rs_dgrad - if hasattr(config, "tp_comm_overlap_rs_dgrad") - else False - ) - else: - transformer_layer_args["ub_split_ag"] = config.tp_comm_split_ag - transformer_layer_args["ub_split_rs"] = config.tp_comm_split_rs - transformer_layer_args["ub_atomic_gemm_ag"] = config.tp_comm_atomic_ag - transformer_layer_args["ub_atomic_gemm_rs"] = config.tp_comm_atomic_rs - self.transformer_layer = AutocastTransformerLayer(**transformer_layer_args) - - if self.config.enable_cuda_graph and self.training: - assert ( - not config.cpu_offloading and config.recompute_granularity is None - ), "Cudagraphs not supported" - self.add_module("cudagraph_manager", CudaGraphManager(config)) - - # Called by MCore's TransformerBlock.forward - # megatron/core/transformer/transformer_block.py - def forward( - self, - hidden_states, - is_first_microbatch=None, - attention_mask=None, - context=None, - context_mask=None, - inference_params=None, - **kwargs, - ): - """Forward function of TETransformerLayerAutocast. Called by MCore's TransformerBlock.forward.""" - # Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise. - hidden_states = self.transformer_layer.forward( - hidden_states, - attention_mask=attention_mask, - encoder_output=context, - enc_dec_attn_mask=context_mask, - inference_params=inference_params, - is_first_microbatch=( - is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch - ), - # checkpoint_core_attention, - ) - self.is_first_microbatch = False - context = None - - # External CUDA graph requires returned values to be Tensors - if ( - hasattr(self.config, "external_cuda_graph") - and self.config.external_cuda_graph - and self.training - ): - return hidden_states - return hidden_states, context - - def _get_layer_offset(self): - pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() - - num_layers_per_pipeline_rank = ( - self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() - ) - - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() - vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() - - total_num_layers = self.config.num_layers - num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size - total_virtual_chunks = total_num_layers // vp_size - offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) - - else: - # Each stage gets a contiguous set of layers. - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - offset = pipeline_rank * num_layers_per_pipeline_rank - else: - offset = 0 - - return offset - - def sharded_state_dict(self, prefix: str = "", sharded_offsets: tuple = (), metadata=None): - """Get the sharded state dict for the transformer layer.""" - TENSOR_PARALLEL_LAYERS_AXIS_MAP = { - "self_attention.layernorm_qkv.weight": 0, - "self_attention.layernorm_qkv.bias": 0, - "self_attention.proj.weight": 1, - "layernorm_mlp.fc1_weight": 0, - "layernorm_mlp.fc1_bias": 0, - "layernorm_mlp.fc2_weight": 1, - } - - state_dict = self.state_dict(prefix="", keep_vars=True) - sharded_state_dict = make_sharded_tensors_for_checkpoint( - state_dict, prefix, TENSOR_PARALLEL_LAYERS_AXIS_MAP, sharded_offsets - ) - - # TODO: we need to add sharded_state_dict_keys_map to the config. Like in TransformerLayer submodules config - # prefixed_map = { - # f'{prefix}{k}': f'{prefix}{v}' - # for k, v in self.config.sharded_state_dict_keys_map.items() - # } - - # if prefixed_map: - # apply_prefix_mapping(sharded_state_dict, prefixed_map) - - return sharded_state_dict - - def __call__(self, *args, **kwargs): - if hasattr(self, "cudagraph_manager"): - return self.cudagraph_manager(self, args, kwargs) - return super().__call__(*args, **kwargs) - - -# Use this spec to use the full Transformer layer from Transformer Engine -def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: - """Get the ModuleSpec for full Transformer layer from Transformer Engine.""" - num_layers = get_num_layers_to_build(transformer_config) - return TransformerBlockSubmodules( - layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, - layer_norm=FusedLayerNorm, - ) - - -def torch_dtype_from_precision(precision: Union[int, str]) -> torch.dtype: - """Mapping from precision types to corresponding PyTorch parameter datatype.""" - if precision in ("bf16", "bf16-mixed"): - return torch.bfloat16 - elif precision in (16, "16", "16-mixed"): - return torch.float16 - elif precision in (32, "32", "32-true"): - return torch.float32 - else: - raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype") diff --git a/flagscale/train/megatron/nemo_bridge/models/gpt_provider.py b/flagscale/train/megatron/nemo_bridge/models/gpt_provider.py deleted file mode 100644 index 322c661097..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/gpt_provider.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import contextlib -import inspect -import logging - -from dataclasses import dataclass, field -from functools import partial -from typing import Any, Callable, Literal, Optional, Union - -import torch - -from megatron.core import parallel_state -from megatron.core.models.gpt import GPTModel as MCoreGPTModel -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec -from megatron.core.transformer import ModuleSpec - -from megatron.nemo_bridge.models.model_provider import ModelProviderMixin -from megatron.nemo_bridge.models.transformer_config import TransformerConfig -from megatron.nemo_bridge.utils import fusions -from megatron.nemo_bridge.utils.vocab_utils import calculate_padded_vocab_size - -logger = logging.getLogger(__name__) - - -def transformer_engine_layer_spec(config: "GPTModelProvider") -> ModuleSpec: - """Create a Transformer Engine layer specification based on the provided config.""" - if ( - "use_te_op_fuser" - in inspect.signature(get_gpt_layer_with_transformer_engine_spec).parameters - ): - kwargs = {"use_te_op_fuser": config.use_transformer_engine_op_fuser} - else: - kwargs = {} - return get_gpt_layer_with_transformer_engine_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - fp8=bool(config.num_moe_experts and (config.fp8 is not None)), - **kwargs, - ) - - -def transformer_engine_full_layer_spec(config: "GPTModelProvider") -> ModuleSpec: - """Create a full Transformer Engine layer specification with autocast support. - - Args: - config: GPT configuration object - - Returns: - ModuleSpec: Module specification for full TE layers - """ - from megatron.nemo_bridge.models.gpt_full_te_layer_autocast_spec import ( - get_gpt_full_te_layer_autocast_spec, - ) - - return get_gpt_full_te_layer_autocast_spec(transformer_config=config) - - -def local_layer_spec(config: "GPTModelProvider") -> ModuleSpec: - """Create a local layer specification without Transformer Engine. - - Args: - config: GPT configuration object - - Returns: - ModuleSpec: Module specification for local implementation layers - """ - return get_gpt_layer_local_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - normalization=config.normalization, - ) - - -def quantization_layer_spec(config: "GPTModelProvider") -> ModuleSpec: - """Layer specification for quantization with ModelOpt.""" - return get_gpt_modelopt_spec( - config=config, - local_core_attention=False, - remap_te_layernorm=True, - real_quant_cfg="None", - use_arbitrary_attention_mask=True, - ) - - -def default_layer_spec(config: "GPTModelProvider") -> ModuleSpec: - """Determine the most appropriate layer specification based on availability.""" - if config.restore_modelopt_state: - return quantization_layer_spec(config) - elif config.use_transformer_engine_full_layer_spec: - return transformer_engine_full_layer_spec(config) - else: - return transformer_engine_layer_spec(config) - - -@dataclass -class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]): - """Configuration and provider for Megatron Core GPT models. - - This class extends TransformerConfig with GPT-specific parameters and - provides a method to instantiate configured GPT models. - """ - - # Model configuration - fp16_lm_cross_entropy: bool = False - parallel_output: bool = True - share_embeddings_and_output_weights: bool = True - make_vocab_size_divisible_by: int = 128 - position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" - rotary_base: int = 10000 - rotary_percent: float = 1.0 - seq_len_interpolation_factor: Optional[float] = None - seq_length: int = 1024 - attention_softmax_in_fp32: bool = False - deallocate_pipeline_outputs: bool = True - scatter_embedding_sequence_parallel: bool = True - tp_only_amax_red: bool = False - tp_comm_overlap_cfg: Optional[Union[str, dict[str, Any]]] = None - """Config file when tp_comm_overlap is enabled.""" - - use_transformer_engine_full_layer_spec: bool = False - use_transformer_engine_op_fuser: bool = False - transformer_layer_spec: Union[ModuleSpec, Callable[["GPTModelProvider"], ModuleSpec]] = ( - default_layer_spec - ) - - generation_config: Optional[Any] = None - - # This represents the unpadded vocab size - # The padded vocab size is automatically calculated in the provide() method. - vocab_size: Optional[int] = None - # Set if the tokenizer provides the vocab size. In this case, the vocab size will be padded - # Controls whether vocab size should be padded for tensor parallelism - should_pad_vocab: bool = False - - # MoE / FP8 - num_moe_experts: Optional[int] = None - moe_grouped_gemm: bool = False - qk_layernorm: bool = False - fp8: Optional[str] = None - normalization: str = "LayerNorm" - - # Multi-token prediction - mtp_enabled: bool = False - - # Additional parameters that might be needed - init_model_with_meta_device: bool = False - use_te_rng_tracker: bool = False - enable_cuda_graph: bool = False - virtual_pipeline_model_parallel_size: Optional[int] = None - account_for_embedding_in_pipeline_split: bool = False - account_for_loss_in_pipeline_split: bool = False - - # Fusions - masked_softmax_fusion: bool = field(default_factory=fusions.can_enable_masked_softmax_fusion) - cross_entropy_loss_fusion: bool = True # Generally beneficial, no specific dependencies - gradient_accumulation_fusion: bool = field( - default_factory=fusions.can_enable_gradient_accumulation_fusion - ) - bias_activation_fusion: bool = ( - False # Disabled by default as it can interfere with certain architectures - ) - persist_layer_norm: bool = False - bias_dropout_fusion: bool = field(default_factory=fusions.can_enable_bias_dropout_fusion) - apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion) - - # If True, restore the modelopt_state that contains quantization, sparsity, speculative decoding transformation state. - # When resuming modelopt_state, we also change the transformer_layer_spec to `megatron.core.post_training.modelopt.gpt.model_specs` which is a combination of local spec + TEDotProductAttention. - - restore_modelopt_state: bool = False - - def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: - """Configure and instantiate a Megatron Core GPT model based on this configuration. - - Args: - pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage - post_process: Whether to include post-processing in the model, defaults to last pipeline stage - vp_stage: Virtual pipeline stage - - Returns: - MCoreGPTModel: Configured Megatron Core GPT model instance - """ - # Validate fusion configurations - if not fusions.validate_rope_fusion_compatibility(self): - self.apply_rope_fusion = False - - if self.enable_cuda_graph: - assert getattr(self, "use_te_rng_tracker", False), ( - "Transformer engine's RNG tracker is required for cudagraphs, it can be " - "enabled with use_te_rng_tracker=True'." - ) - - vp_size = self.virtual_pipeline_model_parallel_size - is_pipeline_asymmetric = getattr( - self, "account_for_embedding_in_pipeline_split", False - ) or getattr(self, "account_for_loss_in_pipeline_split", False) - is_pipeline_asymmetric |= ( - getattr(self, "num_layers_in_first_pipeline_stage", None) - or getattr(self, "num_layers_in_last_pipeline_stage", None) - ) is not None - is_flexible_pp_layout = is_pipeline_asymmetric or ( - getattr(self, "pipeline_model_parallel_layout", None) is not None - ) - if vp_size and not is_flexible_pp_layout: - p_size = self.pipeline_model_parallel_size - assert ( - self.num_layers // p_size - ) % vp_size == 0, ( - "Make sure the number of model chunks is the same across all pipeline stages." - ) - - transformer_layer_spec = self.transformer_layer_spec - if not isinstance(transformer_layer_spec, ModuleSpec): - # Check if the transformer_layer_spec function accepts vp_stage parameter - if "vp_stage" in inspect.signature(transformer_layer_spec).parameters: - transformer_layer_spec = transformer_layer_spec(self, vp_stage=vp_stage) - else: - transformer_layer_spec = transformer_layer_spec(self) - - assert self.vocab_size is not None, "vocab_size must be configured before calling provide()" - if self.should_pad_vocab: - padded_vocab_size = calculate_padded_vocab_size( - self.vocab_size, self.make_vocab_size_divisible_by, self.tensor_model_parallel_size - ) - else: - padded_vocab_size = self.vocab_size - - # Initialize model as meta data instead of allocating data on a device - model_init_device_context = contextlib.nullcontext - if self.init_model_with_meta_device: - model_init_device_context = partial(torch.device, device="meta") - - # Check if mtp_block_spec parameter is supported - kwargs = {} - if "mtp_block_spec" in inspect.signature(MCoreGPTModel.__init__).parameters: - kwargs["mtp_block_spec"] = mtp_block_spec(self, vp_stage=vp_stage) - - with model_init_device_context(): - model = MCoreGPTModel( - self, - transformer_layer_spec=transformer_layer_spec, - vocab_size=padded_vocab_size, - max_sequence_length=self.seq_length, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - position_embedding_type=self.position_embedding_type, - rotary_percent=self.rotary_percent, - rotary_base=self.rotary_base, - seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=pre_process - or parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage), - post_process=post_process - or parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage), - scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, - vp_stage=vp_stage, - **kwargs, - ) - - # If using full TE layer, need to set TP, CP group since the module call - # is not routed through megatron core, which normally handles passing the - # TP, CP group to the TE modules. - # Deep iterate but skip self to avoid infinite recursion. - if self.use_transformer_engine_full_layer_spec: - # Copied from: - # https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py - if parallel_state.get_tensor_model_parallel_world_size() > 1: - for index, child in enumerate(model.modules()): - if index == 0: - continue - if hasattr(child, "set_tensor_parallel_group"): - tp_group = parallel_state.get_tensor_model_parallel_group() - child.set_tensor_parallel_group(tp_group) - - if parallel_state.get_context_parallel_world_size() > 1: - cp_stream = torch.cuda.Stream() - for index, child in enumerate(model.modules()): - if index == 0: - continue - if hasattr(child, "set_context_parallel_group"): - child.set_context_parallel_group( - parallel_state.get_context_parallel_group(), - parallel_state.get_context_parallel_global_ranks(), - cp_stream, - ) - - return model - - -def mtp_block_spec( - config: "GPTModelProvider", vp_stage: Optional[int] = None -) -> Optional[ModuleSpec]: - """Pass in the MTP block spec if model has MTP layers. - - Args: - config: GPT configuration object - - Returns: - ModuleSpec: The MTP module specification - """ - if getattr(config, "mtp_num_layers", None): - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec - - if isinstance(config.transformer_layer_spec, Callable): - if "vp_stage" in inspect.signature(config.transformer_layer_spec).parameters: - spec = config.transformer_layer_spec(config, vp_stage=vp_stage) - else: - spec = config.transformer_layer_spec(config) - else: - spec = config.transformer_layer_spec - if hasattr(spec, "layer_specs") and len(spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - spec = default_layer_spec(config) - return get_gpt_mtp_block_spec(config, spec, use_transformer_engine=True, vp_stage=vp_stage) - else: - return None - - -@dataclass -class GPTProvider126M(GPTModelProvider): - """Configuration for a 126M parameter GPT model. - - Predefined configuration for a small GPT model with 12 layers, - 768 hidden size, and 12 attention heads. - """ - - seq_length: int = 2048 - num_layers: int = 12 - hidden_size: int = 768 - ffn_hidden_size: int = 3072 - num_attention_heads: int = 12 - bias_activation_fusion: bool = True - bias_dropout_add_fusion: bool = True - - -@dataclass -class GPTProvider5B(GPTModelProvider): - """Configuration for a 5B parameter GPT model. - - Predefined configuration for a medium-sized GPT model with 24 layers, - 4096 hidden size, and 32 attention heads. - """ - - seq_length: int = 2048 - num_layers: int = 24 - hidden_size: int = 4096 - ffn_hidden_size: int = 16384 - num_attention_heads: int = 32 - bias_activation_fusion: bool = True - bias_dropout_add_fusion: bool = True - - -@dataclass -class GPTProvider7B(GPTModelProvider): - """Configuration for a 7B parameter GPT model. - - Predefined configuration for a medium-sized GPT model with 32 layers, - 4096 hidden size, and 32 attention heads. - """ - - seq_length: int = 2048 - num_layers: int = 32 - hidden_size: int = 4096 - ffn_hidden_size: int = 10880 - num_attention_heads: int = 32 - bias_activation_fusion: bool = True - bias_dropout_add_fusion: bool = True - - -@dataclass -class GPTProvider20B(GPTModelProvider): - """Configuration for a 20B parameter GPT model. - - Predefined configuration for a large GPT model with 44 layers, - 6144 hidden size, and 48 attention heads. - """ - - seq_length: int = 2048 - num_layers: int = 44 - hidden_size: int = 6144 - ffn_hidden_size: int = 24576 - num_attention_heads: int = 48 - bias_activation_fusion: bool = True - bias_dropout_add_fusion: bool = True - - -@dataclass -class GPTProvider40B(GPTModelProvider): - """Configuration for a 40B parameter GPT model. - - Predefined configuration for a large GPT model with 48 layers, - 8192 hidden size, and 64 attention heads. - """ - - seq_length: int = 2048 - num_layers: int = 48 - hidden_size: int = 8192 - ffn_hidden_size: int = 32768 - num_attention_heads: int = 64 - bias_activation_fusion: bool = True - bias_dropout_add_fusion: bool = True - - -@dataclass -class GPTProvider175B(GPTModelProvider): - """Configuration for a 175B parameter GPT model. - - Predefined configuration for a massive GPT model with 96 layers, - 12288 hidden size, and 96 attention heads. - """ - - seq_length: int = 2048 - num_layers: int = 96 - hidden_size: int = 12288 - ffn_hidden_size: int = 49152 - num_attention_heads: int = 96 - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - bias_activation_fusion: bool = True - bias_dropout_add_fusion: bool = True - layernorm_zero_centered_gamma: bool = True diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py index 81f80fd7ac..f42a8bb6ae 100644 --- a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.nemo_bridge.models.hf_pretrained.vlm import PreTrainedVLM -__all__ = ["PreTrainedCausalLM", "PreTrainedVLM"] +__all__ = ["PreTrainedCausalLM"] diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py deleted file mode 100644 index ad8b3e332d..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/base.py +++ /dev/null @@ -1,237 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import shutil - -from abc import ABC, abstractmethod -from fnmatch import fnmatch -from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Union - -import torch - -from transformers import AutoConfig, PreTrainedModel - -from megatron.nemo_bridge.models.hf_pretrained.state import ( - SafeTensorsStateSource, - StateDict, - StateSource, -) - - -class PreTrainedBase(ABC): - """ - Abstract base class for all pretrained models. - - This class provides a generic mechanism for managing model artifacts - (e.g., config, tokenizer) with lazy loading. Subclasses that are - decorated with `@dataclass` can define artifacts as fields with metadata - specifying a loader method. The `model` itself is handled via a - dedicated property that relies on the abstract `_load_model` method. - - Example: - @dataclass - class MyModel(PreTrainedBase): - config: AutoConfig = field( - init=False, - metadata=artifact(loader="_load_config") - ) - - def _load_model(self) -> "PreTrainedModel": - # Implementation for the loading logic - ... - """ - - model_name_or_path: Union[str, Path] - ARTIFACTS: ClassVar[List[str]] = [] - OPTIONAL_ARTIFACTS: ClassVar[List[str]] = [] - - def __init__(self, **kwargs): - self._state_dict_accessor: Optional[StateDict] = None - self.init_kwargs = kwargs - # Store the original source path for custom modeling file preservation - self._original_source_path: Optional[Union[str, Path]] = None - - def get_artifacts(self) -> Dict[str, str]: - """Get the artifacts dictionary mapping artifact names to their attribute names.""" - return {artifact: f"_{artifact}" for artifact in self.ARTIFACTS} - - def _copy_custom_modeling_files( - self, source_path: Union[str, Path], target_path: Union[str, Path] - ) -> None: - """Copy custom modeling files from source to target directory. - - This preserves custom modeling files that were used during model loading - with trust_remote_code=True, ensuring the saved model can be loaded properly. - - Args: - source_path: Source directory containing custom modeling files - target_path: Target directory to copy files to - """ - source_path = Path(source_path) - target_path = Path(target_path) - - # Common custom modeling file patterns - custom_file_patterns = ["*.py", "*.json", "*.jpeg", "*.png", "*.jpg", "*.mp4"] - copied_files = [] - - # First, try to copy from local directory if it exists - if source_path.exists() and source_path.is_dir(): - for pattern in custom_file_patterns: - for file_path in source_path.glob(pattern): - if file_path.is_file(): - target_file = target_path / file_path.name - try: - shutil.copy2(file_path, target_file) - copied_files.append(file_path.name) - except (OSError, IOError): - # Silently skip files that can't be copied - pass - - # If no files were copied and source_path looks like a HuggingFace Hub ID, - # try to download the custom modeling files directly from the Hub - if not copied_files and "/" in str(source_path) and not source_path.exists(): - try: - from huggingface_hub import hf_hub_download, list_repo_files - - # Get list of Python files in the repository - repo_files = list_repo_files(str(source_path)) - print("repo_files: ", repo_files) - for file in repo_files: - # Check if it matches our custom file patterns - if any(fnmatch(file, pattern) for pattern in custom_file_patterns): - try: - downloaded_file = hf_hub_download( - repo_id=str(source_path), - filename=file, - local_dir=target_path, - local_dir_use_symlinks=False, - ) - copied_files.append(file) - except Exception as e: - print("Error downloading file: ", e, "Skipping file...") - # Silently skip files that can't be downloaded - pass - - except Exception as e: - print( - "Error downloading custom modeling files: ", - e, - "Skipping custom modeling files...", - ) - # If HuggingFace Hub operations fail, silently continue - pass - - return copied_files - - def save_artifacts(self, save_directory: Union[str, Path]): - """ - Saves all loaded, generic artifacts that have a `save_pretrained` method - to the specified directory. Note: This does not save the `model` attribute. - - If the model was loaded with trust_remote_code=True, this method will also - attempt to preserve any custom modeling files to ensure the saved model - can be loaded properly. - """ - save_path = Path(save_directory) - save_path.mkdir(parents=True, exist_ok=True) - - _ = getattr(self, "config") # trigger lazy loading of config - if hasattr(self, "_config") and self._config is not None: - self._config.save_pretrained(save_path) - - # Iterate over required artifacts to save them in a predictable order - # for name in self.ARTIFACTS: - # # Access the public property to trigger lazy loading if needed - # artifact = getattr(self, name) - # attr_name = f"_{name}" - # if hasattr(self, attr_name): - # if artifact is not None and hasattr(artifact, "save_pretrained"): - # artifact.save_pretrained(save_path) - - # Iterate over optional artifacts - only save if they exist and have save_pretrained - for name in self.OPTIONAL_ARTIFACTS: - artifact = getattr(self, name, None) - if artifact is not None and hasattr(artifact, "save_pretrained"): - artifact.save_pretrained(save_path) - - # Preserve custom modeling files if trust_remote_code was used - if hasattr(self, 'trust_remote_code') and self.trust_remote_code: - # Try original source path first, then fallback to model_name_or_path - source_paths = [] - if hasattr(self, '_original_source_path') and self._original_source_path: - source_paths.append(self._original_source_path) - if hasattr(self, 'model_name_or_path') and self.model_name_or_path: - source_paths.append(self.model_name_or_path) - - for source_path in source_paths: - copied_files = self._copy_custom_modeling_files(source_path, save_path) - if copied_files: - # Successfully copied files, no need to try other paths - break - - @abstractmethod - def _load_model(self) -> PreTrainedModel: - """Subclasses must implement this to load the main model.""" - pass - - @abstractmethod - def _load_config(self) -> AutoConfig: - """Subclasses must implement this to load the model config.""" - pass - - @property - def model(self) -> PreTrainedModel: - """Lazily loads and returns the underlying model.""" - if not hasattr(self, "_model"): - self._model = self._load_model() - return self._model - - @model.setter - def model(self, value: PreTrainedModel): - """Manually set the model.""" - self._model = value - - @property - def config(self) -> AutoConfig: - """Lazy load and return the model config.""" - if not hasattr(self, "_config"): - self._config = self._load_config() - return self._config - - @config.setter - def config(self, value: AutoConfig): - """Set the config manually.""" - self._config = value - - @property - def state(self) -> StateDict: - """ - Get the state dict accessor for pandas-like querying. - - This accessor can be backed by either a fully loaded model in memory - or a ".safetensors" checkpoint on disk, enabling lazy loading of tensors. - - Examples: - model.state() # Get full state dict - model.state["key"] # Get single entry - model.state[["key1", "key2"]] # Get multiple entries - model.state["*.weight"] # Glob pattern - model.state.regex(r".*\\.bias$") # Regex pattern - """ - if self._state_dict_accessor is None: - source: Optional[Union[Dict[str, torch.Tensor], StateSource]] = None - # Prioritize the loaded model's state_dict if available - if hasattr(self, "_model") and self._model is not None: - source = self.model.state_dict() - elif hasattr(self, "model_name_or_path") and self.model_name_or_path: - source = SafeTensorsStateSource(self.model_name_or_path) - - if source is None: - raise ValueError( - "Cannot create StateDict accessor: model is not loaded and model_name_or_path is not set." - ) - self._state_dict_accessor = StateDict(source) - return self._state_dict_accessor diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py index e5383f11fc..c8e3c35006 100644 --- a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 # Copyright (c) 2025, BAAI. All rights reserved. -# -# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge import sys @@ -10,109 +8,9 @@ import torch -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - GenerationConfig, - PreTrainedTokenizer, -) -from transformers.generation.utils import GenerateOutput - -from megatron.nemo_bridge.models.hf_pretrained.base import PreTrainedBase -from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import ( - safe_load_config_with_retry, -) - -# Python 3.12+ supports PEP 692 (TypedDict Unpack) -if sys.version_info >= (3, 12): - from typing import TypedDict, Unpack -else: - from typing_extensions import TypedDict, Unpack - - -CausalLMType = TypeVar("CausalLMType", bound=AutoModelForCausalLM) - - -class PreTrainedCausalLM(PreTrainedBase, Generic[CausalLMType]): - """ - A generic class for Pretrained Causal Language Models with lazy loading. - - Allows type-safe access to specific model implementations like LlamaForCausalLM. - - Examples: - Basic usage with lazy loading: - >>> from mbridge.pretrained import PreTrainedCausalLM - >>> # Create instance - no model loading happens yet - >>> model = PreTrainedCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> # Components are loaded on first access - >>> config = model.config # Loads config - >>> tokenizer = model.tokenizer # Loads tokenizer - >>> # Generate text - model is loaded here - >>> inputs = model.encode("Hello, how are you?") - >>> outputs = model.generate(**inputs, max_length=50) - >>> print(model.decode(outputs[0], skip_special_tokens=True)) - - Using specific model types with type hints: - >>> from transformers import LlamaForCausalLM - >>> from mbridge.pretrained import PreTrainedCausalLM - >>> # Type-safe access to Llama-specific features - >>> llama_model: PreTrainedCausalLM[LlamaForCausalLM] = PreTrainedCausalLM.from_pretrained( - ... "meta-llama/Llama-2-7b-chat-hf", - ... torch_dtype=torch.float16, - ... device="cuda" - ... ) - >>> # Access Llama-specific attributes - >>> model_instance = llama_model.model # Type is LlamaForCausalLM - - Loading with custom configurations: - >>> # Load model with specific settings - >>> model = PreTrainedCausalLM.from_pretrained( - ... "gpt2", - ... device="cuda:0", - ... torch_dtype=torch.bfloat16, - ... attn_implementation="flash_attention_2", - ... load_in_8bit=True - ... ) - >>> # Override generation config - >>> from transformers import GenerationConfig - >>> model.generation_config = GenerationConfig( - ... max_length=100, - ... temperature=0.7, - ... top_p=0.9, - ... do_sample=True - ... ) - - Manual component management: - >>> # Create empty instance - >>> model = PreTrainedCausalLM() - >>> # Manually set components - >>> from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM - >>> model.config = AutoConfig.from_pretrained("microsoft/phi-2") - >>> model.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") - >>> model.model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") - >>> # Save all components - >>> model.save_artifacts("./my_model") - - Batch processing example: - >>> # Process multiple prompts - >>> prompts = [ - ... "The capital of France is", - ... "Machine learning is", - ... "Python programming language was created by" - ... ] - >>> # Encode all prompts - >>> inputs = model.encode(prompts, padding=True, truncation=True) - >>> # Generate completions - >>> outputs = model.generate(**inputs, max_new_tokens=20) - >>> # Decode results - >>> for i, output in enumerate(outputs): - ... print(f"Prompt {i+1}: {model.decode(output, skip_special_tokens=True)}") - """ - - ARTIFACTS = ["tokenizer"] - OPTIONAL_ARTIFACTS = ["generation_config"] +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM as OriginalPreTrainedCausalLM +class PreTrainedCausalLM(OriginalPreTrainedCausalLM): def __init__( self, model_name_or_path: Optional[Union[str, Path]] = None, @@ -121,537 +19,49 @@ def __init__( trust_remote_code: bool = False, **kwargs, ): - """ - Initialize a Pretrained Causal LM with lazy loading. - - Args: - model_name_or_path: HuggingFace model identifier or local path - device: Device to load model on (e.g., 'cuda', 'cpu') - torch_dtype: Data type to load model in (e.g., torch.float16) - trust_remote_code: Whether to trust remote code when loading - **kwargs: Additional arguments passed to from_pretrained methods - """ - self._model_name_or_path = model_name_or_path - # self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.device = "cpu" - self.torch_dtype = torch_dtype - self.trust_remote_code = trust_remote_code - super().__init__(**kwargs) - # Store the original source path for custom modeling file preservation - if model_name_or_path and trust_remote_code: - self._original_source_path = model_name_or_path - - def _load_model(self) -> CausalLMType: - """Load the model.""" - if self.model_name_or_path is None: - raise ValueError("model_name_or_path must be provided to load model") - - model_kwargs = {"trust_remote_code": self.trust_remote_code, **self.init_kwargs} - if self.torch_dtype is not None: - model_kwargs["torch_dtype"] = self.torch_dtype - config = getattr(self, "_config", None) - if config is not None: - model_kwargs["config"] = config - - model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **model_kwargs) - model = model.to(self.device) - - generation_config = getattr(self, "_generation_config", None) - if generation_config is not None and hasattr(model, "generation_config"): - model.generation_config = generation_config - return model - - def _load_config(self) -> AutoConfig: - """Load the model config with thread-safety protection.""" - if self.model_name_or_path is None: - raise ValueError("model_name_or_path must be provided to load config") - return safe_load_config_with_retry( - self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs - ) - - def _load_tokenizer(self) -> PreTrainedTokenizer: - """Load the tokenizer.""" - if self.model_name_or_path is None: - raise ValueError("model_name_or_path must be provided to load tokenizer") - tokenizer = AutoTokenizer.from_pretrained( - self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs - ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - def _load_generation_config(self) -> Optional[GenerationConfig]: - """Load the generation config.""" - if self.model_name_or_path is not None: - try: - return GenerationConfig.from_pretrained( - self.model_name_or_path, - trust_remote_code=self.trust_remote_code, - **self.init_kwargs, - ) - except Exception: - # Not all models have generation configs - pass - return None - - @property - def generation_config(self) -> Optional[GenerationConfig]: - """Lazy load and return the generation config.""" - if not hasattr(self, "_generation_config"): - self._generation_config = self._load_generation_config() - return self._generation_config - - @generation_config.setter - def generation_config(self, value: GenerationConfig): - """Set the generation config manually.""" - self._generation_config = value - # Update model's generation config if model is already loaded - model = getattr(self, "_model", None) - if model is not None and hasattr(model, "generation_config"): - model.generation_config = value - - @property - def tokenizer(self) -> PreTrainedTokenizer: - """Lazy load and return the tokenizer.""" - if not hasattr(self, "_tokenizer"): - self._tokenizer = self._load_tokenizer() - return self._tokenizer - - @tokenizer.setter - def tokenizer(self, value: PreTrainedTokenizer): - """Set the tokenizer manually.""" - self._tokenizer = value - - @property - def model_name_or_path(self) -> Optional[Union[str, Path]]: - """Return the model name or path.""" - return self._model_name_or_path - - @property - def has_model(self) -> bool: - """Check if model has been loaded.""" - return hasattr(self, "_model") and self._model is not None - - @property - def model(self) -> CausalLMType: - """Lazy load and return the underlying model.""" - return super().model - - @model.setter - def model(self, value: CausalLMType): - """Set the model manually and move it to the appropriate device.""" - self._model = value - if self._model is not None: - self._model = self._model.to(self.device) - - @classmethod - def from_pretrained( - cls, - model_name_or_path: Union[str, Path], - device: Optional[Union[str, torch.device]] = None, - torch_dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - **kwargs, - ) -> "PreTrainedCausalLM[CausalLMType]": - """ - Create a PreTrainedCausalLM instance for lazy loading. - - Args: - model_name_or_path: HuggingFace model identifier or local path - device: Device to load model on - torch_dtype: Data type to load model in - trust_remote_code: Whether to trust remote code - **kwargs: Additional arguments for from_pretrained methods - - Returns: - PreTrainedCausalLM instance configured for lazy loading - """ - return cls( + super().__init__( model_name_or_path=model_name_or_path, - device=device, + device=self.device, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, - **kwargs, + **kwargs ) + #if hasattr(self, '_model') and self._model is not None: + # self._model.to("cpu") - def generate( - self, input_ids: Optional[torch.LongTensor] = None, **kwargs: Unpack["GenerateKwargs"] - ) -> Union[torch.LongTensor, GenerateOutput]: - """ - Generate text using the underlying language model. - - This method forwards all arguments to the model's generate method, - supporting all generation strategies provided by the transformers library. - - Common parameters include: - inputs (torch.LongTensor, optional): Input token IDs. If not provided, - will generate from the beginning of sequence token. - max_length (int, optional): Maximum length of generated sequence. - Defaults to model's max_length configuration. - min_length (int, optional): Minimum length of generated sequence. - max_new_tokens (int, optional): Maximum number of tokens to generate, - ignoring the number of tokens in the prompt. - do_sample (bool, optional): Whether to use sampling. Defaults to False - (greedy decoding). - temperature (float, optional): Temperature for sampling. Higher values - produce more random outputs. Typical range: 0.1-2.0. - top_p (float, optional): Nucleus sampling threshold. Only tokens with - cumulative probability up to top_p are considered. Range: 0.0-1.0. - top_k (int, optional): Only consider the top k tokens for sampling. - num_beams (int, optional): Number of beams for beam search. 1 means - no beam search. - repetition_penalty (float, optional): Penalty for repeating tokens. - Values > 1.0 discourage repetition. - pad_token_id (int, optional): ID of padding token. - eos_token_id (int or List[int], optional): ID(s) of end-of-sequence token(s). - use_cache (bool, optional): Whether to use past key values to speed up - generation. Defaults to True. - - Returns: - torch.LongTensor or transformers.generation.utils.GenerateOutput: - Generated token IDs. If return_dict_in_generate=True, returns a - GenerateOutput object containing generated sequences and additional - information like scores. - - Examples: - >>> # Basic generation - >>> model = PreTrainedCausalLM.from_pretrained("gpt2") - >>> inputs = model.encode("Hello, how are") - >>> outputs = model.generate(inputs["input_ids"], max_length=20) - >>> print(model.decode(outputs[0])) - - >>> # Generation with sampling - >>> outputs = model.generate( - ... inputs["input_ids"], - ... max_length=50, - ... do_sample=True, - ... temperature=0.8, - ... top_p=0.9 - ... ) - - >>> # Beam search - >>> outputs = model.generate( - ... inputs["input_ids"], - ... max_length=50, - ... num_beams=5, - ... early_stopping=True - ... ) - - Note: - For detailed documentation of all parameters, see the transformers - library documentation for generation methods. + def save_artifacts(self, save_directory: Union[str, Path]): """ - model = self.model # Ensures model is loaded - # Sync generation config if it has been set on the wrapper - generation_config = getattr(self, "_generation_config", None) - if generation_config is not None and hasattr(model, "generation_config"): - model.generation_config = generation_config - return model.generate(input_ids, **kwargs) - - def __call__(self, *args, **kwargs): - """Forward call to model.""" - return self.model(*args, **kwargs) + Saves all loaded, generic artifacts that have a `save_pretrained` method + to the specified directory. Note: This does not save the `model` attribute. - def encode( - self, text: Union[str, List[str]], **kwargs: Unpack["EncodeKwargs"] - ) -> Dict[str, torch.Tensor]: - """ - Encode text into token IDs using the model's tokenizer. - - This method tokenizes input text and returns tensors ready for model input. - The output is automatically moved to the same device as the model. - - Args: - text (str or List[str]): Input text to encode. Can be a single string - or a list of strings for batch encoding. - **kwargs: Additional arguments passed to the tokenizer. Common options: - padding (bool or str, optional): Padding strategy. - - True or 'longest': Pad to longest sequence in batch - - 'max_length': Pad to max_length - - False or 'do_not_pad': No padding (default) - truncation (bool or str, optional): Truncation strategy. - - True or 'longest_first': Truncate to max_length - - 'only_first': Truncate only first sequence (for pairs) - - False: No truncation - max_length (int, optional): Maximum length of returned sequences. - Defaults to model's max_length. - add_special_tokens (bool, optional): Whether to add special tokens - (e.g., [CLS], [SEP]). Defaults to True. - return_attention_mask (bool, optional): Whether to return attention - mask. Defaults to True. - return_token_type_ids (bool, optional): Whether to return token - type IDs (for models like BERT). Defaults to True if model - expects them. - - Returns: - Dict[str, torch.Tensor]: Dictionary containing: - - input_ids: Token IDs tensor of shape (batch_size, sequence_length) - - attention_mask: Attention mask tensor of same shape (if applicable) - - token_type_ids: Token type IDs tensor (if applicable) - Additional keys may be present depending on the tokenizer. - - Examples: - >>> model = PreTrainedCausalLM.from_pretrained("gpt2") - >>> # Single text encoding - >>> tokens = model.encode("Hello world!") - >>> print(tokens["input_ids"].shape) # torch.Size([1, 3]) - - >>> # Batch encoding with padding - >>> texts = ["Hello!", "How are you doing today?"] - >>> tokens = model.encode(texts, padding=True) - >>> print(tokens["input_ids"].shape) # torch.Size([2, 6]) - - >>> # Encoding with truncation - >>> tokens = model.encode( - ... "This is a very long text that might exceed the maximum length", - ... truncation=True, - ... max_length=10 - ... ) - - Note: - The returned tensors are on the same device as the model, ready - for immediate use in forward passes or generation. - """ - # Only set return_tensors default if not provided - if "return_tensors" not in kwargs: - kwargs["return_tensors"] = "pt" - - return self.tokenizer(text, **kwargs).to(self.device) - - def decode( - self, token_ids: Union[int, List[int], torch.Tensor], **kwargs: Unpack["DecodeKwargs"] - ) -> str: - """ - Decode token IDs back into text using the model's tokenizer. - - This method converts token IDs (from model output or encode method) - back into human-readable text. - - Args: - token_ids (int, List[int], or torch.Tensor): Token IDs to decode. - Can be: - - Single token ID (int) - - List of token IDs - - 1D tensor of token IDs - - 2D tensor (will decode the first sequence) - **kwargs: Additional arguments passed to the tokenizer's decode method: - skip_special_tokens (bool, optional): Whether to remove special - tokens (e.g., [PAD], [CLS], [SEP]) from output. Defaults to True. - clean_up_tokenization_spaces (bool, optional): Whether to clean up - tokenization artifacts (extra spaces, etc.). Defaults to True. - - Returns: - str: Decoded text string. - - Examples: - >>> model = PreTrainedCausalLM.from_pretrained("gpt2") - >>> # Encode and decode round-trip - >>> text = "Hello, world!" - >>> tokens = model.encode(text) - >>> decoded = model.decode(tokens["input_ids"][0]) - >>> print(decoded) # "Hello, world!" - - >>> # Decode generated tokens - >>> inputs = model.encode("The weather is") - >>> outputs = model.generate(inputs["input_ids"], max_length=10) - >>> decoded = model.decode(outputs[0]) - >>> print(decoded) # "The weather is nice today..." - - >>> # Decode without special tokens - >>> token_ids = [101, 7592, 1010, 2088, 999, 102] # BERT-style tokens - >>> decoded = model.decode(token_ids, skip_special_tokens=True) - >>> print(decoded) # "Hello, world!" - - >>> # Decode keeping special tokens - >>> decoded = model.decode(token_ids, skip_special_tokens=False) - >>> print(decoded) # "[CLS] Hello, world! [SEP]" - - Note: - If a 2D tensor is provided (batch of sequences), only the first - sequence is decoded. For batch decoding, use tokenizer.batch_decode() - directly or iterate over the sequences. - """ - return self.tokenizer.decode(token_ids, **kwargs) - - def to(self, device: Union[str, torch.device]): - """Move model to specified device.""" - self.device = device - if self.has_model: - self._model = self._model.to(device) - return self - - def half(self): - """Convert model to half precision (float16).""" - if self.has_model: - self._model = self._model.half() - return self - - def float(self): - """Convert model to full precision (float32).""" - if self.has_model: - self._model = self._model.float() - return self - - def save_pretrained(self, save_directory: Union[str, Path]): - """ - Save all components (model, tokenizer, config, generation_config) to a directory. - - This method saves: - - Model weights and config - - Tokenizer files - - Generation config (if available) - - Args: - save_directory: Path to directory where components will be saved + If the model was loaded with trust_remote_code=True, this method will also + attempt to preserve any custom modeling files to ensure the saved model + can be loaded properly. """ save_path = Path(save_directory) save_path.mkdir(parents=True, exist_ok=True) - # Save model if loaded - if hasattr(self, "_model") and self._model is not None: - self._model.save_pretrained(save_path) - - # Use the base class save_artifacts to save config and all artifacts - self.save_artifacts(save_path) - - @property - def dtype(self) -> Optional[torch.dtype]: - """Get model's dtype if loaded.""" - if self.has_model: - try: - return next(self.model.parameters()).dtype - except StopIteration: - return None - return None - - @property - def num_parameters(self) -> Optional[int]: - """Get total number of parameters if model is loaded.""" - if self.has_model: - return sum(p.numel() for p in self.model.parameters()) - return None - - def __repr__(self) -> str: - """Return a string representation of the PreTrainedCausalLM instance.""" - try: - # Access config to trigger lazy loading for a richer repr - _ = self.config - except Exception: - # If loading fails, repr shouldn't crash. - pass - - lines = [f"{self.__class__.__name__}("] - for name, attr_name in sorted(self.get_artifacts().items()): - is_loaded = hasattr(self, attr_name) - artifact_instance = getattr(self, attr_name, None) if is_loaded else None - - type_name = "N/A" - details = "not loaded" - if is_loaded and artifact_instance is not None: - type_name = artifact_instance.__class__.__name__ - if name == "tokenizer": - vocab = getattr(artifact_instance, "vocab_size", "N/A") - details = f"vocab_size={vocab}" - elif name == "config": - m_type = getattr(artifact_instance, "model_type", "N/A") - details = f"model_type={m_type}" - else: - details = "loaded" - lines.append(f" ({name}): {type_name} [{details}]") - - # Manually add model repr - model_repr_content: str - if self.has_model: - model_class_name = self.model.__class__.__name__ - # Assuming self.config is loaded or available here due to earlier attempt - config = self.config - layers = getattr(config, "num_hidden_layers", "N/A") - hidden_size = getattr(config, "hidden_size", "N/A") - model_repr_content = ( - f"{model_class_name} [layers={layers}, hidden_size={hidden_size}, loaded]" - ) - elif "config" in self.__dict__: # Model not loaded, but config is - config = self.config - model_class_name_from_hf_config = "CausalLM" # Default - if hasattr(config, "architectures") and config.architectures: - model_class_name_from_hf_config = config.architectures[0] - elif getattr(config, "model_type", None): - mt = config.model_type - model_class_name_from_hf_config = f"{mt.capitalize()}Model" if mt else "CausalLM" - - details_parts = [] - if getattr(config, "num_hidden_layers", None) is not None: - details_parts.append(f"layers={config.num_hidden_layers}") - if getattr(config, "hidden_size", None) is not None: - details_parts.append(f"hidden_size={config.hidden_size}") - - details_str = ", ".join(details_parts) - status_suffix = "not loaded" - if details_str: - model_repr_content = ( - f"{model_class_name_from_hf_config}({details_str}) [{status_suffix}]" - ) - else: - model_repr_content = f"{model_class_name_from_hf_config} [{status_suffix}]" - else: # Model and Config also not loaded - model_repr_content = "AutoModelForCausalLM [not loaded]" - - lines.append(f" (model): {model_repr_content}") - - lines.sort() - - params_str = f"{self.num_parameters:,}" if self.num_parameters is not None else "N/A" - dtype_str = str(self.dtype).replace("torch.", "") if self.dtype is not None else "N/A" - lines.extend( - [ - f" (parameters): {params_str}", - f" (device): {str(self.device)}", - f" (dtype): {dtype_str}", - ")", - ] - ) - return "\n".join(lines) - - -# TypedDict definitions for method parameters -class GenerateKwargs(TypedDict, total=False): - """TypedDict for generate method parameters.""" - - attention_mask: Optional[torch.Tensor] - max_length: Optional[int] - max_new_tokens: Optional[int] - min_length: Optional[int] - do_sample: Optional[bool] - temperature: Optional[float] - top_k: Optional[int] - top_p: Optional[float] - repetition_penalty: Optional[float] - pad_token_id: Optional[int] - eos_token_id: Optional[Union[int, List[int]]] - bos_token_id: Optional[int] - num_beams: Optional[int] - num_return_sequences: Optional[int] - early_stopping: Optional[bool] - use_cache: Optional[bool] - return_dict_in_generate: Optional[bool] - output_scores: Optional[bool] - output_attentions: Optional[bool] - - -class EncodeKwargs(TypedDict, total=False): - """TypedDict for encode method parameters.""" - - padding: Union[bool, str] - truncation: Union[bool, str] - max_length: Optional[int] - add_special_tokens: bool - return_attention_mask: bool - return_token_type_ids: Optional[bool] - return_tensors: str - - -class DecodeKwargs(TypedDict, total=False): - """TypedDict for decode method parameters.""" - - skip_special_tokens: bool - clean_up_tokenization_spaces: bool + _ = getattr(self, "config") # trigger lazy loading of config + if hasattr(self, "_config") and self._config is not None: + self._config.save_pretrained(save_path) + + for name in self.OPTIONAL_ARTIFACTS: + artifact = getattr(self, name, None) + if artifact is not None and hasattr(artifact, "save_pretrained"): + artifact.save_pretrained(save_path) + + # Preserve custom modeling files if trust_remote_code was used + if hasattr(self, 'trust_remote_code') and self.trust_remote_code: + # Try original source path first, then fallback to model_name_or_path + source_paths = [] + if hasattr(self, '_original_source_path') and self._original_source_path: + source_paths.append(self._original_source_path) + if hasattr(self, 'model_name_or_path') and self.model_name_or_path: + source_paths.append(self.model_name_or_path) + + for source_path in source_paths: + copied_files = self._copy_custom_modeling_files(source_path, save_path) + if copied_files: + # Successfully copied files, no need to try other paths + break diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py deleted file mode 100644 index 9d5e9490aa..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/safe_config_loader.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -""" -Thread-safe configuration loading utilities. - -This module provides utilities for safely loading HuggingFace model configurations -in multi-threaded environments, preventing race conditions that can occur when -multiple threads try to download and cache the same model simultaneously. -""" - -import hashlib -import os -import time - -from pathlib import Path -from typing import Union - -import filelock - -from transformers import AutoConfig -from transformers.configuration_utils import PretrainedConfig - - -def safe_load_config_with_retry( - path: Union[str, Path], - trust_remote_code: bool = False, - max_retries: int = 3, - base_delay: float = 1.0, - **kwargs, -) -> PretrainedConfig: - """ - Thread-safe and process-safe configuration loading with retry logic. - - This function prevents race conditions when multiple threads/processes - try to download and cache the same model configuration simultaneously. - Uses file locking (if filelock is available) to coordinate access across - processes. - - Args: - path: HuggingFace model ID or path to model directory - trust_remote_code: Whether to trust remote code when loading config - max_retries: Maximum number of retry attempts (default: 3) - base_delay: Base delay in seconds for exponential backoff (default: 1.0) - **kwargs: Additional arguments passed to AutoConfig.from_pretrained - - Returns: - PretrainedConfig: The loaded model configuration - - Raises: - ValueError: If config loading fails after all retries - - Environment Variables: - MEGATRON_CONFIG_LOCK_DIR: Override the directory where lock files are created. - Default: ~/.cache/huggingface/ - Useful for multi-node setups where a shared lock directory is needed. - - Example: - >>> config = safe_load_config_with_retry("meta-llama/Meta-Llama-3-8B") - >>> print(config.model_type) - - >>> # With custom retry settings - >>> config = safe_load_config_with_retry( - ... "gpt2", - ... max_retries=5, - ... base_delay=0.5, - ... trust_remote_code=True - ... ) - - >>> # Multi-node setup with shared lock directory - >>> import os - >>> os.environ["MEGATRON_CONFIG_LOCK_DIR"] = "/shared/locks" - >>> config = safe_load_config_with_retry("meta-llama/Meta-Llama-3-8B") - """ - last_exception = None - - for attempt in range(max_retries + 1): - try: - # Use file locking for process-safe access - # Create a lock file based on the path hash to avoid conflicts - path_hash = hashlib.md5(str(path).encode()).hexdigest() - - # Allow override of lock directory via environment variable - # This is useful for multi-node setups where a shared lock directory is needed - lock_dir = os.getenv("MEGATRON_CONFIG_LOCK_DIR") - if lock_dir: - lock_file = Path(lock_dir) / f".megatron_config_lock_{path_hash}" - else: - lock_file = ( - Path.home() / ".cache" / "huggingface" / f".megatron_config_lock_{path_hash}" - ) - - lock_file.parent.mkdir(parents=True, exist_ok=True) - - with filelock.FileLock(str(lock_file) + ".lock", timeout=60): - return AutoConfig.from_pretrained( - path, trust_remote_code=trust_remote_code, **kwargs - ) - - except Exception as e: - last_exception = e - - # Don't retry on certain types of errors - error_msg = str(e).lower() - if any( - phrase in error_msg - for phrase in [ - "does not appear to have a file named config.json", - "repository not found", - "entry not found", - "401 client error", - "403 client error", - ] - ): - # Model doesn't exist or access denied, no point retrying - raise ValueError( - f"Failed to load configuration from {path}. " - f"Ensure the path is valid and contains a config.json file. " - f"Error: {e}" - ) from e - - if attempt < max_retries: - # Exponential backoff with jitter - delay = base_delay * (2**attempt) + (time.time() % 1) * 0.1 - time.sleep(delay) - else: - # Final attempt failed - break - - # All retries exhausted - raise ValueError( - f"Failed to load configuration from {path} after {max_retries + 1} attempts. " - f"This might be due to network issues or concurrent access conflicts. " - f"Last error: {last_exception}" - ) from last_exception diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py deleted file mode 100644 index 01c401ec52..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/state.py +++ /dev/null @@ -1,850 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import fnmatch -import json -import re - -from abc import ABC, abstractmethod -from collections import defaultdict -from collections.abc import Mapping -from functools import lru_cache -from pathlib import Path -from typing import Dict, Iterable, List, Optional, Pattern, Tuple, Union, overload - -import torch - - -class StateDict(Mapping[str, torch.Tensor]): - """ - A state dict accessor that provides a unified interface for querying model - checkpoints. - - `StateDict` allows for efficient and flexible access to tensor data from - various sources, such as in-memory dictionaries or directories of - `.safetensors` files. A key feature is its ability to query and load only - the required tensors without loading the entire checkpoint into memory, - making it highly memory-efficient for large models. - - It supports a flexible, pandas-like querying interface that allows for - accessing tensors by exact name, a list of names, glob patterns, or regular - expressions. This makes it easy to inspect and manipulate model - checkpoints. - - Examples: - >>> # Setup an example StateDict from an in-memory dictionary - >>> import torch - >>> import re - >>> d = { - ... "model.layer.0.weight": torch.randn(10, 10), - ... "model.layer.0.bias": torch.randn(10), - ... "model.layer.1.weight": torch.randn(10, 10), - ... "model.layer.1.bias": torch.randn(10), - ... } - >>> state = StateDict(d) - >>> - >>> # 1. Access a single tensor by exact key - >>> state["model.layer.0.weight"].shape - torch.Size([10, 10]) - >>> - >>> # 2. Access multiple tensors with a list of strings - >>> list(state[["model.layer.0.weight", "model.layer.1.weight"]].keys()) - ['model.layer.0.weight', 'model.layer.1.weight'] - >>> - >>> # 3. Access with a glob pattern - >>> sorted(list(state.glob("model.layer.*.bias").keys())) - ['model.layer.0.bias', 'model.layer.1.bias'] - >>> - >>> # 4. Access with a compiled regex pattern - >>> regex = re.compile(r"model\\\\.layer\\\\.0\\\\..*") - >>> sorted(list(state[regex].keys())) - ['model.layer.0.bias', 'model.layer.0.weight'] - - The same querying flexibility applies to checkpoints on disk. The following - is a conceptual example of using `StateDict` with a `SafetensorsStateSource` - to query a sharded checkpoint without loading all of it into memory. - - .. code-block:: python - - # Assume SafetensorsStateSource is available - # from megatron.nemo_bridge.models.state import SafetensorsStateSource - - # Imagine a directory 'my_model_checkpoint/' with sharded weights. - state_from_disk = StateDict(SafetensorsStateSource('my_model_checkpoint/')) - - # You can query it just like the in-memory dictionary. Only the required - # tensors (e.g., all weight tensors) will be loaded from disk. - weights = state_from_disk.glob("model.layer.*.weight") - """ - - source: "StateSource" - - def __init__(self, source: Dict[str, torch.Tensor] | "StateSource"): - """ - Initializes the StateDict query accessor. - - Args: - source: The source of the tensor data. This can be a standard - Python dictionary mapping tensor names to `torch.Tensor` objects, - or an instance of a `StateSource` subclass (e.g., - `SafetensorsStateSource`) for more advanced, out-of-memory - access. - """ - if isinstance(source, dict): - source = DictStateSource(source) - - if not isinstance(source, StateSource): - raise TypeError(f"StateDict source must be a dict or a StateSource, got {type(source)}") - - self.source = source - - def _get_all_keys(self) -> List[str]: - """ - Get all available tensor keys from the underlying source. - """ - return self.source.get_all_keys() - - def _load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: - """ - Load specified tensors from the underlying source. - """ - return self.source.load_tensors(keys_to_load) - - def _match_keys(self, pattern: Union[str, Pattern]) -> List[str]: - """Match keys against a glob pattern or regex.""" - all_keys = self._get_all_keys() - - if isinstance(pattern, Pattern): - # Regex pattern - return [k for k in all_keys if pattern.search(k)] - elif "*" in pattern or "?" in pattern or "[" in pattern: - # Glob pattern - return [k for k in all_keys if fnmatch.fnmatch(k, pattern)] - else: - # Exact match - return [pattern] if pattern in all_keys else [] - - @overload - def __getitem__(self, key: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... - - @overload - def __getitem__(self, key: List[str]) -> Dict[str, torch.Tensor]: ... - - @overload - def __getitem__(self, key: Pattern) -> Dict[str, torch.Tensor]: ... - - def __getitem__( - self, key: Union[str, List[str], Pattern] - ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - """ - Accesses state dict entries using various key types. - - This method allows for retrieving tensors using: - - A single string for an exact key match. - - A list of strings for multiple exact key matches. - - A string with glob-style wildcards (`*`, `?`, `[]`). - - A compiled regular expression object. - - Args: - key: A single key string, a list of keys, a glob pattern string, or a - compiled regular expression. - - Returns: - - A single `torch.Tensor` if `key` is a string that matches exactly one key - and does not contain wildcards. - - A `Dict[str, torch.Tensor]` for all other cases (list of keys, glob - pattern, or regex), mapping the matched keys to their corresponding - tensors. - - Raises: - KeyError: If the key (or any key in a list) is not found, or if a - pattern matches no keys. - - Examples: - >>> d = { - ... "model.embed_tokens.weight": torch.randn(10, 1), - ... "model.layers.0.mlp.weight": torch.randn(10, 1), - ... "model.layers.0.self_attn.q_proj.weight": torch.randn(10, 1), - ... "lm_head.weight": torch.randn(10, 1), - ... } - >>> state = StateDict(d) - >>> - >>> # Exact match (returns a single tensor) - >>> tensor = state["model.embed_tokens.weight"] - >>> isinstance(tensor, torch.Tensor) - True - >>> - >>> # List of keys (returns a dict of tensors) - >>> tensors = state[["model.embed_tokens.weight", "lm_head.weight"]] - >>> sorted(tensors.keys()) - ['lm_head.weight', 'model.embed_tokens.weight'] - >>> - >>> # Glob pattern (returns a dict of tensors) - >>> layer_0_weights = state["model.layers.0.*.weight"] - >>> sorted(layer_0_weights.keys()) - ['model.layers.0.mlp.weight', 'model.layers.0.self_attn.q_proj.weight'] - >>> - >>> # Regex pattern (returns a dict of tensors) - >>> import re - >>> attn_weights = state[re.compile(r".*self_attn.*")] - >>> list(attn_weights.keys()) - ['model.layers.0.self_attn.q_proj.weight'] - """ - if isinstance(key, Pattern): - matched_keys = self._match_keys(key) - if not matched_keys: - raise KeyError(f"No keys match regex pattern: {key.pattern}") - return self._load_tensors(matched_keys) - elif isinstance(key, str): - if "*" in key or "?" in key or "[" in key: - matched_keys = self._match_keys(key) - if not matched_keys: - raise KeyError(f"No keys match pattern: {key}") - return self._load_tensors(matched_keys) - else: - if key not in self._get_all_keys(): - raise KeyError(f"Key not found: {key}") - return self._load_tensors([key])[key] - elif isinstance(key, list): - all_keys_set = set(self._get_all_keys()) - missing_keys = [k for k in key if k not in all_keys_set] - if missing_keys: - raise KeyError(f"Keys not found: {missing_keys}") - return self._load_tensors(key) - else: - raise TypeError(f"Key must be str, list of str, or compiled regex, got {type(key)}") - - def regex(self, pattern: str) -> Dict[str, torch.Tensor]: - """ - Queries the state dict with a regular expression pattern. - - This is a convenience method that compiles the pattern string and uses it - to retrieve all matching tensors. - - Args: - pattern: The regular expression string to match against tensor keys. - - Returns: - A dictionary mapping matching tensor names to their `torch.Tensor` objects. - - Examples: - >>> d = { - ... "model.layers.0.self_attn.weight": torch.randn(1, 1), - ... "model.layers.1.self_attn.weight": torch.randn(1, 1), - ... "model.layers.1.mlp.weight": torch.randn(1, 1) - ... } - >>> state = StateDict(d) - >>> # Get all attention-related weights - >>> attention_weights = state.regex(r"model\\.layers\\.\\d+\\.self_attn.*") - >>> sorted(attention_weights.keys()) - ['model.layers.0.self_attn.weight', 'model.layers.1.self_attn.weight'] - """ - return self[re.compile(pattern)] - - def glob(self, pattern: str) -> Dict[str, torch.Tensor]: - """ - Queries the state dict with a glob pattern. - - This is a convenience method for pattern matching using Unix shell-style - wildcards. - - Args: - pattern: The glob pattern string to match against tensor keys. - - Returns: - A dictionary mapping matching tensor names to their `torch.Tensor` objects. - - Examples: - >>> d = { - ... "model.layers.0.mlp.weight": torch.randn(1, 1), - ... "model.layers.0.mlp.bias": torch.randn(1, 1), - ... "model.layers.1.mlp.weight": torch.randn(1, 1) - ... } - >>> state = StateDict(d) - >>> # Get all mlp weights and biases from the first layer - >>> layer_0_mlp = state.glob("model.layers.0.mlp.*") - >>> sorted(layer_0_mlp.keys()) - ['model.layers.0.mlp.bias', 'model.layers.0.mlp.weight'] - """ - return self[pattern] - - def __call__(self) -> Dict[str, torch.Tensor]: - """ - Loads and returns the entire state dict as a dictionary. - - Note: - This method loads all tensors from the source into memory. For large - models, this can be memory-intensive. Prefer using pattern-based - or single-key lookups for more efficient access if you only need a - subset of the state dict. - - Returns: - A dictionary containing all tensor names and their corresponding - `torch.Tensor` objects. - """ - all_keys = self._get_all_keys() - return self._load_tensors(all_keys) - - def keys(self) -> List[str]: - """Get all state dict keys.""" - return self._get_all_keys() - - def items(self) -> List[tuple]: - """Get all state dict items.""" - return list(self().items()) - - def __contains__(self, key: str) -> bool: - """Check if a key exists in the state dict.""" - return key in self._get_all_keys() - - def __repr__(self) -> str: - """String representation.""" - try: - num_params = len(self) - return f"" - except Exception: - return "" - - def get(self, key: str, default=None) -> Optional[torch.Tensor]: - """ - Gets a tensor from the state dict. - Returns `default` if the key is not found. - Note: This method is for single key lookup and does not support patterns. - """ - if key in self._get_all_keys(): - return self._load_tensors([key])[key] - return default - - def __iter__(self) -> Iterable[str]: - """Iterate over state dict keys.""" - return iter(self.keys()) - - def __len__(self) -> int: - """Get number of entries in the state dict.""" - return len(self.keys()) - - def has_glob(self, pattern: str) -> bool: - """ - Efficiently checks if any tensor key matches the given glob pattern. - This is forwarded to the underlying StateSource which may have an - optimized implementation that avoids iterating over all keys. - - Args: - pattern: The glob pattern to match against tensor keys. - - Returns: - True if a matching key is found, False otherwise. - """ - return self.source.has_glob(pattern) - - -class StateSource(ABC, Mapping[str, torch.Tensor]): - """ - Abstract base class for a source of model state. - - This class defines a standard interface for `StateDict` to access tensor - data, abstracting away the details of how and where the data is stored. - Subclasses can implement loading from different storage backends, such as - in-memory dictionaries or files on disk. This allows `StateDict` to handle - various checkpoint formats in a uniform way. - """ - - @abstractmethod - def get_all_keys(self) -> List[str]: - """Returns a list of all available tensor keys in the source.""" - pass - - @abstractmethod - def load_tensors(self, keys: List[str]) -> Dict[str, torch.Tensor]: - """Loads the specified tensors from the source.""" - pass - - def __getitem__(self, key: str) -> torch.Tensor: - """Loads a single tensor by key.""" - tensors = self.load_tensors([key]) - if key not in tensors: - raise KeyError(f"Key not found in source: {key}") - return tensors[key] - - def __iter__(self) -> Iterable[str]: - """Iterates over all tensor keys.""" - return iter(self.get_all_keys()) - - def __len__(self) -> int: - """Returns the total number of tensors in the source.""" - return len(self.get_all_keys()) - - def has_glob(self, pattern: str) -> bool: - """ - Checks if any tensor key matches the given glob pattern. - This default implementation is not efficient for all sources, as it may - load all keys. Subclasses should override this method if a more - performant implementation is available. - """ - import fnmatch - - for key in self.get_all_keys(): - if fnmatch.fnmatch(key, pattern): - return True - return False - - -class DictStateSource(StateSource): - """ - A state source backed by an in-memory Python dictionary. - - This is the simplest `StateSource` implementation. It's used when the entire - model state dict is already loaded into a dictionary in memory. - - Args: - state_dict: A dictionary mapping tensor names (str) to `torch.Tensor` objects. - """ - - def __init__(self, state_dict: Dict[str, torch.Tensor]): - self._dict = state_dict - self._keys_cache: Optional[List[str]] = None - - def get_all_keys(self) -> List[str]: - if self._keys_cache is None: - self._keys_cache = sorted(list(self._dict.keys())) - return self._keys_cache - - def load_tensors(self, keys: List[str]) -> Dict[str, torch.Tensor]: - return {key: self._dict[key] for key in keys if key in self._dict} - - -class SafeTensorsStateSource(StateSource): - """ - A state source backed by a directory of .safetensors files. - - This source is designed for efficiently loading tensors from checkpoints saved - in the Safetensors format, which is common for large models that are often - "sharded" into multiple files. - - It can handle two common scenarios: - 1. A directory containing multiple `.safetensors` files. - 2. A directory containing a `model.safetensors.index.json` file, which maps - tensor names to the specific `.safetensors` file they reside in. This is - the standard format used by Hugging Face Transformers. - - Using this source allows `StateDict` to query for tensor keys and load only - the necessary files and tensors from disk, avoiding high memory usage. - - Args: - path: The path to the directory containing the `.safetensors` files - and/or the index file. Can also be a Hugging Face Hub model ID. - """ - - def __init__(self, path: Union[str, Path]): - self.model_name_or_path = path - self._resolved_path_cache: Optional[Path] = None - self._keys_cache: Optional[List[str]] = None - self._key_to_filename_map_cache: Optional[Dict[str, str]] = None - - @property - def path(self) -> Path: - """ - The local path to the checkpoint files. - If the initial path is a Hugging Face Hub model ID, this property - will handle downloading the necessary files and return the local - cache path. - """ - if self._resolved_path_cache is None: - self._resolved_path_cache = self._resolve_path(self.model_name_or_path) - return self._resolved_path_cache - - @property - def key_to_filename_map(self) -> Dict[str, str]: - """ - Provides a mapping from tensor keys to the safetensor filename they - are stored in. - - This map is constructed either from `model.safetensors.index.json` if - it exists, or by scanning all `.safetensors` files in the directory. - The result is cached for efficiency. - """ - if self._key_to_filename_map_cache is not None: - return self._key_to_filename_map_cache - - # First, try to load from the index file. - key_map = self._cached_get_key_to_filename_map(self.path) - if key_map: - self._key_to_filename_map_cache = key_map - return key_map - - # If no index, scan the directory. - import os - - from glob import glob as file_glob - - from safetensors import safe_open - - key_map = {} - safetensor_files = file_glob(str(self.path / "*.safetensors")) - for file_path in safetensor_files: - filename = os.path.basename(file_path) - try: - with safe_open(file_path, framework="pt", device="cpu") as f: - for key in f.keys(): - if key in key_map: - # This is an issue. Same key in multiple files, and no index. - # How to resolve ambiguity? Let's just warn and overwrite. Last one wins. - print( - f"Warning: duplicate key '{key}' found in '{filename}' and '{key_map[key]}'. Using '{filename}'." - ) - key_map[key] = filename - except Exception as e: - # Can be not a safetensor file, etc. - print(f"Warning: could not open {filename} as a safetensors file: {e}") - - self._key_to_filename_map_cache = key_map - return key_map - - @staticmethod - def _resolve_path(model_name_or_path: Union[str, Path]) -> Path: - """ - Resolves a model name or path to a local directory. - If the path is not a local directory, it is treated as a Hugging - Face Hub model ID, and the corresponding files are downloaded. - """ - local_path = Path(model_name_or_path) - if local_path.is_dir(): - return local_path - - try: - from huggingface_hub import snapshot_download - from huggingface_hub.utils import HfHubHTTPError - - # Not a local directory, so we assume it's a model ID - # on the Hugging Face Hub. - return Path( - snapshot_download( - repo_id=str(model_name_or_path), - allow_patterns=["*.safetensors", "model.safetensors.index.json"], - # Ignore other large files. - ignore_patterns=["*.bin", "*.pt", "*.pth"], - ) - ) - except (ImportError, HfHubHTTPError, ValueError): - # If huggingface_hub is not installed, or if it's not a - # valid model ID, we return the original path and let the - # subsequent logic handle the file not found error. - return local_path - - def get_all_keys(self) -> List[str]: - if self._keys_cache is not None: - return self._keys_cache - - from glob import glob as file_glob - - from safetensors import safe_open - - all_keys = set() - key_to_filename_map = self.key_to_filename_map - if key_to_filename_map: - all_keys.update(key_to_filename_map.keys()) - - if not all_keys: - safetensor_files = file_glob(str(self.path / "*.safetensors")) - if not safetensor_files and not key_to_filename_map: - raise FileNotFoundError( - f"No .safetensors files or index found in {self.model_name_or_path}" - ) - for safetensor_file in safetensor_files: - with safe_open(safetensor_file, framework="pt", device="cpu") as f: - all_keys.update(f.keys()) - - self._keys_cache = sorted(list(all_keys)) - return self._keys_cache - - def load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: - if not keys_to_load: - return {} - - from glob import glob as file_glob - - from safetensors import safe_open - - loaded_tensors = {} - remaining_keys = set(keys_to_load) - key_to_filename_map = self.key_to_filename_map - - if key_to_filename_map: - file_to_keys_map = defaultdict(list) - for key in list(remaining_keys): - if key in key_to_filename_map: - filename = key_to_filename_map[key] - file_to_keys_map[filename].append(key) - - for filename, keys_in_file in file_to_keys_map.items(): - file_path = self.path / filename - if file_path.exists(): - with safe_open(file_path, framework="pt", device="cpu") as f: - for key in keys_in_file: - if key in f.keys(): - loaded_tensors[key] = f.get_tensor(key) - remaining_keys.discard(key) - - if remaining_keys: - safetensor_files = file_glob(str(self.path / "*.safetensors")) - if not safetensor_files and not key_to_filename_map and not loaded_tensors: - raise FileNotFoundError( - f"No .safetensors files found in {self.model_name_or_path} to load keys: {remaining_keys}" - ) - for safetensor_file_path in safetensor_files: - if not remaining_keys: - break - with safe_open(safetensor_file_path, framework="pt", device="cpu") as f: - current_file_keys = f.keys() - for key in list(remaining_keys): - if key in current_file_keys: - loaded_tensors[key] = f.get_tensor(key) - remaining_keys.remove(key) - - if remaining_keys: - raise KeyError( - f"Keys not found in safetensors from {self.model_name_or_path}: {remaining_keys}" - ) - - return loaded_tensors - - def has_glob(self, pattern: str) -> bool: - """ - Efficiently checks if any tensor key matches the given glob pattern. - - This method avoids loading all tensor keys into memory at once. It scans - the checkpoint index or file headers and returns as soon as a match is - found. - - Args: - pattern: The glob pattern to match against tensor keys. - - Returns: - True if a matching key is found, False otherwise. - """ - import fnmatch - - from glob import glob as file_glob - - from safetensors import safe_open - - key_to_filename_map = self.key_to_filename_map - if key_to_filename_map: - for key in key_to_filename_map.keys(): - if fnmatch.fnmatch(key, pattern): - return True - return False - - # If no index map, scan the files directly. - safetensor_files = file_glob(str(self.path / "*.safetensors")) - if not safetensor_files: - return False - - for safetensor_file in safetensor_files: - try: - with safe_open(safetensor_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if fnmatch.fnmatch(key, pattern): - return True - except Exception: - # Ignore files that are not valid safetensors - continue - - return False - - def save_generator( - self, - generator: Iterable[Tuple[str, torch.Tensor]], - output_path: Union[str, Path], - strict: bool = True, - ): - """ - Saves tensors from a generator to `.safetensors` files, preserving the - original sharding structure in a memory-efficient, streaming fashion. - - This method reads the sharding information (which tensor belongs to which - file) from the source checkpoint. It then consumes a generator of tensors, - buffering them in memory only until a complete file shard can be written to - disk. This approach minimizes peak memory usage compared to collecting all - tensors first. - - If the original checkpoint had a `model.safetensors.index.json` file, a new - one will be created for the saved tensors. - - Args: - generator: An iterable of (tensor_name, tensor) tuples. - output_path: The directory where the new safetensor files and index - will be saved. - strict: If True (default), raises a KeyError if the generator - yields a tensor name not found in the original model's - sharding structure. If False, it prints a warning and - skips the tensor. - """ - # In a distributed environment, only rank 0 should write to disk. - # Other ranks must still exhaust the generator to participate in collectives. - is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() - rank = torch.distributed.get_rank() if is_distributed else 0 - - if rank != 0: - # Other ranks must exhaust the generator to avoid hangs in collectives. - for _ in generator: - pass - return - - # Rank 0 proceeds with saving. - from safetensors.torch import save_file - - output_path = Path(output_path) - output_path.mkdir(parents=True, exist_ok=True) - - key_to_filename_map = self.key_to_filename_map - all_expected_keys = set(key_to_filename_map.keys()) - - if not key_to_filename_map: - buffered_tensors = dict(generator) - if buffered_tensors: - save_file(buffered_tensors, output_path / "model.safetensors") - return - - filename_to_keys_map = defaultdict(set) - for key, filename in key_to_filename_map.items(): - filename_to_keys_map[filename].add(key) - - files_to_save = dict(filename_to_keys_map) - buffered_tensors = {} - all_yielded_keys = set() - all_saved_keys = set() - - for name, tensor in generator: - all_yielded_keys.add(name) - if name not in all_expected_keys: - if strict: - raise KeyError( - f"Tensor '{name}' from generator not found in the original model structure. " - "To ignore, set strict=False." - ) - else: - print( - f"Warning: tensor '{name}' from generator not found in original model structure. Skipping." - ) - continue - - buffered_tensors[name] = tensor - - # Check if any file is complete and can be saved. - # Iterate over a copy of keys since we might modify the dict. - for filename in list(files_to_save.keys()): - keys_for_file = files_to_save[filename] - if keys_for_file.issubset(buffered_tensors.keys()): - # This shard is complete, save it. - tensors_to_save = {key: buffered_tensors[key] for key in keys_for_file} - - output_file_path = output_path / filename - save_file(tensors_to_save, output_file_path) - - # Free memory by removing saved tensors from the buffer. - for key in keys_for_file: - del buffered_tensors[key] - - all_saved_keys.update(keys_for_file) - del files_to_save[filename] - - # --- Final Reporting --- - if files_to_save: - if strict: - print( - "Warning: The following files could not be saved because the generator did not yield all of their tensors:" - ) - else: - print( - "Warning: The following files are different from the source because the generator did not yield all " - "of their tensors. However they are still saved because strict=False." - ) - for filename, keys_for_file in files_to_save.items(): - missing_for_file = keys_for_file - all_yielded_keys - if missing_for_file: - print(f" - {filename}: missing {len(missing_for_file)} tensors:") - for key in sorted(list(missing_for_file)): - print(f" - {key}") - if not strict: - for filename in list(files_to_save.keys()): - keys_for_file = files_to_save[filename] - tensors_to_save = { - key: buffered_tensors[key] - for key in keys_for_file - if key in buffered_tensors - } - # missing_keys = set(keys_for_file) - tensors_to_save.keys() - # if missing_keys: - # print(f" - {filename}: missing {len(missing_keys)} tensors:") - # for key in sorted(list(missing_keys)): - # print(f" - {key}") - output_file_path = output_path / filename - save_file(tensors_to_save, output_file_path) - - # Free memory by removing saved tensors from the buffer. - for key in tensors_to_save.keys(): - del buffered_tensors[key] - - all_saved_keys.update(keys_for_file) - del files_to_save[filename] - - if buffered_tensors: - print( - f"Warning: {len(buffered_tensors)} tensors were yielded but not saved because their corresponding file shards were incomplete." - ) - - # Final check on whether all original tensors were written. - unsaved_keys = all_expected_keys - all_saved_keys - if not unsaved_keys: - extra_keys = all_yielded_keys - all_expected_keys - if extra_keys: - print( - f"\nSuccess: All tensors from the original checkpoint were written. " - f"({len(extra_keys)} extra tensors from generator were ignored as per strict=False)." - ) - else: - print("\nSuccess: All tensors from the original checkpoint were written.") - else: - print( - f"\nError: {len(unsaved_keys)} tensors from the original checkpoint were not written. See warnings above for details." - ) - - # Create index file for the saved shards. - original_index_file = self.path / "model.safetensors.index.json" - if original_index_file.exists(): - with open(original_index_file, "r") as f: - original_index_data = json.load(f) - - new_weight_map = {key: key_to_filename_map[key] for key in all_saved_keys} - - new_index_data = { - "metadata": original_index_data.get("metadata", {}), - "weight_map": new_weight_map, - } - - output_index_file = output_path / "model.safetensors.index.json" - if new_weight_map: - with open(output_index_file, "w") as f: - json.dump(new_index_data, f, indent=4) - - def _get_key_to_filename_map(self) -> Optional[Dict[str, str]]: - return self._cached_get_key_to_filename_map(self.path) - - @staticmethod - @lru_cache(maxsize=None) - def _cached_get_key_to_filename_map( - model_name_or_path: Union[str, Path] - ) -> Optional[Dict[str, str]]: - """Static, cached method to get the key-to-filename map.""" - index_file = Path(model_name_or_path) / "model.safetensors.index.json" - if index_file.exists(): - with open(index_file, "r") as f: - try: - index_data = json.load(f) - if "weight_map" in index_data and isinstance(index_data["weight_map"], dict): - return index_data["weight_map"] - except json.JSONDecodeError: - return None - return None diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py deleted file mode 100644 index 7ad431f6f9..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/vlm.py +++ /dev/null @@ -1,603 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -from pathlib import Path -from typing import Any, Dict, Generic, List, Optional, TypeVar, Union - -import torch - -from transformers import ( - AutoConfig, - AutoImageProcessor, - AutoModel, - AutoProcessor, - AutoTokenizer, - GenerationConfig, - PreTrainedModel, - PreTrainedTokenizer, - ProcessorMixin, -) -from transformers.generation.utils import GenerateOutput - -from megatron.nemo_bridge.models.hf_pretrained.base import PreTrainedBase -from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import ( - safe_load_config_with_retry, -) - -# Type variable for generic model type -VLMType = TypeVar("VLMType", bound=PreTrainedModel) - - -class PreTrainedVLM(PreTrainedBase, Generic[VLMType]): - """ - A generic class for Pretrained Vision-Language Models with lazy loading. - - Allows type-safe access to specific VLM implementations like LlavaForConditionalGeneration. - - Examples: - Basic usage with image and text: - >>> from megatron.nemo_bridge.models.hf_pretrained.vlm import PreTrainedVLM - >>> from PIL import Image - >>> - >>> # Create instance - no model loading happens yet - >>> vlm = PreTrainedVLM.from_pretrained("llava-hf/llava-1.5-7b-hf") - >>> - >>> # Load an image - >>> image = Image.open("cat.jpg") - >>> - >>> # Process image and text together - processor and model load here - >>> inputs = vlm.process_images_and_text( - ... images=image, - ... text="What do you see in this image?" - ... ) - >>> - >>> # Generate response - >>> outputs = vlm.generate(**inputs, max_new_tokens=100) - >>> print(vlm.decode(outputs[0], skip_special_tokens=True)) - - Batch processing with multiple images: - >>> # Process multiple images with questions - >>> images = [Image.open(f"image_{i}.jpg") for i in range(3)] - >>> questions = [ - ... "What is the main object in this image?", - ... "Describe the scene", - ... "What colors do you see?" - ... ] - >>> - >>> # Process batch - >>> inputs = vlm.process_images_and_text( - ... images=images, - ... text=questions, - ... padding=True - ... ) - >>> - >>> # Generate responses - >>> outputs = vlm.generate(**inputs, max_new_tokens=50) - >>> for i, output in enumerate(outputs): - ... print(f"Image {i+1}: {vlm.decode(output, skip_special_tokens=True)}") - - Using specific VLM types with type hints: - >>> from transformers import LlavaForConditionalGeneration - >>> from megatron.nemo_bridge.models.hf_pretrained.vlm import PreTrainedVLM - >>> - >>> # Type-safe access to Llava-specific features - >>> llava: PreTrainedVLM[LlavaForConditionalGeneration] = PreTrainedVLM.from_pretrained( - ... "llava-hf/llava-1.5-7b-hf", - ... torch_dtype=torch.float16, - ... device="cuda" - ... ) - >>> - >>> # Access model-specific attributes - >>> vision_tower = llava.model.vision_tower # Type-safe access - - Text-only generation (for multimodal models that support it): - >>> # Some VLMs can also work with text-only inputs - >>> text_inputs = vlm.encode_text("Explain what a neural network is.") - >>> outputs = vlm.generate(**text_inputs, max_length=100) - >>> print(vlm.decode(outputs[0], skip_special_tokens=True)) - - Custom preprocessing and generation: - >>> # Load with custom settings - >>> vlm = PreTrainedVLM.from_pretrained( - ... "Qwen/Qwen-VL-Chat", - ... trust_remote_code=True, - ... device_map="auto", - ... load_in_4bit=True - ... ) - >>> - >>> # Custom generation config - >>> from transformers import GenerationConfig - >>> vlm.generation_config = GenerationConfig( - ... max_new_tokens=200, - ... temperature=0.8, - ... top_p=0.95, - ... do_sample=True - ... ) - >>> - >>> # Process with custom parameters - >>> inputs = vlm.process_images_and_text( - ... images=image, - ... text="\\nDescribe this image in detail.", - ... max_length=512 - ... ) - - Manual component setup: - >>> # Create empty instance - >>> vlm = PreTrainedVLM() - >>> - >>> # Load components separately - >>> from transformers import AutoProcessor, AutoModel - >>> vlm.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base") - >>> vlm.model = AutoModel.from_pretrained("microsoft/Florence-2-base") - >>> - >>> # Use for various vision tasks - >>> task_prompt = "" # Object detection task - >>> inputs = vlm.process_images_and_text(images=image, text=task_prompt) - >>> outputs = vlm.generate(**inputs) - - Conversational VLM usage: - >>> # Multi-turn conversation with images - >>> conversation = [] - >>> - >>> # First turn - >>> image1 = Image.open("chart.png") - >>> inputs = vlm.process_images_and_text( - ... images=image1, - ... text="What type of chart is this?" - ... ) - >>> response = vlm.generate(**inputs) - >>> conversation.append(("user", "What type of chart is this?")) - >>> conversation.append(("assistant", vlm.decode(response[0]))) - >>> - >>> # Follow-up question - >>> follow_up = "What is the highest value shown?" - >>> # Format conversation history + new question - >>> full_prompt = format_conversation(conversation) + f"\\nUser: {follow_up}" - >>> inputs = vlm.process_images_and_text(images=image1, text=full_prompt) - >>> response = vlm.generate(**inputs) - """ - - ARTIFACTS = ["processor", "tokenizer", "image_processor"] - OPTIONAL_ARTIFACTS = ["generation_config"] - - def __init__( - self, - model_name_or_path: Optional[Union[str, Path]] = None, - device: Optional[Union[str, torch.device]] = None, - torch_dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - **kwargs, - ): - """ - Initialize a Pretrained VLM with lazy loading. - - Args: - model_name_or_path: HuggingFace model identifier or local path - device: Device to load model on (e.g., 'cuda', 'cpu') - torch_dtype: Data type to load model in (e.g., torch.float16) - trust_remote_code: Whether to trust remote code when loading - **kwargs: Additional arguments passed to component loaders - """ - self._model_name_or_path = model_name_or_path - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - self.torch_dtype = torch_dtype - self.trust_remote_code = trust_remote_code - super().__init__(**kwargs) - - def _load_model(self) -> VLMType: - """Lazy load and return the model.""" - if self.model_name_or_path is None: - raise ValueError("model_name_or_path must be provided to load model") - - model_kwargs = {"trust_remote_code": self.trust_remote_code, **self.init_kwargs} - - if self.torch_dtype is not None: - model_kwargs["torch_dtype"] = self.torch_dtype - - # Use provided config if already loaded - config = getattr(self, "_config", None) - if config is not None: - model_kwargs["config"] = config - - # Try AutoModel first for VLMs - model = AutoModel.from_pretrained(self.model_name_or_path, **model_kwargs) - - # Move to device - model = model.to(self.device) - - # Set generation config if available - generation_config = getattr(self, "_generation_config", None) - if generation_config is not None and hasattr(model, "generation_config"): - model.generation_config = generation_config - return model - - def _load_config(self) -> AutoConfig: - """Lazy load and return the model config with thread-safety protection.""" - if self.model_name_or_path is None: - raise ValueError("model_name_or_path must be provided to load config") - - return safe_load_config_with_retry( - self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs - ) - - def _load_processor(self) -> ProcessorMixin: - """Lazy load and return the processor.""" - if self.model_name_or_path is None: - raise ValueError("model_name_or_path must be provided to load processor") - - try: - return AutoProcessor.from_pretrained( - self.model_name_or_path, - trust_remote_code=self.trust_remote_code, - **self.init_kwargs, - ) - except Exception: - # Some VLMs might not have a processor, fall back to manual loading - raise ValueError( - f"Could not load processor for {self.model_name_or_path}. " - "This model might require manual processor setup." - ) - - def _load_tokenizer(self) -> Optional[PreTrainedTokenizer]: - """ - Lazy load and return the tokenizer. - For VLMs, the tokenizer might be included in the processor. - """ - # Check if tokenizer is available through processor first - processor = getattr(self, "_processor", None) - if processor is not None and hasattr(processor, "tokenizer"): - return processor.tokenizer - - # Try to load tokenizer separately - if self.model_name_or_path is not None: - try: - tokenizer = AutoTokenizer.from_pretrained( - self.model_name_or_path, - trust_remote_code=self.trust_remote_code, - **self.init_kwargs, - ) - - # Set padding token if not present - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - except Exception: - # Some VLMs include tokenizer only in processor - pass - return None - - def _load_image_processor(self) -> Optional[Any]: - """ - Lazy load and return the image processor. - For VLMs, the image processor might be included in the processor. - """ - # Check if image processor is available through processor first - processor = getattr(self, "_processor", None) - if processor is not None and hasattr(processor, "image_processor"): - return processor.image_processor - - # Try to load image processor separately - if self.model_name_or_path is not None: - try: - return AutoImageProcessor.from_pretrained( - self.model_name_or_path, - trust_remote_code=self.trust_remote_code, - **self.init_kwargs, - ) - except Exception: - # Some VLMs include image processor only in processor - pass - return None - - def _load_generation_config(self) -> Optional[GenerationConfig]: - """Lazy load and return the generation config.""" - if self.model_name_or_path is not None: - try: - return GenerationConfig.from_pretrained( - self.model_name_or_path, - trust_remote_code=self.trust_remote_code, - **self.init_kwargs, - ) - except Exception: - # Not all models have generation configs - pass - return None - - @property - def model_name_or_path(self) -> Optional[Union[str, Path]]: - """Return the model name or path.""" - return self._model_name_or_path - - @property - def model(self) -> VLMType: - """Lazy load and return the underlying model.""" - if not hasattr(self, "_model"): - self._model = self._load_model() - else: - # Ensure model is on the right device when accessed - if hasattr(self._model, "device") and hasattr(self._model.device, "type"): - current_device = str(self._model.device) - target_device = str(self.device) - if current_device != target_device: - self._model = self._model.to(self.device) - return self._model - - @model.setter - def model(self, value: VLMType): - """Set the model manually.""" - self._model = value - - @property - def processor(self) -> ProcessorMixin: - """Lazy load and return the processor.""" - if not hasattr(self, "_processor"): - self._processor = self._load_processor() - return self._processor - - @processor.setter - def processor(self, value: ProcessorMixin): - """Set the processor manually.""" - self._processor = value - - @property - def tokenizer(self) -> Optional[PreTrainedTokenizer]: - """Lazy load and return the tokenizer.""" - if not hasattr(self, "_tokenizer"): - self._tokenizer = self._load_tokenizer() - return self._tokenizer - - @tokenizer.setter - def tokenizer(self, value: PreTrainedTokenizer): - """Set the tokenizer manually.""" - self._tokenizer = value - - @property - def image_processor(self) -> Optional[Any]: - """Lazy load and return the image processor.""" - if not hasattr(self, "_image_processor"): - self._image_processor = self._load_image_processor() - return self._image_processor - - @image_processor.setter - def image_processor(self, value: Any): - """Set the image processor manually.""" - self._image_processor = value - - @property - def generation_config(self) -> Optional[GenerationConfig]: - """Lazy load and return the generation config.""" - if not hasattr(self, "_generation_config"): - self._generation_config = self._load_generation_config() - return self._generation_config - - @generation_config.setter - def generation_config(self, value: GenerationConfig): - """Set the generation config manually.""" - self._generation_config = value - # Update model's generation config if model is loaded - if ( - hasattr(self, "_model") - and self._model is not None - and hasattr(self._model, "generation_config") - ): - self._model.generation_config = value - - @property - def kwargs(self) -> Dict[str, Any]: - """Additional initialization kwargs.""" - return self.init_kwargs - - @classmethod - def from_pretrained( - cls, - model_name_or_path: Union[str, Path], - device: Optional[Union[str, torch.device]] = None, - torch_dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - **kwargs, - ) -> "PreTrainedVLM[VLMType]": - """ - Create a PreTrainedVLM instance for lazy loading. - - Args: - model_name_or_path: HuggingFace model identifier or local path - device: Device to load model on - torch_dtype: Data type to load model in - trust_remote_code: Whether to trust remote code - **kwargs: Additional arguments for from_pretrained methods - - Returns: - PreTrainedVLM instance configured for lazy loading - """ - return cls( - model_name_or_path=model_name_or_path, - device=device, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - **kwargs, - ) - - def generate(self, **kwargs) -> Union[torch.LongTensor, GenerateOutput]: - """ - Generate sequences using the model. - - Args: - **kwargs: Arguments for the generate method - - Returns: - Generated sequences - """ - return self.model.generate(**kwargs) - - def __call__(self, *args, **kwargs): - """Forward pass through the model.""" - return self.model(*args, **kwargs) - - def encode_text(self, text: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: - """ - Encode text input using the tokenizer. - - Args: - text: Input text or list of texts - **kwargs: Additional tokenizer arguments - - Returns: - Encoded inputs ready for the model - """ - if self.tokenizer is None: - raise ValueError( - "No tokenizer available. Set tokenizer manually or ensure model has one." - ) - return self.tokenizer(text, return_tensors="pt", **kwargs).to(self.device) - - def decode(self, token_ids: torch.Tensor, **kwargs) -> str: - """ - Decode token IDs to text. - - Args: - token_ids: Token IDs to decode - **kwargs: Additional decoding arguments - - Returns: - Decoded text - """ - if self.tokenizer is None: - raise ValueError( - "No tokenizer available. Set tokenizer manually or ensure model has one." - ) - return self.tokenizer.decode(token_ids, **kwargs) - - def process_images_and_text( - self, images: Optional[Any] = None, text: Optional[Union[str, List[str]]] = None, **kwargs - ) -> Dict[str, torch.Tensor]: - """ - Process images and text together using the processor. - - Args: - images: Input images - text: Input text - **kwargs: Additional processor arguments - - Returns: - Processed inputs ready for the model - """ - inputs = self.processor(images=images, text=text, return_tensors="pt", **kwargs) - # Move all tensors in the dict to the device - if isinstance(inputs, dict): - for key, value in inputs.items(): - if hasattr(value, "to"): - inputs[key] = value.to(self.device) - return inputs - - def save_pretrained(self, save_directory: Union[str, Path]): - """ - Save the model and all components to a directory. - - Args: - save_directory: Directory to save to - """ - save_path = Path(save_directory) - save_path.mkdir(parents=True, exist_ok=True) - - # Save model - if hasattr(self, "_model") and self._model is not None: - self._model.save_pretrained(save_path) - - # Save artifacts through base class - self.save_artifacts(save_path) - - def to(self, device: Union[str, torch.device]) -> "PreTrainedVLM[VLMType]": - """ - Move model to a device. - - Args: - device: Target device - - Returns: - Self for chaining - """ - self.device = device - if hasattr(self, "_model") and self._model is not None: - self._model = self._model.to(device) - return self - - def half(self) -> "PreTrainedVLM[VLMType]": - """ - Convert model to half precision. - - Returns: - Self for chaining - """ - if hasattr(self, "_model") and self._model is not None: - self._model = self._model.half() - self.torch_dtype = torch.float16 - return self - - def float(self) -> "PreTrainedVLM[VLMType]": - """ - Convert model to full precision. - - Returns: - Self for chaining - """ - if hasattr(self, "_model") and self._model is not None: - self._model = self._model.float() - self.torch_dtype = torch.float32 - return self - - @property - def dtype(self) -> Optional[torch.dtype]: - """Return the dtype of the model.""" - if hasattr(self, "_model") and self._model is not None: - return next(self._model.parameters()).dtype - return self.torch_dtype - - def num_parameters(self, only_trainable: bool = False) -> int: - """ - Get the number of parameters in the model. - - Args: - only_trainable: Whether to count only trainable parameters - - Returns: - Number of parameters - """ - if not hasattr(self, "_model") or self._model is None: - return 0 - - if only_trainable: - return sum(p.numel() for p in self._model.parameters() if p.requires_grad) - return sum(p.numel() for p in self._model.parameters()) - - def __repr__(self) -> str: - """String representation.""" - parts = [f"{self.__class__.__name__}("] - - if self._model_name_or_path: - parts.append(f" model_name_or_path='{self._model_name_or_path}',") - - parts.append(f" device='{self.device}',") - - if self.torch_dtype: - parts.append(f" torch_dtype={self.torch_dtype},") - - if self.trust_remote_code: - parts.append(f" trust_remote_code={self.trust_remote_code},") - - # Show loaded components - loaded = [] - if hasattr(self, "_model") and self._model is not None: - loaded.append("model") - if hasattr(self, "_processor") and self._processor is not None: - loaded.append("processor") - if hasattr(self, "_tokenizer") and self._tokenizer is not None: - loaded.append("tokenizer") - if hasattr(self, "_config") and self._config is not None: - loaded.append("config") - - if loaded: - parts.append(f" loaded_components={loaded},") - - parts.append(")") - return "\n".join(parts) diff --git a/flagscale/train/megatron/nemo_bridge/models/model_provider.py b/flagscale/train/megatron/nemo_bridge/models/model_provider.py deleted file mode 100644 index d11f868488..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/model_provider.py +++ /dev/null @@ -1,710 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import abc -import os - -from pathlib import Path -from typing import Callable, Generic, TypedDict, TypeVar, Union - -try: - from typing import Unpack -except ImportError: - try: - from typing_extensions import Unpack - except ImportError: - from unittest.mock import MagicMock - - Unpack = MagicMock() - - -from typing import Callable - -import torch - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.distributed import ( - DistributedDataParallel, - DistributedDataParallelConfig, - FullyShardedDataParallel, - TorchFullyShardedDataParallel, -) -from megatron.core.enums import ModelType -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.module import Float16Module, MegatronModule -from megatron.core.utils import get_model_config - -from megatron.nemo_bridge.models.config import from_hf_pretrained, save_hf_pretrained -from megatron.nemo_bridge.utils.common_utils import get_local_rank_preinit -from megatron.nemo_bridge.utils.instantiate_utils import InstantiationMode - -try: - from megatron.core.fp8_utils import correct_amax_history_if_needed -except ImportError: - correct_amax_history_if_needed = None - - -ModelT = TypeVar("ModelT", bound=MegatronModule) - - -class ModelProviderMixin(abc.ABC, Generic[ModelT]): - """A mixin that implements the ModelProvider pattern for Megatron Bridge. - - The ModelProvider pattern solves ecosystem fragmentation by providing a standardized - way to instantiate models. This mixin provides a consistent `provide_distributed_model()` method - that handles the complexity of distributed training setup, along with HuggingFace-inspired - `.from_hf_pretrained()` and `.save_hf_pretrained()` for interoperability. - - For advanced customization, multiple hooks can be registered via `register_pre_wrap_hook` - and `register_post_wrap_hook`. These hooks allow modifying the model before and after - it's wrapped for distributed training (e.g., freezing layers, logging). The composed - hooks can be accessed via the `pre_wrap_hook` and `post_wrap_hook` properties. - - Subclasses must implement the `provide` method to define the model architecture. - """ - - CONFIG_NAME = "mhub_model.json" - DEFAULT_CONFIG_FORMAT = "json" - - @abc.abstractmethod - def provide( - self, - pre_process: bool | None = None, - post_process: bool | None = None, - vp_stage: int | None = None, - ) -> ModelT: - """Abstract method to provide the model instance. - - Subclasses must implement this method to return the specific Megatron model - (e.g., `GPTModel`) with its configuration. This method is called by `get_model` - to obtain the base model before it is wrapped for distributed training. - - Args: - pre_process (bool, optional): Whether to include the embedding layer (used with pipeline parallelism). - post_process (bool, optional): Whether to include the output layer (used with pipeline parallelism). - vp_stage (int, optional): The virtual pipeline stage of the model. - - Returns: - ModelT: The Megatron model instance. - """ - pass - - def provide_distributed_model( - self, - ddp_config: DistributedDataParallelConfig | None = None, - model_type=ModelType.encoder_or_decoder, - overlap_param_gather_with_optimizer_step: bool = False, - fp16: bool | None = None, - bf16: bool | None = None, - use_megatron_fsdp: bool = False, - use_torch_fsdp2: bool = False, - wrap_with_ddp: bool = True, - data_parallel_random_init: bool = True, - use_cpu_initialization: None | bool = False, - init_model_with_meta_device: bool | None = None, - pre_wrap_hook: ( - Union[ - Callable[[list[MegatronModule]], list[MegatronModule]], - list[Callable[[list[MegatronModule]], list[MegatronModule]]], - ] - | None - ) = None, - post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None, - ) -> list[ModelT]: - """Instantiate and wrap the model for distributed training. - - This method retrieves the model from `provide` and sets up the distributed - environment, including data-parallel and model-parallel configurations. - It's the primary entry point for creating a model that's ready for use - in the Megatron ecosystem. - - Args: - ddp_config: Configuration for distributed data parallel. - model_type: Type of model (encoder, decoder, or both). - overlap_param_gather_with_optimizer_step: Whether to overlap param gathering. - fp16: Override FP16 setting. - bf16: Override BF16 setting. - use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel - use_torch_fsdp2: Use PyTorch FSDP2 instead of custom DDP. - wrap_with_ddp: Whether to wrap model with DDP. - data_parallel_random_init: Initialize parameters randomly across data parallel ranks. - use_cpu_initialization: Initialize model on CPU. - init_model_with_meta_device: Initialize model on meta device. - pre_wrap_hook: A single callable or list of callables to modify the model before it's wrapped. - If provided, this will override all hooks registered via `register_pre_wrap_hook`. - If a list is provided, hooks will be executed in order. - post_wrap_hook: A single callable to modify the model after it's wrapped. If provided, - this will override all hooks registered via `register_post_wrap_hook`. - - Returns: - A list containing the wrapped model instance. - """ - if wrap_with_ddp and not ddp_config: - raise ValueError("ddp_config is required when wrap_with_ddp is True") - - if not torch.distributed.is_initialized(): - os.environ["RANK"] = os.environ.get("RANK", "0") - os.environ["WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1") - os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12355") - torch.cuda.set_device(get_local_rank_preinit()) - torch.distributed.init_process_group("nccl") - - if not parallel_state.is_initialized(): - print("Model parallel not initialized, initializing...") - self.initialize_model_parallel(seed=0) - - # Convert list of hooks to a single composed callable - if isinstance(pre_wrap_hook, list): - - def composed_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]: - for hook in pre_wrap_hook: - model = hook(model) - return model - - final_pre_wrap_hook = composed_pre_wrap_hook - else: - final_pre_wrap_hook = pre_wrap_hook or self.pre_wrap_hook - final_post_wrap_hook = post_wrap_hook or self.post_wrap_hook - - model = get_model( - self, - ddp_config=ddp_config, - model_type=model_type, - overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, - fp16=fp16, - bf16=bf16, - use_megatron_fsdp=use_megatron_fsdp, - use_torch_fsdp2=use_torch_fsdp2, - wrap_with_ddp=wrap_with_ddp, - data_parallel_random_init=data_parallel_random_init, - use_cpu_initialization=use_cpu_initialization, - init_model_with_meta_device=init_model_with_meta_device, - pre_wrap_hook=final_pre_wrap_hook, - ) - - if final_post_wrap_hook: - _model = final_post_wrap_hook(model) - if _model is not None: - model = _model - - return model - - def initialize_model_parallel( - self, seed: int | None = None, seed_kwargs: dict | None = None, **model_parallel_kwargs - ) -> None: - """Initializes model parallelism and sets the random seed. - - This is a convenience method that sets up tensor, pipeline, and other - forms of model parallelism based on the attributes of the provider instance. - - Args: - seed: The random seed for model parallel RNG. - seed_kwargs: Additional arguments for `model_parallel_cuda_manual_seed`. - **model_parallel_kwargs: Additional arguments for `parallel_state.initialize_model_parallel`. - """ - if not torch.distributed.is_initialized(): - torch.cuda.set_device(get_local_rank_preinit()) - torch.distributed.init_process_group("nccl") - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size=getattr(self, "tensor_model_parallel_size", 1), - pipeline_model_parallel_size=getattr(self, "pipeline_model_parallel_size", 1), - virtual_pipeline_model_parallel_size=getattr( - self, "virtual_pipeline_model_parallel_size", None - ), - context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, - expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, - expert_tensor_parallel_size=getattr(self, "expert_tensor_parallel_size", None), - **model_parallel_kwargs, - ) - if seed is not None: - model_parallel_cuda_manual_seed(seed, **(seed_kwargs or {})) - - @property - def meta_model(self) -> list[ModelT]: - """Returns the model instantiated on the meta device for inspection. - - This is useful for examining the model architecture without allocating - GPU memory. - """ - return self(wrap_with_ddp=False, init_model_with_meta_device=True) - - @property - def pre_wrap_hook(self) -> Callable[[list[MegatronModule]], list[MegatronModule]] | None: - """A composed callable of all registered pre-wrap hooks. - - This read-only property returns a single function that executes all registered - pre-wrap hooks in order. The hook is applied before the model is passed to the DDP - wrapper and can be used for tasks like freezing layers or altering model structure. - - Use `register_pre_wrap_hook` to add a hook to the execution chain. - - Returns: - A callable that executes all registered pre-wrap hooks in order, or None if no - hooks are registered. - """ - if not hasattr(self, "_pre_wrap_hooks") or not self._pre_wrap_hooks: - return None - - def composed_hook(model: list[MegatronModule]) -> list[MegatronModule]: - for hook in self._pre_wrap_hooks: - model = hook(model) - return model - - return composed_hook - - def register_pre_wrap_hook( - self, hook: Callable[[list[MegatronModule]], list[MegatronModule]], prepend: bool = False - ) -> None: - """Registers a hook to be executed before the model is wrapped. - - The hook should be a callable that accepts a list of `MegatronModule` instances - and returns a (potentially modified) list of `MegatronModule` instances. - - Args: - hook: The hook to register. - prepend: If True, the hook is inserted at the beginning of the execution - chain. Otherwise, it is appended to the end. - """ - if not hasattr(self, "_pre_wrap_hooks"): - self._pre_wrap_hooks = [] - if prepend: - self._pre_wrap_hooks.insert(0, hook) - else: - self._pre_wrap_hooks.append(hook) - - @property - def post_wrap_hook(self) -> Callable[[list[MegatronModule]], list[MegatronModule]] | None: - """A composed callable of all registered post-wrap hooks. - - This read-only property returns a single function that executes all registered - post-wrap hooks in order. The hook is applied after the model has been wrapped by - DDP and is useful for tasks like logging or attaching custom attributes. - - Use `register_post_wrap_hook` to add a hook to the execution chain. - - Returns: - A callable that executes all registered post-wrap hooks in order, or None if no - hooks are registered. - """ - if not hasattr(self, "_post_wrap_hooks") or not self._post_wrap_hooks: - return None - - def composed_hook(model: list[MegatronModule]) -> list[MegatronModule]: - for hook in self._post_wrap_hooks: - model = hook(model) - return model - - return composed_hook - - def register_post_wrap_hook( - self, hook: Callable[[list[MegatronModule]], list[MegatronModule]], prepend: bool = False - ) -> None: - """Registers a hook to be executed after the model is wrapped. - - The hook should be a callable that accepts a list of `MegatronModule` instances - and returns a (potentially modified) list of `MegatronModule` instances. - - Args: - hook: The hook to register. - prepend: If True, the hook is inserted at the beginning of the execution - chain. Otherwise, it is appended to the end. - """ - if not hasattr(self, "_post_wrap_hooks"): - self._post_wrap_hooks = [] - if prepend: - self._post_wrap_hooks.insert(0, hook) - else: - self._post_wrap_hooks.append(hook) - - @classmethod - def from_hf_pretrained( - cls, - pretrained_model_name_or_path: str | Path, - trust_remote_code: bool = False, - mode: InstantiationMode | None = None, - config_name: str | None = None, - **kwargs, - ): - """Load a pretrained model configuration from a directory or HuggingFace Hub. - - This method provides a HuggingFace-inspired interface for loading model - configurations, enabling interoperability. - - Args: - pretrained_model_name_or_path: The path to a local directory or a - HuggingFace model identifier. - trust_remote_code: Whether to trust remote code when loading. - mode: The instantiation mode (e.g., `LENIENT`). - config_name: The name of the configuration file (without extension). - **kwargs: Additional keyword arguments for `from_hf_pretrained`. - - Returns: - An instance of the model provider with the loaded configuration. - """ - if config_name is None: - config_name = cls.CONFIG_NAME.rsplit(".", 1)[0] - if mode is None: - mode = InstantiationMode.LENIENT - return from_hf_pretrained( - cls, - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - mode=mode, - config_name=config_name, - **kwargs, - ) - - def save_hf_pretrained( - self, - save_directory: str | Path, - config_format: str | None = None, - config_name: str | None = None, - **kwargs, - ): - """Save the model configuration to a directory. - - This method provides a HuggingFace-inspired interface for saving model - configurations, enabling interoperability. - - Args: - save_directory: The directory where the configuration will be saved. - config_format: The format for the configuration file (e.g., `json` or `yaml`). - config_name: The name of the configuration file (without extension). - **kwargs: Additional keyword arguments for `save_hf_pretrained`. - """ - if config_name is None: - config_name = self.CONFIG_NAME.rsplit(".", 1)[0] - if config_format is None: - config_format = self.DEFAULT_CONFIG_FORMAT - return save_hf_pretrained( - self, save_directory, config_format=config_format, config_name=config_name, **kwargs - ) - - -class GetModelKwargs(TypedDict, total=False): - """Keyword arguments for the `provide_distributed_model` method. - - Attributes: - ddp_config: Configuration for distributed data parallel. - model_type: Type of model (encoder, decoder, or both). - overlap_param_gather_with_optimizer_step: Whether to overlap param gathering. - fp16: Override FP16 setting. - bf16: Override BF16 setting. - use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel - use_torch_fsdp2: Use PyTorch FSDP2 instead of custom DDP. - wrap_with_ddp: Whether to wrap model with DDP. - data_parallel_random_init: Initialize parameters randomly across data parallel ranks. - use_cpu_initialization: Initialize model on CPU. - init_model_with_meta_device: Initialize model on meta device. - pre_wrap_hook: A single callable or list of callables that overrides all registered pre-wrap hooks. - post_wrap_hook: A single callable that overrides all registered post-wrap hooks. - """ - - ddp_config: DistributedDataParallelConfig | None - model_type: ModelType - overlap_param_gather_with_optimizer_step: bool - fp16: bool | None - bf16: bool | None - use_megatron_fsdp: bool - use_torch_fsdp2: bool - wrap_with_ddp: bool - data_parallel_random_init: bool - use_cpu_initialization: bool | None - init_model_with_meta_device: bool | None - pre_wrap_hook: ( - Union[ - Callable[[list[MegatronModule]], list[MegatronModule]], - list[Callable[[list[MegatronModule]], list[MegatronModule]]], - ] - | None - ) - post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None - - -class ModelParallelKwargs(TypedDict, total=False): - """Model-parallel override kwargs. - - Attributes map to `TransformerConfig`/provider fields that control parallelism. - Only provided values are applied as overrides. - """ - - tensor_model_parallel_size: int - pipeline_model_parallel_size: int - context_parallel_size: int - expert_model_parallel_size: int - expert_tensor_parallel_size: int - moe_extended_tp: bool - sequence_parallel: bool - virtual_pipeline_model_parallel_size: int | None - hierarchical_context_parallel_sizes: list[int] | None - - -def get_model( - model_provider: ModelProviderMixin, - ddp_config: DistributedDataParallelConfig, - model_type=ModelType.encoder_or_decoder, - overlap_param_gather_with_optimizer_step: bool = False, - fp16: bool | None = None, - bf16: bool | None = None, - use_megatron_fsdp: bool = False, - use_torch_fsdp2: bool = False, - wrap_with_ddp: bool = True, - data_parallel_random_init: bool = True, - use_cpu_initialization: None | bool = False, - init_model_with_meta_device: bool | None = None, - pre_wrap_hook: ( - Union[ - Callable[[list[MegatronModule]], list[MegatronModule]], - list[Callable[[list[MegatronModule]], list[MegatronModule]]], - ] - | None - ) = None, -) -> list[MegatronModule]: - """Create and configure a model for distributed training. - - This function handles the complete model creation pipeline including: - - Model instantiation with proper pipeline parallel configuration - - GPU memory allocation - - Mixed precision (FP16/BF16) wrapping - - Float8 tensor correction - - Distributed Data Parallel (DDP) wrapping - - Args: - model_provider: ModelProviderMixin instance that creates the model. - Uses the provide() method with optional pre_process(bool), post_process(bool), - vp_stage(int) arguments for pipeline parallelism - ddp_config: Configuration for distributed data parallel training - model_type: Type of model (encoder, decoder, or encoder_and_decoder) - overlap_param_gather_with_optimizer_step: Whether to overlap parameter - gathering with optimizer step for performance optimization - fp16: Enable FP16 mixed precision training. If None, uses model config - bf16: Enable BF16 mixed precision training. If None, uses model config - use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel - use_torch_fsdp2: Use PyTorch's Fully Sharded Data Parallel v2 - wrap_with_ddp: Whether to wrap the model with DDP - data_parallel_random_init: Whether to use random initialization for - data parallel ranks (vs broadcasting from rank 0) - use_cpu_initialization: Whether to initialize model on CPU to save GPU memory - init_model_with_meta_device: Whether to initialize the model on the meta device - pre_wrap_hook: A callable or list of callables that takes a list of `MegatronModule` - and returns a modified list, or `None` to clear the hook. If a list is provided, - hooks will be executed in order. - - Returns: - list[MegatronModule]: List of model modules. Contains multiple modules - when using virtual pipeline parallelism, otherwise a single module - """ - if fp16: - model_provider.fp16 = fp16 - if bf16: - model_provider.bf16 = bf16 - - model_provider.use_cpu_initialization = ( - use_cpu_initialization if use_cpu_initialization else False - ) - if init_model_with_meta_device: - model_provider.init_model_with_meta_device = True - with torch.device("meta"): - model = _create_model(model_provider, model_type) - else: - model = _create_model(model_provider, model_type) - - if pre_wrap_hook: - if isinstance(pre_wrap_hook, list): - # Execute hooks in order - for hook in pre_wrap_hook: - if not callable(hook): - raise RuntimeError("All elements in pre_wrap_hook list must be callable") - _model = hook(model) - if _model is not None: - model = _model - else: - if not callable(pre_wrap_hook): - raise RuntimeError("pre_wrap_hook must be a callable or a list of callables") - _model = pre_wrap_hook(model) - if _model is not None: - model = _model - - # Set tensor model parallel attributes if not set - # In case pre_wrap_hook augmented the model (e.g. adding PEFT adapters) - for model_module in model: - for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - _print_num_params(model) - - model_config = get_model_config(model[0]) - - # GPU allocation. - # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory - # in the fully_shard function of FSDP2 instead. - if ( - not use_torch_fsdp2 - and not model_config.use_cpu_initialization - and not model_config.init_model_with_meta_device - ): - for model_module in model: - model_module.cuda(torch.cuda.current_device()) - - if model_config.fp16 or model_config.bf16: - model = [Float16Module(model_config, model_module) for model_module in model] - - if correct_amax_history_if_needed is not None: - correct_amax_history_if_needed(model) - - if wrap_with_ddp: - model = _ddp_wrap( - model, - data_parallel_random_init, - ddp_config, - overlap_param_gather_with_optimizer_step, - use_megatron_fsdp=use_megatron_fsdp, - use_torch_fsdp2=use_torch_fsdp2, - ) - - return model - - -def _create_model( - model_provider: ModelProviderMixin, model_type: ModelType -) -> list[MegatronModule]: - """Create model instances with appropriate pipeline parallel configuration. - - Handles virtual pipeline parallelism (VPP) by creating multiple model - instances when needed. Sets pre_process and post_process flags based on - pipeline parallel rank. - - Args: - model_provider: ModelProviderMixin instance that creates the model - model_type: ModelType enum indicating encoder, decoder, or both - - Returns: - list: List of model instances. Multiple instances for VPP, otherwise single - """ - - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None - ): - assert ( - model_type != ModelType.encoder_and_decoder - ), "Interleaved schedule not supported for model with both encoder and decoder" - model = [] - for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): - pre_process = parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) - post_process = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) - this_model = model_provider.provide( - pre_process=pre_process, post_process=post_process, vp_stage=i - ) - this_model.model_type = model_type - model.append(this_model) - else: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - if model_type == ModelType.encoder_and_decoder: - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - rank = parallel_state.get_pipeline_model_parallel_rank() - first_decoder_rank = parallel_state.get_pipeline_model_parallel_decoder_start() - world_size = parallel_state.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == first_decoder_rank - post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1)) - model = model_provider.provide() - else: - model = model_provider.provide(pre_process=pre_process, post_process=post_process) - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - return model - - -def _ddp_wrap( - model: list[MegatronModule], - data_parallel_random_init: bool, - ddp_config: DistributedDataParallelConfig, - overlap_param_gather_with_optimizer_step: bool, - use_megatron_fsdp: bool = False, - use_torch_fsdp2: bool = False, -) -> list[MegatronModule]: - """Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP). - - Args: - model: List of model modules to wrap - use_torch_fsdp2: Whether to use PyTorch FSDP v2 instead of DDP - data_parallel_random_init: Whether to broadcast parameters from rank 0 - ddp_config: Configuration for distributed data parallel - overlap_param_gather_with_optimizer_step: Whether to disable bucketing - for overlapping parameter gathering with optimizer step - - Returns: - list[MegatronModule]: List of DDP/FSDP wrapped model modules - """ - if use_megatron_fsdp: - DP = FullyShardedDataParallel - if use_torch_fsdp2: - raise ValueError( - "Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported." - ) - elif use_torch_fsdp2: - DP = TorchFullyShardedDataParallel - else: - DP = DistributedDataParallel - - model = [ - DP( - config=get_model_config(model_chunk), - ddp_config=ddp_config, - module=model_chunk, - # Turn off bucketing for model_chunk 2 onwards, since communication for these - # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step, - ) - for (model_chunk_idx, model_chunk) in enumerate(model) - ] - - # Broadcast params from data parallel src rank to other data parallel ranks. - if data_parallel_random_init: - for model_module in model: - model_module.broadcast_params() - - return model - - -def _print_num_params(model: list[MegatronModule]) -> None: - """Print the number of parameters in the model on rank 0. - - Only prints on data parallel rank 0 to avoid duplicate output. - Shows parameter count per (tensor parallel, pipeline parallel) rank. - - Args: - model: List of model modules to count parameters from - """ - if ( - parallel_state.get_data_parallel_rank() == 0 - and parallel_state.get_context_parallel_rank() == 0 - ): - print( - " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_pipeline_model_parallel_rank(), - sum( - [ - sum([p.nelement() for p in model_module.parameters()]) - for model_module in model - ] - ), - ), - flush=True, - ) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py b/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py index 34cefc11d9..e5dfd6c221 100644 --- a/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py @@ -1,56 +1,4 @@ # Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge -from megatron.nemo_bridge.models.qwen.qwen2_bridge import Qwen2Bridge # noqa: F401 from megatron.nemo_bridge.models.qwen.qwen3_bridge import Qwen3Bridge # noqa: F401 -from megatron.nemo_bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge # noqa: F401 -from megatron.nemo_bridge.models.qwen.qwen_provider import ( - Qwen2ModelProvider, - Qwen2ModelProvider1P5B, - Qwen2ModelProvider7B, - Qwen2ModelProvider72B, - Qwen2ModelProvider500M, - Qwen3ModelProvider, - Qwen3ModelProvider1P7B, - Qwen3ModelProvider4B, - Qwen3ModelProvider8B, - Qwen3ModelProvider14B, - Qwen3ModelProvider32B, - Qwen3ModelProvider600M, - Qwen3MoEModelProvider, - Qwen3MoEModelProvider30B_A3B, - Qwen3MoEModelProvider235B_A22B, - Qwen25ModelProvider1P5B, - Qwen25ModelProvider3B, - Qwen25ModelProvider7B, - Qwen25ModelProvider14B, - Qwen25ModelProvider32B, - Qwen25ModelProvider72B, - Qwen25ModelProvider500M, -) -__all__ = [ - "Qwen2ModelProvider", - "Qwen2ModelProvider500M", - "Qwen2ModelProvider1P5B", - "Qwen2ModelProvider7B", - "Qwen2ModelProvider72B", - "Qwen25ModelProvider500M", - "Qwen25ModelProvider1P5B", - "Qwen25ModelProvider3B", - "Qwen25ModelProvider7B", - "Qwen25ModelProvider14B", - "Qwen25ModelProvider32B", - "Qwen25ModelProvider72B", - "Qwen3ModelProvider", - "Qwen3ModelProvider600M", - "Qwen3ModelProvider1P7B", - "Qwen3ModelProvider4B", - "Qwen3ModelProvider8B", - "Qwen3ModelProvider14B", - "Qwen3ModelProvider32B", - "Qwen3MoEModelProvider", - "Qwen3MoEModelProvider30B_A3B", - "Qwen3MoEModelProvider235B_A22B", -] diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py deleted file mode 100644 index 84d6890ce5..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen2_bridge.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import torch - -from transformers import Qwen2ForCausalLM - -from megatron.core.models.gpt.gpt_model import GPTModel - -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge -from megatron.nemo_bridge.models.conversion.param_mapping import ( - AutoMapping, - GatedMLPMapping, - QKVMapping, -) -from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.nemo_bridge.models.qwen.qwen_provider import Qwen2ModelProvider - - -@MegatronModelBridge.register_bridge(source=Qwen2ForCausalLM, target=GPTModel) -class Qwen2Bridge(MegatronModelBridge): - """ - Megatron Bridge for Qwen2 Causal LM. - - This bridge handles the conversion between HuggingFace Qwen2ForCausalLM - and Megatron-Core GPTModel formats, including weight mappings and - configuration translation. - - Example: - >>> from megatron.nemo_bridge import AutoBridge - >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen2-7B") - >>> provider = bridge.to_megatron_provider() - """ - - def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen2ModelProvider: - hf_config = hf_pretrained.config - - provider = Qwen2ModelProvider( - num_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ffn_hidden_size=hf_config.intermediate_size, - num_attention_heads=hf_config.num_attention_heads, - num_query_groups=hf_config.num_key_value_heads, - init_method_std=hf_config.initializer_range, - layernorm_epsilon=hf_config.rms_norm_eps, - gated_linear_unit=True, - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), - rotary_base=hf_config.rope_theta, - share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), - vocab_size=hf_config.vocab_size, - seq_length=hf_config.max_position_embeddings, - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - generation_config=hf_pretrained.generation_config, - add_qkv_bias=True, # Qwen2 has bias in QKV projections - ) - - return provider - - def mapping_registry(self) -> MegatronMappingRegistry: - # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format - # First create simple 1:1 parameter mappings using a dictionary for readability - - # Dictionary maps Megatron parameter names -> HF parameter names - # Supports wildcard (*) patterns for layer-specific parameters - param_mappings = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "output_layer.weight": "lm_head.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", - "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", - "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", - "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", - } - - mapping_list = [] - # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) - for megatron_param, hf_param in param_mappings.items(): - mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) - - # Add special mappings that require parameter concatenation/transformation - mapping_list.extend( - [ - # QKV: Combine separate Q, K, V matrices into single QKV matrix - QKVMapping( - megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", - q="model.layers.*.self_attn.q_proj.weight", - k="model.layers.*.self_attn.k_proj.weight", - v="model.layers.*.self_attn.v_proj.weight", - ), - # QKV bias: Combine separate Q, K, V biases into single QKV bias (Qwen2 specific) - QKVMapping( - megatron_param="decoder.layers.*.self_attention.linear_qkv.bias", - q="model.layers.*.self_attn.q_proj.bias", - k="model.layers.*.self_attn.k_proj.bias", - v="model.layers.*.self_attn.v_proj.bias", - ), - # Gated MLP: Combine gate and up projection matrices into single FC1 matrix - GatedMLPMapping( - megatron_param="decoder.layers.*.mlp.linear_fc1.weight", - gate="model.layers.*.mlp.gate_proj.weight", - up="model.layers.*.mlp.up_proj.weight", - ), - ] - ) - - return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py index 263fe26a32..dbfaabf40d 100644 --- a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, BAAI. All rights reserved. # -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge import torch @@ -8,15 +8,15 @@ from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge from megatron.nemo_bridge.models.conversion.param_mapping import ( AutoMapping, - GatedMLPMapping, QKVMapping, ) +from megatron.bridge.models.conversion.param_mapping import GatedMLPMapping from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.nemo_bridge.models.qwen.qwen_provider import Qwen3ModelProvider +from megatron.bridge.models.qwen.qwen_provider import Qwen3ModelProvider @MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py deleted file mode 100755 index f9cf6fabde..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_moe_bridge.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import torch - -from transformers import Qwen3MoeForCausalLM - -from megatron.core.models.gpt.gpt_model import GPTModel - -from megatron.nemo_bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge -from megatron.nemo_bridge.models.conversion.param_mapping import ( - AutoMapping, - GatedMLPMapping, - QKVMapping, -) -from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.nemo_bridge.models.qwen.qwen_provider import Qwen3MoEModelProvider - - -@MegatronModelBridge.register_bridge(source=Qwen3MoeForCausalLM, target=GPTModel) -class Qwen3MoEBridge(MegatronModelBridge): - """ - Megatron Bridge for Qwen3 MoE Causal LM. - - This bridge handles the conversion between HuggingFace Qwen3MoeForCausalLM - and Megatron-Core GPTModel formats. Qwen3 MoE models use mixture of experts - architecture with QK layernorm. - - Example: - >>> from megatron.nemo_bridge import AutoBridge - >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-235B-A22B") - >>> provider = bridge.to_megatron_provider() - """ - - def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3MoEModelProvider: - hf_config = hf_pretrained.config - - provider = Qwen3MoEModelProvider( - num_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ffn_hidden_size=hf_config.intermediate_size, - moe_ffn_hidden_size=hf_config.moe_intermediate_size, # Maps to moe_intermediate_size in HF - num_attention_heads=hf_config.num_attention_heads, - num_query_groups=hf_config.num_key_value_heads, - num_moe_experts=hf_config.num_experts, - moe_router_topk=hf_config.num_experts_per_tok, # Maps to num_experts_per_tok in HF - init_method_std=hf_config.initializer_range, - layernorm_epsilon=hf_config.rms_norm_eps, - gated_linear_unit=True, - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), - rotary_base=hf_config.rope_theta, - share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), - vocab_size=hf_config.vocab_size, - seq_length=hf_config.max_position_embeddings, - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - generation_config=hf_pretrained.generation_config, - qk_layernorm=True, # Qwen3 MoE uses QK layernorm - moe_grouped_gemm=True, - ) - - return provider - - def mapping_registry(self) -> MegatronMappingRegistry: - # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format - # First create simple 1:1 parameter mappings using a dictionary for readability - - # Dictionary maps Megatron parameter names -> HF parameter names - # Supports wildcard (*) patterns for layer-specific parameters - param_mappings = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "output_layer.weight": "lm_head.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", - "decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight", - "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", - "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", - "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", - "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", - } - - mapping_list = [] - # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) - for megatron_param, hf_param in param_mappings.items(): - mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) - - # Add special mappings that require parameter concatenation/transformation - mapping_list.extend( - [ - # QKV: Combine separate Q, K, V matrices into single QKV matrix - # Note: Qwen3 MoE does NOT have bias in QKV projections - QKVMapping( - megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", - q="model.layers.*.self_attn.q_proj.weight", - k="model.layers.*.self_attn.k_proj.weight", - v="model.layers.*.self_attn.v_proj.weight", - ), - GatedMLPMapping( - megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*", - gate="model.layers.*.mlp.experts.*.gate_proj.weight", - up="model.layers.*.mlp.experts.*.up_proj.weight", - ), - AutoMapping( - megatron_param="decoder.layers.*.mlp.experts.linear_fc2.weight*", - hf_param="model.layers.*.mlp.experts.*.down_proj.weight", - ), - ] - ) - - return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py deleted file mode 100644 index efc7b6ee0c..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen_provider.py +++ /dev/null @@ -1,393 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import logging - -from dataclasses import dataclass -from typing import Callable, Optional - -import torch -import torch.nn.functional as F - -from megatron.nemo_bridge.models.gpt_provider import GPTModelProvider - -logger = logging.getLogger(__name__) - - -@dataclass -class Qwen2ModelProvider(GPTModelProvider): - """Base model provider for Qwen 2 Models.""" - - normalization: str = "RMSNorm" - activation_func: Callable = F.silu - gated_linear_unit: bool = True - add_bias_linear: bool = False - add_qkv_bias: bool = True - seq_length: int = 4096 - init_method_std: int = 0.02 - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - vocab_size: int = 151936 - share_embeddings_and_output_weights: Optional[bool] = False - layernorm_epsilon: float = 1e-6 - rotary_base: float = 1000000.0 - position_embedding_type: str = "rope" - autocast_dtype: torch.dtype = torch.bfloat16 - params_dtype: torch.dtype = torch.bfloat16 - bf16: bool = True - - -# ============================================================================= -# Qwen 2 Model Providers -# ============================================================================= - - -@dataclass -class Qwen2ModelProvider500M(Qwen2ModelProvider): - """ - Config for Qwen 2 0.5B: https://huggingface.co/Qwen/Qwen2-0.5B - """ - - num_layers: int = 24 - hidden_size: int = 896 - num_attention_heads: int = 14 - num_query_groups: int = 2 - ffn_hidden_size: int = 4864 - share_embeddings_and_output_weights: bool = True - seq_length: int = 32768 - - -@dataclass -class Qwen2ModelProvider1P5B(Qwen2ModelProvider): - """ - Config for Qwen 2 1.5B: https://huggingface.co/Qwen/Qwen2-1.5B - """ - - num_layers: int = 28 - hidden_size: int = 1536 - num_attention_heads: int = 12 - num_query_groups: int = 2 - ffn_hidden_size: int = 8960 - seq_length: int = 32768 - share_embeddings_and_output_weights: bool = True - - -@dataclass -class Qwen2ModelProvider7B(Qwen2ModelProvider): - """ - Config for Qwen 2 7B: https://huggingface.co/Qwen/Qwen2-7B - """ - - num_layers: int = 28 - hidden_size: int = 3584 - num_attention_heads: int = 28 - num_query_groups: int = 4 - ffn_hidden_size: int = 18944 - vocab_size: int = 152064 - seq_length: int = 32768 - - -@dataclass -class Qwen2ModelProvider72B(Qwen2ModelProvider): - """ - Config for Qwen 2 72B: https://huggingface.co/Qwen/Qwen2-72B - """ - - num_layers: int = 80 - hidden_size: int = 8192 - num_attention_heads: int = 64 - num_query_groups: int = 8 - ffn_hidden_size: int = 29568 - vocab_size: int = 152064 - layernorm_epsilon: float = 1e-6 - seq_length: int = 32768 - - -# ============================================================================= -# Qwen 2.5 Model Providers -# ============================================================================= - - -@dataclass -class Qwen25ModelProvider500M(Qwen2ModelProvider): - """ - Config for Qwen 2.5 0.5B: https://huggingface.co/Qwen/Qwen2.5-0.5B - """ - - num_layers: int = 24 - hidden_size: int = 896 - num_attention_heads: int = 14 - num_query_groups: int = 2 - ffn_hidden_size: int = 4864 - share_embeddings_and_output_weights: bool = True - seq_length: int = 32768 - - -@dataclass -class Qwen25ModelProvider1P5B(Qwen2ModelProvider): - """ - Config for Qwen 2.5 1.5B: https://huggingface.co/Qwen/Qwen2.5-1.5B - """ - - num_layers: int = 28 - hidden_size: int = 1536 - num_attention_heads: int = 12 - num_query_groups: int = 2 - ffn_hidden_size: int = 8960 - seq_length: int = 32768 - share_embeddings_and_output_weights: bool = True - - -@dataclass -class Qwen25ModelProvider3B(Qwen2ModelProvider): - """ - Config for Qwen 2.5 3B: https://huggingface.co/Qwen/Qwen2.5-3B - """ - - num_layers: int = 36 - hidden_size: int = 2048 - num_attention_heads: int = 16 - num_query_groups: int = 2 - ffn_hidden_size: int = 11008 - vocab_size: int = 151936 - share_embeddings_and_output_weights: bool = True - seq_length: int = 32768 - - -@dataclass -class Qwen25ModelProvider7B(Qwen2ModelProvider): - """ - Config for Qwen 2.5 7B: https://huggingface.co/Qwen/Qwen2.5-7B - """ - - num_layers: int = 28 - hidden_size: int = 3584 - num_attention_heads: int = 28 - num_query_groups: int = 4 - ffn_hidden_size: int = 18944 - vocab_size: int = 152064 - seq_length: int = 32768 - - -@dataclass -class Qwen25ModelProvider14B(Qwen2ModelProvider): - """ - Config for Qwen 2.5 14B: https://huggingface.co/Qwen/Qwen2.5-14B - """ - - num_layers: int = 48 - hidden_size: int = 5120 - num_attention_heads: int = 40 - num_query_groups: int = 8 - ffn_hidden_size: int = 13824 - vocab_size: int = 152064 - layernorm_epsilon: float = 1e-6 - seq_length: int = 32768 - - -@dataclass -class Qwen25ModelProvider32B(Qwen2ModelProvider): - """ - Config for Qwen 2.5 32B: https://huggingface.co/Qwen/Qwen2.5-32B - """ - - num_layers: int = 64 - hidden_size: int = 5120 - num_attention_heads: int = 40 - num_query_groups: int = 8 - ffn_hidden_size: int = 27648 - vocab_size: int = 152064 - layernorm_epsilon: float = 1e-6 - seq_length: int = 32768 - - -@dataclass -class Qwen25ModelProvider72B(Qwen2ModelProvider): - """ - Config for Qwen 2.5 72B: https://huggingface.co/Qwen/Qwen2.5-72B - """ - - num_layers: int = 80 - hidden_size: int = 8192 - num_attention_heads: int = 64 - num_query_groups: int = 8 - ffn_hidden_size: int = 29568 - vocab_size: int = 152064 - layernorm_epsilon: float = 1e-6 - seq_length: int = 32768 - - -# ============================================================================= -# Qwen 3 Model Provider (based on GPTProvider) -# ============================================================================= - - -@dataclass -class Qwen3ModelProvider(GPTModelProvider): - """Base model provider for Qwen 3 Models.""" - - normalization: str = "RMSNorm" - activation_func: Callable = F.silu - gated_linear_unit: bool = True - add_bias_linear: bool = False - add_qkv_bias: bool = False - qk_layernorm: bool = True - kv_channels: Optional[int] = 128 - num_query_groups: int = 8 - seq_length: int = 40960 - init_method_std: int = 0.02 - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - vocab_size: int = 151936 - share_embeddings_and_output_weights: Optional[bool] = False - layernorm_epsilon: float = 1e-6 - rotary_base: float = 1000000.0 - position_embedding_type: str = "rope" - autocast_dtype: torch.dtype = torch.bfloat16 - params_dtype: torch.dtype = torch.bfloat16 - bf16: bool = True - - -@dataclass -class Qwen3ModelProvider600M(Qwen3ModelProvider): - """ - Config for Qwen 3 0.6B: https://huggingface.co/Qwen/Qwen3-0.6B - """ - - num_layers: int = 28 - hidden_size: int = 1024 - num_attention_heads: int = 16 - ffn_hidden_size: int = 3072 - share_embeddings_and_output_weights: bool = True - - -@dataclass -class Qwen3ModelProvider1P7B(Qwen3ModelProvider): - """ - Config for Qwen 3 1.7B: https://huggingface.co/Qwen/Qwen3-1.7B - """ - - num_layers: int = 28 - hidden_size: int = 2048 - num_attention_heads: int = 16 - ffn_hidden_size: int = 6144 - share_embeddings_and_output_weights: bool = True - - -@dataclass -class Qwen3ModelProvider4B(Qwen3ModelProvider): - """ - Config for Qwen 3 4B: https://huggingface.co/Qwen/Qwen3-4B - """ - - num_layers: int = 36 - hidden_size: int = 2560 - num_attention_heads: int = 32 - ffn_hidden_size: int = 9728 - share_embeddings_and_output_weights: bool = True - - -@dataclass -class Qwen3ModelProvider8B(Qwen3ModelProvider): - """ - Config for Qwen 3 8B: https://huggingface.co/Qwen/Qwen3-8B - """ - - num_layers: int = 36 - hidden_size: int = 4096 - num_attention_heads: int = 32 - ffn_hidden_size: int = 12288 - - -@dataclass -class Qwen3ModelProvider14B(Qwen3ModelProvider): - """ - Config for Qwen 3 14B: https://huggingface.co/Qwen/Qwen3-14B - """ - - num_layers: int = 40 - hidden_size: int = 5120 - num_attention_heads: int = 40 - ffn_hidden_size: int = 17408 - - -@dataclass -class Qwen3ModelProvider32B(Qwen3ModelProvider): - """ - Config for Qwen 3 32B: https://huggingface.co/Qwen/Qwen3-32B - """ - - num_layers: int = 64 - hidden_size: int = 5120 - num_attention_heads: int = 64 - ffn_hidden_size: int = 25600 - - -# ============================================================================= -# Qwen 3 MoE Model Provider (based on GPTProvider) -# ============================================================================= - - -@dataclass -class Qwen3MoEModelProvider(GPTModelProvider): - """Base provider for Qwen 3 MoE Models.""" - - normalization: str = "RMSNorm" - activation_func: Callable = F.silu - gated_linear_unit: bool = True - add_bias_linear: bool = False - add_qkv_bias: bool = False - qk_layernorm: bool = True - kv_channels: Optional[int] = 128 - num_query_groups: int = 8 - seq_length: int = 40960 - init_method_std: int = 0.02 - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - vocab_size: int = 151936 - share_embeddings_and_output_weights: Optional[bool] = False - layernorm_epsilon: float = 1e-6 - rotary_base: float = 1000000.0 - position_embedding_type: str = "rope" - autocast_dtype: torch.dtype = torch.bfloat16 - params_dtype: torch.dtype = torch.bfloat16 - bf16: bool = True - - # MoE specific parameters - num_moe_experts: int = 128 - moe_router_load_balancing_type: str = "aux_loss" - moe_aux_loss_coeff: float = 1e-3 - moe_router_topk: int = 8 - moe_router_pre_softmax: bool = False - moe_grouped_gemm: bool = True - moe_token_dispatcher_type: str = "alltoall" - moe_permute_fusion: bool = True - - -@dataclass -class Qwen3MoEModelProvider30B_A3B(Qwen3MoEModelProvider): - """ - Provider for Qwen 3 30B-A3B: https://huggingface.co/Qwen/Qwen3-30B-A3B - """ - - num_layers: int = 48 - hidden_size: int = 2048 - num_attention_heads: int = 32 - num_query_groups: int = 4 - ffn_hidden_size: int = 6144 - moe_ffn_hidden_size: int = 768 - - -@dataclass -class Qwen3MoEModelProvider235B_A22B(Qwen3MoEModelProvider): - """ - Provider for Qwen 3 235B-A22B: https://huggingface.co/Qwen/Qwen3-235B-A22B - """ - - num_layers: int = 94 - hidden_size: int = 4096 - num_attention_heads: int = 64 - num_query_groups: int = 4 - ffn_hidden_size: int = 12288 - moe_ffn_hidden_size: int = 1536 diff --git a/flagscale/train/megatron/nemo_bridge/models/transformer_config.py b/flagscale/train/megatron/nemo_bridge/models/transformer_config.py deleted file mode 100644 index 4a3daf77fc..0000000000 --- a/flagscale/train/megatron/nemo_bridge/models/transformer_config.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -"""Bridge wrapper classes for Megatron Core transformer configurations. - -These classes provide deferred post-initialization to support the Bridge configuration -override system while maintaining compatibility with Megatron Core's post_init behavior. -""" - -from dataclasses import dataclass - -from megatron.core.transformer.transformer_config import ( - MLATransformerConfig as MCoreMLATransformerConfig, - TransformerConfig as MCoreTransformerConfig, -) - - -@dataclass -class TransformerConfig(MCoreTransformerConfig): - """Megatron Core TransformerConfig with deferred post-init. - - This class inherits from Megatron Core's TransformerConfig but defers the - execution of post_init() until finalize() is explicitly called. This allows - for field modifications after construction but before computed fields are - calculated. - - Usage: - # Create config with deferred post-init - config = TransformerConfig(num_layers=32, hidden_size=4096) - - # Modify fields as needed - config.seq_length = 8192 - config.tensor_model_parallel_size = 2 - - # Finalize to compute derived fields - config.finalize() - """ - - def __post_init__(self) -> None: - """Skip MCore post_init during initial construction. - - The original post_init logic is deferred until finalize() is called. - This allows for field modifications after construction without - invalidating computed fields. - """ - pass - - def finalize(self) -> None: - """Execute the deferred MCore post-init logic. - - This method calls the original Megatron Core TransformerConfig.__post_init__() - to compute derived fields based on the current field values. It can be - called multiple times safely. - """ - MCoreTransformerConfig.__post_init__(self) - - -@dataclass -class MLATransformerConfig(TransformerConfig, MCoreMLATransformerConfig): - """Megatron Core MLATransformerConfig with deferred post-init. - - This class inherits from Megatron Core's MLATransformerConfig but defers the - execution of post_init() until finalize() is explicitly called. This allows - for field modifications after construction but before computed fields are - calculated. - - Usage: - # Create config with deferred post-init - config = MLATransformerConfig(num_layers=32, hidden_size=4096) - - # Modify fields as needed - config.q_lora_rank = 1536 - config.kv_lora_rank = 512 - - # Finalize to compute derived fields - config.finalize() - """ - - def __post_init__(self) -> None: - """Skip MCore post_init during initial construction. - - The original post_init logic is deferred until finalize() is called. - This allows for field modifications after construction without - invalidating computed fields. - """ - pass - - def finalize(self) -> None: - """Execute the deferred MCore post-init logic. - - This method calls the original Megatron Core MLATransformerConfig.__post_init__() - to compute derived fields based on the current field values. It can be - called multiple times safely. - """ - MCoreMLATransformerConfig.__post_init__(self) diff --git a/flagscale/train/megatron/nemo_bridge/utils/__init__.py b/flagscale/train/megatron/nemo_bridge/utils/__init__.py deleted file mode 100644 index 3bfe2ab7d3..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge diff --git a/flagscale/train/megatron/nemo_bridge/utils/common_utils.py b/flagscale/train/megatron/nemo_bridge/utils/common_utils.py deleted file mode 100644 index de4e4e17e4..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/common_utils.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import os -import types -import warnings - -import torch -import torch.distributed - -from megatron.core import DistributedDataParallel as DDP -from megatron.core.transformer.module import Float16Module - -try: - from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP - - ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module) -except ImportError: - ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) - - -def get_rank_safe() -> int: - """Get the distributed rank safely, even if torch.distributed is not initialized. - - Returns: - The current process rank. - """ - # In megatron init, args.rank comes from the torchrun env var. - # Once init has been done, args.rank is updated to value of torch get_rank() - if torch.distributed.is_initialized(): - return torch.distributed.get_rank() - else: - return int(os.getenv("RANK", "0")) - - -def get_world_size_safe() -> int: - """Get the distributed world size safely, even if torch.distributed is not initialized. - - Returns: - The total number of processes in the distributed job. - """ - # In megatron init, args.world_size comes from the torchrun env var. - # Once init has been done, args.world_size is updated to value of torch get_world_size() - if torch.distributed.is_initialized(): - return torch.distributed.get_world_size() - else: - return int(os.getenv("WORLD_SIZE", "1")) - - -def get_last_rank() -> int: - """Get the last rank in the distributed group""" - if not torch.distributed.is_initialized(): - return 0 - return torch.distributed.get_world_size() - 1 - - -def get_local_rank_preinit() -> int: - """Get the local rank from the environment variable, intended for use before full init. - - Returns: - The local rank of the current process. - """ - return int(os.getenv("LOCAL_RANK", "0")) - - -def print_rank_0(message: str) -> None: - """Print a message only on global rank 0. - - Args: - message: The message string to print. - """ - rank = get_rank_safe() - if rank == 0: - print(message, flush=True) - - -def warn_rank_0(message): - """Warn only on rank 0.""" - rank = get_rank_safe() - if rank == 0: - warnings.warn(message) - - -def is_last_rank() -> bool: - """Check if the current rank is the last rank in the default process group. - - Returns: - True if the current rank is the last one, False otherwise. - """ - return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) - - -def print_rank_last(message: str) -> None: - """Print a message only on the last rank of the default process group. - - Args: - message: The message string to print. - """ - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True) - - -def hook_hf_module_setattr_for_tp_grad_sync(module: torch.nn.Module) -> torch.nn.Module: - """Mark params for TP grad sync and hook __setattr__ on a module and its children. - - This ensures that all existing parameters under the provided module have the - attribute ``average_gradients_across_tp_domain=True`` and that any future - submodules assigned onto this module (or any of its current children) will - also have their parameters marked automatically. - - Args: - module: The root module (typically a Hugging Face module instance). - - Returns: - The same module instance for convenience. - """ - if module is None: - return module - - # Mark all existing parameters recursively - for param in module.parameters(recurse=True): - setattr(param, "average_gradients_across_tp_domain", True) - - def _wrap_setattr(original_setattr): - def _wrapped(self, name, value): - original_setattr(name, value) - if isinstance(value, torch.nn.Module): - for p in value.parameters(recurse=True): - setattr(p, "average_gradients_across_tp_domain", True) - - return _wrapped - - # Hook __setattr__ on the module and all existing submodules to catch - # future dynamic assignments anywhere in the hierarchy. - for submodule in module.modules(): - if getattr(submodule, "_tp_grad_sync_setattr_wrapped", False): - continue - original_setattr = submodule.__setattr__ - wrapped = _wrap_setattr(original_setattr) - submodule.__setattr__ = types.MethodType(wrapped, submodule) - setattr(submodule, "_tp_grad_sync_setattr_wrapped", True) - - return module diff --git a/flagscale/train/megatron/nemo_bridge/utils/decorators.py b/flagscale/train/megatron/nemo_bridge/utils/decorators.py deleted file mode 100644 index 437db3b4f6..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/decorators.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import functools -import logging -import warnings - -from typing import Any, Callable, TypeVar - -logger = logging.getLogger(__name__) - -# Define a TypeVar for generic return types -R = TypeVar("R") - - -def experimental_fn(func: Callable[..., R]) -> Callable[..., R]: - """Decorator to mark a function as experimental and issue a warning upon its call.""" - warning_message = f"Function '{func.__name__}' is experimental. APIs in this module are subject to change without notice." - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> R: - warnings.warn(warning_message, stacklevel=2) - return func(*args, **kwargs) - - return wrapper diff --git a/flagscale/train/megatron/nemo_bridge/utils/fusions.py b/flagscale/train/megatron/nemo_bridge/utils/fusions.py deleted file mode 100644 index 1f7d6f52a6..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/fusions.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -"""Fusion capability checks for Megatron models. - -This module provides functions to check if various fusion optimizations -can be enabled based on the current environment and dependencies. -""" - -import logging -import os - -from megatron.core.transformer.transformer_config import TransformerConfig - -logger = logging.getLogger(__name__) - -# Control whether to log warnings when fusions are disabled -# Set environment variable MEGATRON_SUPPRESS_FUSION_WARNINGS=1 to disable warnings -LOG_FUSION_DISABLE = os.environ.get("MEGATRON_SUPPRESS_FUSION_WARNINGS", "0") != "1" - - -def can_enable_apply_rope_fusion() -> bool: - """Check if RoPE (Rotary Position Embedding) fusion can be enabled. - - Returns: - bool: True if RoPE fusion is available and compatible. - """ - # Check for Transformer Engine availability - try: - import transformer_engine # noqa: F401 - - from megatron.core.utils import get_te_version, is_te_min_version - - if not is_te_min_version("2.2.0.dev0"): - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion requires Transformer Engine >= 2.2.0.dev0. " - f"Current version: {get_te_version()}. Fusion disabled." - ) - return False - except ImportError: - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion requires Transformer Engine but it is not installed. Fusion disabled." - ) - return False - - # Check for RoPE fusion kernel availability - try: - from megatron.core.models.common.embeddings.rope_utils import ( - fused_apply_rotary_pos_emb, - fused_apply_rotary_pos_emb_thd, - ) - - if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None: - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion kernels are not available in megatron.core. Fusion disabled." - ) - return False - return True - except ImportError: - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion requires RoPE fusion kernels from megatron.core but they are not available. " - "Fusion disabled." - ) - return False - - -def can_enable_gradient_accumulation_fusion() -> bool: - """Check if gradient accumulation fusion can be enabled. - - Returns: - bool: True if gradient accumulation fusion is available. - """ - try: - import fused_weight_gradient_mlp_cuda # noqa: F401 - - return True - except ImportError: - if LOG_FUSION_DISABLE: - logger.warning( - "gradient_accumulation_fusion requires FusedLayerNorm from megatron.core.fusions " - "but it is not available. Fusion disabled." - ) - return False - - -def can_enable_bias_dropout_fusion() -> bool: - """Check if bias dropout fusion can be enabled. - - Returns: - bool: True if bias dropout fusion is available. - """ - try: - from megatron.core.fusions.fused_bias_dropout import ( # noqa: F401 - bias_dropout_add_fused_train, - ) - - return True - except ImportError: - if LOG_FUSION_DISABLE: - logger.warning( - "bias_dropout_fusion requires fused_bias_dropout from megatron.core.fusions " - "but it is not available. Fusion disabled." - ) - return False - - -def can_enable_masked_softmax_fusion() -> bool: - """Check if masked softmax fusion can be enabled. - - Returns: - bool: True if masked softmax fusion kernels are available. - """ - try: - # Try to import the CUDA kernels that are required for masked softmax fusion - import scaled_masked_softmax_cuda # noqa: F401 - import scaled_upper_triang_masked_softmax_cuda # noqa: F401 - - return True - except ImportError: - if LOG_FUSION_DISABLE: - logger.warning( - "masked_softmax_fusion requires CUDA kernels (scaled_masked_softmax_cuda) " - "but they are not available. This typically happens when Megatron-Core is not " - "built with CUDA extensions. Fusion disabled." - ) - return False - - -def validate_rope_fusion_compatibility(config: TransformerConfig) -> bool: - """Validate if RoPE fusion is compatible with the current model configuration. - - Args: - model_provider: The GPTModelProvider instance to validate. - - Returns: - bool: True if RoPE fusion is compatible, False otherwise. - """ - if not config.apply_rope_fusion: - return True - - # Check for multi_latent_attention incompatibility - if getattr(config, "multi_latent_attention", False): - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion for multi-latent attention only supports training. " - "It is experimental and may change in future versions." - ) - return True - - # Check TE version for rotary_interleaved - if getattr(config, "rotary_interleaved", False): - try: - from megatron.core.utils import get_te_version, is_te_min_version - - if not is_te_min_version("2.2.0.dev0"): - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion with rotary_interleaved requires TE >= 2.2.0.dev0. " - f"Current TE version: {get_te_version()}. Consider disabling apply_rope_fusion." - ) - return False - except ImportError: - if LOG_FUSION_DISABLE: - logger.warning( - "apply_rope_fusion with rotary_interleaved requires Transformer Engine. " - "Consider disabling apply_rope_fusion." - ) - return False - - return True diff --git a/flagscale/train/megatron/nemo_bridge/utils/import_utils.py b/flagscale/train/megatron/nemo_bridge/utils/import_utils.py deleted file mode 100644 index 33d1dd4edf..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/import_utils.py +++ /dev/null @@ -1,409 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import importlib -import logging -import traceback - -from contextlib import contextmanager -from typing import Tuple - -import torch - -from packaging.version import Version as PkgVersion - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) -logger.addHandler(logging.StreamHandler()) - -GPU_INSTALL_STRING = ( - """Install GPU packages via `pip install --extra-index-url """ - """https://pypi.nvidia.com nemo-curator[cuda12x]` -or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source""" -) -MISSING_NEMO_EXPORT_DEPLOY_MSG = ( - "nemo-export-deploy is not available. Please install it with `pip install nemo-export-deploy`." -) -MISSING_NVRX_MSG = "nvidia-resiliency-ext is not available. Please install it with `pip install nvidia-resiliency-ext`." -MISSING_NEMO_RUN_MSG = "nemo-run is not available. Please install it with `pip install nemo-run`." - - -class UnavailableError(Exception): - """Error thrown if a symbol is unavailable due to an issue importing it""" - - -@contextmanager -def null_decorator(*args, **kwargs): - """null_decorator""" - if len(kwargs) == 0 and len(args) == 1 and callable(args[0]): - return args[0] - else: - - def inner(func): - return func - - return inner - - -class UnavailableMeta(type): - """A metaclass for generating placeholder objects for unavailable symbols - - This metaclass allows errors to be deferred from import time to the time - that a symbol is actually used in order to streamline the usage of optional - dependencies. This is particularly useful for attempted imports of GPU-only - modules which will only be invoked if GPU-only functionality is - specifically used. - - If an attempt to import a symbol fails, this metaclass is used to generate - a class which stands in for that symbol. Any attempt to call the symbol - (instantiate the class) or access its attributes will throw an - UnavailableError exception. Furthermore, this class can be used in - e.g. isinstance checks, since it will (correctly) fail to match any - instance it is compared against. - - In addition to calls and attribute access, a number of dunder methods are - implemented so that other common usages of imported symbols (e.g. - arithmetic) throw an UnavailableError, but this is not guaranteed for - all possible uses. In such cases, other exception types (typically - TypeErrors) will be thrown instead. - """ - - def __new__(meta, name, bases, dct): - if dct.get("_msg", None) is None: - dct["_msg"] = f"{name} could not be imported" - name = f"MISSING{name}" - return super(UnavailableMeta, meta).__new__(meta, name, bases, dct) - - def __call__(cls, *args, **kwargs): - raise UnavailableError(cls._msg) - - def __getattr__(cls, name): - # Special handling for unittest.mock which tries to access __func_ - # and other attributes during its operations - if name in ("__func__", "__wrapped__", "__name__", "__qualname__"): - raise AttributeError(f"'{cls.__name__}' has no attribute '{name}'") - raise UnavailableError(cls._msg) - - def __eq__(cls, other): - raise UnavailableError(cls._msg) - - def __lt__(cls, other): - raise UnavailableError(cls._msg) - - def __gt__(cls, other): - raise UnavailableError(cls._msg) - - def __le__(cls, other): - raise UnavailableError(cls._msg) - - def __ge__(cls, other): - raise UnavailableError(cls._msg) - - def __ne__(cls, other): - raise UnavailableError(cls._msg) - - def __abs__(cls): - raise UnavailableError(cls._msg) - - def __add__(cls, other): - raise UnavailableError(cls._msg) - - def __radd__(cls, other): - raise UnavailableError(cls._msg) - - def __iadd__(cls, other): - raise UnavailableError(cls._msg) - - def __floordiv__(cls, other): - raise UnavailableError(cls._msg) - - def __rfloordiv__(cls, other): - raise UnavailableError(cls._msg) - - def __ifloordiv__(cls, other): - raise UnavailableError(cls._msg) - - def __lshift__(cls, other): - raise UnavailableError(cls._msg) - - def __rlshift__(cls, other): - raise UnavailableError(cls._msg) - - def __mul__(cls, other): - raise UnavailableError(cls._msg) - - def __rmul__(cls, other): - raise UnavailableError(cls._msg) - - def __imul__(cls, other): - raise UnavailableError(cls._msg) - - def __ilshift__(cls, other): - raise UnavailableError(cls._msg) - - def __pow__(cls, other): - raise UnavailableError(cls._msg) - - def __rpow__(cls, other): - raise UnavailableError(cls._msg) - - def __ipow__(cls, other): - raise UnavailableError(cls._msg) - - def __rshift__(cls, other): - raise UnavailableError(cls._msg) - - def __rrshift__(cls, other): - raise UnavailableError(cls._msg) - - def __irshift__(cls, other): - raise UnavailableError(cls._msg) - - def __sub__(cls, other): - raise UnavailableError(cls._msg) - - def __rsub__(cls, other): - raise UnavailableError(cls._msg) - - def __isub__(cls, other): - raise UnavailableError(cls._msg) - - def __truediv__(cls, other): - raise UnavailableError(cls._msg) - - def __rtruediv__(cls, other): - raise UnavailableError(cls._msg) - - def __itruediv__(cls, other): - raise UnavailableError(cls._msg) - - def __divmod__(cls, other): - raise UnavailableError(cls._msg) - - def __rdivmod__(cls, other): - raise UnavailableError(cls._msg) - - def __neg__(cls): - raise UnavailableError(cls._msg) - - def __invert__(cls): - raise UnavailableError(cls._msg) - - def __hash__(cls): - raise UnavailableError(cls._msg) - - def __index__(cls): - raise UnavailableError(cls._msg) - - def __iter__(cls): - raise UnavailableError(cls._msg) - - def __delitem__(cls, name): - raise UnavailableError(cls._msg) - - def __setitem__(cls, name, value): - raise UnavailableError(cls._msg) - - def __enter__(cls, *args, **kwargs): - raise UnavailableError(cls._msg) - - def __get__(cls, *args, **kwargs): - raise UnavailableError(cls._msg) - - def __delete__(cls, *args, **kwargs): - raise UnavailableError(cls._msg) - - def __len__(cls): - raise UnavailableError(cls._msg) - - -def is_unavailable(obj): - """Helper to check if given symbol is actually a placeholder""" - return type(obj) is UnavailableMeta - - -class UnavailableNullContext: - """A placeholder class for unavailable context managers - - This context manager will return a value which will throw an - UnavailableError if used in any way, but the context manager itself can be - safely invoked. - """ - - def __init__(self, *args, **kwargs): - pass - - def __enter__(self): - return UnavailableMeta( - "MissingContextValue", - (), - {"_msg": "Attempted to make use of placeholder context return value."}, - ) - - def __exit__(self, *args, **kwargs): - pass - - -def safe_import(module, *, msg=None, alt=None) -> Tuple[object, bool]: - """A function used to import modules that may not be available. - - This function will attempt to import a module with the given name, but it - will not throw an ImportError if the module is not found. Instead, it will - return a placeholder object which will raise an exception only if used. - - Args: - module (str): The name of the module to import. - msg (str, optional): An error message to be displayed if this module is used - after a failed import. Defaults to None. - alt (object, optional): A module to be used in place of the given module if it - fails to import. Defaults to None. - - Returns: - tuple: A tuple containing two elements. The first element is the imported module, - the given alternate, or a class derived from UnavailableMeta. The second element - is a boolean indicating whether the intended import was successful. - """ - try: - return importlib.import_module(module), True - except ImportError: - exception_text = traceback.format_exc() - logger.debug(f"Import of {module} failed with: {exception_text}") - except Exception: - exception_text = traceback.format_exc() - raise - if msg is None: - msg = f"{module} could not be imported" - if alt is None: - return UnavailableMeta(module.rsplit(".")[-1], (), {"_msg": msg}), False - else: - return alt, False - - -def safe_import_from( - module, symbol, *, msg=None, alt=None, fallback_module=None -) -> Tuple[object, bool]: - """A function used to import symbols from modules that may not be available. - - This function will attempt to import a symbol with the given name from - the given module, but it will not throw an ImportError if the symbol is not - found. Instead, it will return a placeholder object which will raise an - exception only if used. - - Args: - module (str): The name of the module in which the symbol is defined. - symbol (str): The name of the symbol to import. - msg (str, optional): An error message to be displayed if this symbol is used - after a failed import. Defaults to None. - alt (object, optional): An object to be used in place of the given symbol if it fails - to import. Defaults to None. - fallback_module (str, optional): Alternative name of the model in which the symbol is defined. - The function will first try to import using the `module` value and if that fails - will also try the `fallback_module`. Defaults to None. - - Returns: - tuple: A tuple containing two elements. The first element is the imported symbol, - the given alternate, or a class derived from UnavailableMeta. The second element - is a boolean indicating whether the intended import was successful. - """ - try: - imported_module = importlib.import_module(module) - return getattr(imported_module, symbol), True - except ImportError: - exception_text = traceback.format_exc() - logger.debug(f"Import of {module} failed with: {exception_text}") - except AttributeError: - # if there is a fallback module try it. - if fallback_module is not None: - return safe_import_from(fallback_module, symbol, msg=msg, alt=alt, fallback_module=None) - exception_text = traceback.format_exc() - logger.info(f"Import of {symbol} from {module} failed with: {exception_text}") - except Exception: - exception_text = traceback.format_exc() - raise - if msg is None: - msg = f"{module}.{symbol} could not be imported" - if alt is None: - return UnavailableMeta(symbol, (), {"_msg": msg}), False - else: - return alt, False - - -def gpu_only_import(module, *, alt=None) -> Tuple[object, bool]: - """A function used to import modules required only in GPU installs. - - This function will attempt to import a module with the given name. - This function will attempt to import a symbol with the given name from - the given module, but it will not throw an ImportError if the symbol is not - found. Instead, it will return a placeholder object which will raise an - exception only if used with instructions on installing a GPU build. - - Args: - module (str): The name of the module to import. - alt (object, optional): A module to be used in place of the given module if it - fails to import in a non-GPU-enabled install. Defaults to None. - - Returns: - tuple: A tuple containing two elements. The first element is the imported module, - the given alternate, or a class derived from UnavailableMeta. The second element - is a boolean indicating whether the intended import was successful. - """ - - return safe_import( - module, - msg=f"{module} is not enabled in non GPU-enabled installations or environemnts. {GPU_INSTALL_STRING}", - alt=alt, - ) - - -def gpu_only_import_from(module, symbol, *, alt=None) -> Tuple[object, bool]: - """A function used to import symbols required only in GPU installs. - - This function will attempt to import a module with the given name. - This function will attempt to import a symbol with the given name from - the given module, but it will not throw an ImportError if the symbol is not - found. Instead, it will return a placeholder object which will raise an - exception only if used with instructions on installing a GPU build. - - Args: - module (str): The name of the module to import. - symbol (str): The name of the symbol to import. - alt (object, optional): An object to be used in place of the given symbol if it fails - to import in a non-GPU-enabled install. Defaults to None. - - Returns: - tuple: A tuple containing two elements. The first element is the imported symbol, - the given alternate, or a class derived from UnavailableMeta. The second element - is a boolean indicating whether the intended import was successful. - """ - return safe_import_from( - module, - symbol, - msg=f"{module}.{symbol} is not enabled in non GPU-enabled installations or environments. {GPU_INSTALL_STRING}", - alt=alt, - ) - - -def get_torch_version(): - """Returns the installed PyTorch version as a packaging.version.Version object. - - Handles potential exceptions during version parsing, returning a dummy version - ("0.0.0") if parsing fails (e.g., during documentation builds where torch - might not be fully imported or available). - - Returns: - packaging.version.Version: The parsed PyTorch version, or Version("0.0.0") on error. - """ - try: - _torch_version = PkgVersion(torch.__version__) - except Exception: - # This is a WAR for building docs, where torch is not actually imported - _torch_version = PkgVersion("0.0.0") - return _torch_version - - -def is_torch_min_version(version, check_equality=True): - """Check if minimum version of `torch` is installed.""" - if check_equality: - return get_torch_version() >= PkgVersion(version) - return get_torch_version() > PkgVersion(version) diff --git a/flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py b/flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py deleted file mode 100644 index 2bbf9a7eb3..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/instantiate_utils.py +++ /dev/null @@ -1,418 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import copy -import functools -import logging - -from enum import Enum -from textwrap import dedent -from typing import Any, Callable, Sequence, Union - -from omegaconf import OmegaConf -from omegaconf._utils import is_structured_config - - -class InstantiationException(Exception): - """Custom exception type for instantiation errors.""" - - ... - - -class InstantiationMode(Enum): - """Enum for instantiation modes.""" - - STRICT = "strict" - LENIENT = "lenient" - - -class _Keys(str, Enum): - """Special keys in configs used by instantiate.""" - - TARGET = "_target_" - PARTIAL = "_partial_" - CALL = "_call_" - ARGS = "_args_" - - -def instantiate( - config: Any, *args: Any, mode: InstantiationMode = InstantiationMode.LENIENT, **kwargs: Any -) -> Any: - """Instantiate an object or callable from a config object. - - This function takes a configuration object (dictionary, list, OmegaConf config, - or Structured Config instance) and instantiates the target specified within it. - - The config object must contain: - _target_ (str): The fully qualified name of the class or callable to instantiate. - - The config object may also contain: - _args_ (list): Positional arguments for the target. - _partial_ (bool): If True, return a functools.partial object instead of calling - the target. Defaults to False. - _call_ (bool): If False, simply resolves and returns the target without calling it. - Defaults to True. - Additional keyword arguments to pass to the target. - - Args: - config: The configuration object describing the target and its parameters. - *args: Optional positional arguments that will override _args_ in the config - if provided. - mode: Instantiation mode (STRICT or LENIENT). In LENIENT mode (default), - errors during instantiation of parameters are logged as warnings, - and None is used instead. In STRICT mode, errors are raised. - **kwargs: Optional keyword arguments that will override parameters in the config. - Note: Dataclass instances in kwargs are treated as nested configs. - - Returns: - The instantiated object or the return value of the callable. - If config._partial_ is True, returns a functools.partial object. - If config._call_ is False, returns the resolved target callable/class itself. - Returns None if the input config is None. - - Raises: - InstantiationException: If the config is invalid, the target cannot be resolved, - or instantiation fails in STRICT mode. - TypeError: If the _partial_ flag is not a boolean. - """ - - # Return None if config is None - if config is None: - return None - - if isinstance(config, (dict, list)): - config = _prepare_input_dict_or_list(config) - - kwargs = _prepare_input_dict_or_list(kwargs) - - # Structured Config always converted first to OmegaConf - if is_structured_config(config) or isinstance(config, (dict, list)): - config = OmegaConf.structured(config, flags={"allow_objects": True}) - - if OmegaConf.is_dict(config): - # Finalize config (convert targets to strings, merge with kwargs) - config_copy = copy.deepcopy(config) - config_copy._set_flag( - flags=["allow_objects", "struct", "readonly"], values=[True, False, False] - ) - config_copy._set_parent(config._get_parent()) - config = config_copy - - if kwargs: - config = OmegaConf.merge(config, kwargs) - - OmegaConf.resolve(config) - - _partial_ = config.pop(_Keys.PARTIAL, False) - - return instantiate_node(config, *args, partial=_partial_, mode=mode) - elif OmegaConf.is_list(config): - # Finalize config (convert targets to strings, merge with kwargs) - config_copy = copy.deepcopy(config) - config_copy._set_flag( - flags=["allow_objects", "struct", "readonly"], values=[True, False, False] - ) - config_copy._set_parent(config._get_parent()) - config = config_copy - - OmegaConf.resolve(config) - - _partial_ = kwargs.pop(_Keys.PARTIAL, False) - - if _partial_: - raise InstantiationException( - "The _partial_ keyword is not compatible with top-level list instantiation" - ) - - return instantiate_node(config, *args, partial=_partial_, mode=mode) - else: - raise InstantiationException( - dedent( - f"""\ - Cannot instantiate config of type {type(config).__name__}. - Top level config must be an OmegaConf DictConfig/ListConfig object, - a plain dict/list, or a Structured Config class or instance.""" - ) - ) - - -def instantiate_node( - node: Any, - *args: Any, - partial: bool = False, - mode: InstantiationMode = InstantiationMode.LENIENT, -) -> Any: - """Recursively instantiates a node within a configuration structure. - - This function handles the instantiation of individual nodes (dictionaries, - lists, or primitive values) within a larger configuration tree, typically - managed by OmegaConf. - - If the node is a dictionary containing a `_target_` key, it resolves and - instantiates the target callable/class using the other items in the - dictionary as keyword arguments. Nested nodes are recursively instantiated. - - If the node is a list, it recursively instantiates each item in the list. - - If the node is not an OmegaConf config node (e.g., a primitive type), it's - returned directly. - - Args: - node: The configuration node to instantiate (can be DictConfig, ListConfig, - or a primitive type). - *args: Positional arguments passed down from the top-level `instantiate` call, - used primarily for the final target call if the node is a dictionary - with `_target_`. - partial: Boolean flag indicating whether to return a `functools.partial` object - instead of calling the target. This can be overridden by a - `_partial_` key within the node itself. - mode: Instantiation mode (STRICT or LENIENT). Determines error handling - behavior for nested instantiations. - - Returns: - The instantiated object, list, or the original node if it wasn't a config. - Returns None if the input node is None or represents a None value in OmegaConf. - - Raises: - InstantiationException: If instantiation fails in STRICT mode, or if there are - issues like incompatible arguments or non-callable targets. - TypeError: If a `_partial_` flag within the config is not a boolean. - """ - # Return None if config is None - if node is None or (OmegaConf.is_config(node) and node._is_none()): - return None - - if not OmegaConf.is_config(node): - return node - - # Override parent modes from config if specified - if OmegaConf.is_dict(node): - # using getitem instead of get(key, default) because OmegaConf will raise an exception - # if the key type is incompatible on get. - partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial - - full_key = node._get_full_key(None) - - if not isinstance(partial, bool): - msg = f"Instantiation: _partial_ flag must be a bool, got {type(partial)}" - if node and full_key: - msg += f"\nfull_key: {full_key}" - raise TypeError(msg) - - if OmegaConf.is_list(node): - items = [instantiate_node(item, mode=mode) for item in node._iter_ex(resolve=True)] - - return items - elif OmegaConf.is_dict(node): - exclude_keys = set(item.value for item in _Keys if item != _Keys.ARGS) - if _is_target(node): - should_call_target = node.get("_call_", True) - _target_ = _resolve_target( - node.get(_Keys.TARGET), full_key, check_callable=should_call_target - ) - kwargs = {} - is_partial = node.get("_partial_", False) or partial - - if not should_call_target: - if len(set(node.keys()) - {"_target_", "_call_"}) != 0: - extra_keys = set(node.keys()) - {"_target_", "_call_"} - raise InstantiationException( - f"_call_ was set to False for target {_convert_target_to_string(_target_)}," - f" but extra keys were found: {extra_keys}" - ) - else: - return _target_ - - for key in node.keys(): - if key not in exclude_keys: - if OmegaConf.is_missing(node, key) and is_partial: - continue - value = node[key] - try: - value = instantiate_node(value, mode=mode) - except (ImportError, InstantiationException) as e: - if mode == InstantiationMode.STRICT: - raise InstantiationException( - f"Error instantiating {value} for key {full_key}.{key}: {e}" - ) from e - else: - value = None - logging.warning( - f"Error instantiating {value} for key {full_key}.{key}. " - f"Using None instead in lenient mode." - ) - kwargs[key] = _convert_node(value) - - assert callable(_target_) - return _call_target(_target_, partial, args, kwargs, full_key) - else: - dict_items = {} - for key, value in node.items(): - dict_items[key] = instantiate_node(value, mode=mode) - return dict_items - - else: - assert False, f"Unexpected config type : {type(node).__name__}" - - -def _locate(path: str) -> Any: - """ - Locate an object by name or dotted path, importing as necessary. - This function attempts to import modules starting from the most specific path - (back to front), making it possible to import objects where the final component - could be either a module or an attribute of the previous module. - """ - if path == "": - raise ImportError("Empty path") - from importlib import import_module - - parts = [part for part in path.split(".")] - for part in parts: - if not len(part): - raise ValueError( - f"Error loading '{path}': invalid dotstring." - + "\nRelative imports are not supported." - ) - assert len(parts) > 0 - - # Try importing from the most specific path first (back to front) - for i in range(len(parts), 0, -1): - module_path = ".".join(parts[:i]) - try: - obj = import_module(module_path) - - # If this isn't the full path, get the remaining attributes - remaining_parts = parts[i:] - for part in remaining_parts: - try: - obj = getattr(obj, part) - except AttributeError as exc_attr: - raise ImportError( - f"Error loading '{path}':\n{repr(exc_attr)}" - + f"\nAre you sure that '{part}' is an attribute of '{module_path}'?" - ) from exc_attr - - # Successfully found the object - return obj - - except ModuleNotFoundError: - # Module not found, try a less specific path - continue - except Exception as exc_import: - # If we hit a different exception, it's likely an issue with the module itself - raise ImportError(f"Error loading '{path}':\n{repr(exc_import)}") from exc_import - - # If we've tried all paths and nothing worked, report failure with the base module - raise ImportError( - f"Error loading '{path}': Unable to import any module in the path. " - f"Are you sure that module '{parts[0]}' is installed?" - ) - - -def _is_target(x: Any) -> bool: - if isinstance(x, dict): - return "_target_" in x - if OmegaConf.is_dict(x): - return "_target_" in x - return False - - -def _call_target( - _target_: Callable[..., Any], - _partial_: bool, - args: tuple[Any, ...], - kwargs: dict[str, Any], - full_key: str, -) -> Any: - """Call target (type) with args and kwargs.""" - args, kwargs = _extract_pos_args(args, kwargs) - if _partial_: - try: - return functools.partial(_target_, *args, **kwargs) - except Exception as e: - msg = ( - f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:" - + f"\n{repr(e)}" - ) - if full_key: - msg += f"\nfull_key: {full_key}" - raise InstantiationException(msg) from e - else: - try: - return _target_(*args, **kwargs) - except Exception as e: - msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}" - if full_key: - msg += f"\nfull_key: {full_key}" - raise InstantiationException(msg) from e - - -def _convert_target_to_string(t: Any) -> Any: - if callable(t): - return f"{t.__module__}.{t.__qualname__}" - else: - return t - - -def _prepare_input_dict_or_list(d: Union[dict[Any, Any], list[Any]]) -> Any: - res: Any - if isinstance(d, dict): - res = {} - for k, v in d.items(): - if k == "_target_": - v = _convert_target_to_string(d["_target_"]) - elif isinstance(v, (dict, list)): - v = _prepare_input_dict_or_list(v) - res[k] = v - elif isinstance(d, list): - res = [] - for v in d: - if isinstance(v, (list, dict)): - v = _prepare_input_dict_or_list(v) - res.append(v) - else: - assert False - return res - - -def _resolve_target( - target: Union[str, type, Callable[..., Any]], full_key: str, check_callable: bool = True -) -> Union[type, Callable[..., Any], object]: - """Resolve target string, type or callable into type or callable.""" - if isinstance(target, str): - try: - target = _locate(target) - except Exception as e: - msg = f"Error locating target '{target}'." - if full_key: - msg += f"\nfull_key: {full_key}" - raise InstantiationException(msg) from e - if check_callable and not callable(target): - msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" - if full_key: - msg += f"\nfull_key: {full_key}" - raise InstantiationException(msg) - return target - - -def _extract_pos_args(input_args: Any, kwargs: Any) -> tuple[Any, Any]: - config_args = kwargs.pop(_Keys.ARGS, ()) - output_args = config_args - - if isinstance(config_args, Sequence): - if len(input_args) > 0: - output_args = input_args - else: - raise InstantiationException( - f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'" - ) - - return output_args, kwargs - - -def _convert_node(node: Any) -> Any: - if OmegaConf.is_config(node): - node = OmegaConf.to_container(node, resolve=True) - - return node diff --git a/flagscale/train/megatron/nemo_bridge/utils/path_utils.py b/flagscale/train/megatron/nemo_bridge/utils/path_utils.py deleted file mode 100644 index 0fe9c30ee8..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/path_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -from pathlib import Path - - -def resolve_path(path: str) -> Path: - """Resolve a path to an absolute path.""" - return Path(path).expanduser().absolute().resolve() diff --git a/flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py b/flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py deleted file mode 100644 index 85b68e1683..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/vocab_utils.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import math - -from functools import lru_cache - -from megatron.nemo_bridge.utils.common_utils import print_rank_0 - - -def calculate_padded_vocab_size( - vocab_size: int, - make_vocab_size_divisible_by: int, - tensor_model_parallel_size: int, - logging_enabled: bool = True, -) -> int: - """Calculate padded vocab size for tensor parallelism. - - This function pads the vocabulary size to ensure it's divisible by the required - multiple for efficient tensor parallel operations. - - Args: - vocab_size: The original (unpadded) vocabulary size - make_vocab_size_divisible_by: Base divisibility requirement (e.g., 128) - tensor_model_parallel_size: Number of tensor parallel ranks - logging_enabled: Whether to log the padding information - - Returns: - int: The padded vocabulary size - """ - padded_size = _calculate_padded_vocab_size_cached( - vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size - ) - - # Handle logging separately to avoid affecting cache behavior - if logging_enabled: - print_rank_0( - " > padded vocab (size: {}) with {} dummy tokens (new size: {})".format( - vocab_size, padded_size - vocab_size, padded_size - ) - ) - - return padded_size - - -@lru_cache(maxsize=128) -def _calculate_padded_vocab_size_cached( - vocab_size: int, make_vocab_size_divisible_by: int, tensor_model_parallel_size: int -) -> int: - """Cached computation of padded vocab size.""" - if vocab_size <= 0: - raise ValueError(f"vocab_size must be positive, got {vocab_size}") - if make_vocab_size_divisible_by <= 0: - raise ValueError( - f"make_vocab_size_divisible_by must be positive, got {make_vocab_size_divisible_by}" - ) - if tensor_model_parallel_size <= 0: - raise ValueError( - f"tensor_model_parallel_size must be positive, got {tensor_model_parallel_size}" - ) - - multiple = make_vocab_size_divisible_by * tensor_model_parallel_size - return int(math.ceil(vocab_size / multiple) * multiple) diff --git a/flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py b/flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py deleted file mode 100644 index f38553d6a7..0000000000 --- a/flagscale/train/megatron/nemo_bridge/utils/yaml_utils.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2025, BAAI. All rights reserved. -# -# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge - -import enum -import functools -import inspect - -from contextlib import contextmanager -from typing import Any, Generator, Optional - -import yaml - - -@contextmanager -def safe_yaml_representers() -> Generator[None, None, None]: - """ - Context manager for safely adding and removing custom YAML representers. - - Temporarily adds custom representers for functions, classes, and other objects - to the YAML SafeDumper, and restores the original representers when exiting - the context. - - Usage: - with safe_yaml_representers(): - yaml_str = yaml.safe_dump(my_complex_object) - """ - # Save original representers - original_representers = yaml.SafeDumper.yaml_representers.copy() - original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() - - try: - # Register custom representers - - # Partial representer - yaml.SafeDumper.add_representer(functools.partial, _partial_representer) - - # Enum representer - yaml.SafeDumper.add_multi_representer(enum.Enum, _enum_representer) - - # Function representer - yaml.SafeDumper.add_representer(type(lambda: ...), _function_representer) - yaml.SafeDumper.add_representer(type(object), _function_representer) - - # Try to add torch dtype representer if available - try: - import torch - - yaml.SafeDumper.add_representer(torch.dtype, _torch_dtype_representer) - except ModuleNotFoundError: - pass - - # Try to add GenerationConfig representer if available - try: - from transformers import GenerationConfig - - yaml.SafeDumper.add_representer(GenerationConfig, _generation_config_representer) - except ModuleNotFoundError: - pass - - # Try to add PretrainedConfig representer if available (generic for HF configs) - try: - from transformers import PretrainedConfig - - # Use multi-representer so subclasses of PretrainedConfig are also handled - yaml.SafeDumper.add_multi_representer(PretrainedConfig, _pretrained_config_representer) - except ModuleNotFoundError: - pass - - # General object representer - yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) - - yield - finally: - # Restore original representers - yaml.SafeDumper.yaml_representers = original_representers - yaml.SafeDumper.yaml_multi_representers = original_multi_representers - - -def dump_dataclass_to_yaml(obj: Any, filename: Optional[str] = None) -> Optional[str]: - """Dump a dataclass object or other Python object to a YAML file or string. - - Uses safe representers to handle common types. - - Args: - obj: The object to dump. - filename: If provided, the path to the file where YAML should be written. - If None, returns the YAML string directly. - - Returns: - If filename is None, returns the YAML string representation of the object. - Otherwise, returns None. - """ - with safe_yaml_representers(): - if filename is not None: - with open(filename, "w+") as f: - yaml.safe_dump(obj, f) - else: - return yaml.safe_dump(obj) - - -def _function_representer(dumper, data): - """Represent functions in YAML.""" - value = { - "_target_": f"{inspect.getmodule(data).__name__}.{data.__qualname__}", # type: ignore - "_call_": False, - } - return dumper.represent_data(value) - - -def _torch_dtype_representer(dumper, data): - """Represent torch dtypes in YAML.""" - value = {"_target_": str(data), "_call_": False} - return dumper.represent_data(value) - - -def _safe_object_representer(dumper, data): - """ - General object representer for YAML. - - This function is a fallback for objects that don't have specific representers. - If the object has __qualname__ attr, - the _target_ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}". - If the object does not have a __qualname__ attr, the _target_ is set from its __class__ attr. - The _call_ key is used to indicate whether the target should be called to create an instance. - - Args: - dumper (yaml.Dumper): The YAML dumper to use for serialization. - data (Any): The data to serialize. - - Returns: - The YAML representation of the data. - """ - try: - obj = data - target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" - call = False - except AttributeError: - obj = data.__class__ - target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" - call = True - - value = {"_target_": target, "_call_": call} # type: ignore - return dumper.represent_data(value) - - -def _partial_representer(dumper, data): - """Represent functools.partial objects in YAML.""" - # Get the underlying function - func = data.func - - # Create a dictionary representation - value = { - "_target_": f"{inspect.getmodule(func).__name__}.{func.__qualname__}", - "_partial_": True, - "_args_": list(data.args) if data.args else [], - } - - # Add keyword arguments if any exist - if data.keywords: - for k, v in data.keywords.items(): - value[k] = v - - return dumper.represent_data(value) - - -def _enum_representer(dumper, data): - """Represent enums in YAML.""" - # Create a dictionary representation - enum_class = data.__class__ - value = { - "_target_": f"{inspect.getmodule(enum_class).__name__}.{enum_class.__qualname__}", - "_call_": True, - "_args_": [data.value], - } - - return dumper.represent_data(value) - - -def _generation_config_representer(dumper, data): - """Represent transformers GenerationConfig objects in YAML.""" - cls = data.__class__ - value = { - "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", - "_call_": True, - "config_dict": data.to_dict(), - } - - return dumper.represent_data(value) - - -def _pretrained_config_representer(dumper, data): - """Represent transformers PretrainedConfig objects in YAML generically. - - Uses the class's from_dict/to_dict methods to ensure full round-trip of all fields. - """ - cls = data.__class__ - value = { - "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", - "_call_": True, - "config_dict": data.to_dict(), - } - return dumper.represent_data(value) diff --git a/flagscale/train/megatron/training/checkpointing.py b/flagscale/train/megatron/training/checkpointing.py index c14032dfcc..b76ba3f510 100644 --- a/flagscale/train/megatron/training/checkpointing.py +++ b/flagscale/train/megatron/training/checkpointing.py @@ -479,7 +479,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati if iteration % args.save_hf_interval == 0 or iteration == args.train_iters: #use megatron bridge from megatron.nemo_bridge.models import AutoBridge - from megatron.nemo_bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry + from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry from transformers import AutoConfig #Load the HF model from config config_load = args.hf_config_path diff --git a/flagscale/train/megatron/training/yaml_arguments.py b/flagscale/train/megatron/training/yaml_arguments.py index c8ad21e255..417be3758c 100644 --- a/flagscale/train/megatron/training/yaml_arguments.py +++ b/flagscale/train/megatron/training/yaml_arguments.py @@ -410,8 +410,6 @@ def core_transformer_config_from_yaml(args, transfomer_key = "language_model"): kw_args['deallocate_pipeline_outputs'] = True kw_args['pipeline_dtype'] = kw_args['params_dtype'] kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm - - kw_args['untie_embeddings_and_output_weights'] = args.untie_embeddings_and_output_weights assert args.activation_func in ["swiglu","squaredrelu","gelu"], f"{args.activation_func} is not a supported activation function" if args.activation_func == "swiglu": From 2900956eadf2a45579388705ee69af1b51807bc0 Mon Sep 17 00:00:00 2001 From: chai-xiaonan <3072824838@qq.com> Date: Wed, 21 Jan 2026 09:11:22 +0800 Subject: [PATCH 3/3] delete readme and swp file --- .../megatron/nemo_bridge/.requirements.txt.swp | Bin 12288 -> 0 bytes flagscale/train/megatron/nemo_bridge/README.md | 9 --------- 2 files changed, 9 deletions(-) delete mode 100644 flagscale/train/megatron/nemo_bridge/.requirements.txt.swp delete mode 100644 flagscale/train/megatron/nemo_bridge/README.md diff --git a/flagscale/train/megatron/nemo_bridge/.requirements.txt.swp b/flagscale/train/megatron/nemo_bridge/.requirements.txt.swp deleted file mode 100644 index ba527123a01890cc58a79cd9f2a4743f4f9a8587..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI%u}%Up9LMqEZZz@&sB`IoA>j#39Gr~q?WmdWSMfoVjn4 zw*P;+^!M4C&a$Vw`@wWL7R@+sm!I+R-WSpPwC@k?L59UvzHX(_8&?%_nPizuJoA%rKnw{X3?3#s%Y2GF4~o? zQawzYf8FWOeK|v*n*yieW<1K;?AKS9gZ{;1HyKEc00IagfB*srAb>ze1ybmVwD!0tg_0 Q00IagfB*srAb