Skip to content

m0at/rvllm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

773 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

rvLLM

LLM inference engine. Rust+CUDA on GPU, JAX+XLA on TPU.

Three Gemma 4 models on TPU v6e-4: E4B (16,794 tok/s peak, 78.3 tok/s B=1, PPL 5.87), 26B-A4B MoE (14,899 tok/s peak), 31B (9,600 tok/s peak, 128K context). GPU: 31B on H100 at 8,786 tok/s (FP8, CUDA graph, PPL 14.75). Zero custom kernels on TPU -- ~500 lines of JAX. Native Rust binary on GPU -- zero Python in the serving path.

Full benchmarks

At a glance

E4B (4B) 26B-A4B (MoE) 31B TPU 31B GPU vLLM H100
B=1 tok/s 78.3 52.9 44.2 53 66.9
Peak tok/s 16,794 14,899 9,600 8,786 3,848
PPL 5.87 90.21 24.76 14.75 -
Cached TTFT 25.9 ms 35.3 ms 73.3 ms 63 ms -
Peak tok/s/$ 3,230 2,865 1,846 4,576 2,004

TPU: v6e-4, $5.20/hr, int8, max-ctx 2048. GPU: H100 SXM, $1.92/hr, FP8. All measured.

TPU: Gemma 4 on v6e-4

Pure JAX + XLA. No custom kernels. XLA compiles the entire forward pass to TPU machine code from a ~500 line JAX script. Three models, one codebase.

Models supported

Property E4B (4B) 26B-A4B (MoE) 31B
Total / active params ~4B / 4B 26B / ~4B 31B / 31B
Layers 42 30 60
Hidden size 2,560 2,816 5,376
Q / KV heads (sliding) 8 / 2 16 / 8 32 / 16
Q / KV heads (global) 8 / 2 16 / 2 (V=K) 32 / 4 (V=K)
Head dim (sliding / global) 256 / 512 256 / 512 256 / 512
Sliding window 512 1,024 1,024
MoE none 128 experts, top-8 none
KV-shared layers 18 (of 42) 0 0
Per-layer input injection 256-d gated (5.6 GB embed) none none

Batch scaling (max-ctx 2048)

Batch E4B tok/s 26B-A4B tok/s 31B tok/s vLLM H100
1 78 53 44 66.9
8 542 390 318 515
64 3,661 2,662 2,112 2,794
128 6,298 4,915 3,853 3,848
256 10,214 8,192 6,246 3,709
512 13,773 12,390 8,550 3,788
768 15,514 14,899 9,600 3,671
1024 16,794 - - -

31B context scaling (B=1)

Context ms/step tok/s Architecture KV type
512 12.79 78.2 Single-scan, 60-layer scan + cond bf16
2,048 22.6 44.2 Single-scan bf16
32K ~66 ~15 Single-scan bf16
64K ~91 ~11 Split-cache, 10 groups x 6 int8
128K 40.56 24.7 Split-cache + blockwise global int8

Dual-path architecture auto-switches at the 32K boundary.

TPU deployment

# Create TPU v6e-4 ($5.20/hr)
gcloud compute tpus tpu-vm create rvllm-gemma4 \
  --zone=us-east5-b --accelerator-type=v6e-4 --version=v2-alpha-tpuv6e \
  --boot-disk-size=200

# Install (30 seconds)
pip3 install 'jax[tpu]' huggingface_hub tokenizers \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Download model
huggingface-cli download google/gemma-4-E4B-it --local-dir ~/models/gemma-4-E4B-it

# Run E4B (78.3 tok/s B=1)
python3 tpu/harness/gemma4_tpu_infer.py \
  --model-dir ~/models/gemma-4-E4B-it --max-tokens 200 --max-ctx 2048

# Run 31B batched (9,600 tok/s B=768)
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
  --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
  --xla_tpu_enable_async_collective_fusion_multiple_steps=true \
  --xla_tpu_overlap_compute_collective_tc=true \
  --xla_tpu_scoped_vmem_limit_kib=131072" \
python3 tpu/harness/gemma4_tpu_infer.py \
  --model-dir ~/models/gemma-4-31B-it --fused --max-tokens 200 --max-ctx 2048 --batch 768

