Skip to content

NVIDIA-NeMo/Megatron-Bridge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸ“£ News

Overview

NeMo Megatron Bridge is a PyTorch-native library within the NeMo Framework that provides pretraining, SFT and LoRA for popular LLM and VLM models. It serves as a powerful bridge, conversion, and verification layer between πŸ€— Hugging Face and Megatron Core. It provides bidirectional checkpoint conversion between these formats, enabling other projects to leverage Megatron Core's parallelism capabilities or export models for various inference engines. The bridge includes built-in verification mechanisms to ensure conversion accuracy and checkpoint integrity across different model formats.

On top of the bridge, NeMo Megatron Bridge provides a performant and scalable PyTorch-native training loop that leverages Megatron Core to deliver state-of-the-art training throughput. It supports pretraining and fine-tuning with features like tensor and pipeline parallelism, and mixed precision (FP8, BF16, FP4, etc.). Users can either use existing πŸ€— Hugging Face models or define custom PyTorch model definitions for flexible end-to-end workflows.

NeMo Megatron Bridge is a refactor of the previous NeMo training stack that adopts a PyTorch-native training loop to provide greater flexibility and customizability for developers.

image

πŸ”§ Installation

🐳 NeMo Framework container

The best experience, highest performance, and full feature support are provided by the NeMo Framework container. Fetch the most recent $TAG and run the following to start a container:

docker run --rm -it -w /workdir -v $(pwd):/workdir \
  --entrypoint bash \
  --gpus all \
  nvcr.io/nvidia/nemo:${TAG}

For development installation and additional details, please refer to our Contribution guide.

⚑ Quickstart

To get started, install Megatron Bridge or download a NeMo Framework container as described above.

Log in to Hugging Face Hub:

huggingface-cli login --token <your token>

Conversion-only quickstart (βœ… Core):

from megatron.bridge import AutoBridge

# 1) Create a bridge from a Hugging Face model (hub or local path)
bridge = AutoBridge.from_hf_pretrained("meta-llama/Llama-3.2-1B", trust_remote_code=True)

# 2) Get a Megatron provider and configure parallelism before instantiation
provider = bridge.to_megatron_provider()
provider.tensor_model_parallel_size = 1
provider.pipeline_model_parallel_size = 1
provider.finalize()
# 3) Materialize Megatron Core model(s)
model = provider.provide_distributed_model(wrap_with_ddp=False)

# 4a) Export Megatron β†’ Hugging Face (full HF folder with config/tokenizer/weights)
bridge.save_hf_pretrained(model, "./hf_exports/llama32_1b")

# 4b) Or stream only weights (Megatron β†’ HF)
for name, weight in bridge.export_hf_weights(model, cpu=True):
    print(name, tuple(weight.shape))

Training quickstart using pre-configured recipes:

from megatron.bridge.recipes.llama import llama32_1b_pretrain_config
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain

if __name__ == "__main__":
    # The recipe uses the Llama 3.2 1B model configuration from HuggingFace
    cfg = llama32_1b_pretrain_config(seq_length=1024)

    # Override training parameters
    cfg.train.train_iters = 10
    cfg.scheduler.lr_decay_iters = 10000
    cfg.model.vocab_size = 8192
    cfg.tokenizer.vocab_size = cfg.model.vocab_size

    pretrain(cfg, forward_step)

You can launch the above script with:

torchrun --nproc-per-node=<num devices> /path/to/script.py

More examples:

For a deeper dive into conversion design and advanced usage, see the models README.

πŸš€ Key Features

  • Bridge with πŸ€— Hugging Face: Seamless bidirectional conversion between πŸ€— Hugging Face and Megatron formats for interoperability (model bridges, auto bridge, conversion examples)
    • Online import/export without intermediate full checkpoints
    • Parallelism-aware (TP/PP/VPP/CP/EP/ETP) during conversion
    • Memory-efficient per-parameter streaming
    • Simple high-level AutoBridge API with architecture auto-detection
    • Optimized paths when Transformer Engine is available
  • Flexible to Customize: Lightweight custom training loop making it easy to configure custom logic in data loading, distributed training, checkpointing, evaluation and logging (training framework, training utilities)
  • Supervised & Parameter-Efficient Finetuning: SFT & PEFT implementation tailored for Megatron-based models that supports LoRA, DoRA, and user-defined PEFT methods (PEFT implementations, finetune module, SFT dataset)
  • SOTA Training Recipes: Pre-configured production-ready training recipes for popular models like Llama 3, with optimized hyperparameters and distributed training configuration (Llama recipes, recipe examples)
  • Performance Optimization: Built-in support for FP8 training, model parallelism, and memory-efficient techniques to offer high utilization and near-linear scalability to thousands of nodes. (mixed precision, communication overlap, optimizer utilities)

Supported Models

Megatron Bridge provides out-of-the-box bridges and training recipes for a wide range of models, built on top of base model architectures from Megatron Core. Refer to the models directory for the most up-to-date list of model bridges.

Supported Models Overview

For more details on supported models, see our documentation:

