Skip to content

πŸ”₯ LLM-powered GPU kernel synthesis: Train models to convert PyTorch ops into optimized Triton kernels via SFT+RL. Multi-turn compilation feedback, cross-platform NVIDIA/AMD, Kernelbook + KernelBench

Notifications You must be signed in to change notification settings

RLsys-Foundation/TritonForge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

TritonForge Logo

TritonForge

πŸ”₯ Forging Optimal GPU Kernels through SFT + RL

License Python CUDA ROCm Ask DeepWiki

Transform PyTorch Operations into Optimized GPU Kernels with LLMs

πŸ“š Documentation | πŸ—οΈ Architecture | πŸš€ Quick Start | πŸ“Š Results | πŸ—ΊοΈ Roadmap | 🀝 Contributing


🌟 Highlights

Feature Description
πŸŽ“ Two-Stage Training SFT on high-quality datasets followed by RL optimization
πŸ”„ Multi-Turn Refinement Iterative kernel improvement through compilation feedback
⚑ Cross-Platform Support for both NVIDIA CUDA and AMD ROCm GPUs
πŸ“ˆ Performance Metrics Comprehensive evaluation of correctness and speedup
πŸ§ͺ 200+ Benchmarks Extensive test suite across multiple difficulty levels

πŸ“° News

🎯 Overview

TritonForge is an advanced machine learning framework that trains Large Language Models (LLMs) to automatically convert PyTorch operations into optimized Triton GPU kernels. By combining supervised fine-tuning (SFT) with reinforcement learning (RL), TritonForge achieves state-of-the-art performance in automated kernel generation.

πŸ—οΈ Architecture Deep Dive: For a comprehensive understanding of our server-based SFT + RL framework, evaluation infrastructure, and cross-platform support, see our Architecture Documentation.

🌍 Fully Open-Source Initiative

We believe in complete transparency and community collaboration. Everything is open-source:

  • πŸ“š Training Data: Custom-curated datasets (GPUMODE/KernelBook)
  • πŸ€– Model Checkpoints: All intermediate and final models (HuggingFace)
  • πŸ—οΈ Training Framework: Complete SLIME RL implementation (fixed version with improvements)
  • 🐳 Environment Setup: Docker images and configurations for both NVIDIA and AMD
  • πŸ“– Training Recipes: Detailed scripts and hyperparameters for reproduction

We invite the community to join us in advancing automated kernel generation together!

🧠 SLIME

Reinforcement Learning Framework

Note: This is a fixed and improved version of the original SLIME framework. We believe in being honest and transparent - this is essentially SLIME with bug fixes and optimizations that enable multi-turn iterative kernel improvement through compilation feedback and performance metrics.

Learn More β†’

πŸ“Š KBenchEval

Comprehensive Benchmark Suite

Based on ScalingIntelligence/KernelBench, evaluating GPU kernel generation quality and performance across 200+ problems with varying difficulty levels

Learn More β†’

πŸš€ Quick Start

Prerequisites

Requirement NVIDIA AMD
Verified GPU H100 MI300X
Memory 80GB 192GB
Docker βœ… Required βœ… Required
Python 3.10+ 3.10+
CUDA/ROCm 12.6.1 6.3.4

Installation

Choose your platform and follow the setup guide:

Β Β Β Β 

πŸ“— NVIDIA Setup

1. Launch Docker Container

docker pull zhuzilin/slime:20250706-v2

docker run --rm --gpus all --ipc=host --shm-size=128g \
  --ulimit memlock=-1 --ulimit stack=67108864 \
  -v $HOME:$HOME \
  -it zhuzilin/slime:20250706-v2 /bin/bash

2. Clone Repository

git clone https://github.com/RLsys-Foundation/TritonForge.git
cd TritonForge

3. Setup KBenchEval

cd KBenchEval

# Create virtual environment
python -m venv .venv
source .venv/bin/activate

# Install dependencies
pip install --upgrade pip
pip install -r requirements.txt

