Transform PyTorch Operations into Optimized GPU Kernels with LLMs
π Documentation | ποΈ Architecture | π Quick Start | π Results | πΊοΈ Roadmap | π€ Contributing
| Feature | Description |
|---|---|
| π Two-Stage Training | SFT on high-quality datasets followed by RL optimization |
| π Multi-Turn Refinement | Iterative kernel improvement through compilation feedback |
| β‘ Cross-Platform | Support for both NVIDIA CUDA and AMD ROCm GPUs |
| π Performance Metrics | Comprehensive evaluation of correctness and speedup |
| π§ͺ 200+ Benchmarks | Extensive test suite across multiple difficulty levels |
- [2025/10/09] π We just gave a talk about TritonForge as a guest speaker for Li Lab @ CMU! Slide here if u feel interested~
- [2025/09/29] π We released both English and Chinese versions of the TritonForge Tech Blog! English version | Chinese version (δΈζη)
TritonForge is an advanced machine learning framework that trains Large Language Models (LLMs) to automatically convert PyTorch operations into optimized Triton GPU kernels. By combining supervised fine-tuning (SFT) with reinforcement learning (RL), TritonForge achieves state-of-the-art performance in automated kernel generation.
ποΈ Architecture Deep Dive: For a comprehensive understanding of our server-based SFT + RL framework, evaluation infrastructure, and cross-platform support, see our Architecture Documentation.
We believe in complete transparency and community collaboration. Everything is open-source:
- π Training Data: Custom-curated datasets (GPUMODE/KernelBook)
- π€ Model Checkpoints: All intermediate and final models (HuggingFace)
- ποΈ Training Framework: Complete SLIME RL implementation (fixed version with improvements)
- π³ Environment Setup: Docker images and configurations for both NVIDIA and AMD
- π Training Recipes: Detailed scripts and hyperparameters for reproduction
We invite the community to join us in advancing automated kernel generation together!
|
Reinforcement Learning Framework Note: This is a fixed and improved version of the original SLIME framework. We believe in being honest and transparent - this is essentially SLIME with bug fixes and optimizations that enable multi-turn iterative kernel improvement through compilation feedback and performance metrics. |
Comprehensive Benchmark Suite Based on ScalingIntelligence/KernelBench, evaluating GPU kernel generation quality and performance across 200+ problems with varying difficulty levels |
| Requirement | NVIDIA | AMD |
|---|---|---|
| Verified GPU | H100 | MI300X |
| Memory | 80GB | 192GB |
| Docker | β Required | β Required |
| Python | 3.10+ | 3.10+ |
| CUDA/ROCm | 12.6.1 | 6.3.4 |
Choose your platform and follow the setup guide:
π NVIDIA Setup
docker pull zhuzilin/slime:20250706-v2
docker run --rm --gpus all --ipc=host --shm-size=128g \
--ulimit memlock=-1 --ulimit stack=67108864 \
-v $HOME:$HOME \
-it zhuzilin/slime:20250706-v2 /bin/bashgit clone https://github.com/RLsys-Foundation/TritonForge.git
cd TritonForgecd KBenchEval
# Create virtual environment
python -m venv .venv
source .venv/bin/activate
# Install dependencies
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
deactivatecd ../SLIME
pip install -e .# Create models directory
mkdir -p models
# Hugging Face format of fine-tuned Qwen3-8B model (for evaluation)
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-HF --local-dir models/Qwen3-8B-Kernelbook-SFT-HF
# Megatron format of fine-tuned Qwen3-8B model (for continued training)
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-filtered --local-dir models/Qwen3-8B-Kernelbook-SFT-filtered
# Base Qwen3-8B model (HuggingFace format)
huggingface-cli download Qwen/Qwen3-8B --local-dir models/Qwen3-8B
# Base Qwen3-8B model (Megatron format)
huggingface-cli download zyzshishui0627/Qwen3-8B_torch_dist --local-dir models/Qwen3-8B_torch_distπ AMD Setup
docker pull rlsys/tritonforge:stable
docker run -it \
--device /dev/dri \
--device /dev/kfd \
--group-add video \
--cap-add SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
--shm-size 128G \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v "$HOME/.ssh:/root/.ssh:ro" \
-v "$HOME:$HOME" \
-e HF_HOME="$HOME/.cache/huggingface" \
-e TRANSFORMERS_CACHE="$HOME/.cache/huggingface" \
-e XDG_CACHE_HOME="$HOME/.cache" \
-w "$PWD" \
-p 127.0.0.1:18265:8265 \
--name tritonforge_dev \
rlsys/tritonforge:stable \
/bin/bashgit clone https://github.com/RLsys-Foundation/TritonForge.git
cd TritonForgecd ../SLIME
pip install -e .# Set AMD environment variables
# gfx942 is especially for MI300X
export ROCM_HOME=/opt/rocm
export HIP_PLATFORM=amd
export PYTORCH_ROCM_ARCH=gfx942
export PATH=$ROCM_HOME/bin:$PATH
export LD_LIBRARY_PATH=$ROCM_HOME/lib:$LD_LIBRARY_PATH
export SGLANG_API_KEY=local-key
export PYTHONPATH=/workspace/KernelBench:$PYTHONPATH
# AMD optimizations
export HSA_ENABLE_SDMA=0
# Prevent GPU core dumps
export HSA_ENABLE_COREDUMP=0
export AMD_LOG_LEVEL=0
export ROCM_DISABLE_CRASH_DUMP=1
export HIP_ENABLE_COREDUMP=0
export HSA_TOOLS_LIB=/opt/rocm/lib/librocm-debug-agent.so.2:0
export GPU_MAX_HW_QUEUES=1cd KBenchEval
# need to install missing packages
pip install pydra_config==0.0.15 # May need to do something fix for pydra
cd /usr/local/lib/python3.12/dist-packages && ln -sf pydra_config pydra
pip install together
pip install google-generativeai
# No more virtual environment here cause we're just using Python path in the docker
# Install dependencies
cd /root/TritonForge/KBenchEval
pip install -e .# Download the same models as NVIDIA setup
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-HF --local-dir /root/Qwen3-8B-Kernelbook-SFT-HF
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-filtered --local-dir /root/Qwen3-8B-Kernelbook-SFT-filtered
huggingface-cli download Qwen/Qwen3-8B --local-dir /root/Qwen3-8B
huggingface-cli download zyzshishui0627/Qwen3-8B_torch_dist --local-dir /root/Qwen3-8B_torch_distπ Detailed Architecture: See our comprehensive Architecture Documentation for the complete server-based SFT + RL framework design.
We leverage the same SLIME framework for both SFT and RL stages, providing a unified training pipeline. The SFT stage fine-tunes the base Qwen3-8B model using:
- GPUMODE/KernelBook: 18.2k curated PyTorch-to-Triton code pairs (filtered to ~17k)
- Custom data augmentations: Multi-turn conversations, thinking tags, and length filtering
Training Configuration (SLIME/scripts/run-qwen3-8B-kernelbook-sft.sh):
| Parameter | Value | Purpose |
|---|---|---|
| Tensor Parallel (TP) | 2 | Splits model across 2 GPUs for memory efficiency |
| Context Parallel (CP) | 4 | Handles long sequences by splitting context |
| Pipeline Parallel (PP) | 1 | No pipeline parallelism |
| Data Parallel (DP) | 1 | Single data parallel replica |
| Batch Size | 32 | Global batch size for training |
| Learning Rate | 1e-5 | With cosine decay to 1e-6 |
| Precision | BF16 | Mixed precision training |
| Gradient Recomputation | Full (12 layers) | Reduces memory footprint |
The resulting model is available at JinnP/Qwen3-8B-Kernelbook-SFT-filtered.
We then apply reinforcement learning using SLIME (our fixed and improved version) to further improve the model's kernel generation capabilities:
| Component | Description |
|---|---|
| Training Data | KernelBench Level 1-2 (200 problems) |
| Approach | Multi-turn iterative refinement with compilation and performance feedback |
| Reward Signal | Compilation success + functional correctness + speedup metrics |
| Max Turns | 3 iterations per kernel |
| Discount Factor | Ξ³ = 0.4 |
|
NVIDIA cd KBenchEval
source .venv/bin/activate
python scripts/generate_and_eval_single_sample.py \
dataset_src="huggingface" \
level=1 \
problem_id=19 \
verbose_logging=true |
AMD cd KBenchEval
export OPENAI_API_KEY="dummy-key"
python scripts/generate_and_eval_single_sample.py \
dataset_src=local \
level=1 \
problem_id=19 \
gpu_arch='["MI300X"]' \
backend=triton \
server_type=sglang \
eval_device=0 \
verbose=True |
|
NVIDIA - SFT Stage cd SLIME
# Supervised Fine-Tuning using SLIME
bash scripts/run-qwen3-8B-kernelbook-sft.shNVIDIA - RL Stage cd SLIME
# Multi-turn kernel generation training
bash scripts/run_agent_kbench_qwen3_8B_sft_fixed.sh |
AMD # Terminal 1: Launch SGLang server
cd KBenchEval
HIP_VISIBLE_DEVICES=2,3 python3 -m sglang.launch_server \
--model-path models/Qwen3-8B-Kernelbook-SFT-HF \
--tp 2 \
--trust-remote-code \
--host 0.0.0.0 \
--port 30000
# Terminal 2: Run evaluation
cd KBenchEval
python kernelbench_amd_tools/scripts/run_qwen3_evaluation_robust.py --levels 1,2 |
TritonForge/
βββ π SLIME/ # RL training framework (fixed version of SLIME)
β βββ slime/ # Core SLIME framework
β βββ slime_plugins/ # Custom generators and reward functions
β βββ scripts/ # Training launch scripts
βββ π KBenchEval/ # Kernel evaluation framework
β βββ KernelBench/ # Benchmark problems (Level 1-2 mainly)
β βββ src/ # Evaluation logic
β βββ scripts/ # Evaluation scripts
βββ π docs/ # Documentation and assets
β βββ assets/ # Images and logos
We evaluated our SFT fine-tuned Qwen3-8B model on KernelBench Level 1-2:
| Model | Level 1 Pass@1 | Level 2 Pass@1 | Training Data | Notes |
|---|---|---|---|---|
| Qwen3-8B-Kernelbook-SFT | 18% | 8% | 17k filtered samples | Close to KernelBook baseline (20%) |
| KernelBook Baseline | 20% | - | Original dataset | Reference performance |
We have conducted extensive experiments across different hardware platforms and training configurations:
|
Model: Qwen3-8B Fine-tuned with SFT
Training: Multi-turn iterative refinement π View Training Logs on WandB Key Achievements:
|
Model: Qwen3-8B Fine-tuned with SFT
Training: Single-turn generation π View Training Logs on WandB Key Achievements:
|
| Configuration | Hardware | Model | Status | Results |
|---|---|---|---|---|
| Single-Turn (Baseline) | NVIDIA H100 | KernelLLM | β Complete | π Detailed Report |
| Multi-Turn RL | NVIDIA H100 | Qwen3-8B-fine-tuned | β Complete | See above |
| Single-Turn | AMD MI300X | Qwen3-8B-fine-tuned | β Complete | See above |
| Multi-Turn RL | AMD MI300X | Qwen3-8B-fine-tuned | π In Progress | Coming Soon |
- Multi-Turn Advantage: Multi-turn refinement shows 15-20% improvement over single-turn generation in complex kernel optimizations
- Cross-Platform Consistency: Similar performance characteristics observed across NVIDIA and AMD platforms
- Model Scaling: Fine-tuned Qwen3-8B outperforms baseline models by 25-30% on average
- Compilation Success: Achieved >90% compilation rate with proper error handling in multi-turn setting
We have an ambitious roadmap to transform TritonForge into a comprehensive, intelligent kernel development platform. Our immediate focus for the remaining months of 2025:
| Month | Focus | Key Deliverables | Status |
|---|---|---|---|
| Oct 2025 | Foundation & Quick Wins | AMD stability fixes, Basic GUI v0.1, KernelBench setup | π Starting |
| Nov 2025 | Scaling & Optimization | 4+4+2 architecture, GUI v0.5, MOE testing | π Planned |
| Dec 2025 | Advanced Features | Qwen3-30B-A3B, Tool calling v1, GUI v1.0 | π Planned |
- ποΈ Infrastructure: Scale from 4+2+2 to 4+4+2 architecture for enhanced multi-turn training
- π€ Model Support: Enable MOE models like Qwen3-30B-A3B for superior performance
- π οΈ Intelligent Agent: Tool calling for profiling, documentation access, and search
- π Multi-DSL: Support CUDA, HIP/ROCm, OpenCL beyond just Triton
- π Production GUI: Real-time monitoring and visualization dashboard
For the complete roadmap with detailed milestones, task breakdowns, and progress tracking:
We welcome community feedback and contributions to help shape TritonForge's future!
Issue: Multi-turn RL training on AMD MI300X GPUs experiences node crashes within 2 steps with CPU hitting 100% utilization.
Status: π Under active investigation
Workaround:
- Use single-turn training (stable)
- See Issue #1 for details and updates
Reproduction: bash SLIME/scripts/run_agent_kbench_qwen3_8B_sft_amd_multiturn_robust.sh
We believe in community-driven development and welcome all contributions! Our goal is to work together with the community to push the boundaries of automated kernel generation.
- ποΈ Add GPU Architecture Support: Extend to more NVIDIA/AMD/Intel GPUs
- π Contribute Training Data: Share high-quality PyTorch-to-kernel examples
- π Improve Optimization Strategies: Develop new kernel optimization techniques
- π Enhance Multi-Turn Training: Refine the iterative improvement process
- π Build Analysis Tools: Create performance profiling and debugging utilities
- π§ͺ Add Benchmarks: Contribute new challenging kernel problems
- π Improve Documentation: Help others understand and use the framework
Join our community effort to democratize GPU kernel optimization! See our Contributing Guide for more details.
We extend our deepest gratitude to the individuals whose dedication and expertise made TritonForge possible:
- Zilin Zhu and Chengxing Xie - For their foundational work on the SLIME framework and the entire async slime_plugins system that enables customizable rollout and reward mechanisms. Without their RL framework, TritonForge would not have been possible.
- Xiang Long - For his crucial collaboration in bridging SLIME with KernelBench evaluation through an innovative server-based architecture, enabling seamless integration between training and evaluation pipelines.
- Kexun Zhang - For pioneering work on implementing multi-turn refinement methods and insightfully providing advice for our SFT data generation pipeline, significantly enhancing our training data quality.
- Junrong Lin and Haoran Wang - For their valuable insights and contributions to the system design and optimization strategies that shaped TritonForge's architecture.
- Yusheng Su (AMD Mentor), Yuzhen Zhou, and Jiajun Li - For their instrumental support in enabling AMD MI300X compatibility and ROCm optimization. Their expertise was critical in making TritonForge a truly cross-platform solution.
We were heavily inspired by Kevin from Cognition AI, which pioneered multi-turn reinforcement learning for writing CUDA kernels. Kevin's approach to iterative kernel refinement through RL directly influenced our multi-turn training methodology. By open-sourcing our complete framework, we hope to contribute back to the community and enable further research in automated kernel optimization.
| Project | Contribution |
|---|---|
| KernelBench | The foundational benchmark framework that KBenchEval is built upon |
| SLIME | The foundational RL framework that our training system is built upon |
| Meta AI | Laying the foundation for Triton backend support through PR #35 |
| GPUMODE/KernelBook | 18.2k curated PyTorch-to-Triton training pairs for SFT |
| facebook/KernelLLM | Additional high-quality SFT dataset for kernel generation |
| Megatron-LM | Distributed training infrastructure |
| SGLang | High-performance inference serving |
| Triton | GPU kernel programming language |
Apache 2.0 - See LICENSE file for details