# 128K context (24.7 tok/s)
python3 tpu/harness/gemma4_tpu_infer.py \
  --model-dir ~/models/gemma-4-31B-it --fused --max-tokens 200 --max-ctx 131072

# API server (OpenAI-compatible)
python3 tpu/harness/api_server.py --model-dir ~/models/gemma-4-31B-it --port 8080

# Perplexity
python3 tpu/harness/gemma4_tpu_infer.py \
  --model-dir ~/models/gemma-4-31B-it --perplexity --max-ctx 2048

No Docker. No conda. No torch. No vLLM. One pip install, one Python file, one command.

EAGLE-3 Speculative Decoding (TPU, experimental)

450M-param draft head proposes K=5 tokens per cycle; the full 31B verifies K+1=6 in one forward pass. Lossless for greedy decode.

Metric Value
Baseline (B=1, 512 ctx) 78.2 tok/s, 12.79 ms/step
EAGLE-3 fused cycle 31.0 ms/cycle
Projected @ tau=3.5 ~145 tok/s (1.8x)
Hardware ceiling ~300 tok/s (3.8x)

Requires 50K+ training examples for production tau. Current: 2K examples, loss 7.1, pipeline validated end-to-end. See tpu/harness/EAGLE3_SPEC.md.

GPU: 31B Gemma 4 on H100

Rust + CUDA on H100 SXM 80GB. FP8 weights with per-channel scales + CUTLASS channelscale epilogue, F16 KV cache, F16 paged attention (FA3 SM90). All 60 layers captured in a single CUDA graph (~935 nodes). 8,786 tok/s peak (B=512), PPL 14.75, TTFT 63 ms.

GPU batch scaling

Batch tok/s ms/step Scaling
1 53 18.7 1.0x
8 434 18.4 8.2x
32 1,743 18.4 32.9x
64 3,265 19.6 61.6x
128 5,802 22.1 109.5x
256 7,808 32.8 147.3x
512 8,786 58.3 165.8x

rvLLM vs vLLM on H100 (measured)

Batch rvLLM tok/s vLLM tok/s Delta
1 53 69 -23%
32 1,743 1,748 ~0%
64 3,265 3,130 +4%
128 5,802 4,689 +24%
256 7,808 7,077 +10%
512 8,786 8,243 +7%

rvLLM overtakes vLLM at B=64 and leads by 24% at B=128.

GPU perplexity

Weight path KV cache PPL tok/s (B=1)
FP8-Dynamic + CUTLASS channelscale epilogue F16 14.75 53
BF16 split QKV per-tensor FP8 F16 17.96 37.9
F16 weights (no FP8) F16 19.79 37.9
HuggingFace BF16 reference -- 19.62 --

Gemma 4 forward pass (14 launches per layer)

For each layer in 0..60:
  1.  fused_rmsnorm_fp8_quant           input layernorm + FP8 quantize
  2.  cutlass_fp8_gemm_channelscale     fused Q||K||V + channelscale epilogue
  3.  fused_qkv_rmsnorm                 Q/K norm (learned) + V norm (parameter-free)
  4.  fused_rope_partial_f16kv          partial RoPE + F16 KV cache write
  5.  paged_decode (FA3 SM90)           attention (head_dim=256 sliding, 512 global)
  6.  quantize_fp8_per_token            attn output to FP8
  7.  fp8_gemm                          O projection
  8.  fused_norm_add_residual           channelscale + rmsnorm + residual add
  9.  fused_rmsnorm_fp8_quant           pre-FFN layernorm + FP8 quantize
  10. cutlass_fp8_gemm_channelscale     fused gate||up + channelscale epilogue
  11. fused_gelu_mul_fp8_quant          GELU(tanh)(gate) * up to FP8
  12. fp8_gemm                          down projection
  13. fused_norm_add_residual           channelscale + rmsnorm + residual + layer_scalar