pip install -e .

deactivate

4. Setup SLIME

cd ../SLIME
pip install -e .

5. Download Models

# Create models directory
mkdir -p models

# Hugging Face format of fine-tuned Qwen3-8B model (for evaluation)
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-HF --local-dir models/Qwen3-8B-Kernelbook-SFT-HF

# Megatron format of fine-tuned Qwen3-8B model (for continued training)
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-filtered --local-dir models/Qwen3-8B-Kernelbook-SFT-filtered

# Base Qwen3-8B model (HuggingFace format)
huggingface-cli download Qwen/Qwen3-8B --local-dir models/Qwen3-8B

# Base Qwen3-8B model (Megatron format)
huggingface-cli download zyzshishui0627/Qwen3-8B_torch_dist --local-dir models/Qwen3-8B_torch_dist
πŸ“• AMD Setup

1. Launch Docker Container

docker pull rlsys/tritonforge:stable

docker run -it \
  --device /dev/dri \
  --device /dev/kfd \
  --group-add video \
  --cap-add SYS_PTRACE \
  --security-opt seccomp=unconfined \
  --privileged \
  --shm-size 128G \
  --ulimit memlock=-1 \
  --ulimit stack=67108864 \
  -v "$HOME/.ssh:/root/.ssh:ro" \
  -v "$HOME:$HOME" \
  -e HF_HOME="$HOME/.cache/huggingface" \
  -e TRANSFORMERS_CACHE="$HOME/.cache/huggingface" \
  -e XDG_CACHE_HOME="$HOME/.cache" \
  -w "$PWD" \
  -p 127.0.0.1:18265:8265 \
  --name tritonforge_dev \
  rlsys/tritonforge:stable \
  /bin/bash

2. Clone Repository

git clone https://github.com/RLsys-Foundation/TritonForge.git
cd TritonForge

3. Setup SLIME

cd ../SLIME
pip install -e .

4. Set AMD Environment Variables

# Set AMD environment variables
# gfx942 is especially for MI300X
export ROCM_HOME=/opt/rocm
export HIP_PLATFORM=amd
export PYTORCH_ROCM_ARCH=gfx942
export PATH=$ROCM_HOME/bin:$PATH
export LD_LIBRARY_PATH=$ROCM_HOME/lib:$LD_LIBRARY_PATH
export SGLANG_API_KEY=local-key
export PYTHONPATH=/workspace/KernelBench:$PYTHONPATH

# AMD optimizations
export HSA_ENABLE_SDMA=0

# Prevent GPU core dumps
export HSA_ENABLE_COREDUMP=0
export AMD_LOG_LEVEL=0
export ROCM_DISABLE_CRASH_DUMP=1
export HIP_ENABLE_COREDUMP=0
export HSA_TOOLS_LIB=/opt/rocm/lib/librocm-debug-agent.so.2:0
export GPU_MAX_HW_QUEUES=1

5. Set up KBenchEval for MI300X

cd KBenchEval

# need to install missing packages
pip install pydra_config==0.0.15 # May need to do something fix for pydra
cd /usr/local/lib/python3.12/dist-packages && ln -sf pydra_config pydra
pip install together
pip install google-generativeai

# No more virtual environment here cause we're just using Python path in the docker
# Install dependencies
cd /root/TritonForge/KBenchEval
pip install -e .

6. Download Models

# Download the same models as NVIDIA setup
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-HF --local-dir /root/Qwen3-8B-Kernelbook-SFT-HF
huggingface-cli download JinnP/Qwen3-8B-Kernelbook-SFT-filtered --local-dir /root/Qwen3-8B-Kernelbook-SFT-filtered
huggingface-cli download Qwen/Qwen3-8B --local-dir /root/Qwen3-8B
huggingface-cli download zyzshishui0627/Qwen3-8B_torch_dist --local-dir /root/Qwen3-8B_torch_dist

πŸŽ“ Training Pipeline

