2 releases
Uses new Rust 2024
| 0.1.1 | Apr 24, 2026 |
|---|---|
| 0.1.0 | Apr 17, 2026 |
#1702 in Hardware support
265KB
4.5K
SLoC
kaio-candle
Candle bridge for KAIO — CustomOp bindings that let you call KAIO's tensor-core GPU kernels directly on candle_core::Tensor.
Ships eight ops: matmul_tc, matmul_tc_async, matmul_int4, matmul_int8, attention_tc, attention_tc_causal, qkv_project_int8, qkv_project_int4. matmul_tc and matmul_tc_async support backward (autograd); all other ops are forward-only.
Status — v0.1.0 (Sprint 7.4a–7.4d)
All ops are bit-exact verified against direct kaio-ops calls with the same input bits.
Why a separate crate?
kaio-candle is not a member of the main KAIO workspace. cudarc rejects dynamic-loading + dynamic-linking as simultaneously active features:
- Main KAIO defaults to
dynamic-loading— no CUDA toolkit required to build. Host tests pass on bare GitHub runners. - candle-core with its
cudafeature activatesdynamic-linking— it links againstlibcudaat compile time.
Cargo unions features across a workspace build, so including kaio-candle in the main workspace would force every main-workspace build to also carry candle's dynamic-linking, breaking no-CUDA CI. The standalone crate keeps the two worlds apart.
Consumers who already build candle with the cuda feature see no new system requirement beyond what candle itself needs.
Build
cd kaio-candle
cargo build --features cuda
The cuda feature is required for any actual bridge functionality. Without it, kaio-candle is an empty shell (matches candle-core's own opt-in cuda pattern) — attempting to call kaio_candle::matmul_tc(...) surfaces a "function not found" compile error pointing at the missing feature.
Build requirements with cuda:
- CUDA toolkit (candle-core's cudarc feature uses
dynamic-linking). - NVIDIA GPU with SM 8.0 or newer (Ampere, Ada, Hopper).
Quickstart
# Cargo.toml
[dependencies]
kaio-candle = { version = "0.1", features = ["cuda"] }
kaio = "0.4"
candle-core = { version = "0.10", features = ["cuda"] }
half = "2"
use std::sync::Arc;
use candle_core::{Device, Tensor};
use half::f16;
use kaio::prelude::KaioDevice;
fn main() -> anyhow::Result<()> {
let candle_dev = Device::new_cuda(0)?;
let kaio_dev = Arc::new(KaioDevice::new(0)?);
let m = 128usize;
let k = 128usize;
let n = 128usize;
let a_host: Vec<f16> = (0..m * k).map(|i| f16::from_f32((i % 17) as f32 * 0.01)).collect();
let b_host: Vec<f16> = (0..k * n).map(|i| f16::from_f32((i % 13) as f32 * 0.02)).collect();
let a = Tensor::from_vec(a_host, (m, k), &candle_dev)?;
let b = Tensor::from_vec(b_host, (k, n), &candle_dev)?;
// f16[m,k] x f16[k,n] -> f32[m,n]
let c = kaio_candle::matmul_tc(&kaio_dev, &a, &b)?;
println!("output shape: {:?}", c.shape().dims());
Ok(())
}
Three runnable examples ship in examples/:
cd kaio-candle
cargo run --release --features cuda --example matmul_tc_candle
cargo run --release --features cuda --example matmul_int8_candle
cargo run --release --features cuda --example attention_tc_candle
Op surface
| Op | Trait | Shapes | Dtype |
|---|---|---|---|
matmul_tc(kd, a, b) |
CustomOp2 |
a: [M, K], b: [K, N] → [M, N] |
f16 × f16 → f32 |
matmul_tc_async(kd, a, b) |
CustomOp2 |
same | f16 × f16 → f32 |
matmul_int4(kd, a, b_packed, scales) |
CustomOp3 |
a: [M, K], b_packed: [K/8, N], scales: [K/128, N] → [M, N] |
f16 × u32 × f16 → f32 |
matmul_int8(kd, a, b, scale) |
CustomOp2 |
a: [M, K], b: [K, N] → [M, N] |
u8-as-i8 × u8-as-i8 → f32 (× f32 scale) |
attention_tc(kd, q, k, v) |
CustomOp3 |
q: [seq_q, d_k], k: [seq_k, d_k], v: [seq_k, d_v] → [seq_q, d_v] |
f16 × f16 × f16 → f32 |
attention_tc_causal(kd, q, k, v) |
CustomOp3 |
same | f16 × f16 × f16 → f32 |
qkv_project_int8(kd, x, wq, wk, wv, sq, sk, sv) |
Direct-call | x: [M, K], wq/wk/wv: [K, N] → (Q, K, V) each [M, N] |
f16 × u8-as-i8 → f16 |
qkv_project_int4(kd, x, wq, wk, wv, sq, sk, sv) |
Direct-call | x: [M, K], wq/wk/wv: [K/8, N], sq/sk/sv: [K/128, N] → (Q, K, V) each [M, N] |
f16 × u32 × f16 → f16 |
matmul_int4 is GPTQ-style: group_size=128 is locked in by the kaio-ops kernel contract. K must be a multiple of 128, weights are packed 8 INT4 values per u32, one f16 scale per group of 128 elements.
matmul_int8 is W8A8 symmetric quant. Candle has no DType::I8, so the convention is DType::U8 tensors whose bytes are interpreted as signed INT8 (-128..=127) by the kernel. The bridge reinterprets the storage via a same-layout transmute. scale is a scalar f32 applied in the accumulator; a typical realistic value is max_abs / 127.
attention_tc uses a shared-memory scores buffer capped at seq_k ≤ 384. FlashAttention-TC will lift this cap in a later sprint.
qkv_project_int8 and qkv_project_int4 are direct-call functions (not CustomOpN — candle's trait maxes at 3 inputs and single output). They return (Tensor, Tensor, Tensor) with DType::F16 output because the fused kernel performs the f32→f16 conversion internally as part of the projection fusion. Gradient-tracked inputs are rejected with a loud error requiring .detach() — these ops are forward-only.
Backward support
| Op | Backward | Notes |
|---|---|---|
matmul_tc |
Supported | dA = grad @ B^T, dB = A^T @ grad via forward kernel |
matmul_tc_async |
Supported | Same, uses cp.async variant in both directions |
attention_tc / attention_tc_causal |
Not yet | FlashAttention backward requires new PTX kernels (Phase 8) |
matmul_int4 / matmul_int8 |
No | Quantized inference ops — frozen weights, no backprop in practice |
qkv_project_int8 / qkv_project_int4 |
No | Direct-call ops, inference-only by design |
Numerically approximate: The backward implementation downcasts the f32 upstream gradient to f16 to reuse the existing tensor-core forward kernels, and casts the output gradients back to f16 to satisfy candle's dtype-matching constraint. This is an initial autograd integration proving the bwd() bridge pattern, not a final mixed-precision training stack.
Memory: The backward pass materializes transposed tensors in VRAM (.t()?.contiguous() = allocation + copy). Peak backward memory is approximately 2–3x the forward input size. Designed for integration testing and light training, not high-throughput training loops where allocator overhead matters.
Device lifetime
The Arc<kaio::prelude::KaioDevice> you construct and pass to kaio-candle wrapper functions is independent of the candle_core::Device you use for your tensors. Both retain the same CUDA primary context via cuDevicePrimaryCtxRetain; neither owns the other. Drop order between them is unconstrained.
Every wrapper call checks that the KAIO device and candle device share the same CUDA ordinal; a mismatch is a loud error.
Candle version policy
kaio-candle = 0.1 pins candle-core = "=0.10.2" exactly. This is deliberate:
- candle 0.10.2 is the current release at the time of publishing.
- The
CustomOp2/CustomOp3surface has changed between candle minor versions in the past. - cudarc feature conventions change with candle releases.
We re-pin kaio-candle against each new candle minor release. Use kaio-candle 0.1.x with candle-core 0.10.x; kaio-candle 0.2 will target whichever candle minor is current when we publish.
A weekly GitHub Actions workflow (.github/workflows/candle-head.yml) builds kaio-candle against candle-core's git main branch once per Monday. If this badge goes red for more than two consecutive weeks, either the pin moves to the new candle minor or this section documents the divergence.
Known limitations (v0.1)
- Non-contiguous tensors rejected. Call
.contiguous()?upstream. - Non-zero storage offset rejected (e.g. from
.narrow(...)/.slice(...)). Call.contiguous()?to compact. - Rank-2 only. Multi-head attention callers must reshape
[heads, seq, d]to[heads * seq, d]or call per-head with rank-2 slices. Wrappers error with a concrete reshape hint for higher-rank inputs. - CUDA Graph capture partially unblocked. Event-based sync (Sprint 7.4c) removes the prior
cuCtxSynchronizeblocker. However, full CUDA Graph capture requires non-default streams on both the candle and KAIO sides, which is not yet verified. - f32 output (CustomOp ops) / f16 output (direct-call ops).
matmul_tc,matmul_int4,matmul_int8,attention_tcreturnf32matching the kaio-ops accumulator.qkv_project_int{4,8}returnf16because the fused kernel converts internally. - No CPU fallback.
cpu_fwdreturns a loud error rather than silently routing tocandle.matmul(). KAIO's value is GPU-specific PTX; a silent CPU fallback would mask every perf claim. - Bench numbers vs direct-call gap. Each bridge call issues event-based stream sync (two
join()calls per op). This replaced the heaviercuCtxSynchronizefrom v0.1 but still allocates a transientCudaEventper call. KAIO's published %-of-cuBLAS numbers are measured via direct kaio-ops calls, not through the bridge.
License
Dual-licensed under MIT or Apache-2.0, at your option.
Dependencies
~0.8–9MB
~156K SLoC