Sampling tail:
  fused_rmsnorm                       final layernorm
  f16_gemm_f32                        lm_head
  logit_softcap                       30 * tanh(logits / 30)
  argmax_kernel                       token selection

Kernel fusion summary

Four rounds of fusion + custom CUTLASS epilogue reduced graph nodes from 1776 to ~935 (47% reduction):

Fusion Kernels eliminated Nodes saved
f32_to_bf16 + rmsnorm + vector_add -> fused_norm_add_residual 3 -> 1 (x2/layer) 240
scale_cols_f32 fused into norm+add kernel (O-proj, down) 1 -> 0 (x2/layer) 120
residual_scale_f16 fused into post-ff norm+add 1 -> 0 (x1/layer) 60
vnorm_f16 fused into qk_rmsnorm -> fused_qkv_rmsnorm 2 -> 1 (x1/layer) 60
CUTLASS channelscale epilogue (QKV, gate_up) 3 -> 1 (x2/layer) 240+

The CUTLASS channelscale kernel uses a custom SM90 EVT epilogue that applies per-token activation scale (ColBroadcast) and per-channel weight scale (RowBroadcast) directly in the GEMM epilogue while the accumulator is still F32, then casts to F16.

Help wanted: The current CUTLASS kernel uses a 128x128x128 tile which is suboptimal for low-batch decode (M <= 16). A smaller tile variant (e.g. 64x64x128) would improve B=1-8 throughput. PRs welcome for additional tile shapes with autotune selection.

GPU build and run

# One-time on H100 box (~15 min)
bash kernels/build.sh               # fused PTX
bash kernels/build_cutlass_so.sh    # libcutlass_kernels.so
bash kernels/build_fa3.sh           # libfa3_kernels.so

# Build
cargo build --release --features cuda --manifest-path v3/Cargo.toml -p rvllm-bench

# Run
RVLLM_MODEL_DIR=/workspace/models/gemma-4-31B-it \
RVLLM_KERNELS_DIR=/workspace/rvllm/kernels/sm_90 \
RVLLM_CUTLASS_SO=/workspace/rvllm/kernels/sm_90/libcutlass_kernels.so \
RVLLM_FA3_SO=/workspace/rvllm/kernels/sm_90/libfa3_kernels.so \
RVLLM_POLICY=/workspace/rvllm/kernels/sm_90/policy.json \
RVLLM_BATCH=128 RVLLM_ITERS=30 RVLLM_WARMUP=5 \
  ./v3/target/release/rvllm-bench

OpenAI-compatible Gemma 4 server

The server is a Rust-only Gemma 4 path with an OpenAI-compatible HTTP surface. It keeps CUDA execution on a single engine owner thread and accepts requests through /v1/chat/completions.

For the solidSF agents production shape, including 256K context, four-seat admission, the paid-plan busy response, CAD harness prompting, systemd service shape, and verification scripts, see docs/solidsf-agent-serving.md.

export CUDA_ARCH=sm_90
export RVLLM_MODEL_DIR=/workspace/models/gemma-4-31B-it
export RVLLM_KERNELS_DIR=/workspace/rvllm/kernels/sm_90
export RVLLM_CUTLASS_SO=/workspace/rvllm/kernels/sm_90/libcutlass_kernels.so
export RVLLM_FA3_SO=/workspace/rvllm/kernels/sm_90/libfa3_kernels.so
export RVLLM_POLICY=/workspace/rvllm/kernels/sm_90/policy.json
export RVLLM_SERVED_MODEL_NAME=gemma4-31b
export RUST_LOG=info

bash kernels/build.sh sm_90
bash kernels/build_cutlass_so.sh sm_90
bash kernels/build_fa3.sh
cargo build --release --features cuda,cublaslt --manifest-path v3/Cargo.toml -p rvllm-serve