TritonForge Training Pipeline

πŸ“– Detailed Architecture: See our comprehensive Architecture Documentation for the complete server-based SFT + RL framework design.

Stage 1: Supervised Fine-Tuning (SFT)

We leverage the same SLIME framework for both SFT and RL stages, providing a unified training pipeline. The SFT stage fine-tunes the base Qwen3-8B model using:

  • GPUMODE/KernelBook: 18.2k curated PyTorch-to-Triton code pairs (filtered to ~17k)
  • Custom data augmentations: Multi-turn conversations, thinking tags, and length filtering

Training Configuration (SLIME/scripts/run-qwen3-8B-kernelbook-sft.sh):

Parameter Value Purpose
Tensor Parallel (TP) 2 Splits model across 2 GPUs for memory efficiency
Context Parallel (CP) 4 Handles long sequences by splitting context
Pipeline Parallel (PP) 1 No pipeline parallelism
Data Parallel (DP) 1 Single data parallel replica
Batch Size 32 Global batch size for training
Learning Rate 1e-5 With cosine decay to 1e-6
Precision BF16 Mixed precision training
Gradient Recomputation Full (12 layers) Reduces memory footprint

The resulting model is available at JinnP/Qwen3-8B-Kernelbook-SFT-filtered.

Stage 2: Reinforcement Learning (RL)

We then apply reinforcement learning using SLIME (our fixed and improved version) to further improve the model's kernel generation capabilities:

Component Description
Training Data KernelBench Level 1-2 (200 problems)
Approach Multi-turn iterative refinement with compilation and performance feedback
Reward Signal Compilation success + functional correctness + speedup metrics
Max Turns 3 iterations per kernel
Discount Factor Ξ³ = 0.4

πŸ“Š Quick Evaluation

Test Single Problem

NVIDIA

cd KBenchEval
source .venv/bin/activate

python scripts/generate_and_eval_single_sample.py \
  dataset_src="huggingface" \
  level=1 \
  problem_id=19 \
  verbose_logging=true

AMD

cd KBenchEval

export OPENAI_API_KEY="dummy-key"
python scripts/generate_and_eval_single_sample.py \
  dataset_src=local \
  level=1 \
  problem_id=19 \
  gpu_arch='["MI300X"]' \
  backend=triton \
  server_type=sglang \
  eval_device=0 \
  verbose=True

Run Full Training

NVIDIA - SFT Stage

cd SLIME
# Supervised Fine-Tuning using SLIME
bash scripts/run-qwen3-8B-kernelbook-sft.sh

NVIDIA - RL Stage

cd SLIME
# Multi-turn kernel generation training
bash scripts/run_agent_kbench_qwen3_8B_sft_fixed.sh

AMD

# Terminal 1: Launch SGLang server
cd KBenchEval
HIP_VISIBLE_DEVICES=2,3 python3 -m sglang.launch_server \
  --model-path models/Qwen3-8B-Kernelbook-SFT-HF \
  --tp 2 \
  --trust-remote-code \
  --host 0.0.0.0 \
  --port 30000

# Terminal 2: Run evaluation
cd KBenchEval
python kernelbench_amd_tools/scripts/run_qwen3_evaluation_robust.py --levels 1,2

πŸ“ Project Structure

TritonForge/
β”œβ”€β”€ πŸ“ SLIME/                      # RL training framework (fixed version of SLIME)
β”‚   β”œβ”€β”€ slime/                     # Core SLIME framework
β”‚   β”œβ”€β”€ slime_plugins/             # Custom generators and reward functions
β”‚   └── scripts/                   # Training launch scripts
β”œβ”€β”€ πŸ“ KBenchEval/                 # Kernel evaluation framework
β”‚   β”œβ”€β”€ KernelBench/               # Benchmark problems (Level 1-2 mainly)
β”‚   β”œβ”€β”€ src/                       # Evaluation logic
β”‚   └── scripts/                   # Evaluation scripts
β”œβ”€β”€ πŸ“ docs/                       # Documentation and assets
β”‚   └── assets/                    # Images and logos

