MegatronApp: Toolchain built around Megatron-LM for Distributed Training
An extension for performance tuning, slow-node detection, and training-process visualization.
- [2025.10.17] π₯π₯π₯ We provide user-friendly docker guidance for all four features of MegatronApp. Please try it out!
- [2025.07.27] π’π’π’ The MegatronApp technical report has been released! See here.
- [2025.07.04] π₯π₯π₯ MegatronApp is officially launched at WAIC 2025! Our code is available here. Come and try it out!
MegaScan
MegaScope
MegatronApp is a toolchain built around the Megatron-LM training framework, designed to give practitioners a suite of value-added capabilities such as performance tuning, slow-node detection, and training-process visualization.
The project currently offers four core modules:
- MegaScan is a low-overhead tracing and anomaly detection system designed on Megatron-LM for large-scale distributed training. Detecting and locating hardware performance anomalies, such as GPU downclocking, is extremely challenging in large distributed environments. A single slow GPU can cause a cascading delay, degrading the performance of the entire cluster and making it difficult to pinpoint the source. This module aims to solve this problem by capturing and analyzing runtime trace data. By providing a global, high-precision view of all operations across all GPUs, MegaScan can identify specific patterns caused by hardware anomalies, allowing for accurate detection and root cause localization.
- MegaFBD (Forward-Backward Decoupling) β Automatically splits the forward and backward phases onto different devices to resolve imbalances in compute, communication, and memory usage between the two stages, optimizing resource allocation and boosting overall utilization.
- MegaDPP (Dynamic Pipeline Planning) β Dynamically optimizes pipeline-parallel scheduling during training, allowing each device to adjust its schedule in real time according to progress, deferring selected compute or transfer steps to alleviate network pressure.
- MegaScope β Dynamically captures, processes, and caches intermediate results during training according to user-defined metrics, then displays them through an interactive visualization interface. MegaScope aims to make the "black box" of Large Language Models transparent. With this tool, users can observe and analyze things that happen inside a model as it processes text, such as how attention scores and output probabilities are distributed, how the vector representations change among different tokens and prompts.
The four modules are fully isolated and integrated into the Megatron-LM codebase as plugins; users can flexibly enable or disable any of them at launch via control flags.
The technical report of MegatronApp can be seen here.
π Low-Overhead Tracing: Utilizes CUDA Events for high-precision, asynchronous timing of operations with minimal impact on training performance (approx. 10% overhead in tests).
π οΈ Automated Data Pipeline: Automatically aggregates trace files from all distributed ranks, reconstructs communication dependencies, and aligns scattered timelines into a single, globally consistent view.
π§ Heuristic Detection Algorithm: Implements a multi-stage heuristic algorithm to detect and locate faults like GPU downclocking by comparing peer operations across parallel dimensions and analyzing communication behavior.
π₯οΈ Rich Visualization: Generates trace files in the Chrome Tracing Format, allowing for intuitive, interactive visualization and analysis of complex distributed training runs using standard tools like chrome://tracing and Perfetto UI.
π Real-time generation and visualization: Input any prompt and watch the model generate text token by token, with its internal states displayed in sync.
π οΈ Intermediate result visualization:
- Display key intermediate variables like QKV vectors and MLP layer outputs as heatmaps.
- Attention matrix analysis: Freely select any layer and attention head to view its dynamic attention weight distribution.
- Output probability visualization: At each generate-next-token step, show the sampled token and its probability, along with other top-k candidates, revealing the model's decisions.
π§ Interactive analysis:
- A rich set of interactive controls allows users to easily switch between different visualization dimensions.
- PCA dimensionality reduction: Project high-dimensional vector representations onto a 2D space to analyze the similarities and differences between tokens and prompts.
π₯οΈ Model perturbation injection: To facilitate in-depth research on model robustness, we provide several model perturbation features.
- Storage perturbation: Inject noise into critical model parameters to simulate the error in storage devices.
- Calculation perturbation: Inject noise during the model's forward pass (e.g. at the output of MLP layer).
- System perturbation: Simulate a constant error between each transformer layer. Through the UI, users can precisely control the location, activation, type and extent of the perturbations.
π A dynamic pipeline-parallel scheduling algorithm: It selects the next microbatch to compute via a customized greedy rule based on user requirements:
- Depth-first computation: give priority to computing the same data on different model chunks for lower GPU memory usage
- Breadth-first computation: give priority to computing different data on the same model chunks for lower communication contention
π οΈ An efficient shared-memory based communication library:
- Concurrent asynchronous send/recv operations
- Dynamically track the completion status of operations
For more details, see README_Megatron.md
π Instance-Level Decoupled Scheduling: The forward and backward phases are split into two logical processes, each assigned a different rank and bound to separate resources to reduce coupling.
π οΈ Heterogeneous Resource Mapping Optimization: The forward phase can be deployed on lightly loaded devices or CPUs, alleviating GPU pressure.
π§ Differentiated Parallelism Configuration: Considering factors like activation reuse and communication volume, the forward phase is assigned a lower degree of parallelism to reduce communication overhead.
π₯οΈ Thread-Level Coordination Mechanism: A communication coordinator ensures necessary data synchronization between forward and backward phases, avoiding deadlocks and redundant communication.
MegatronApp uses a decoupled frontend-backend architecture with WebSockets to enable low-latency, real-time data communication between the model backend and the visualization frontend.
- Frontend: Based on Vite+Vue+TypeScript, rendering all interactive charts and controls.
- Backend: Based on Megatron, responsible for hosting the LLM. It uses flags to control the extraction of intermediate results during a forward pass, which maintains low time overhead when visualization function is not enabled.
- Communication: The frontend and backend are connected via a WebSocket.
We strongly recommend using the release of PyTorch NGC Container for installation. This container comes with all dependencies pre-installed with compatible versions and optimized configurations for NVIDIA GPUs.
# Run container with mounted directories
docker run --runtime --nvidia --gpus all -it --rm \
-v /path/to/megatron:/workspace/megatron \
-v /path/to/dataset:/workspace/dataset \
-v /path/to/checkpoints:/workspace/checkpoints \
nvcr.io/nvidia/pytorch:25.04-py3To install additional required packages, run
pip install -r requirements.txtWe provide a basic repro for you to quickly get started with MegaScan.
- Data preparation:
Please refer to README_Megatron.md section "Dataset Preparation" and Nvidia's Megatron-LM for more details.
- Run NVIDIAβs Megatron-LM training with MegaScan enabled by adding the following command line arguments:
--trace
--trace-dir trace_output
--trace-interval 5 # optional, default is 5 iterations
--continuous-trace-iterations 2 # optional, default is 2 iterations
--trace-granularity full # optional, default is full
--transformer-impl local # currently only supports local transformer implementationexamples/gpt3/train_gpt3_345m_distributed.sh is an example script. You can modify the script to suit your needs.
If you want to train on multiple nodes, change the GPU_PER_NODE, NUM_NODES, MASTER_ADDR, MASTER_PORT, NODE_RANK, WORLD_SIZE in the script accordingly.
Alternatively, you can use elastic training. See torchrun for more details.
- After training, you will find separated trace files in the current directory. The trace files are named as
benchmark-data-{}-pipeline-{}-tensor-{}.json, where{}is the rank number. Now we should aggregate the trace files into a single trace file:
python scripts/aggregate.py --b trace_output --output benchmark.json- You can visualize the trace file using Chrome Tracing (or Perfetto UI). Open the trace file in Chrome Tracing by navigating to
chrome://tracingin your browser (or https://ui.perfetto.dev/). Now you can explore the trace data, zoom in on specific events, and analyze the performance characteristics of your distributed training run.
-
To illustrate the detection algorithm, we can manually inject a fault into the training process. We provide a script
scripts/gpu_control.shto simulate a GPU downclocking.- Run the script to inject a fault into the training process:
# Inject a fault into GPU 0, downclocking it to 900MHz bash scripts/gpu_control.sh limit 0 900- Run the training script. Then aggregate the trace files as described above, but with an additional command line argument to enable the detection algorithm:
python scripts/aggregate.py \ -b . \ # Equivalent to --bench-dir -d # Enable the detection algorithm, Equivalent to --detectWe can see output indicating that GPU 0 may be abnormal.
First, start the backend and frontend servers.
Backend (Megatron): For inference mode, run the text generation server script, pointing it to your model and tokenizer paths, and make sure to turn on the switch --enable-ws-server in the argument.
bash examples/inference/a_text_generation_server_bash_script.sh /path/to/model /path/to/tokenizerFor example
bash examples/inference/llama_mistral/run_text_generation_llama3.sh /gfshome/llama3-ckpts/Meta-Llama-3-8B-Instruct-megatron-core-v0.12.0-TP1PP1 /root/llama3-ckpts/Meta-Llama-3-8B-InstructFor training mode, run the training script, and add --training-ws-port XXX (e.g. --training-ws-port 5000) to the argument. A typical command is
bash a_pretrain_script.sh $RANKFor example
bash pretrain_gpt.sh 0Frontend (Vue): Navigate to the frontend directory and start the development server.
cd transformer-visualize
npm run devAfter launching both, open your browser to the specified address (usually http://localhost:5173). You will see the main interface.
In the input prompts area, enter one or more prompts. Each text box represents a separate batch, allowing for parallel processing and comparison.
In the control panel, set the desired number of tokens to generate. Also enable or disable the real-time display of specific internal states, such as QKV vectors and MLP outputs. This helps manage performance and focus on relevant data. The filter expressions of vectors can be customized by the input box below.
After starting generation, the visualization results will update token-by-token. In the first tab, the intermediate vector heatmaps are displayed and the output probabilities are shown in the expandable sections.
The second tab contains attention matrices. Use the dropdown menus to select the layer and attention head you wish to inspect.
The third tab is the PCA dimensionality reduction feature where you can visually inspect the clustering of tokens and understand how the model groups similar concepts. The displayed layer can also be selected.
The expandable perturbation control panel can introduce controlled noise into the model's forward pass. Each kind of perturbation has an independent switch, controlling the noise type and intensity.
The currently supported noise types include:
- Additive Gaussian Noise (noise1): output = input + N(0, coefΒ²), where N is a random value from a Gaussian (normal) distribution with mean 0.
- Multiplicative Uniform Noise (noise2): output = input * U(1 - val, 1 + val), where U is a random value from a uniform distribution.
Similar visualization support is provided during the training process. The overall control is the same, and the training process will be controlled on the frontend page. Critical intermediate results and perturbations are supported in training.
- The following is the pod configuration.
ContainerImage: ngc.nju.edu.cn/nvidia/pytorch:25.03-py3
GPU: RTX4090
NVMEStorage: 50G
Limits:
CPU: 28
memory: 100Gi
GPU: 4
UseShm: true
ShmSize: 16Gi
UseIB: true- The Python environment in the image automatically includes almost all of the required packages. To install additional required packages, run
pip install -r requirements.txt- Install infiniband prerequisites
bash prerequisite.sh- Build the
shm_tensor_new_rdma(for multinode) andshm_tensor_new_rdma_pre_allocmodules.
cd megatron/shm_tensor_new_rdma
pip install -e .cd megatron/shm_tensor_new_rdma_pre_alloc
pip install -e .The dataset preparation step follows largely from the Megatron framework.
First, prepare your dataset in the following .json format with one sample per line
{"src": "bloomberg", "text": "BRIEF-Coach Inc launches tender offer to acquire Kate Spade & Co for $18.50 per share in cash. May 26 (Reuters) - Coach Inc: * Coach Inc launches tender offer to acquire Kate Spade & Company for $18.50 per share in cash * Coach Inc launches tender offer to acquire kate spade & company for $18.50 per share in cash * Coach Inc - tender offer will expire at 11:59 P.M. Edt on June 23, 2017, unless extended * Coach Inc - Chelsea Merger Sub Inc, has commenced a tender offer for all of outstanding shares of common stock, par value $1.00 per share, of Kate Spade & Company Source text for Eikon: Further company coverage: May 26 (Reuters) - Coach Inc: * Coach Inc launches tender offer to acquire Kate Spade & Company for $18.50 per share in cash * Coach Inc launches tender offer to acquire kate spade & company for $18.50 per share in cash * Coach Inc - tender offer will expire at 11:59 P.M. Edt on June 23, 2017, unless extended * Coach Inc - Chelsea Merger Sub Inc, has commenced a tender offer for all of outstanding shares of common stock, par value $1.00 per share, of Kate Spade & Company Source text for Eikon: Further company coverage:", "type": "Eng", "id": "0", "title": "BRIEF-Coach Inc launches tender offer to acquire Kate Spade & Co for $18.50 per share in cash. "}
{"src": "bloomberg", "text": "Var Energi agrees to buy Exxonmobil's Norway assets for $4.5 bln. MILAN, Sept 26 (Reuters) - Var Energi AS, the Norwegian oil and gas group 69.6% owned by Italian major Eni, has agreed to buy the Norwegian upstream assets of ExxonMobil for $4.5 billion. The deal is expected to be completed in the final quarter of this year, Var Energi said on Thursday. Reporting by Stephen Jewkes; editing by Francesca Landini MILAN, Sept 26 (Reuters) - Var Energi AS, the Norwegian oil and gas group 69.6% owned by Italian major Eni, has agreed to buy the Norwegian upstream assets of ExxonMobil for $4.5 billion. The deal is expected to be completed in the final quarter of this year, Var Energi said on Thursday. Reporting by Stephen Jewkes; editing by Francesca Landini", "type": "Eng", "id": "1", "title": "Var Energi agrees to buy Exxonmobil's Norway assets for $4.5 bln. "}
{"src": "bloomberg", "text": "Trump says 'incorrect' he is willing to meet Iran with 'no conditions'. WASHINGTON (Reuters) - U.S. President Donald Trump on Sunday appeared to play down the chances that he might be willing to meet with Iranian officials, saying reports that he would do so without conditions were not accurate. \u201cThe Fake News is saying that I am willing to meet with Iran, \u2018No Conditions.\u2019 That is an incorrect statement (as usual!),\u201d Trump said on Twitter. In fact, as recently as on Sept. 10, U.S. Secretary of State Mike Pompeo said \u201cHe (Trump) is prepared to meet with no preconditions.\u201d Reporting By Arshad Mohammed; Editing by Shri Navaratnam WASHINGTON (Reuters) - U.S. President Donald Trump on Sunday appeared to play down the chances that he might be willing to meet with Iranian officials, saying reports that he would do so without conditions were not accurate. \u201cThe Fake News is saying that I am willing to meet with Iran, \u2018No Conditions.\u2019 That is an incorrect statement (as usual!),\u201d Trump said on Twitter. In fact, as recently as on Sept. 10, U.S. Secretary of State Mike Pompeo said \u201cHe (Trump) is prepared to meet with no preconditions.\u201d Reporting By Arshad Mohammed; Editing by Shri Navaratnam", "type": "Eng", "id": "2", "title": "Trump says 'incorrect' he is willing to meet Iran with 'no conditions'. "}note that we have provided a sample dataset under datasets_gpt/ and datasets_bert/.
Then, prepare the vocab file (gpt and bert) and the merges file (gpt-only). We have provided it in the respective directories.
For bert, run the following
cd datasets
python ../tools/preprocess_data.py \
--input ../datasets_bert/dataset.json \
--output-prefix bert \
--vocab-file ../datasets_bert/vocab.txt \
--tokenizer-type BertWordPieceLowerCase \
--split-sentences \
--workers $(nproc)where the paths can be changed according to the location of your files and the place where you want the generated files to be.
For GPT, run the following
cd datasets
python ../tools/preprocess_data.py \
--input ../datasets_gpt/dataset.json \
--output-prefix gpt \
--vocab-file ../datasets_gpt/vocab.json \
--tokenizer-type GPT2BPETokenizer \
--merge-file ../datasets_gpt/merges.txt \
--append-eod \
--workers $(nproc)For other models, please refer to nvidia/megatron for the corresponding datasets.
To run distributed training on a single node, go to the project root directory and run
bash run_single_gpt.shfor GPT and
bash run_single_bert.shfor bert.
The run_single_<model>.sh files have the following structure:
- Parameters include
pipeline_parallel,model_chunksandtensor_parallel - The
virtual_stage_layerparameter specifies how many layers there are in a single virtual pipeline stage. It is calculated as $$ \frac{\text{total layer of model}}{\text{pipeline parallel}\times\text{model chunks}} $$ where total layer is set underexamples/the corresponding model. - It gets the IP address of the pod and writes it to the shell script.
- Finally it runs the shell script under the corresponding model under
examples/
There are also several critical parameters in examples/gpt3/train_gpt3_175b_distributed.sh (bert model under the corresponding bert/ directory)
--use-dppswitches to DPP algorithm--workloadspecifies the workload of each single thread, and hence determines the number of threads used in P2P communication--num-gpusspecifies the number of GPUs on the current node (single node training)- Other critical parameters include the number of layers of the model, the global batch size and the sequence length
- Note that currently the global batch size value is 16 and is static in
run_single_<model>.sh. It needs to simultaneously modifyrun_single_<model>.shif adjusting the layers.
For the remaining models, you can either directly run
bash examples/<model>/<train_file>.shor write a file similar to run_{single,master,worker}_<model>.sh that sets up configurations and runs the shell under examples/
To run distributed training on multiple nodes, go to the root directory. First run
bash run_master_<model>.shand then start another pod and run
bash run_worker_<model>.shThe run_master_<model>.sh has the following parameters
- Similar to
run_single_<model>.sh, we havepipeline_parallel,model_chunksandtensor_parallel - It writes the master pod IP to
examples/gpt3/train_gpt3_175b_distributed_master.shand totrain_gpt3_175b_distributed_worker.sh(bert in the corresponding directory) - Set the number of nodes to be 2 and master node has rank 0
- Starts the shell under
examples
and run_worker_<model>.sh does the following
- Set the number of nodes to be 2 and the worker node has rank 1
- Starts the shell under
examples
The examples/gpt3/train_gpt3_175b_distributed_master.sh and examples/gpt3/train_gpt3_175b_distributed_worker.sh are similar to the single node version, except that the --node-ips is mandatory, which is the infiniband IPs of the pods in the order of their GPU ranks. And also the --multi-node flag should be turned on.
Each run will generate a trace dir in benchmark. Go to the profiling directory and run
python aggregate.py --benchmark_dir benchmark/your-benchmark-dir
in the root dir to produce an aggregated trace file.
-
Install infiniband prerequisites.
-
Build the RDMA C++ extension modules:
shm_tensor_new_rdma(for multinode) andshm_tensor_new_rdma_pre_allocmodule.
Just follow above installation instructions.
bash pretrain_gpt.sh $RANKHere pretrain_gpt.sh is an example pretraining Bash script.
There are two extra options: --forward-backward-disaggregating and --ignore-forward-tensor-parallel in TRAINING_ARGS.
-
--forward-backward-disaggregatingSplits each rank into two: one for forward pass and one for backward pass. After doing this, your DP will be halved. Make sure your DP is even before adding this option.
-
--ignore-forward-tensor-parallelEnables merging forward ranks within the same TP group. After doing this, your number of ranks will be multiplied by
$\frac{TP+1}{2TP}$ . Be sure you are using the correct number of ranks.
Currently Context Parallel and Expert parallel are not supported. --transformer-impl should be local.
If you find a security issue with our project, report the vulnerability privately to OpenSQZ. It is critical to avoid public disclosure.
An overview of the vulnerability handling process is:
-
The reporter reports the vulnerability privately to OpenSQZ.
-
The appropriate project's security team works privately with the reporter to resolve the vulnerability.
-
The project creates a new release of the package the vulnerability affects to deliver its fix.
-
The project publicly announces the vulnerability and describes how to apply the fix.
Contributions and collaborations are welcome and highly appreciated. Check out the Contributor Guide and get involved.
This project is licensed under the Apache 2.0 License, see the LICENSE file for details.
Use WeChat to scan below QR code.