./v3/target/release/rvllm-server \
  --host 127.0.0.1 \
  --port 8080 \
  --max-model-len 8192 \
  --max-num-seqs 1 \
  --max-num-batched-tokens 2048 \
  --max-prefill-chunk 128

The server exposes GET /health, GET /v1/models, and POST /v1/chat/completions with non-stream and SSE streaming responses. Only greedy Gemma 4 decoding is currently enabled; set temperature: 0.

Smoke:

curl -fsS http://127.0.0.1:8080/health
curl -fsS http://127.0.0.1:8080/v1/models
curl -fsS http://127.0.0.1:8080/v1/chat/completions \
  -H 'Content-Type: application/json' \
  -d '{"model":"gemma4-31b","messages":[{"role":"user","content":"Reply exactly: RVLLM_RUST_OK"}],"max_tokens":16,"temperature":0}'
curl -fsS --no-buffer http://127.0.0.1:8080/v1/chat/completions \
  -H 'Content-Type: application/json' \
  -d '{"model":"gemma4-31b","messages":[{"role":"user","content":"hi"}],"max_tokens":16,"temperature":0,"stream":true}'

For bind-only local checks without CUDA:

RVLLM_DRY_RUN=1 cargo run --manifest-path v3/Cargo.toml -p rvllm-serve -- \
  --host 127.0.0.1 \
  --port 8080

Kernels

Every kernel has a known purpose, a pinned variant, and a workspace contract. No dispatch fallback chains.

Kernel Purpose
cutlass_fp8_gemm_channelscale SM90 FP8 GEMM with EVT channelscale epilogue (QKV, gate_up)
fused_rmsnorm_fp8_quant layernorm + FP8 quantize in one launch
fused_qkv_rmsnorm per-head RMSNorm on Q, K (learned) and V (parameter-free)
fused_rope_partial_f16kv partial RoPE + F16 KV cache write
fused_gelu_mul_fp8_quant GELU(tanh)(gate) * up to FP8
fused_norm_add_residual channelscale + RMSNorm + residual add (+ optional layer_scalar)
logit_softcap 30 * tanh(logits / 30)
quantize_fp8_per_token activation to FP8 with per-token scale
argmax f32 logits to i32 token

No fallbacks. Missing kernel .so = engine refuses to start.

v3 crate map

v3/crates/
  rvllm-core         typed errors, IDs, dtype, shape, config, env
  rvllm-mem          HbmArena, Region, Stream, Event, PinnedBuf, CudaContextHandle
  rvllm-kernels      manifest (sha-pinned), PTX loader, kernel catalog
  rvllm-fused        8 fused-kernel launchers + pure-Rust f32 references
  rvllm-attention    FA3 SM90 paged decode/prefill dlopen
  rvllm-cutlass      FP8 variant catalog + schedule pairing trait + cuBLASLt wrapper
  rvllm-metadata     frozen-layout metadata per bucket (one upload path)
  rvllm-loader       safetensors mmap -> HBM + CPU-path FP8 quant + clamp gate
  rvllm-sampling     argmax tail, pinned DtoH
  rvllm-graph        captured-graph pool keyed on MetaLayoutHash
  rvllm-runtime      Engine, scheduler, layer_exec, bring_up
  rvllm-bench        RVLLM_* env-driven bench binary
  rvllm-invariants   DAG-dep test, no-megakernel gate

Correctness discipline

  1. No fallbacks. Missing autotune entry = engine panic. Missing .so = refuse start. No silent degradation.
  2. Graph-capture invariant. Metadata buffer layout frozen per (bucket, max_blocks_per_seq). Captured graphs bind exact offsets.
  3. CUTLASS schedule/epilogue pairing. Mainloop and epilogue schedules must match. Enforced via static_assert.
  4. No unwrap() in libraries. Result<T, RvllmError> end-to-end with structured context.
  5. Real block-change detection. Scheduler emits block table updates; missing signals = stale KV reads caught at the type level.

License

Apache-2.0.

Further reading

About

rvLLM: High-performance LLM inference in Rust. Drop-in vLLM replacement.

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors