LLM inference engine. Rust+CUDA on GPU, JAX+XLA on TPU.
Three Gemma 4 models on TPU v6e-4: E4B (16,794 tok/s peak, 78.3 tok/s B=1, PPL 5.87), 26B-A4B MoE (14,899 tok/s peak), 31B (9,600 tok/s peak, 128K context). GPU: 31B on H100 at 8,786 tok/s (FP8, CUDA graph, PPL 14.75). Zero custom kernels on TPU -- ~500 lines of JAX. Native Rust binary on GPU -- zero Python in the serving path.
| E4B (4B) | 26B-A4B (MoE) | 31B TPU | 31B GPU | vLLM H100 | |
|---|---|---|---|---|---|
| B=1 tok/s | 78.3 | 52.9 | 44.2 | 53 | 66.9 |
| Peak tok/s | 16,794 | 14,899 | 9,600 | 8,786 | 3,848 |
| PPL | 5.87 | 90.21 | 24.76 | 14.75 | - |
| Cached TTFT | 25.9 ms | 35.3 ms | 73.3 ms | 63 ms | - |
| Peak tok/s/$ | 3,230 | 2,865 | 1,846 | 4,576 | 2,004 |
TPU: v6e-4, $5.20/hr, int8, max-ctx 2048. GPU: H100 SXM, $1.92/hr, FP8. All measured.
Pure JAX + XLA. No custom kernels. XLA compiles the entire forward pass to TPU machine code from a ~500 line JAX script. Three models, one codebase.
| Property | E4B (4B) | 26B-A4B (MoE) | 31B |
|---|---|---|---|
| Total / active params | ~4B / 4B | 26B / ~4B | 31B / 31B |
| Layers | 42 | 30 | 60 |
| Hidden size | 2,560 | 2,816 | 5,376 |
| Q / KV heads (sliding) | 8 / 2 | 16 / 8 | 32 / 16 |
| Q / KV heads (global) | 8 / 2 | 16 / 2 (V=K) | 32 / 4 (V=K) |
| Head dim (sliding / global) | 256 / 512 | 256 / 512 | 256 / 512 |
| Sliding window | 512 | 1,024 | 1,024 |
| MoE | none | 128 experts, top-8 | none |
| KV-shared layers | 18 (of 42) | 0 | 0 |
| Per-layer input injection | 256-d gated (5.6 GB embed) | none | none |
| Batch | E4B tok/s | 26B-A4B tok/s | 31B tok/s | vLLM H100 |
|---|---|---|---|---|
| 1 | 78 | 53 | 44 | 66.9 |
| 8 | 542 | 390 | 318 | 515 |
| 64 | 3,661 | 2,662 | 2,112 | 2,794 |
| 128 | 6,298 | 4,915 | 3,853 | 3,848 |
| 256 | 10,214 | 8,192 | 6,246 | 3,709 |
| 512 | 13,773 | 12,390 | 8,550 | 3,788 |
| 768 | 15,514 | 14,899 | 9,600 | 3,671 |
| 1024 | 16,794 | - | - | - |
| Context | ms/step | tok/s | Architecture | KV type |
|---|---|---|---|---|
| 512 | 12.79 | 78.2 | Single-scan, 60-layer scan + cond | bf16 |
| 2,048 | 22.6 | 44.2 | Single-scan | bf16 |
| 32K | ~66 | ~15 | Single-scan | bf16 |
| 64K | ~91 | ~11 | Split-cache, 10 groups x 6 | int8 |
| 128K | 40.56 | 24.7 | Split-cache + blockwise global | int8 |
Dual-path architecture auto-switches at the 32K boundary.
# Create TPU v6e-4 ($5.20/hr)
gcloud compute tpus tpu-vm create rvllm-gemma4 \
--zone=us-east5-b --accelerator-type=v6e-4 --version=v2-alpha-tpuv6e \
--boot-disk-size=200
# Install (30 seconds)
pip3 install 'jax[tpu]' huggingface_hub tokenizers \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Download model
huggingface-cli download google/gemma-4-E4B-it --local-dir ~/models/gemma-4-E4B-it
# Run E4B (78.3 tok/s B=1)
python3 tpu/harness/gemma4_tpu_infer.py \
--model-dir ~/models/gemma-4-E4B-it --max-tokens 200 --max-ctx 2048
# Run 31B batched (9,600 tok/s B=768)
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_tpu_scoped_vmem_limit_kib=131072" \
python3 tpu/harness/gemma4_tpu_infer.py \
--model-dir ~/models/gemma-4-31B-it --fused --max-tokens 200 --max-ctx 2048 --batch 768
# 128K context (24.7 tok/s)
python3 tpu/harness/gemma4_tpu_infer.py \
--model-dir ~/models/gemma-4-31B-it --fused --max-tokens 200 --max-ctx 131072
# API server (OpenAI-compatible)
python3 tpu/harness/api_server.py --model-dir ~/models/gemma-4-31B-it --port 8080
# Perplexity
python3 tpu/harness/gemma4_tpu_infer.py \
--model-dir ~/models/gemma-4-31B-it --perplexity --max-ctx 2048No Docker. No conda. No torch. No vLLM. One pip install, one Python file, one command.
450M-param draft head proposes K=5 tokens per cycle; the full 31B verifies K+1=6 in one forward pass. Lossless for greedy decode.
| Metric | Value |
|---|---|
| Baseline (B=1, 512 ctx) | 78.2 tok/s, 12.79 ms/step |
| EAGLE-3 fused cycle | 31.0 ms/cycle |
| Projected @ tau=3.5 | ~145 tok/s (1.8x) |
| Hardware ceiling | ~300 tok/s (3.8x) |
Requires 50K+ training examples for production tau. Current: 2K examples, loss 7.1, pipeline validated end-to-end. See tpu/harness/EAGLE3_SPEC.md.
Rust + CUDA on H100 SXM 80GB. FP8 weights with per-channel scales + CUTLASS channelscale epilogue, F16 KV cache, F16 paged attention (FA3 SM90). All 60 layers captured in a single CUDA graph (~935 nodes). 8,786 tok/s peak (B=512), PPL 14.75, TTFT 63 ms.
| Batch | tok/s | ms/step | Scaling |
|---|---|---|---|
| 1 | 53 | 18.7 | 1.0x |
| 8 | 434 | 18.4 | 8.2x |
| 32 | 1,743 | 18.4 | 32.9x |
| 64 | 3,265 | 19.6 | 61.6x |
| 128 | 5,802 | 22.1 | 109.5x |
| 256 | 7,808 | 32.8 | 147.3x |
| 512 | 8,786 | 58.3 | 165.8x |
| Batch | rvLLM tok/s | vLLM tok/s | Delta |
|---|---|---|---|
| 1 | 53 | 69 | -23% |
| 32 | 1,743 | 1,748 | ~0% |
| 64 | 3,265 | 3,130 | +4% |
| 128 | 5,802 | 4,689 | +24% |
| 256 | 7,808 | 7,077 | +10% |
| 512 | 8,786 | 8,243 | +7% |
rvLLM overtakes vLLM at B=64 and leads by 24% at B=128.
| Weight path | KV cache | PPL | tok/s (B=1) |
|---|---|---|---|
| FP8-Dynamic + CUTLASS channelscale epilogue | F16 | 14.75 | 53 |
| BF16 split QKV per-tensor FP8 | F16 | 17.96 | 37.9 |
| F16 weights (no FP8) | F16 | 19.79 | 37.9 |
| HuggingFace BF16 reference | -- | 19.62 | -- |
For each layer in 0..60:
1. fused_rmsnorm_fp8_quant input layernorm + FP8 quantize
2. cutlass_fp8_gemm_channelscale fused Q||K||V + channelscale epilogue
3. fused_qkv_rmsnorm Q/K norm (learned) + V norm (parameter-free)
4. fused_rope_partial_f16kv partial RoPE + F16 KV cache write
5. paged_decode (FA3 SM90) attention (head_dim=256 sliding, 512 global)
6. quantize_fp8_per_token attn output to FP8
7. fp8_gemm O projection
8. fused_norm_add_residual channelscale + rmsnorm + residual add
9. fused_rmsnorm_fp8_quant pre-FFN layernorm + FP8 quantize
10. cutlass_fp8_gemm_channelscale fused gate||up + channelscale epilogue
11. fused_gelu_mul_fp8_quant GELU(tanh)(gate) * up to FP8
12. fp8_gemm down projection
13. fused_norm_add_residual channelscale + rmsnorm + residual + layer_scalar
Sampling tail:
fused_rmsnorm final layernorm
f16_gemm_f32 lm_head
logit_softcap 30 * tanh(logits / 30)
argmax_kernel token selection
Four rounds of fusion + custom CUTLASS epilogue reduced graph nodes from 1776 to ~935 (47% reduction):
| Fusion | Kernels eliminated | Nodes saved |
|---|---|---|
| f32_to_bf16 + rmsnorm + vector_add -> fused_norm_add_residual | 3 -> 1 (x2/layer) | 240 |
| scale_cols_f32 fused into norm+add kernel (O-proj, down) | 1 -> 0 (x2/layer) | 120 |
| residual_scale_f16 fused into post-ff norm+add | 1 -> 0 (x1/layer) | 60 |
| vnorm_f16 fused into qk_rmsnorm -> fused_qkv_rmsnorm | 2 -> 1 (x1/layer) | 60 |
| CUTLASS channelscale epilogue (QKV, gate_up) | 3 -> 1 (x2/layer) | 240+ |
The CUTLASS channelscale kernel uses a custom SM90 EVT epilogue that applies per-token activation scale (ColBroadcast) and per-channel weight scale (RowBroadcast) directly in the GEMM epilogue while the accumulator is still F32, then casts to F16.
Help wanted: The current CUTLASS kernel uses a 128x128x128 tile which is suboptimal for low-batch decode (M <= 16). A smaller tile variant (e.g. 64x64x128) would improve B=1-8 throughput. PRs welcome for additional tile shapes with autotune selection.
# One-time on H100 box (~15 min)
bash kernels/build.sh # fused PTX
bash kernels/build_cutlass_so.sh # libcutlass_kernels.so
bash kernels/build_fa3.sh # libfa3_kernels.so
# Build
cargo build --release --features cuda --manifest-path v3/Cargo.toml -p rvllm-bench
# Run
RVLLM_MODEL_DIR=/workspace/models/gemma-4-31B-it \
RVLLM_KERNELS_DIR=/workspace/rvllm/kernels/sm_90 \
RVLLM_CUTLASS_SO=/workspace/rvllm/kernels/sm_90/libcutlass_kernels.so \
RVLLM_FA3_SO=/workspace/rvllm/kernels/sm_90/libfa3_kernels.so \
RVLLM_POLICY=/workspace/rvllm/kernels/sm_90/policy.json \
RVLLM_BATCH=128 RVLLM_ITERS=30 RVLLM_WARMUP=5 \
./v3/target/release/rvllm-benchThe server is a Rust-only Gemma 4 path with an OpenAI-compatible HTTP surface.
It keeps CUDA execution on a single engine owner thread and accepts requests
through /v1/chat/completions.
For the solidSF agents production shape, including 256K context, four-seat
admission, the paid-plan busy response, CAD harness prompting, systemd service
shape, and verification scripts, see
docs/solidsf-agent-serving.md.
export CUDA_ARCH=sm_90
export RVLLM_MODEL_DIR=/workspace/models/gemma-4-31B-it
export RVLLM_KERNELS_DIR=/workspace/rvllm/kernels/sm_90
export RVLLM_CUTLASS_SO=/workspace/rvllm/kernels/sm_90/libcutlass_kernels.so
export RVLLM_FA3_SO=/workspace/rvllm/kernels/sm_90/libfa3_kernels.so
export RVLLM_POLICY=/workspace/rvllm/kernels/sm_90/policy.json
export RVLLM_SERVED_MODEL_NAME=gemma4-31b
export RUST_LOG=info
bash kernels/build.sh sm_90
bash kernels/build_cutlass_so.sh sm_90
bash kernels/build_fa3.sh
cargo build --release --features cuda,cublaslt --manifest-path v3/Cargo.toml -p rvllm-serve
./v3/target/release/rvllm-server \
--host 127.0.0.1 \
--port 8080 \
--max-model-len 8192 \
--max-num-seqs 1 \
--max-num-batched-tokens 2048 \
--max-prefill-chunk 128The server exposes GET /health, GET /v1/models, and
POST /v1/chat/completions with non-stream and SSE streaming responses.
Only greedy Gemma 4 decoding is currently enabled; set temperature: 0.
Smoke:
curl -fsS http://127.0.0.1:8080/health
curl -fsS http://127.0.0.1:8080/v1/models
curl -fsS http://127.0.0.1:8080/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{"model":"gemma4-31b","messages":[{"role":"user","content":"Reply exactly: RVLLM_RUST_OK"}],"max_tokens":16,"temperature":0}'
curl -fsS --no-buffer http://127.0.0.1:8080/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{"model":"gemma4-31b","messages":[{"role":"user","content":"hi"}],"max_tokens":16,"temperature":0,"stream":true}'For bind-only local checks without CUDA:
RVLLM_DRY_RUN=1 cargo run --manifest-path v3/Cargo.toml -p rvllm-serve -- \
--host 127.0.0.1 \
--port 8080Every kernel has a known purpose, a pinned variant, and a workspace contract. No dispatch fallback chains.
| Kernel | Purpose |
|---|---|
cutlass_fp8_gemm_channelscale |
SM90 FP8 GEMM with EVT channelscale epilogue (QKV, gate_up) |
fused_rmsnorm_fp8_quant |
layernorm + FP8 quantize in one launch |
fused_qkv_rmsnorm |
per-head RMSNorm on Q, K (learned) and V (parameter-free) |
fused_rope_partial_f16kv |
partial RoPE + F16 KV cache write |
fused_gelu_mul_fp8_quant |
GELU(tanh)(gate) * up to FP8 |
fused_norm_add_residual |
channelscale + RMSNorm + residual add (+ optional layer_scalar) |
logit_softcap |
30 * tanh(logits / 30) |
quantize_fp8_per_token |
activation to FP8 with per-token scale |
argmax |
f32 logits to i32 token |
No fallbacks. Missing kernel .so = engine refuses to start.
v3/crates/
rvllm-core typed errors, IDs, dtype, shape, config, env
rvllm-mem HbmArena, Region, Stream, Event, PinnedBuf, CudaContextHandle
rvllm-kernels manifest (sha-pinned), PTX loader, kernel catalog
rvllm-fused 8 fused-kernel launchers + pure-Rust f32 references
rvllm-attention FA3 SM90 paged decode/prefill dlopen
rvllm-cutlass FP8 variant catalog + schedule pairing trait + cuBLASLt wrapper
rvllm-metadata frozen-layout metadata per bucket (one upload path)
rvllm-loader safetensors mmap -> HBM + CPU-path FP8 quant + clamp gate
rvllm-sampling argmax tail, pinned DtoH
rvllm-graph captured-graph pool keyed on MetaLayoutHash
rvllm-runtime Engine, scheduler, layer_exec, bring_up
rvllm-bench RVLLM_* env-driven bench binary
rvllm-invariants DAG-dep test, no-megakernel gate
- No fallbacks. Missing autotune entry = engine panic. Missing .so = refuse start. No silent degradation.
- Graph-capture invariant. Metadata buffer layout frozen per (bucket, max_blocks_per_seq). Captured graphs bind exact offsets.
- CUTLASS schedule/epilogue pairing. Mainloop and epilogue schedules must match. Enforced via
static_assert. - No
unwrap()in libraries.Result<T, RvllmError>end-to-end with structured context. - Real block-change detection. Scheduler emits block table updates; missing signals = stale KV reads caught at the type level.
Apache-2.0.
docs/bench.html- interactive benchmark results with chartsv3/GEMMA4_SPEC.md- 31B Gemma 4 architecture details and weight shapesv3/SPEC.md,v3/IMPL_PLAN.md- v3 rewrite plan, 16 agent specstpu/harness/EAGLE3_SPEC.md- EAGLE-3 speculative decoding specdocs/arch.md- full crate architecture