Flex RNN compiles linear attention mechanisms like RWKV and Mamba from simple pytorch code. For example, the Delta Rule kernel can be implemented as
@flex_rnn.jit
def delta_rule(q, k, v, beta, S):
S = S + beta * k * (v - (S*k).sum(-1,True))
return (S*q).sum(-1,True), SSee tests/examples.py for example implementations of RWVK-7, RWKV-6, RWKV-4, Mamba-2, Mamba, Delta Rule, Gated Delta Rule, Retention, GLA, GSA, HGRN, Longhorn, mLSTM and S4D.
The compiler traces the code, generates the backward pass, compiles it to cuda code, and returns a callable function equivalent to naive defined below. Regardless of input and output dtypes, all internal calculations are performed using 32-bit floats. The backward pass fully recalculates all states, so precision should be similar to torch autograd. See below for how the speed compares to other kernels.
For batch size B, sequence length T, number of heads H and state dimensions MxN, the resulting function expects inputs with shapes [B,T,H,M,N] followed by initial state(s) with shape(s) [B,H,M,N]. Inputs with dimensions of size 1 will be broadcast, so for the Delta Rule, q should have shape [B,T,H,1,N] and v have shape [B,T,H,M,1]. Mamba-2's A is a headwise parameter with no batch or time dependence, so it would have shape [1,1,H,1,1]. Initial states are not broadcast.
Let's see how to implement an equivalent to fla.ops.delta_rule.chunk_delta_rule(q, k, v, beta, scale = 1, initial_state = S, output_final_state = True). That function expects q, k and v to have shape [B,T,H,M], beta to have shape B,T,H and the initial state to have shape [B,T,H,M,N]. Hence, we can implement it by unsqueezing as follows:
def matching_delta_rule(q, k, v, beta, S):
q, k, v, beta = [i.unsqueeze(-2) for i in (q, k, v, beta.unsqueeze(-1))]
y, S = delta_rule(q, k, v.mT, beta, S.mT)
return y, S.mT In general, the compiled function behaves like jit.naive from flex_rnn/jit.py. It is often helpful to first debug using the naive function (for example, use y, S = delta_rule.naive(q, k, v.mT, beta, S.mT) instead of y, S = delta_rule(q, k, v.mT, beta, S.mT) in the code above), before working with the fast compiled kernels.
When there's a single state variable and a single output variable, and disregarding strides of outputs, the compiled function behaves like
def simplified_naive(step, *args):
inputs, state = args[:-1], args[-1]
T = max(i.shape[1] for i in inputs)
output = []
for t in range(T):
inputs_t = [i.expand(-1,T,-1,-1,-1)[:,t].float() for i in inputs]
output_t, state = step(*inputs_t, state)
output.append(output_t)
return torch.stack(output, dim=1).squeeze(-1), stategit clone https://github.com/johanwind/flex_rnn
cd flex_rnn
pip install .You should then be able to run examples like
python tests/examples.py --op rwkv7 --check-flexBatch size B = 8, sequence length T = 1024, number of heads H = 4096 / 128 = 32 and state dimensions 128x128. NVIDIA 4070 mobile GPU.
| Op | Reference | flex_rnn / reference time | flex_rnn / reference memory |
|---|---|---|---|
| RWKV-4 | fla fused_recurrent_rwkv4 | 3 ms / 6 ms | 0.7 GB / 0.9 GB |
| RWKV-6 | fla chunk_rwkv6 | 44 ms / 34 ms | 3.2 GB / 2.5 GB |
| RWKV-7 | fla chunk_rwkv7 | 55 ms / 82 ms | 3.2 GB / 3.4 GB |
| HGRN | fla chunk_hgrn | 3 ms / 11 ms | 0.6 GB / 0.9 GB |
| Retention | fla chunk_retention | 31 ms / 13 ms | 1.9 GB / 1.4 GB |
| GLA | fla chunk_gla | 32 ms / 32 ms | 2.2 GB / 2.4 GB |
| GSA | fla chunk_gsa 1 | 105 ms / 67 ms | 3.1 GB / 2.5 GB |
| Delta Rule | fla chunk_delta_rule | 45 ms / 19 ms | 2.8 GB / 1.2 GB |
| Gated Delta Rule | fla chunk_gated_delta_rule | 50 ms / 24 ms | 2.8 GB / 1.3 GB |
| Mamba | mamba_ssm selective_scan_fn | 32 ms / 17 ms | 1.5 GB / 0.8 GB |
| Mamba2 | mamba_ssm mamba_chunk_scan_combined | 39 ms / 10 ms | 2.5 GB / 0.6 GB |
| Longhorn | github.com/Cranial-XIX/longhorn_cuda | 120 ms / 115 ms | 3.2 GB / 0.8 GB |
| S4D | github.com/state-spaces/s4 (with log_vandermonde_cuda) | 40 ms / 60 ms | 1.6 GB / 2.1 GB |
| mLSTM | mlstm_kernels mlstm_chunkwise__xl_chunk 2 | 49 ms / 12 ms | 3.3 GB / 1.1 GB |
All inputs and initial states in bfloat16 except S4D and Mamba (because their reference implementations only support float32).
- Compatible with torch.compile without graph breaks
- Handles non-contiguous input strides without copying
- float32, bfloat16, float16 input and output types
- High precision because of float32 internal calculations
- The backend computes the state row-wise. Specifically, that means we only support reductions in the last dimension, i.e.
torch.sum(x, dim=-1, keepdim=True). - The supported operations are essentially
+, -, *, /, <, >, exp, log, sqrtandx.sum(dim=-1, keepdim=True). - Input dimensions have only been tested for powers of 2.
- Speed and memory usage are often worse than specialized implementations (see speed comparison above).