diff --git a/flagscale/train/megatron/nemo_bridge/__init__.py b/flagscale/train/megatron/nemo_bridge/__init__.py new file mode 100644 index 000000000..713df8c97 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +"""Megatron Bridge - A component of the Megatron ecosystem.""" + +from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge + +__all__ = ["AutoBridge"] diff --git a/flagscale/train/megatron/nemo_bridge/models/__init__.py b/flagscale/train/megatron/nemo_bridge/models/__init__.py new file mode 100644 index 000000000..3d2aa52d7 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import ( + AutoMapping, + QKVMapping, +) +from megatron.nemo_bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge +from megatron.nemo_bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + +__all__ = [ + "AutoBridge", + "MegatronModelBridge", + "QKVMapping", + "AutoMapping", + "DeepSeekV3Bridge", + "Qwen3Bridge", + "PreTrainedCausalLM", +] diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py b/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py new file mode 100644 index 000000000..b8a567215 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge + +__all__ = [ + "AutoBridge", +] diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py b/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py new file mode 100644 index 000000000..8112c5895 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py @@ -0,0 +1,202 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +from megatron.bridge import AutoBridge as OriginalAutoBridge +import transformers +import torch.distributed as dist +from transformers import AutoModelForCausalLM +from transformers.configuration_utils import PretrainedConfig + +from megatron.core.transformer.module import MegatronModule +from megatron.nemo_bridge.models.conversion import model_bridge +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge + +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource +from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry +from megatron.bridge.models.conversion.utils import get_causal_lm_class_via_auto_map + +from typing import TypeVar, Union +from pathlib import Path +MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule) + + +class AutoBridge(OriginalAutoBridge): + + def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): + if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): + raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance") + self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained + super().__init__(hf_pretrained) + + @classmethod + def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge": + """ + Load an AutoBridge from a pretrained model, automatically detecting the model type. + """ + # First load just the config to check architecture support + # Use thread-safe config loading to prevent race conditions + config = safe_load_config_with_retry(path, trust_remote_code=kwargs.get("trust_remote_code", False)) + + cls._validate_config(config, str(path)) + + try: + return cls(PreTrainedCausalLM.from_pretrained(path, **kwargs)) + except Exception as e: + raise ValueError(f"Failed to load model with AutoBridge: {e}") from e + + @classmethod + def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> None: + # Check if this is a causal LM model + if not cls.supports(config): + architectures = getattr(config, "architectures", []) + raise ValueError( + f"\n�~\~W Model architecture not supported by AutoBridge\n\n" + f"Model: {path}\n" + f"Architectures: {architectures}\n\n" + f"AutoBridge only supports models with architectures ending in 'ForCausalLM' or" + f"'ForConditionalGeneration' or 'NemotronH_Nano_VL_V2'.\n" + f"Found architectures that don't match this pattern.\n\n" + f"If this is a different model type (e.g., Vision, Sequence-to-Sequence),\n" + f"you may need to use a different bridge class." + ) + + # Check if we have an implementation for this specific architecture + architecture = None + for arch_name in config.architectures: + if arch_name.endswith( + ("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2") + ): + architecture = arch_name + break + + if architecture: + # Try auto_map first + arch_class = ( + get_causal_lm_class_via_auto_map(model_name_or_path=path, config=config) + if path + else None + ) + if arch_class is not None: + # For auto_map models, use class-name string + arch_key = getattr(arch_class, "__name__", str(arch_class)) + else: + try: + arch_class = getattr(transformers, architecture) + arch_key = arch_class + except AttributeError: + # Fall back to name-based registration + arch_key = architecture + + # Test if we have a registered implementation (type or class-name string) + has_implementation = False + if hasattr(model_bridge.get_model_bridge, "_exact_types"): + registry = model_bridge.get_model_bridge._exact_types + if isinstance(arch_key, str): + has_implementation = arch_key in registry + else: + has_implementation = (arch_key in registry) or ( + getattr(arch_key, "__name__", None) in registry + ) + + if not has_implementation: + raise ValueError( + f"\n�~\~W Model architecture '{architecture}' is not yet supported\n\n" + f"Model: {path}\n" + f"Architecture: {architecture}\n\n" + + f"\n\nTo add support for {architecture}, you need to:\n" + f"1. Create a new bridge class that inherits from MegatronModelBridge\n" + f"2. Implement the required methods (provider_bridge, mapping_registry)\n" + f"3. Register it with @MegatronModelBridge.register_bridge decorator\n\n" + f"Example implementation:\n" + f" from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge\n" + f" from transformers import {architecture}\n" + f" from megatron.core.models.gpt import GPTModel\n\n" + f" @MegatronModelBridge.register_bridge(source={architecture}, target=GPTModel)\n" + f" class Megatron{architecture.replace('ForCausalLM', '')}Bridge(MegatronModelBridge):\n" + f" def provider_bridge(self, hf_pretrained):\n" + f" # Return a ModelProvider instance\n" + f" ...\n\n" + f" def mapping_registry(self):\n" + f" # Return a MegatronMappingRegistry with weight mappings\n" + f" ...\n\n" + f"For reference implementations, see:\n" + f" �~@� src/megatron/bridge/models/llama/llama_bridge.py\n" + f" �~@� src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py" + ) from None + + + @classmethod + def from_hf_config(cls, config: PretrainedConfig) -> "AutoBridge": + cls._validate_config(config) + model = PreTrainedCausalLM() + model.config = config + import torch + + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + + with init_empty_weights(): + hf_model = AutoModelForCausalLM.from_config(model.config) + + for name, param in hf_model.named_parameters(): + set_module_tensor_to_device( + hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype) + ) + model.model = hf_model + return cls(model) + + def load_hf_weights( + self, model: list[MegatronModelT], hf_path: str | Path | None = None + ) -> None: + if hf_path is None: + if not isinstance(self.hf_pretrained, PreTrainedCausalLM): + raise ValueError( + "hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance" + ) + pre_trained = self.hf_pretrained + else: + pre_trained = PreTrainedCausalLM.from_pretrained(hf_path) + # Preserve trust_remote_code setting from the original bridge instance + trust_remote_code = getattr(self.hf_pretrained, 'trust_remote_code', False) + pre_trained = PreTrainedCausalLM.from_pretrained( + hf_path, trust_remote_code=trust_remote_code + ) + # self._model_bridge.load_weights_hf_to_megatron(model, pre_trained) + self._model_bridge.load_weights_hf_to_megatron(pre_trained, model) + + return model + + def save_hf_weights( + self, + model: list[MegatronModelT], + path: str | Path, + show_progress: bool = True, + strict: bool = True, + ) -> None: + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) + generator = model_bridge.stream_weights_megatron_to_hf( + dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress + ) + source = SafeTensorsStateSource(path) + # Check if the state source is SafeTensorsStateSource for streaming save. + if ( + hasattr(self.hf_pretrained, "state") + and hasattr(self.hf_pretrained.state, "source") + # and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource) + ): + # self.hf_pretrained.state.source.save_generator(generator, path, strict=strict) + source.save_generator(generator, path, strict=strict) + else: + raise ValueError( + "The state source is not a SafeTensorsStateSource, cannot save in streaming mode." + ) + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + @property + def _model_bridge(self) -> "MegatronModelBridge": + return model_bridge.get_model_bridge(self._causal_lm_architecture) diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py b/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py new file mode 100644 index 000000000..c9554ffab --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/model_bridge.py @@ -0,0 +1,359 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import itertools +import logging + +from typing import ( + Callable, + Iterable, + List, + Optional, + Type, + TypeVar, + Union, +) + +import torch +from transformers.modeling_utils import PreTrainedModel + +from megatron.core import parallel_state +from megatron.core.transformer.module import MegatronModule +from megatron.core.utils import unwrap_model + +from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping +from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + persistent_buffers, +) +from megatron.bridge.utils.common_utils import print_rank_0 +from megatron.bridge.models.decorators.dispatch import dispatch + +logger = logging.getLogger(__name__) + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge as OriginalMegatronModelBridge +from megatron.bridge.models.conversion.model_bridge import ( + _megatron_local_name_to_global, + stream_weights_megatron_to_hf, + HFWeightTuple, + WeightConversionTask, +) +MappingT = TypeVar("MappingT", bound=MegatronParamMapping) +HFPreTrained = TypeVar("HFPreTrained") +MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) +_BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") + +def padding_embedd_size(mcore_weight: torch.Tensor, hf_vocab_size: int): + hf_size = hf_vocab_size + mcore_size = mcore_weight.shape[0] + full_word = {} + is_rank0 = torch.distributed.get_rank() == 0 + # Cut out extra padding we don't need + if mcore_size > hf_size: + full_word = mcore_weight[0:hf_size, :] + if is_rank0: + print(f"> padding embedding size mcore {mcore_size} to hf {hf_size}") + + # Expanding embedding to larger size by replicating final entry + elif mcore_size < hf_size: + padding_size = hf_size - mcore_size + + full_word = torch.cat( + (mcore_weight, mcore_weight[-1].unsqueeze(0).expand(padding_size, -1)) + ) + if is_rank0: + print(f"> padding embedding size mcore {mcore_size} to hf {hf_size}") + # Same size! + else: + full_word = mcore_weight + return full_word + +class MegatronModelBridge(OriginalMegatronModelBridge): + + def _broadcast_shared_embeddings( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> None: + """Broadcast shared embeddings and output weights across embedding group. + + When embeddings and output weights are shared and pipeline parallelism is enabled, + this method ensures all ranks in the embedding group have the same weights by + broadcasting from rank 0. + + Args: + megatron_model: Megatron model instance or list of model instances. + """ + unwrapped_model = unwrap_model(megatron_model)[0] + # hack for vlm to work properly + if ( + hasattr(unwrapped_model, "language_model") + and unwrapped_model.language_model is not None + ): + unwrapped_model = unwrapped_model.language_model + model_config = unwrapped_model.config + if ( + not model_config.untie_embeddings_and_output_weights + and model_config.pipeline_model_parallel_size > 1 + ): + # Broadcast embeddings and output weights from rank 0 to embedding group + embd_group = parallel_state.get_embedding_group() + embd_group_ranks = torch.distributed.get_process_group_ranks(embd_group) + if embd_group is not None and torch.distributed.get_rank() in embd_group_ranks: + # Get embeddings and output weights from rank 0 + if hasattr(unwrapped_model, "embedding") and hasattr( + unwrapped_model.embedding, "word_embeddings" + ): + embd_weights = unwrapped_model.embedding.word_embeddings.weight.data + else: + assert hasattr(unwrapped_model, "output_layer"), "Output layer not found" + embd_weights = torch.empty_like(unwrapped_model.output_layer.weight.data) + torch.distributed.broadcast(embd_weights, src=embd_group_ranks[0], group=embd_group) + if hasattr(unwrapped_model, "output_layer"): + unwrapped_model.output_layer.weight.data.copy_(embd_weights) + + @classmethod + def register_bridge( + cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] + ) -> Callable[[_BridgeImplClass], _BridgeImplClass]: + return create_bridge_decorator(source=source, target=target) + + def stream_weights_megatron_to_hf( + self, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[HFWeightTuple]: + """Export Megatron weights to HuggingFace format. + """ + + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + # Use provided conversion tasks or build them + if conversion_tasks is None: + conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + + megatron_to_hf_tasks = conversion_tasks + model_config = unwrap_model(megatron_model)[0].config + # embeddings_are_tied = model_config.share_embeddings_and_output_weights + embeddings_are_tied = not model_config.untie_embeddings_and_output_weights + for task in self._with_progress_tracking( + megatron_to_hf_tasks, "Converting to HuggingFace", show_progress + ): + converted_weights_dict = task.mapping.megatron_to_hf( + task.param_weight, task.megatron_module + ) + + # All ranks get the full tensor + for hf_name, tensor in converted_weights_dict.items(): + final_tensor = tensor.cpu() + + if hf_name == "model.embed_tokens.weight" or hf_name == "lm_head.weight": + final_tensor = padding_embedd_size( + final_tensor, hf_pretrained.config.vocab_size + ) + + # Handle tied embeddings case + # TODO(yuya): fix this hard coded naming + if embeddings_are_tied and hf_name == "model.embed_tokens.weight": + # Yield the embedding weight + yield HFWeightTuple(hf_name, final_tensor) + + # Also yield as lm_head.weight if it's expected + if hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source"): + expected_keys = hf_pretrained.state.source.get_all_keys() + if "lm_head.weight" in expected_keys: + final_tensor = final_tensor.detach().clone() + yield HFWeightTuple("lm_head.weight", final_tensor) + elif embeddings_are_tied and hf_name == "lm_head.weight": + # This should not happen when embeddings are tied - assert error + raise ValueError( + "Encountered lm_head.weight when embeddings are tied. This indicates a mapping error." + ) + else: + # Regular case - yield the tensor normally + yield HFWeightTuple(hf_name, final_tensor) + + def build_conversion_tasks( + self, hf_pretrained: HFPreTrained, megatron_model: List[MegatronModel] + ) -> List[None | WeightConversionTask]: + """Construct the conversion tasks between HF and megatron. + + The algorithm walks over every parameter of every destination model, + asks the :class:`MegatronMappingRegistry` whether it has a mapping for that + parameter, and – if the corresponding HF weights actually exist – yields + an :class:`_HFLoadTask` describing exactly how that parameter will be + populated. + """ + + # Ensure hf_pretrained has the required state structure + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() + mapping_registry = self.mapping_registry() + model_config = unwrap_model(megatron_model)[0].config + # embeddings_are_tied = model_config.share_embeddings_and_output_weights + embeddings_are_tied = not model_config.untie_embeddings_and_output_weights + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks( + megatron_model + ) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name + for name in sorted_global_param_names_all_pp_ranks + if "output_layer" not in name + ] + + global_names_index_dict = { + name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks) + } + + tasks = [None] * len(sorted_global_param_names_all_pp_ranks) + for vp_stage, model in enumerate(megatron_model): + # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately + for local_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_name: + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global( + megatron_model, model_config, local_name, vp_stage + ) + # if name removed due to some reason, continue. e.g. embeddings_are_tied + if global_name not in global_names_index_dict: + print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") + continue + global_name_idx = global_names_index_dict[global_name] + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + # ensure hf weights exist + if isinstance(mapping.hf_param, str): + if mapping.hf_param not in hf_keys: + prefix = '.'.join(mapping.hf_param.split('.')[:-2]) + if not (('q_proj.weight' in mapping.hf_param) and ( + f'{prefix}.q_a_layernorm.weight' in hf_keys + and f'{prefix}.q_a_proj.weight' in hf_keys + and f'{prefix}.q_b_proj.weight' in hf_keys + )): + logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") + continue + else: + missing_params = [ + hf_param + for hf_param in mapping.hf_param.values() + if hf_param not in hf_keys + ] + if missing_params: + logger.warning( + f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}" + ) + continue + + local_module, local_weights = get_module_and_param_from_name( + megatron_model, local_name, vp_stage + ) + tasks[global_name_idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + + # Fill the remaining ones for pp communications + for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if tasks[idx] is None: + # This is an exception here we pass in global name + # we are not using global_name to extract module and weights + # only use it for param mapping auto dispatch checks + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + return tasks + +@dispatch +def get_model_bridge(hf_architecture) -> "MegatronModelBridge": + """Get the appropriate model bridge for a given HuggingFace architecture.""" + ... + +def register_bridge_implementation( + *, + source: Type["PreTrainedModel"] | str, + target: Type["MegatronModule"], + bridge_class: Type["MegatronModelBridge"], +) -> None: + """Register a bridge implementation with the dispatch system. + + Args: + source: HuggingFace PreTrainedModel class or the class name as a string. + Using a string allows registering bridges for architectures that are + available only via auto_map. + target: Megatron model class (e.g., GPTModel) + bridge_class: MegatronModelBridge implementation class + """ + bridge_class_name = bridge_class.__name__ + + @get_model_bridge.impl(source) + def _get_model_bridge_impl(_) -> "MegatronModelBridge": + bridge = bridge_class() + return bridge + + @stream_weights_megatron_to_hf.impl((source, target)) + def _megatron_to_hf_registered_impl( + _, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[HFWeightTuple]: + bridge = bridge_class() + + # allow bridge to access model config + bridge.hf_config = hf_pretrained.config + + return bridge.stream_weights_megatron_to_hf( + megatron_model, hf_pretrained, cpu=cpu, show_progress=show_progress, conversion_tasks=conversion_tasks + ) + + # Set meaningful names for debugging + _get_model_bridge_impl.__name__ = f"_bridge_with_{bridge_class_name}" + _megatron_to_hf_registered_impl.__name__ = f"_megatron_to_hf_with_{bridge_class_name}" + + +def create_bridge_decorator( + *, source: Type["PreTrainedModel"] | str, target: Type["MegatronModule"] +) -> Callable[[Type["MegatronModelBridge"]], Type["MegatronModelBridge"]]: + """Create a decorator for registering bridge implementations. + + Args: + source: HuggingFace PreTrainedModel class or the class name as a string + (useful for auto_map architectures) + target: Megatron model class + + Returns: + Decorator function that registers the bridge implementation + """ + + def decorator(bridge_class: Type["MegatronModelBridge"]) -> Type["MegatronModelBridge"]: + register_bridge_implementation(source=source, target=target, bridge_class=bridge_class) + return bridge_class + + return decorator diff --git a/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py b/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py new file mode 100644 index 000000000..e402a0d95 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/conversion/param_mapping.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +import torch +import torch.nn as nn +from megatron.bridge.models.conversion.utils import get_module_and_param_from_name +from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping as OriginalColumnParallelMapping +from megatron.bridge.models.conversion.param_mapping import AutoMapping as OriginalAutoMapping +from megatron.bridge.models.conversion.param_mapping import QKVMapping as OriginalQKVMapping +from megatron.bridge.models.conversion.param_mapping import ( + MegatronParamMapping, + RowParallelMapping, + ReplicatedMapping, + GatedMLPMapping, +) + + +import logging +logger = logging.getLogger(__name__) + +def col_padding_size(hf_weight: torch.Tensor, mcore_weight: torch.Tensor, tp_size: int): + hf_size = hf_weight.shape[0] + mcore_size = mcore_weight.shape[0] * tp_size + full_word = {} + is_rank0 = torch.distributed.get_rank() == 0 + # Cut out extra padding we don't need + if hf_size > mcore_size: + full_word = hf_weight[0:mcore_size, :] + if is_rank0: + print(f"> padding TP-ColumnParallelfrom {hf_size} to {mcore_size}") + + # Expanding embedding to larger size by replicating final entry + elif hf_size < mcore_size: + padding_size = mcore_size - hf_size + + full_word = torch.cat((hf_weight, hf_weight[-1].unsqueeze(0).expand(padding_size, -1))) + if is_rank0: + print(f"> padding TP-ColumnParallelfrom {hf_size} to {mcore_size}") + # Same size! + else: + full_word = hf_weight + return full_word + + +class ColumnParallelMapping(OriginalColumnParallelMapping): + """ + Mapping for column-parallel linear and embedding weights. + + """ + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Split weight along dim 0 and distribute to TP ranks.""" + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + if self.tp_size == 1: + full_weight = col_padding_size(hf_weights, target_param, self.tp_size) + return full_weight + + # On rank 0, check for divisibility and split + if self.tp_rank == 0: + if hf_weights is None: + raise ValueError("hf_weights should not be None on rank 0") + + if hf_weights.dtype != target_param.dtype: + logger.warning( + f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. " + f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. " + f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. " + ) + hf_weights = hf_weights.to(target_param.dtype) + + full_weight = col_padding_size(hf_weights, target_param, self.tp_size) + full_size = full_weight.shape[0] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(full_weight, self.tp_size, dim=0) + else: + splits = None + + # Scatter to all ranks. Each rank gets its sharded shape from its module. + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + +class AutoMapping(OriginalAutoMapping): + def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[torch.Tensor]: + """Get or create the appropriate mapping for the given type.""" + if parallelism_type == "column": + return ColumnParallelMapping(self.megatron_param, self.hf_param) + elif parallelism_type == "row": + return RowParallelMapping(self.megatron_param, self.hf_param) + elif parallelism_type == "replicated": + return ReplicatedMapping(self.megatron_param, self.hf_param) + else: + raise ValueError(f"Unknown parallelism type: {parallelism_type}") + +class QKVMapping(OriginalQKVMapping): + def __init__(self, megatron_param: str, q: str, k: str, v: str): + super().__init__(megatron_param, q, k, v) + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py new file mode 100644 index 000000000..bee2b1aee --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +from megatron.nemo_bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge # noqa: F401 + diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py new file mode 100644 index 000000000..ee257f4a0 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/common.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from megatron.nemo_bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + +try: + import apex # noqa: F401 + + HAVE_APEX = True +except ImportError: + HAVE_APEX = False + + +def get_common_configs(hf_pretrained: PreTrainedCausalLM) -> dict: + """ + Returns a dictionary of common configurations for the DeepSeek family of models. + """ + hf_config = hf_pretrained.config + + configs = {} + + if not HAVE_APEX: + configs["gradient_accumulation_fusion"] = False + + if hasattr(hf_config, "rope_scaling") and hf_config.rope_scaling is not None: + configs["rotary_scaling_factor"] = hf_config.rope_scaling["factor"] + configs["mscale"] = hf_config.rope_scaling["mscale"] + configs["mscale_all_dim"] = hf_config.rope_scaling["mscale_all_dim"] + else: + configs["rotary_scaling_factor"] = 1.0 + configs["mscale"] = 1.0 + configs["mscale_all_dim"] = 1.0 + + configs["num_layers"] = hf_config.num_hidden_layers + configs["hidden_size"] = hf_config.hidden_size + configs["ffn_hidden_size"] = hf_config.intermediate_size + configs["num_attention_heads"] = hf_config.num_attention_heads + configs["kv_channels"] = hf_config.num_key_value_heads + configs["q_lora_rank"] = hf_config.q_lora_rank + configs["num_moe_experts"] = hf_config.n_routed_experts + configs["moe_ffn_hidden_size"] = hf_config.moe_intermediate_size + configs["moe_shared_expert_intermediate_size"] = ( + hf_config.moe_intermediate_size * hf_config.n_shared_experts + ) + configs["moe_layer_freq"] = [0] * hf_config.first_k_dense_replace + [1] * ( + hf_config.num_hidden_layers - hf_config.first_k_dense_replace + ) + configs["moe_router_topk"] = hf_config.num_experts_per_tok + configs["moe_router_num_groups"] = hf_config.n_group + configs["moe_router_group_topk"] = hf_config.topk_group + configs["moe_router_topk_scaling_factor"] = hf_config.routed_scaling_factor + configs["kv_lora_rank"] = hf_config.kv_lora_rank + configs["qk_head_dim"] = hf_config.qk_nope_head_dim + configs["qk_pos_emb_head_dim"] = hf_config.qk_rope_head_dim + configs["v_head_dim"] = hf_config.v_head_dim + + # Ensure MLA is enabled + configs["multi_latent_attention"] = True + configs["generation_config"] = hf_pretrained.generation_config + configs["vocab_size"] = hf_config.vocab_size + configs["rotary_base"] = hf_config.rope_theta + configs["init_method_std"] = hf_config.initializer_range + configs["layernorm_epsilon"] = hf_config.rms_norm_eps + + return configs + + +def get_common_mapping_list() -> list: + """ + Returns a list of common parameter mappings for the DeepSeek family of models. + """ + param_mappings = { + # Embed + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + # Attention + "decoder.layers.*.input_layernorm.weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + # Reference: https://github.com/NVIDIA/NeMo/blob/50cceb9c90ea1f440d1e14074fa13bd45f60a1c4/nemo/collections/llm/gpt/model/deepseek.py#L637-L650 + # In deepseek, HF weight `model.layers.*.post_attention_layernorm.weight` is mapped to the following mcore weights depending on the layer type: + # (a) `decoder.layers.*.pre_mlp_layernorm.weight`, if the layer is MoE + # (b) `decoder.layers.*.mlp.linear_fc1.layer_norm_weight`, if the layer is dense + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.linear_kv_down_proj.weight": "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", + "decoder.layers.*.self_attention.linear_kv_up_proj.weight": "model.layers.*.self_attn.kv_b_proj.weight", + "decoder.layers.*.self_attention.linear_kv_up_proj.layer_norm_weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + # Mcore local spec + "decoder.layers.*.self_attention.kv_layernorm.weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + # Dense MLP + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + # MoE + "decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight", + "decoder.layers.*.mlp.experts.linear_fc2.weight*": "model.layers.*.mlp.experts.*.down_proj.weight", + "decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "model.layers.*.mlp.shared_experts.down_proj.weight", + # LM Head + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + # MLA + "decoder.layers.*.self_attention.linear_q_down_proj.weight": "model.layers.*.self_attn.q_a_proj.weight", + "decoder.layers.*.self_attention.linear_q_up_proj.weight": "model.layers.*.self_attn.q_b_proj.weight", + "decoder.layers.*.self_attention.linear_q_up_proj.layer_norm_weight": "model.layers.*.self_attn.q_a_layernorm.weight", + # Mcore local spec + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_a_layernorm.weight", + # For models without MLA + "decoder.layers.*.self_attention.linear_q_proj.weight": "model.layers.*.self_attn.q_proj.weight", + } + + # TODO: mtp layers + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + mapping_list.extend( + [ + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="model.layers.*.mlp.experts.*.gate_proj.weight", + up="model.layers.*.mlp.experts.*.up_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.shared_experts.linear_fc1.weight", + gate="model.layers.*.mlp.shared_experts.gate_proj.weight", + up="model.layers.*.mlp.shared_experts.up_proj.weight", + ), + ] + ) + + return mapping_list diff --git a/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py new file mode 100644 index 000000000..b83b90b11 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/deepseek/deepseek_v3_bridge.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import AutoMapping +from megatron.nemo_bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, +) +from megatron.bridge.models.deepseek.deepseek_provider import DeepSeekV3ModelProvider +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + + +@MegatronModelBridge.register_bridge(source="DeepseekV3ForCausalLM", target=GPTModel) +class DeepSeekV3Bridge(MegatronModelBridge): + """ + Megatron Bridge for DeepSeek-V3. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V3-Base", trust_remote_code=True) + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV3ModelProvider: + hf_config = hf_pretrained.config + configs = get_common_configs(hf_pretrained) + + configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16 + configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16 + configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32) + + configs["make_vocab_size_divisible_by"] = 1280 + configs["moe_router_score_function"] = "sigmoid" + configs["moe_router_enable_expert_bias"] = True + # aux_loss_alpha is not set in all DSv3 HF configs + if hasattr(hf_config, "aux_loss_alpha"): + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + # TODO: mtp + + provider = DeepSeekV3ModelProvider(**configs) + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + mapping_list = get_common_mapping_list() + + param_mappings = { + # expert bias + "decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias" + } + + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py new file mode 100644 index 000000000..f42a8bb6a --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + +__all__ = ["PreTrainedCausalLM"] diff --git a/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py new file mode 100644 index 000000000..c8e3c3500 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/hf_pretrained/causal_lm.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. + +import sys + +from pathlib import Path +from typing import Dict, Generic, List, Optional, TypeVar, Union + +import torch + +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM as OriginalPreTrainedCausalLM + +class PreTrainedCausalLM(OriginalPreTrainedCausalLM): + def __init__( + self, + model_name_or_path: Optional[Union[str, Path]] = None, + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ): + self.device = "cpu" + super().__init__( + model_name_or_path=model_name_or_path, + device=self.device, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + **kwargs + ) + #if hasattr(self, '_model') and self._model is not None: + # self._model.to("cpu") + + def save_artifacts(self, save_directory: Union[str, Path]): + """ + Saves all loaded, generic artifacts that have a `save_pretrained` method + to the specified directory. Note: This does not save the `model` attribute. + + If the model was loaded with trust_remote_code=True, this method will also + attempt to preserve any custom modeling files to ensure the saved model + can be loaded properly. + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + _ = getattr(self, "config") # trigger lazy loading of config + if hasattr(self, "_config") and self._config is not None: + self._config.save_pretrained(save_path) + + for name in self.OPTIONAL_ARTIFACTS: + artifact = getattr(self, name, None) + if artifact is not None and hasattr(artifact, "save_pretrained"): + artifact.save_pretrained(save_path) + + # Preserve custom modeling files if trust_remote_code was used + if hasattr(self, 'trust_remote_code') and self.trust_remote_code: + # Try original source path first, then fallback to model_name_or_path + source_paths = [] + if hasattr(self, '_original_source_path') and self._original_source_path: + source_paths.append(self._original_source_path) + if hasattr(self, 'model_name_or_path') and self.model_name_or_path: + source_paths.append(self.model_name_or_path) + + for source_path in source_paths: + copied_files = self._copy_custom_modeling_files(source_path, save_path) + if copied_files: + # Successfully copied files, no need to try other paths + break diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py b/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py new file mode 100644 index 000000000..e5dfd6c22 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2025, BAAI. All rights reserved. + +from megatron.nemo_bridge.models.qwen.qwen3_bridge import Qwen3Bridge # noqa: F401 + diff --git a/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py new file mode 100644 index 000000000..dbfaabf40 --- /dev/null +++ b/flagscale/train/megatron/nemo_bridge/models/qwen/qwen3_bridge.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen3ForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.nemo_bridge.models.conversion.param_mapping import ( + AutoMapping, + QKVMapping, +) +from megatron.bridge.models.conversion.param_mapping import GatedMLPMapping +from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.qwen.qwen_provider import Qwen3ModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class Qwen3Bridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3 Causal LM. + + This bridge handles the conversion between HuggingFace Qwen2ForCausalLM + (used for Qwen3 models) and Megatron-Core GPTModel formats. Qwen3 differs + from Qwen2 by using QK layernorm. + + Example: + >>> from megatron.nemo_bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-1.7B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3ModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen3ModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + qk_layernorm=True, # Qwen3 uses QK layernorm + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", # Qwen3 specific + "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", # Qwen3 specific + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + # Note: Qwen3 does NOT have bias in QKV projections (unlike Qwen2) + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + # Gated MLP: Combine gate and up projection matrices into single FC1 matrix + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/megatron/training/arguments.py b/flagscale/train/megatron/training/arguments.py index 9d4ffc213..f29ede5fe 100644 --- a/flagscale/train/megatron/training/arguments.py +++ b/flagscale/train/megatron/training/arguments.py @@ -889,6 +889,17 @@ def validate_args(args, defaults={}): if args.save_retain_interval is not None: assert args.save_retain_interval > 0 assert args.save_retain_interval % args.save_interval == 0 + + if args.save_hf is not None: + assert args.save is not None + assert args.save_interval is not None + assert args.save_interval > 0 + assert args.save_hf_interval is not None + assert args.save_hf_interval > 0 + assert args.save_hf_interval > args.save_interval and args.save_hf_interval % args.save_interval == 0, \ + "save_hf_interval must be greater than save_interval and be an integer multiple of it" + assert args.hf_config_path is not None + # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' @@ -2573,6 +2584,15 @@ def _add_checkpointing_args(parser): ' rank for saving. Turn on only if experiencing host or device memory' ' issues. Has affect only with `--dist-ckpt-optim-fully-reshardable`' ' flag.') + group.add_argument('--load-hf', action='store_true',default=None, + help='Use the HF format for warm start, and save it in the torch_dict' + 'format while also saving it in the HF format.') + group.add_argument('--save-hf', action='store_true',default=None, + help='Save as Hugging Face format checkpoint.') + group.add_argument('--hf-config-path', default=None, + help='Load the HF model from config.') + group.add_argument('--save-hf-interval', type=int, default=None, + help='Number of iterations between hf checkpoint saves.') return parser diff --git a/flagscale/train/megatron/training/checkpointing.py b/flagscale/train/megatron/training/checkpointing.py index 9489b4763..b76ba3f51 100644 --- a/flagscale/train/megatron/training/checkpointing.py +++ b/flagscale/train/megatron/training/checkpointing.py @@ -471,6 +471,24 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati # Save dataloader state if the dataloader supports it (currently only Megatron Energon). maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) + # save hf format model weight + hf_checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel, + tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=True) + if args.save_hf and hasattr(args,'hf_config_path') and args.save_hf_interval : + assert args.hf_config_path is not None, "hf_config_path should not be None" + if iteration % args.save_hf_interval == 0 or iteration == args.train_iters: + #use megatron bridge + from megatron.nemo_bridge.models import AutoBridge + from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry + from transformers import AutoConfig + #Load the HF model from config + config_load = args.hf_config_path + config = safe_load_config_with_retry(config_load, trust_remote_code=False) + bridge = AutoBridge.from_hf_config(config) + #Save the HF model weights in the corresponding iteration's safetensor folder. + safe_save = os.path.join(hf_checkpoint_name, 'safetensor') + bridge.save_hf_pretrained(model=model,path=safe_save) + # Save distributed optimizer's custom parameter state. if ( args.use_distributed_optimizer @@ -1408,6 +1426,17 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', args = get_args() load_dir = getattr(args, load_arg) + # load hf format + if args.load_hf: + # use megatron bridge + from megatron.nemo_bridge.models import AutoBridge + bridge=AutoBridge.from_hf_pretrained(load_dir) + bridge.load_hf_weights(ddp_model) + # no optimizer weight + iteration=0 + num_floating_point_operations_so_far=0 + return iteration, num_floating_point_operations_so_far + # Finetuning directories pretrained_dir = getattr(args, 'pretrained_checkpoint', None) if pretrained_dir is not None and not checkpoint_exists(load_dir): diff --git a/flagscale/train/megatron/training/yaml_arguments.py b/flagscale/train/megatron/training/yaml_arguments.py index 405d7b70f..417be3758 100644 --- a/flagscale/train/megatron/training/yaml_arguments.py +++ b/flagscale/train/megatron/training/yaml_arguments.py @@ -409,7 +409,7 @@ def core_transformer_config_from_yaml(args, transfomer_key = "language_model"): # Hardcoded kw_args['deallocate_pipeline_outputs'] = True kw_args['pipeline_dtype'] = kw_args['params_dtype'] - kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm + kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm assert args.activation_func in ["swiglu","squaredrelu","gelu"], f"{args.activation_func} is not a supported activation function" if args.activation_func == "swiglu":