Skip to content

KayneWest/flashgrpo

Repository files navigation

FlashGRPO

GRPO cuBLAS and Triton Kernels

FlashGRPO is an implementation of Group Relative Policy Optimization (GRPO) for large language model training that combines cuBLAS and Triton kernels

Quick Start

Installation

git clone https://github.com/KayneWest/flashgrpo.git
cd flashgrpo
pip install -e .

Basic Usage

import torch
from flashgrpo import GRPOCuBLASxTritonPacked

# Initialize GRPO operator
grpo_op = GRPOCuBLASxTritonPacked(
    temperature=1.0,
    epsilon_low=0.2,
    epsilon_high=0.2,
    beta=0.04,  # KL regularization
)

# Your model forward pass
with torch.cuda.amp.autocast(False):
    outputs = model(input_ids=batch["input_ids"],
                   attention_mask=batch["attention_mask"],
                   output_hidden_states=True)
    
    # Extract components
    H = outputs.hidden_states[-1][:, :-1, :].to(torch.bfloat16)  # [B,T,K]
    W = model.get_output_embeddings().weight                       # [V,K] 
    targets = batch["input_ids"][:, 1:]                           # [B,T]
    old_logps = batch["old_logps"][:, :-1]                        # [B,T]
    completion_mask = batch["completion_mask"][:, :-1]            # [B,T]
    advantages = batch["advantages"]                              # [B]
    
    # Compute GRPO loss
    loss = grpo_op(H, W, targets, old_logps, completion_mask, advantages)

loss.backward()
optimizer.step()

Architecture

Technical: What FlashGRPO is vs. “normal” GRPO

FlashGRPO implements the standard GRPO objective:

$$ \mathcal{L} = -\mathbb{E}\left[\min\left(r \cdot A, , \mathrm{clip}(r, 1-\varepsilon_1, 1+\varepsilon_2) \cdot A\right)\right] + \beta \cdot \mathrm{KL}(\pi_{\mathrm{ref}} | \pi) $$

but changes how the per‑token log-probabilities and gradients are computed to minimize memory, maximize throughput, and keep numerics stable at BF16/FP16 scale.

What makes FlashGRPO different

1) Streamed vocab math (no full V matrix materialization). Vanilla GRPO typically forms full logits $[N,V]$ (or full log-softmax) per step and then gathers target log‑probs, which costs $O(N \cdot V)$ memory traffic and forces a large temporary. FlashGRPO instead:

  • Packs tokens to keep only valid rows $N \le B \cdot T$ (mask + ignore index).
  • Streams the vocabulary in tiles of width $C$ (auto‑tuned by a memory budget) and runs cuBLAS GEMM $X[N,K] \times W_{v_0:v_1}^T[C,K] \to [N,C]$.
  • Consumes each tile immediately with a Triton reducer that performs online log-sum-exp (tracks per-row $(m,s)$) and gathers the target logit $z$ if it falls in the current tile. No global $[N,V]$ tensor is ever written.

This keeps temporary memory to $O(N \cdot C)$ instead of $O(N \cdot V)$ while preserving full‑vocab accuracy.

2) Online LSE + softcap for numerics. The Triton reducer maintains $(m, s)$ so that $\mathrm{LSE} = m + \log s$ is exact over all tiles. Optional softcap clamps logits to $\pm \mathrm{softcap}$ during both forward and backward, with a proper straight‑through gate (derivative zero outside the cap). Temperature is applied in‑place during reduction to avoid extra GEMMs.

3) Overlapped execution (dual streams). cuBLAS GEMM for tile $t+1$ runs concurrently with Triton reduction for tile $t$ using ping‑pong buffers and CUDA events. This hides reduction latency and keeps GEMM saturated.

4) Custom backward that never forms $p(\cdot)$ explicitly. Instead of backpropagating through a full log-softmax, FlashGRPO reconstructs per‑tile gradients:

$$ \frac{\partial \ell}{\partial \mathrm{logits}_{i,v}} = \underbrace{\frac{\partial \ell}{\partial \log p_i}}_{\mathrm{grad_row}} \cdot \left( \mathbb{1}[v = y_i] \cdot g_t - p_{i,v} \cdot g_{i,v} \right) $$

