GRPO cuBLAS and Triton Kernels
FlashGRPO is an implementation of Group Relative Policy Optimization (GRPO) for large language model training that combines cuBLAS and Triton kernels
git clone https://github.com/KayneWest/flashgrpo.git
cd flashgrpo
pip install -e .import torch
from flashgrpo import GRPOCuBLASxTritonPacked
# Initialize GRPO operator
grpo_op = GRPOCuBLASxTritonPacked(
temperature=1.0,
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04, # KL regularization
)
# Your model forward pass
with torch.cuda.amp.autocast(False):
outputs = model(input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
output_hidden_states=True)
# Extract components
H = outputs.hidden_states[-1][:, :-1, :].to(torch.bfloat16) # [B,T,K]
W = model.get_output_embeddings().weight # [V,K]
targets = batch["input_ids"][:, 1:] # [B,T]
old_logps = batch["old_logps"][:, :-1] # [B,T]
completion_mask = batch["completion_mask"][:, :-1] # [B,T]
advantages = batch["advantages"] # [B]
# Compute GRPO loss
loss = grpo_op(H, W, targets, old_logps, completion_mask, advantages)
loss.backward()
optimizer.step()FlashGRPO implements the standard GRPO objective:
but changes how the per‑token log-probabilities and gradients are computed to minimize memory, maximize throughput, and keep numerics stable at BF16/FP16 scale.
1) Streamed vocab math (no full V matrix materialization).
Vanilla GRPO typically forms full logits gathers target log‑probs, which costs
-
Packs tokens to keep only valid rows
$N \le B \cdot T$ (mask + ignore index). -
Streams the vocabulary in tiles of width
$C$ (auto‑tuned by a memory budget) and runs cuBLAS GEMM$X[N,K] \times W_{v_0:v_1}^T[C,K] \to [N,C]$ . -
Consumes each tile immediately with a Triton reducer that performs online log-sum-exp (tracks per-row $(m,s)$) and gathers the target logit
$z$ if it falls in the current tile. No global$[N,V]$ tensor is ever written.
This keeps temporary memory to
2) Online LSE + softcap for numerics.
The Triton reducer maintains
3) Overlapped execution (dual streams).
cuBLAS GEMM for tile
4) Custom backward that never forms
then contracts locally to parameters:
-
dX accumulation:
$dX += (\partial \ell/\partial \mathrm{logits}) \times W_{\mathrm{tile}}$ -
dW accumulation:
$dW_{\mathrm{tile}} += (\partial \ell/\partial \mathrm{logits})^T \times X$ -
dB accumulation (optional): row‑wise sums over
$\partial \ell/\partial \mathrm{logits}$
The kernels compute these contractions per tile and stream across
5) Ratio/clip gradient is selective and stable.
FlashGRPO computes token‑wise
6) BF16/FP16 friendly with FP32 critical paths.
Inputs can be BF16/FP16; the reduction/LSE and accumulators are FP32. Chunk size is auto‑selected given
Triton kernels reduce_chunk_online_lse (forward LSE + target gather), finalize_grpo_1d (ratio/clip + advantage), and the backward pair bwd_dx_from_tile / bwd_dwdb_from_tile handle all row‑packing, softcap gating, and FP32 accumulation; scheduling uses ping‑pong buffers with CUDA events; chunk sizing is chosen via choose_chunk_size given a memory budget.
NOTE: inspired heavily by both Liger Kernel and Flash Attention
grpo_op = GRPOCuBLASxTritonPacked(
temperature=1.0, # Temperature scaling
epsilon_low=0.2, # Lower clipping bound
epsilon_high=0.2, # Upper clipping bound
delta=None, # Optional delta clipping
softcap=30.0, # Logit clamping
chunk_size=4096, # Fixed chunk size (auto if None)
max_temp_mb=768, # Memory budget for auto chunking
beta=0.04, # KL regularization coefficient
)# Include reference model log probabilities
ref_logps = get_reference_logps(batch) # [B,T]
loss = grpo_op(
H, W, targets, old_logps, completion_mask, advantages,
ref_per_token_logps=ref_logps
)from flashgrpo.utils import estimate_memory_usage, choose_chunk_size
# Estimate memory requirements
memory_stats = estimate_memory_usage(B=4, T=512, K=4096, V=32000)
print(f"Estimated peak memory: {memory_stats['backward_peak']:.1f} MB")
# Choose optimal chunk size
chunk_size = choose_chunk_size(N_rows=1024, V=32000, dtype_bytes=2)FlashGRPO includes comprehensive benchmarking tools:
from flashgrpo.utils import run_performance_comparison, profile_memory_usage
# Performance comparison across configurations
results = run_performance_comparison(
B=4, T=512, K=4096, V=32000,
chunk_sizes=[2048, 4096, 8192, None],
num_trials=50
)
# Memory profiling
memory_stats = profile_memory_usage(B=2, T=256, K=2048, V=16000)FlashGRPO implements GRPO with the following objective:
where
Where:
-
$r = \exp(\log \pi - \log \pi_{\mathrm{old}})$ is the probability ratio -
$A$ are the advantages -
$\varepsilon_1, \varepsilon_2$ are clipping bounds -
$\beta$ is the KL regularization coefficient
# Run all tests with coverage
python run_tests.py all --coverage
# Or use pytest directly
pytest tests/ -v --cov=flashgrpo