Skip to content

NVIDIA-NeMo/Megatron-Bridge

Repository files navigation

Overview

NeMo Megatron Bridge is a PyTorch-native library within the NeMo Framework that serves as a powerful bridge, conversion, and verification layer between πŸ€— Hugging Face and Megatron Core. It provides bidirectional checkpoint conversion between these formats, enabling other projects to leverage Megatron Core's parallelism capabilities or export models for various inference engines. The bridge includes built-in verification mechanisms to ensure conversion accuracy and checkpoint integrity across different model formats.

On top of the bridge, NeMo Megatron Bridge provides a performant and scalable PyTorch-native training loop that leverages Megatron Core to deliver state-of-the-art training throughput. It supports pretraining and fine-tuning with features like tensor and pipeline parallelism, and mixed precision (FP8, BF16, FP4, etc.). Users can either use existing πŸ€— Hugging Face models or define custom PyTorch model definitions for flexible end-to-end workflows.

NeMo Megatron Bridge is a refactor of the previous NeMo training stack that adopts a PyTorch-native training loop to provide greater flexibility and customizability for developers.

image

πŸ”§ Installation

🐳 NeMo Framework container

The best experience, highest performance, and full feature support are provided by the NeMo Framework container. Fetch the most recent $TAG and run the following to start a container:

docker run --rm -it -w /workdir -v $(pwd):/workdir \
  --entrypoint bash \
  --gpus all \
  nvcr.io/nvidia/nemo:${TAG}

πŸ“¦ Bare-metal installation with Transformer Engine

Transformer Engine is a required dependency for Megatron Bridge. To install on bare metal (without a container), the following system requirements must be met:

  • Python >= 3.10
  • PyTorch >= 2.7
  • CUDA >= 12.8
  • cuDNN >= 9.3

We recommend installing the same versions that are present in the latest NGC PyTorch containers. The versions of these components for each container release are listed in the PyTorch and CUDA container release notes.

Please see the instructions for installing cuDNN for your target platform. You can check if the CUDA toolkit and cuDNN are installed with:

dpkg -l | grep 'cuda-toolkit'
dpkg -l | grep 'cudnn.*cuda'

Then install Megatron Bridge:

pip install torch setuptools pybind11 wheel_stub  # Required for TE
pip install --no-build-isolation megatron-bridge

Using uv

uv pip install torch --torch-backend=auto
uv pip install --no-build-isolation transformer_engine[pytorch]
uv pip install megatron-bridge

For development installation and additional details, please refer to our Contribution guide.

⚑ Quickstart

To get started, install Megatron Bridge or download a NeMo Framework container as described above.

Log in to Hugging Face Hub:

huggingface-cli login --token <your token>

Conversion-only quickstart (βœ… Core):

from megatron.bridge import AutoBridge

# 1) Create a bridge from a Hugging Face model (hub or local path)
bridge = AutoBridge.from_hf_pretrained("meta-llama/Llama-3.2-1B", trust_remote_code=True)

# 2) Get a Megatron provider and configure parallelism before instantiation
provider = bridge.to_megatron_provider()
provider.tensor_model_parallel_size = 1
provider.pipeline_model_parallel_size = 1
provider.finalize()
# 3) Materialize Megatron Core model(s)
model = provider.provide_distributed_model(wrap_with_ddp=False)

# 4a) Export Megatron β†’ Hugging Face (full HF folder with config/tokenizer/weights)
bridge.save_hf_pretrained(model, "./hf_exports/llama32_1b")

# 4b) Or stream only weights (Megatron β†’ HF)
for name, weight in bridge.export_hf_weights(model, cpu=True):
    print(name, tuple(weight.shape))

Training quickstart:

from megatron.bridge import AutoBridge

import megatron.bridge.recipes.llama.llama32_1b as llama32_1b
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain

if __name__ == "__main__":
    # Load Llama from Hugging Face Hub and convert to Megatron
    bridge = AutoBridge.from_hf_pretrained("meta-llama/Llama-3.2-1B")
    model_provider = bridge.to_megatron_provider()

    # Get defaults for other configuration from an existing Llama 3.2 recipe
    cfg = llama32_1b.pretrain_config()
    cfg.model = model_provider
    cfg.train.train_iters = 10

    cfg.dataset.seq_length = cfg.model.seq_length
    cfg.tokenizer.vocab_size = cfg.model.vocab_size

    pretrain(cfg, forward_step)

You can launch the above script with:

torchrun --nproc-per-node=<num devices> /path/to/script.py

More examples:

For a deeper dive into conversion design and advanced usage, see the models README.