then contracts locally to parameters:

  • dX accumulation: $dX += (\partial \ell/\partial \mathrm{logits}) \times W_{\mathrm{tile}}$
  • dW accumulation: $dW_{\mathrm{tile}} += (\partial \ell/\partial \mathrm{logits})^T \times X$
  • dB accumulation (optional): row‑wise sums over $\partial \ell/\partial \mathrm{logits}$

The kernels compute these contractions per tile and stream across $V$, so no $[N,V]$ probability matrix is ever realized. Gradients are accumulated in FP32 and cast back to parameter dtype at the end.

5) Ratio/clip gradient is selective and stable. FlashGRPO computes token‑wise $\log p - \log p_{\mathrm{old}}$ in FP32 and clamps it to $[-10,10]$ before exponentiating. The gradient flows only through the chosen branch (the unclipped ratio branch when it wins and isn't hard‑clamped); if everything is clipped, a small fallback signal is used to avoid dead zones. The optional KL adds a closed‑form k3 term $\exp(\mathrm{ref}-\log p) - (\mathrm{ref}-\log p) - 1$ with derivative $1 - \exp(\mathrm{ref}-\log p)$.

6) BF16/FP16 friendly with FP32 critical paths. Inputs can be BF16/FP16; the reduction/LSE and accumulators are FP32. Chunk size is auto‑selected given $N$, $V$, dtype‑bytes, and a MB budget; everything is 128‑aligned for coalescing.

Triton kernels reduce_chunk_online_lse (forward LSE + target gather), finalize_grpo_1d (ratio/clip + advantage), and the backward pair bwd_dx_from_tile / bwd_dwdb_from_tile handle all row‑packing, softcap gating, and FP32 accumulation; scheduling uses ping‑pong buffers with CUDA events; chunk sizing is chosen via choose_chunk_size given a memory budget.

NOTE: inspired heavily by both Liger Kernel and Flash Attention

Usage

Configuration Options

grpo_op = GRPOCuBLASxTritonPacked(
    temperature=1.0,        # Temperature scaling
    epsilon_low=0.2,        # Lower clipping bound  
    epsilon_high=0.2,       # Upper clipping bound
    delta=None,             # Optional delta clipping
    softcap=30.0,           # Logit clamping
    chunk_size=4096,        # Fixed chunk size (auto if None)
    max_temp_mb=768,        # Memory budget for auto chunking
    beta=0.04,              # KL regularization coefficient
)

With KL Regularization

# Include reference model log probabilities
ref_logps = get_reference_logps(batch)  # [B,T]

loss = grpo_op(
    H, W, targets, old_logps, completion_mask, advantages,
    ref_per_token_logps=ref_logps
)

Memory-Efficient Streaming

from flashgrpo.utils import estimate_memory_usage, choose_chunk_size

# Estimate memory requirements
memory_stats = estimate_memory_usage(B=4, T=512, K=4096, V=32000)
print(f"Estimated peak memory: {memory_stats['backward_peak']:.1f} MB")

# Choose optimal chunk size
chunk_size = choose_chunk_size(N_rows=1024, V=32000, dtype_bytes=2)

Benchmarking

FlashGRPO includes comprehensive benchmarking tools:

from flashgrpo.utils import run_performance_comparison, profile_memory_usage

# Performance comparison across configurations
results = run_performance_comparison(
    B=4, T=512, K=4096, V=32000,
    chunk_sizes=[2048, 4096, 8192, None],
    num_trials=50
)

# Memory profiling
memory_stats = profile_memory_usage(B=2, T=256, K=2048, V=16000)

Algorithm

FlashGRPO implements GRPO with the following objective:

$$ L = -\mathbb{E}\left[\min\left(r \cdot A, , \mathrm{clip}(r, 1-\varepsilon_1, 1+\varepsilon_2) \cdot A\right)\right] + \beta \cdot \mathrm{KL}(\pi_{\mathrm{ref}} | \pi) $$

where $r = \frac{\pi}{\pi_{\mathrm{old}}}$.

Where:

  • $r = \exp(\log \pi - \log \pi_{\mathrm{old}})$ is the probability ratio
  • $A$ are the advantages
  • $\varepsilon_1, \varepsilon_2$ are clipping bounds
  • $\beta$ is the KL regularization coefficient

Running Tests

# Run all tests with coverage
python run_tests.py all --coverage

# Or use pytest directly
pytest tests/ -v --cov=flashgrpo

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages