Write PTX kernels in Python. Launch them from
jax.jit, PyTorch, andtorch.compile.
pyptx is a Python DSL for handwritten PTX on NVIDIA Ampere (sm_80),
Ada (sm_89), Hopper (sm_90a), Blackwell datacenter (sm_100a), and
Blackwell workstation (sm_120). Pre-Ampere targets like Turing (sm_75,
T4) work for kernels that stay within the sm_75 ISA — anything using
cp.async, mbarrier, bf16, wgmma, tcgen05, or TMA needs an
Ampere-or-newer card.
One call = one instruction. No optimizer, no autotuner, no tensor IR between the Python function and the PTX it emits.
- explicit registers, predicates, barriers, shared memory
- Ampere:
mma.sync(m16n8k{8,16,32}),cp.async,ldmatrix, SMEM staging - Hopper: WGMMA, TMA 2D/3D with multicast, mbarriers, cluster launch
- Blackwell:
tcgen05.mma/.ld, TMEM, SMEM descriptors, warp specialization - callable from JAX, PyTorch eager, and
torch.compile arch="auto"picks the right target for the current GPU at trace time (validated on T4, A100, L4, H100, B200, RTX Pro 6000 Blackwell)- real PTX parser + emitter + transpiler — round-trips 218+ real PTX files byte-identical
Docs: pyptx.dev · Examples:
examples/ampere/,
examples/hopper/,
examples/blackwell/ ·
API: pyptx.dev/api
| Command | What you get |
|---|---|
pip install pyptx |
DSL, parser, emitter, transpiler (no GPU runtime) |
pip install 'pyptx[torch]' |
+ PyTorch eager and torch.compile launch path |
pip install 'pyptx[jax]' |
+ jax.jit launch path via typed FFI |
pip install 'pyptx[all]' |
+ both PyTorch and JAX |
Tip: pip install ninja so the PyTorch C++ extension JIT-builds on first
launch (drops dispatch overhead from ~34 µs to ~14 µs).
| Kernel | Shape | pyptx | cuBLAS | best / cuBLAS |
|---|---|---|---|---|
GEMM (tcgen05.mma, 4-stage pipeline, 1SM) |
8192³ | 1240 TFLOPS | 1610 | 77% |
| GEMM (1SM) | 4096³ | 1194 TFLOPS | 1532 | 78% |
GEMM 2SM (cta_group::2, 5-stage) |
2048³ | 649 TFLOPS (beats 1SM) | 1006 | 64% |
| Grouped GEMM (tcgen05, MoE) | G=4 M=2048 N=256 K=2048 | 401 TFLOPS | torch ref | ~10.0× |
| RMS norm / Layer norm / SwiGLU | maintained Blackwell ports | benchmarked | torch ref | see kernel suite |
| Kernel | Shape | pyptx | vs reference |
|---|---|---|---|
| GEMM (wgmma, warp-specialized) | 8192³ | 815 TFLOPS | beats cuBLAS ≥ 6K |
| Grouped GEMM (bf16→f32) | G=8 M=K=2048 | 104 TFLOPS | — |
| RMS norm (f32) | B=2048 N=8192 | 2.6 TB/s (88% HBM) | 3.9× torch |
| Layer norm (f32) | B=2048 N=8192 | 2.5 TB/s (83% HBM) | 1.5× F.layer_norm |
| SwiGLU (f32) | M=2048 F=8192 | 2.8 TB/s (94% HBM) | 1.6× F.silu(g)*u |
| Softmax (f32, row-wise) | B=2048 N=8192 | 2.8 TB/s (95% HBM) | 1.16× torch.softmax |
| Flash attention (bf16) | M=N=4096, HD=64 | 88 µs | 3.0× naive torch |
| Kernel | Shape | pyptx | vs reference |
|---|---|---|---|
GEMM (ldmatrix.x4 + cp.async 4-stage + register frag double-buffer + XOR swizzle + serpentine mma.sync) |
4096³ bf16 | 162 TFLOPS | cuBLAS 223 TFLOPS (73%) |
| GEMM (same kernel) | 2048³ bf16 | 108 TFLOPS | cuBLAS 158 TFLOPS (68%) |
GEMM (simple mma.sync + 2-stage pipeline, teaching kernel) |
4096³ bf16 | 64 TFLOPS | cuBLAS 230 TFLOPS (28%) |
| RMS norm (f32) | B=2048 N=8192 | 928 GB/s | 2.2× torch |
| SwiGLU (f32) | M=2048 F=8192 | 1.33 TB/s | 1.35× F.silu(g)*u |
| Layer norm (f32) | B=2048 N=8192 | 916 GB/s | 0.89× F.layer_norm (torch's fused kernel is hard to beat) |
A100 numbers reproduce via python benchmarks/bench_ampere_kernels.py.
The high-perf A100 GEMM follows the CUTLASS SM80 / MatmulTutorial v15
design pattern: 128×128×32 CTA tile, 4 warps in 2×2 owning 64×64
output sub-tiles each, warp-collective ldmatrix.x4 for SMEM→register
fragment loads, 4-stage cp.async ring buffer (3 in-flight),
register fragment double-buffering that pre-loads the next
K-iter's first K-block during the current iter's last mma,
CUTLASS XOR swizzle (atom ^= row & 3) on all SMEM paths to
eliminate ldmatrix bank conflicts, serpentine N-fragment order for
adjacent-mma operand reuse, and per-thread offset hoisting so each
inner-loop ldmatrix is one add instead of 5+ ops. 64
mma.sync.m16n8k16 per warp per K-iter (256 per CTA per K-iter). We
haven't spent much time tuning this kernel — the 27% remaining gap is
addressable (persistent / stream-K scheduling, more aggressive
instruction-level overlap, autotuned tile sizes). See
examples/ampere/gemm_highperf_ampere.py
for the full kernel.
Full benchmark tables + reproduction commands: pyptx.dev/performance.
PyTorch dispatch tiers:
- CUDA graph replay: ~4 µs per launch
- Turbo eager: ~14 µs (cached C++ extension)
torch.compile: ~14–22 µs (custom_op path)
from pyptx import kernel, reg, smem, ptx, Tile
from pyptx.types import bf16, f32
@kernel(
in_specs=(Tile("M", "K", bf16), Tile("K", "N", bf16)),
out_specs=(Tile("M", "N", f32),),
grid=lambda M, N, K: (N // 64, M // 64),
block=(128, 1, 1),
arch="sm_90a",
)
def gemm(A, B, C):
sA = smem.wgmma_tile(bf16, (64, 16), major="K")
sB = smem.wgmma_tile(bf16, (16, 64), major="MN")
acc = reg.array(f32, 32)
# ... TMA loads + ptx.wgmma.mma_async(...) — each call emits exactly one PTX instructionEvery ptx.* call is a single PTX instruction. print(gemm.ptx()) shows
exactly what you wrote.
The same kernel object works in JAX, PyTorch eager, and torch.compile:
# PyTorch eager
out = gemm(a, b)
# torch.compile
out = torch.compile(gemm)(a, b)
# JAX jit (lowers through typed FFI)
out = jax.jit(gemm)(a, b)Under the hood the PTX is JITed through cuModuleLoadData, registered
with a ~150-line C++ launch shim, and dispatched from PyTorch via
torch.library.custom_op or from JAX via jax.ffi.ffi_call.
pyptx is also a real PTX-to-Python transpiler. Feed it output from
nvcc, Triton, Pallas, or any other source:
python -m pyptx.codegen kernel.ptx --sugar --name my_kernel > my_kernel.py--sugar demangles names, raises spin-loops into ptx.loop(...), collapses
mbarrier-wait blocks, and groups expression chains. Round-trips are
byte-identical on 218+ corpus files (CUTLASS, Triton, fast.cu, DeepGEMM,
ThunderKittens, LLVM tests).
The 815 TFLOPS Hopper GEMM in examples/hopper/gemm_highperf_hopper.py is
exactly this workflow applied to
fast.cu's kernel12.
Ampere (sm_80):
examples/ampere/rms_norm.py/layer_norm.py/swiglu.py/softmax.py— maintained Hopper kernels retargeted tosm_80.examples/ampere/gemm.py— single-warpmma.sync.aligned.m16n8k16bf16 GEMM, no SMEM staging. The minimal end-to-end Ampere tensor-core path.examples/ampere/gemm_pipelined.py—cp.async2-stage SMEM ring buffermma.syncon a 64×64 CTA tile (per-threadld.shared, noldmatrix). The first-step pipelined kernel (~64 TFLOPS at 4096³).
examples/ampere/gemm_highperf_ampere.py— production-leaning A100 GEMM following CUTLASS SM80 + MatmulTutorial v15. 128×128×32 CTA tile, 4 warps in 2×2 owning 64×64 each,ldmatrix.x4, 4-stagecp.asyncpipeline, register frag double-buffering across K-iters, XOR swizzle + serpentinemma, 64mma.syncper warp per K-iter. 162 TFLOPS at 4096³ bf16 = 73% of cuBLAS (2.5× the simplergemm_pipelined.py). Bit-exact through 4096³.benchmarks/bench_ampere_kernels.py— A100 RMSNorm, LayerNorm, SwiGLU, and GEMM benchmark suite.
Hopper (sm_90a):
examples/hopper/rms_norm.py— simplest real kernel, v4 loads + warp reduceexamples/hopper/grouped_gemm.py— multi-k WGMMA for MoE shapesexamples/hopper/gemm_highperf_hopper.py— warp-specialized 815 TFLOPS GEMM
Blackwell (sm_100a):
examples/blackwell/tcgen05_suite.py— 13 isolated tcgen05 primitives (alloc, MMA, ld, commit/fence, GEMM probes). Run this first on a B200 to verify the runtime stack.examples/blackwell/gemm_highperf_blackwell.py—build_gemm(1SM, 4-stage ring buffer, 1.24 PFLOPS at 8192³ bf16) andbuild_gemm_2sm(2SMcta_group::2cooperative MMA, 5-stage).examples/blackwell/gemm_experimental_blackwell.py— persistent and Pallas-style experimental GEMM paths, plus the no-TMA tcgen05 debug GEMM.examples/blackwell/grouped_gemm.py— G-problem MoE grouped GEMM on top of the sametcgen05.mmamainloop, bit-exact againsteinsum("gmk,gkn->gmn")through G=8 M=1024 N=128 K=1024.examples/blackwell/rms_norm.py/layer_norm.py/swiglu.py— Hopper kernels re-targeted tosm_100a.benchmarks/bench_blackwell_gemm.py— reproduce the 1SM + 2SM + cuBLAS table above.benchmarks/bench_blackwell_kernels.py— Blackwell grouped GEMM, RMSNorm, LayerNorm, and SwiGLU benchmark suite.
Docs:
0.1.0, pre-launch. Scope:
- handwritten PTX DSL with full Hopper ISA (wgmma, TMA 2D/3D, mbarriers, cluster)
- Blackwell
tcgen05ISA (alloc,mma.kind::f16/tf32/f8,ld/st, commit, fence) with instruction-descriptor + SMEM-descriptor helpers - PTX parser / emitter with 218+ corpus round-trip tests
- PTX → Python transpiler with sugar pass
- JAX runtime integration (typed FFI)
- PyTorch eager +
torch.compile+ CUDA graph replay - C++ dispatch extension for low-overhead launches
- GMMA/UMMA SMEM swizzle helpers (B32 / B64 / B128, CuTe-compatible
Swizzle<B,4,3>) - PyTorch autograd via
differentiable_kernel
Apache-2.0. See LICENSE.