πŸš€ Key Features

  • Bridge with πŸ€— Hugging Face: Seamless bidirectional conversion between πŸ€— Hugging Face and Megatron formats for interoperability (model bridges, auto bridge, conversion examples)
    • Online import/export without intermediate full checkpoints
    • Parallelism-aware (TP/PP/VPP/CP/EP/ETP) during conversion
    • Memory-efficient per-parameter streaming
    • Simple high-level AutoBridge API with architecture auto-detection
    • Optimized paths when Transformer Engine is available
  • Flexible to Customize: Lightweight custom training loop making it easy to configure custom logic in data loading, distributed training, checkpointing, evaluation and logging (training framework, training utilities)
  • Supervised & Parameter-Efficient Finetuning: SFT & PEFT implementation tailored for Megatron-based models that supports LoRA, DoRA, and user-defined PEFT methods (PEFT implementations, finetune module, SFT dataset)
  • SOTA Training Recipes: Pre-configured production-ready training recipes for popular models like Llama 3, with optimized hyperparameters and distributed training configuration (Llama recipes, recipe examples)
  • Performance Optimization: Built-in support for FP8 training, model parallelism, and memory-efficient techniques to offer high utilization and near-linear scalability to thousands of nodes. (mixed precision, communication overlap, optimizer utilities)

Supported Models

Megatron Bridge provides out-of-the-box bridges and training recipes for a wide range of models, built on top of base model architectures from Megatron Core. Refer to the models directory for the most up-to-date list of model bridges.

Supported Models Overview

Model Bridge Conversion Pretrain Recipes SFT & LoRA Recipes
Llama 3 βœ… βœ… (8b, 70b) Coming soon
Llama 3.1 βœ… βœ… (8b, 70b, 405b) Coming soon
Llama 3.2 βœ… βœ… (1b, 3b) Coming soon
Llama 3.3 βœ… Coming soon Coming soon
Qwen2 βœ… βœ… (500m, 1.5b, 7b, 72b) Coming soon
Qwen2.5 βœ… βœ… (500m, 1.5b, 7b, 14b, 32b, 72b) Coming soon
Qwen3 βœ… βœ… (600m, 1.7b, 4b, 8b, 14b, 32b) Coming soon
Qwen3-MoE βœ… βœ… (A3B, A22B) Coming soon
Qwen2.5-VL βœ… Coming soon Coming soon
DeepSeek V2 Lite βœ… βœ… (v2-lite) Coming soon
DeepSeek V2 βœ… βœ… (v2) Coming soon
DeepSeek V3 βœ… βœ… (v3) Coming soon
Moonlight βœ… Coming soon Coming soon

Launching Recipes

All recipes are ready to train out of the box, using mock data by default. For an example of how to override the default configuration through YAML or Hydra-style CLI overrides, please have a look at this script. The script can then be launched with torchrun. For example, with the aforementioned script:

torchrun --nproc-per-node=2 pretrain_llama3_8b.py model.tensor_model_parallel_size=1 <additional overrides ...>

Optionally, Megatron Bridge also supports launching with NeMo-Run. See the following examples for reference on launching with NeMo-Run:

These examples can also be run as-is with the Llama 3 8B recipe (with NeMo-Run installed).

Launch Llama 3 8B pretraining with NeMo-Run's run.Script:

uv run python pretrain_llama3_8b_nemo_run_script.py \
    --nproc-per-node=2 \
    model.pipeline_model_parallel_size=1 \
    train.train_iters=10 # this script passes Hydra-style overrides to the target script

Launch Llama 3 8B pretraining with NeMo-Run's run.Partial:

uv run python pretrain_llama3_8b_nemo_run_partial.py \
    --nproc-per-node=2

Performance Benchmarks

Coming soon ...

Project Structure

Megatron-Bridge/
β”œβ”€β”€ examples/
β”‚   β”œβ”€β”€ models/                  # Bridge usage examples
β”‚   └── recipes/                 # Training examples
β”œβ”€β”€ src/megatron/bridge/
β”‚   β”œβ”€β”€ data/                    # Dataloaders and iterators
β”‚   β”œβ”€β”€ models/                  # Hugging Face bridge infrastructure and model-specific implementations
β”‚   β”‚   β”œβ”€β”€ llama/               # Llama model providers
β”‚   β”‚   └── .../                 # Other models (gpt, t5, etc.)
β”‚   β”œβ”€β”€ peft/                    # PEFT transformations and wrappers
β”‚   β”œβ”€β”€ recipes/                 # Complete training recipes
β”‚   β”œβ”€β”€ training/                # Training loop components
β”‚   β”‚   β”œβ”€β”€ tokenizers/          # Tokenizer library
β”‚   β”‚   └── utils/               # Training-specific utilities
β”‚   └── utils/                   # Generic utilities for repo-wide usage
└── tests/                       # Comprehensive test suite

Contributing

We welcome community contributions! Please see our Contributor Guidelines for more information on how to get involved.

About

Training library for Megatron-based models

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors 34