A high-performance FlashAttention-style implementation built with Triton and PyTorch. This repository provides optimized forward and backward attention kernels, benchmarking utilities, and numerical correctness validation against PyTorch reference implementations.
.
├── README.md
├── .gitignore
├── results/
└── triton/
├── benchmark_flash2.py
├── check_correctness_flash2.py
├── flash2-triton.py
└── requirements.txt
- Triton-based FlashAttention-style forward and backward kernels
- Custom PyTorch autograd integration
- Numerical correctness validation against PyTorch reference implementations
- Performance benchmarking with CUDA event timing
- Support for causal and non-causal attention
- Configurable batch sizes, sequence lengths, head counts, and head dimensions
Install dependencies using:
pip install -r triton/requirements.txt- Python 3.10+
- NVIDIA GPU with CUDA support
- CUDA Toolkit (compatible with your installed PyTorch version)
- PyTorch
- Triton
Verify CUDA availability:
nvidia-smiVerify that PyTorch can access the GPU:
python -c "import torch; print(torch.cuda.is_available())"Core implementation containing:
- Forward attention kernel
- Backward attention kernels
- Triton kernel launch logic
TritonAttentionautograd wrapper
Benchmarking utility that:
- Measures latency using CUDA events
- Reports mean, median, minimum, and maximum execution times
- Estimates throughput in TFLOPS
- Compares against PyTorch Scaled Dot Product Attention (SDPA) when available
- Falls back to Triton-only benchmarking when SDPA is unavailable
Validation script that compares Triton outputs against PyTorch reference implementations for:
- Forward pass outputs
- Query gradients (
dQ) - Key gradients (
dK) - Value gradients (
dV)
Default settings are intentionally conservative to accommodate memory-constrained GPUs.
Basic benchmark:
python triton/benchmark_flash2.pyCausal attention benchmark:
python triton/benchmark_flash2.py --causal --seq-lens 1024,2048 --head-dims 64Larger-scale benchmark:
python triton/benchmark_flash2.py --batch 8 --heads 16 --seq-lens 4096 --head-dims 64Default validation:
python triton/check_correctness_flash2.pyMinimal validation configuration:
python triton/check_correctness_flash2.py --batch 1 --heads 1 --seq-lens 128 --head-dims 64Extended validation sweep:
python triton/check_correctness_flash2.py --seq-lens 128,256,512 --head-dims 64 --causal-modes false,trueFor Kaggle or other resource-constrained environments:
- Begin with sequence lengths of
128or256 - Run correctness checks before large-scale benchmarks
- Avoid dense PyTorch reference attention for very large sequence lengths
- Benchmark Triton kernels independently when GPU memory is limited
- Expect minor autotuning differences across GPU architectures
Recommended workflow:
- Install dependencies
- Run correctness validation on small problem sizes
- Benchmark on target sequence lengths and head dimensions
Planned improvements include:
- Native CUDA implementation alongside the Triton implementation
- Additional kernel optimizations and autotuning strategies
- Support for a wider range of attention configurations
- Extended benchmarking across different GPU architectures
- Comprehensive performance comparisons between CUDA, Triton, and PyTorch SDPA
- Improved testing and validation coverage
- Multi-GPU experimentation and scaling studies
The implementation and understanding of FlashAttention concepts were informed by the following resources:
- Tri Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Tri Dao et al., FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Umar Jamil's educational content and implementation walkthroughs on Attention, FlashAttention, and Triton programming
- Official Triton Documentation
- Official PyTorch Documentation
Special thanks to Tri Dao and the FlashAttention research team for their pioneering work on efficient attention algorithms, and to Umar Jamil for providing accessible educational resources that helped deepen understanding of FlashAttention and Triton kernel development.
This repository is an independent educational and research implementation. It is not affiliated with or endorsed by the FlashAttention authors, Triton developers, or any associated organizations.