Model Checkpoint Conversion Pretrain Recipes SFT & LoRA Recipes
DeepSeek V2 βœ… βœ… (v2) Coming soon
DeepSeek V2 Lite βœ… βœ… (v2-lite) Coming soon
DeepSeek V3 βœ… βœ… (v3) Coming soon
Gemma βœ… Coming soon Coming soon
Gemma 2 βœ… Coming soon Coming soon
Gemma 3 βœ… βœ… (1B) βœ… (1B)
Gemma 3-VL βœ… Coming soon βœ… (4B/12B/27B)
GLM-4.5 βœ… βœ… (106B-Air/355B) βœ… (106B-Air/355B)
GPT-oss βœ… βœ… (20B/120B) βœ… (20B/120B)
Llama 2 βœ… βœ… (7B) Coming soon
Llama 3 βœ… βœ… (8B/70B) βœ… (8B/70B)
Llama 3.1 βœ… βœ… (8B/70B/405B) βœ… (8B/70B/405B)
Llama 3.2 βœ… βœ… (1B/3B) βœ… (1B/3B)
Llama 3.3 βœ… Coming soon Coming soon
Llama Nemotron βœ… Coming soon Coming soon
Mistral βœ… Coming soon Coming soon
Ministral βœ… βœ… 3B/8B/14B βœ… 3B/8B/14B
Moonlight βœ… βœ… (16B) βœ… (16B)
Nemotron βœ… Coming soon Coming soon
Nemotron-3 βœ… βœ… (A3B) βœ… (A3B)
Nemotron-H βœ… βœ… (4B/8B/47B/56B) Coming soon
Nemotron Nano v2 βœ… βœ… (9B/12B) Coming soon
Nemotron Nano v2 VL βœ… Coming soon βœ… (9B/12B)
OlMoE βœ… βœ… (7B) βœ… (7B)
Qwen2 βœ… βœ… (500M/1.5B/7B/72B) βœ… (500M/1.5B/7B/72B)
Qwen2.5 βœ… βœ… (500M/1.5B/7B/14B/32B/72B) βœ… (500M/1.5B/7B/14B/32B/72B)
Qwen2.5-VL βœ… Coming soon βœ… (3B/7B/32B/72B)
Qwen3 βœ… βœ… (600M/1.7B/4B/8B/14B/32B) βœ… (600M/1.7B/4B/8B/14B/32B)
Qwen3-MoE βœ… βœ… (A3B/A22B) βœ… (A3B/A22B)
Qwen3 Next βœ… βœ… (80B-A3B) βœ… (80B-A3B)
Qwen3-VL βœ… Coming soon βœ… (8B/A3B-A30B-MoE)

Launching Recipes

For a conceptual overview of how recipes are structured, overridden, and launched with either torchrun or NeMo-Run, read the Using Recipes guide.

Runnable tutorials live in tutorials/recipes/llama that covers:

  • 00_quickstart_pretrain.py for mock-data pretraining
  • 01_quickstart_finetune.py + LoRA configs
  • YAML-driven flows and launch helpers

Performance Benchmarks

For detailed performance benchmarks including throughput metrics across different GPU systems (DGX-GB200, DGX-B200, DGX-H100) and model configurations, see the Performance Summary in our documentation.

Project Structure

Megatron-Bridge/
β”œβ”€β”€ examples/
β”‚   β”œβ”€β”€ models/                  # Bridge usage examples
β”‚   └── recipes/                 # Training examples
β”œβ”€β”€ src/megatron/bridge/
β”‚   β”œβ”€β”€ data/                    # Dataloaders and iterators
β”‚   β”œβ”€β”€ models/                  # Hugging Face bridge infrastructure and model-specific implementations
β”‚   β”‚   β”œβ”€β”€ llama/               # Llama model providers
β”‚   β”‚   └── .../                 # Other models (gpt, t5, etc.)
β”‚   β”œβ”€β”€ peft/                    # PEFT transformations and wrappers
β”‚   β”œβ”€β”€ recipes/                 # Complete training recipes
β”‚   β”œβ”€β”€ training/                # Training loop components
β”‚   β”‚   β”œβ”€β”€ tokenizers/          # Tokenizer library
β”‚   β”‚   └── utils/               # Training-specific utilities
β”‚   └── utils/                   # Generic utilities for repo-wide usage
└── tests/                       # Comprehensive test suite

Acknowledgement & Contributing

Megatron-Bridge is the continuation of MBridge by Yan Bai. We appreciate all the contribution and adoptions by the community partners:

  • Mind Lab successfully used Megatron-bridge and VeRL to trained GRPO Lora for Trillion-parameter model on 64 H800 - See their techblog.
  • veRL has adopted MBridge as a connector to Megatron-Core.
  • slime has adopted MBridge as Megatron-Core checkpoint converter.
  • SkyRL has adopted MBridge as Megatron-Core connector and is migrating to Megatron-Bridge.
  • Nemo-RL has adopted Megatron-Bridge as Megatron-Core connector.
  • Community contributions: Special thanks to Guanyou He and Junyu Wu from Weixin Group Infrastructure Center.

Please see our Contributor Guidelines for more information on how to get involved.

About

HuggingFace conversion and training library for Megatron-based models

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors 64

Languages