πŸ“Š Results

We evaluated our SFT fine-tuned Qwen3-8B model on KernelBench Level 1-2:

Model Level 1 Pass@1 Level 2 Pass@1 Training Data Notes
Qwen3-8B-Kernelbook-SFT 18% 8% 17k filtered samples Close to KernelBook baseline (20%)
KernelBook Baseline 20% - Original dataset Reference performance

Experimental Results

We have conducted extensive experiments across different hardware platforms and training configurations:

🎯 Multi-Turn vs Single-Turn Performance

NVIDIA H100 (Multi-Turn)

Model: Qwen3-8B Fine-tuned with SFT Training: Multi-turn iterative refinement
Hardware: NVIDIA H100 GPUs

NVIDIA H100 Multi-Turn Results

πŸ“Š View Training Logs on WandB

Key Achievements:

  • Significant improvement in kernel optimization through iterative refinement
  • Higher success rate on complex fusion patterns
  • Consistent performance gains across Level 1-2 benchmarks

AMD MI300X (Single-Turn)

Model: Qwen3-8B Fine-tuned with SFT Training: Single-turn generation
Hardware: AMD MI300X GPUs

AMD MI300X Single-Turn Results

πŸ“Š View Training Logs on WandB

Key Achievements:

  • First successful deployment on AMD MI300X architecture
  • Competitive performance with NVIDIA in single-turn setting
  • Optimized for ROCm/HIP compilation pipeline

Additional Experiments

Configuration Hardware Model Status Results
Single-Turn (Baseline) NVIDIA H100 KernelLLM βœ… Complete πŸ“– Detailed Report
Multi-Turn RL NVIDIA H100 Qwen3-8B-fine-tuned βœ… Complete See above
Single-Turn AMD MI300X Qwen3-8B-fine-tuned βœ… Complete See above
Multi-Turn RL AMD MI300X Qwen3-8B-fine-tuned πŸ”„ In Progress Coming Soon

Key Findings

  1. Multi-Turn Advantage: Multi-turn refinement shows 15-20% improvement over single-turn generation in complex kernel optimizations
  2. Cross-Platform Consistency: Similar performance characteristics observed across NVIDIA and AMD platforms
  3. Model Scaling: Fine-tuned Qwen3-8B outperforms baseline models by 25-30% on average
  4. Compilation Success: Achieved >90% compilation rate with proper error handling in multi-turn setting

πŸ—ΊοΈ Roadmap

Q4 2025 & Beyond

We have an ambitious roadmap to transform TritonForge into a comprehensive, intelligent kernel development platform. Our immediate focus for the remaining months of 2025:

Month Focus Key Deliverables Status
Oct 2025 Foundation & Quick Wins AMD stability fixes, Basic GUI v0.1, KernelBench setup πŸš€ Starting
Nov 2025 Scaling & Optimization 4+4+2 architecture, GUI v0.5, MOE testing πŸ“‹ Planned
Dec 2025 Advanced Features Qwen3-30B-A3B, Tool calling v1, GUI v1.0 πŸ“‹ Planned

🎯 Key Initiatives

  • πŸ—οΈ Infrastructure: Scale from 4+2+2 to 4+4+2 architecture for enhanced multi-turn training
  • πŸ€– Model Support: Enable MOE models like Qwen3-30B-A3B for superior performance
  • πŸ› οΈ Intelligent Agent: Tool calling for profiling, documentation access, and search
  • 🌍 Multi-DSL: Support CUDA, HIP/ROCm, OpenCL beyond just Triton
  • πŸ“Š Production GUI: Real-time monitoring and visualization dashboard

πŸ“– Full Roadmap Details

For the complete roadmap with detailed milestones, task breakdowns, and progress tracking:

We welcome community feedback and contributions to help shape TritonForge's future!

⚠️ Known Issues

AMD MI300X Multi-Turn Training Crash

Issue: Multi-turn RL training on AMD MI300X GPUs experiences node crashes within 2 steps with CPU hitting 100% utilization.

Status: πŸ” Under active investigation

Workaround:

  • Use single-turn training (stable)
  • See Issue #1 for details and updates

Reproduction: bash SLIME/scripts/run_agent_kbench_qwen3_8B_sft_amd_multiturn_robust.sh

🀝 Contributing

We believe in community-driven development and welcome all contributions! Our goal is to work together with the community to push the boundaries of automated kernel generation.

How You Can Help

  • πŸ—οΈ Add GPU Architecture Support: Extend to more NVIDIA/AMD/Intel GPUs
  • πŸ“š Contribute Training Data: Share high-quality PyTorch-to-kernel examples
  • πŸš€ Improve Optimization Strategies: Develop new kernel optimization techniques
  • πŸ”„ Enhance Multi-Turn Training: Refine the iterative improvement process
  • πŸ“ˆ Build Analysis Tools: Create performance profiling and debugging utilities
  • πŸ§ͺ Add Benchmarks: Contribute new challenging kernel problems
  • πŸ“– Improve Documentation: Help others understand and use the framework

Join our community effort to democratize GPU kernel optimization! See our Contributing Guide for more details.

πŸ™ Acknowledgments

Core Contributors

We extend our deepest gratitude to the individuals whose dedication and expertise made TritonForge possible:

πŸ—οΈ Framework Architecture

  • Zilin Zhu and Chengxing Xie - For their foundational work on the SLIME framework and the entire async slime_plugins system that enables customizable rollout and reward mechanisms. Without their RL framework, TritonForge would not have been possible.

πŸ”— System Integration

  • Xiang Long - For his crucial collaboration in bridging SLIME with KernelBench evaluation through an innovative server-based architecture, enabling seamless integration between training and evaluation pipelines.

πŸ”„ Multi-Turn Innovation

  • Kexun Zhang - For pioneering work on implementing multi-turn refinement methods and insightfully providing advice for our SFT data generation pipeline, significantly enhancing our training data quality.

πŸ’‘ Research Insights

  • Junrong Lin and Haoran Wang - For their valuable insights and contributions to the system design and optimization strategies that shaped TritonForge's architecture.

πŸš€ AMD Platform Support

  • Yusheng Su (AMD Mentor), Yuzhen Zhou, and Jiajun Li - For their instrumental support in enabling AMD MI300X compatibility and ROCm optimization. Their expertise was critical in making TritonForge a truly cross-platform solution.

Research Inspiration

πŸ“š Kevin: Multi-Turn RL for CUDA Kernels

We were heavily inspired by Kevin from Cognition AI, which pioneered multi-turn reinforcement learning for writing CUDA kernels. Kevin's approach to iterative kernel refinement through RL directly influenced our multi-turn training methodology. By open-sourcing our complete framework, we hope to contribute back to the community and enable further research in automated kernel optimization.

Project Dependencies

Project Contribution
KernelBench The foundational benchmark framework that KBenchEval is built upon
SLIME The foundational RL framework that our training system is built upon
Meta AI Laying the foundation for Triton backend support through PR #35
GPUMODE/KernelBook 18.2k curated PyTorch-to-Triton training pairs for SFT
facebook/KernelLLM Additional high-quality SFT dataset for kernel generation
Megatron-LM Distributed training infrastructure
SGLang High-performance inference serving
Triton GPU kernel programming language

πŸ“„ License

Apache 2.0 - See LICENSE file for details

πŸ“§ Contact

Β Β Β Β 


TritonForge - Forging optimal GPU kernels through reinforcement learning πŸ”₯⚑

About

πŸ”₯ LLM-powered GPU kernel synthesis: Train models to convert PyTorch ops into optimized Triton kernels via SFT+RL. Multi-turn compilation feedback, cross-platform NVIDIA/AMD, Kernelbook + KernelBench

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •