An optimized implementation of the Kimi Linear architecture - a hybrid linear attention mechanism outperforming traditional full attention.
Installation • Quick Start • Documentation • Benchmarks • Contributing
- Overview
- Key Features
- Architecture
- Performance
- Installation
- Quick Start
- Project Structure
- Development
- Benchmarks
- Citation
- License
- Acknowledgments
Kimi Linear is a groundbreaking hybrid attention architecture that combines the best of both worlds: the efficiency of linear attention and the performance of full attention mechanisms. This implementation focuses on optimization, hardware efficiency, and production deployment.
Kimi Linear introduces Kimi Delta Attention (KDA), a linear attention mechanism with:
- Fine-grained gating: Channel-wise decay for precise memory control
- Hardware-efficient algorithms: Specialized DPLR variant optimized for modern GPUs
- Hybrid architecture: 3:1 KDA-to-MLA ratio for optimal performance/efficiency
- 🚀 6× faster decoding at 1M token contexts
- 💾 75% KV cache reduction for long sequences
- 📊 Superior accuracy: Matches or exceeds full attention on all benchmarks
- ⚡ Linear complexity: O(n) vs O(n²) for standard attention
- 🎯 Production-ready: vLLM integration, Docker support
This project aims to create a production-ready, optimized implementation of the Kimi Linear architecture for researchers and engineers working on:
- Long-Context Language Models: Process sequences up to 1M tokens efficiently
- Agentic AI Systems: Enable fast test-time scaling with RL training
- Resource-Constrained Deployment: Reduce memory and compute requirements
- Research & Development: Provide modular, well-documented codebase for experimentation
Why This Project Exists:
- 📚 Educational: Clear, documented implementation of cutting-edge attention mechanisms
- 🔬 Research: Modular architecture for experimentation with linear attention variants
- 🚀 Production: Optimized kernels and efficient memory management for deployment
- 🌐 Open Source: Community-driven development with transparent benchmarks
-
Kimi Delta Attention (KDA)
- Fine-grained channel-wise gating mechanism
- Hardware-efficient chunkwise parallelization
- Delta rule learning with online gradient descent
- Constrained DPLR formulation for numerical stability
-
Hybrid Architecture
- 3:1 KDA-to-MLA ratio (configurable)
- Multi-Head Latent Attention (MLA) for global context
- No Position Encoding (NoPE) design
- Seamless integration with existing frameworks
-
CUDA/Triton Kernels
- Fused attention kernels
- Memory-efficient tiling strategies
-
80% memory bandwidth utilization
- 2× faster than general DPLR implementations
-
Memory Management
- Fixed-size state (constant memory)
- Efficient buffer reuse
- Secondary chunking for numerical stability
- Mixed precision support (FP16, BF16, FP32)
-
Comprehensive Test Suite
- Unit tests (>95% coverage target)
- Synthetic tasks (Palindrome, MQAR, Stack)
- Integration tests
- Benchmark framework
-
Performance Profiling
- Kernel-level analysis (Nsight Compute)
- System-level profiling (Nsight Systems)
- Memory bandwidth monitoring
- Automated regression testing
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','primaryTextColor':'#fff','primaryBorderColor':'#7c3aed','lineColor':'#f39c12','secondaryColor':'#2c3e50','tertiaryColor':'#1e1e1e','background':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','tertiaryBkg':'#2c3e50','textColor':'#ecf0f1','fontSize':'16px'}}}%%
graph TB
A["🔤 Input Token Embeddings<br/>(Batch × SeqLen × Dim)"] --> B["⚡ KDA Layer 1<br/>Fine-grained Gating + Delta Rule"]
B --> C["⚡ KDA Layer 2<br/>State Update: St ∈ R^(dk×dv)"]
C --> D["⚡ KDA Layer 3<br/>Chunkwise Parallelization"]
D --> E["🌐 MLA Layer 1<br/>Global Attention (NoPE)"]
E --> F["📊 Feed-Forward + MoE<br/>8 of 256 Experts Activated"]
F --> G{"More Layers?"}
G -->|Yes| B
G -->|No| H["📤 Output Logits<br/>(Batch × SeqLen × VocabSize)"]
style A fill:#2c3e50,stroke:#3498db,stroke-width:3px,color:#ecf0f1
style B fill:#2c3e50,stroke:#9b59b6,stroke-width:3px,color:#ecf0f1
style C fill:#2c3e50,stroke:#9b59b6,stroke-width:3px,color:#ecf0f1
style D fill:#2c3e50,stroke:#9b59b6,stroke-width:3px,color:#ecf0f1
style E fill:#2c3e50,stroke:#e74c3c,stroke-width:3px,color:#ecf0f1
style F fill:#2c3e50,stroke:#f39c12,stroke-width:3px,color:#ecf0f1
style G fill:#34495e,stroke:#95a5a6,stroke-width:2px,color:#ecf0f1
style H fill:#2c3e50,stroke:#27ae60,stroke-width:3px,color:#ecf0f1
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','primaryTextColor':'#fff','primaryBorderColor':'#7c3aed','lineColor':'#f39c12','secondaryColor':'#2c3e50','tertiaryColor':'#1e1e1e','background':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','textColor':'#ecf0f1','fontSize':'14px'}}}%%
graph LR
A["📥 Input x<br/>(B×T×D)"] --> B["🔀 Q/K/V Projection<br/>Linear + ShortConv + Swish"]
B --> C["📏 L2Norm(Q, K)<br/>Eigenvalue Stability"]
C --> D["🎛️ FineGrainedGating<br/>α_t = σ(W↑W↓x)"]
D --> E["🔢 StateManager<br/>St ∈ R^(dk×dv)"]
E --> F["⚡ DPLR Transition<br/>Diag(α) - βkk^T"]
F --> G["📦 ChunkwiseKDA<br/>WY + UT Transform"]
G --> H["🎯 Output Gate<br/>σ(W↑W↓x) ⊙ RMSNorm"]
H --> I["📤 Output o<br/>(B×T×D)"]
style A fill:#2c3e50,stroke:#3498db,stroke-width:2px,color:#ecf0f1
style B fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style C fill:#2c3e50,stroke:#1abc9c,stroke-width:2px,color:#ecf0f1
style D fill:#2c3e50,stroke:#e67e22,stroke-width:2px,color:#ecf0f1
style E fill:#2c3e50,stroke:#e74c3c,stroke-width:2px,color:#ecf0f1
style F fill:#2c3e50,stroke:#f39c12,stroke-width:2px,color:#ecf0f1
style G fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style H fill:#2c3e50,stroke:#1abc9c,stroke-width:2px,color:#ecf0f1
style I fill:#2c3e50,stroke:#27ae60,stroke-width:2px,color:#ecf0f1
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','primaryTextColor':'#fff','primaryBorderColor':'#7c3aed','lineColor':'#f39c12','secondaryColor':'#2c3e50','tertiaryColor':'#1e1e1e','background':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','textColor':'#ecf0f1','fontSize':'14px'}}}%%
stateDiagram-v2
[*] --> S0: Initialize State
S0 --> S1: Apply Diagonal Decay<br/>S' = Diag(α_t)·S_{t-1}
S1 --> S2: Rank-1 Correction<br/>S'' = (I - βk_tk_t^T)·S'
S2 --> S3: Add KV Association<br/>S_t = S'' + βk_tv_t^T
S3 --> Output: Compute Output<br/>o_t = q_t^T·S_t
Output --> S3: Next Token
S3 --> [*]: End Sequence
note right of S0
Constant Memory
O(dk × dv)
end note
note right of S2
Delta Rule
Online Gradient Descent
end note
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','primaryTextColor':'#fff','primaryBorderColor':'#7c3aed','lineColor':'#f39c12','secondaryColor':'#2c3e50','tertiaryColor':'#1e1e1e','background':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','textColor':'#ecf0f1','fontSize':'14px'}}}%%
graph TD
subgraph "Block 1 (3:1 Ratio)"
A1["⚡ KDA Layer 1"] --> A2["⚡ KDA Layer 2"]
A2 --> A3["⚡ KDA Layer 3"]
A3 --> A4["🌐 MLA Layer 1"]
end
subgraph "Block 2 (3:1 Ratio)"
B1["⚡ KDA Layer 4"] --> B2["⚡ KDA Layer 5"]
B2 --> B3["⚡ KDA Layer 6"]
B3 --> B4["🌐 MLA Layer 2"]
end
subgraph "Block N (3:1 Ratio)"
N1["⚡ KDA Layer N-2"] --> N2["⚡ KDA Layer N-1"]
N2 --> N3["⚡ KDA Layer N"]
N3 --> N4["🌐 MLA Layer N/4"]
end
A4 --> B1
B4 --> C["..."]
C --> N1
style A1 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style A2 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style A3 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style A4 fill:#2c3e50,stroke:#e74c3c,stroke-width:2px,color:#ecf0f1
style B1 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style B2 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style B3 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style B4 fill:#2c3e50,stroke:#e74c3c,stroke-width:2px,color:#ecf0f1
style N1 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style N2 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style N3 fill:#2c3e50,stroke:#9b59b6,stroke-width:2px,color:#ecf0f1
style N4 fill:#2c3e50,stroke:#e74c3c,stroke-width:2px,color:#ecf0f1
# Simplified KDA forward pass
def kda_forward(q, k, v, alpha, beta, state):
"""
KDA implements:
St = (I - βt kt kt^T) Diag(αt) St-1 + βt kt vt^T
ot = qt^T St
"""
# Step 1: Apply fine-grained diagonal decay
state_decayed = diag(alpha) @ state # Channel-wise forgetting
# Step 2: Delta rule correction (Householder transform)
correction = beta * k @ (k.T @ state_decayed)
state_corrected = state_decayed - correction
# Step 3: Add new key-value association
state_new = state_corrected + beta * k @ v.T
# Step 4: Compute output (inter-chunk + intra-chunk)
output_inter = (q * gamma.exp()) @ state_new # Recurrent
output_intra = tril(q @ k.T) @ v # Parallel
return output_inter + output_intra, state_new- Input Projections: Q, K, V via linear layers + short convolution (kernel=4)
- Gating: Channel-wise forget gate (α), scalar learning rate (β)
- Output: Low-rank gating + RMSNorm
- Normalization: L2Norm for Q/K (eigenvalue stability), RMSNorm for output
| Technology | Version | Purpose | Why Chosen |
|---|---|---|---|
| PyTorch | ≥2.6 | Deep learning framework |
• Industry standard for research & production • Excellent CUDA integration & autograd • Dynamic computation graphs for debugging • Native support for distributed training • Extensive ecosystem (TorchScript, ONNX) |
| CUDA | ≥12.0 | GPU acceleration |
• Direct access to GPU hardware features • Custom kernel optimization for KDA • Tensor Core utilization for mixed precision • High memory bandwidth (>900 GB/s on A100) • Required for production-level performance |
| Triton | ≥2.2 | Kernel development |
• Python-based GPU kernel programming • Automatic optimization & code generation • Easier to maintain than raw CUDA • Similar performance to hand-tuned CUDA • Rapid prototyping of custom operators |
| Flash Attention | ≥2.0 | Efficient attention |
• Memory-efficient attention algorithm • IO-aware kernel design (minimizes HBM access) • Up to 3× speedup over naive attention • Industry-proven implementation • Baseline for comparison |
| vLLM | ≥0.6 | Inference engine |
• PagedAttention for efficient KV cache • Continuous batching for high throughput • Production-grade serving infrastructure • Easy integration with existing models • Active community & regular updates |
| Docker | ≥24.0 | Containerization |
• Reproducible development environment • Consistent CUDA/cuDNN versions • Easy deployment to cloud platforms • Isolation of dependencies • Multi-stage builds for size optimization |
| pytest | ≥8.0 | Testing framework |
• Simple, Pythonic test syntax • Excellent fixture system • Parameterized testing support • Coverage integration • Industry standard for Python projects |
| Black | ≥24.0 | Code formatting |
• Opinionated, consistent formatting • Reduces bikeshedding in reviews • Automatic via pre-commit hooks • Fast (written in Rust core) • PEP 8 compliant |
| NumPy | ≥1.24 | Numerical computing |
• Efficient array operations • Foundation for scientific Python • Used for synthetic data generation • CPU-based testing utilities • Interoperability with PyTorch |
| Einops | ≥0.8 | Tensor manipulation |
• Readable tensor reshaping/rearranging • Self-documenting dimension operations • Reduces bugs in shape transformations • Einstein notation support • Clear intent for reviewers |
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','primaryTextColor':'#fff','primaryBorderColor':'#7c3aed','lineColor':'#f39c12','secondaryColor':'#2c3e50','tertiaryColor':'#1e1e1e','background':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','textColor':'#ecf0f1','fontSize':'12px'}}}%%
mindmap
root((Kimi Linear<br/>Architecture))
Core Modules
FineGrainedGating
Channel-wise decay
Low-rank projection
Sigmoid activation
StateManager
Fixed memory O(K×V)
Checkpointing
NaN/Inf handling
DPLRTransition
Specialized DPLR
Eigenvalue stability
2× faster vs general
Attention Layers
KDA Layer
Delta rule learning
Chunkwise parallel
WY representation
MLA Layer
Global attention
NoPE design
Multi-head latent
Optimization
CUDA Kernels
Fused operations
Tensor Core usage
Memory tiling
Triton Kernels
Auto-tuning
Python-based
Easy maintenance
Memory Management
Pre-allocated buffers
Efficient reuse
Secondary chunking
Mixed precision
| Component | Time Complexity | Space Complexity | Description |
|---|---|---|---|
| FineGrainedGating | O(B·T·D·rank) | O(D·rank) | Low-rank projection for channel-wise gates |
| StateManager | O(B·H·K·V) | O(B·H·K·V) | Constant per-head memory, scales with batch |
| DPLRTransition | O(B·H·K·V) | O(B·H·K·V) | 2× faster than general DPLR (O(K²·V)) |
| ChunkwiseKDA | O(B·T·K·V + T·C²) | O(B·H·K·V) | Parallel intra-chunk + recurrent inter-chunk |
| Full MLA | O(B·T²·D) | O(B·H·T·K) | Standard attention with linear KV cache growth |
| Hybrid Model | O(B·T·D·V + T²·D/4) | O(B·H·K·V + T·D/4) | 3:1 ratio reduces global attention cost by 75% |
| Decision | Rationale | Trade-offs |
|---|---|---|
| Channel-wise vs Head-wise Gating | More precise memory control, better long-context performance | Slightly higher parameter count (~1%) |
| 3:1 KDA-to-MLA Ratio | Optimal balance of speed and accuracy | Tunable for specific use cases |
| NoPE (No Position Encoding) | Simplifies long-context extension, KDA provides positional bias | Requires careful training schedule |
| Pre-allocated State Buffer | Eliminates allocation overhead, predictable memory | Fixed maximum batch size |
| WY Representation | Efficient Householder matrix products | More complex implementation |
| Secondary Chunking | Numerical stability in log-space | Additional memory overhead |
| Eigenvalue Monitoring | Early detection of training instabilities | Small runtime cost (<1%) |
| Low-rank Gate Projection | Reduces parameters while maintaining expressiveness | Slightly lower capacity |
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','textColor':'#ecf0f1','fontSize':'14px','primaryTextColor':'#ecf0f1','primaryBorderColor':'#3498db'}}}%%
graph TD
A[Context Length: 4K] -->|"MLA: 2.1ms<br/>Kimi: 2.0ms"| B[Speed: 1.05× faster]
C[Context Length: 128K] -->|"MLA: 45.2ms<br/>Kimi: 11.4ms"| D[Speed: 3.98× faster ⚡]
E[Context Length: 512K] -->|"MLA: 182.7ms<br/>Kimi: 79.4ms"| F[Speed: 2.30× faster]
G[Context Length: 1M] -->|"MLA: 365.4ms<br/>Kimi: 125.8ms"| H[Speed: 2.90× faster]
I[Memory @ 128K] -->|"MLA: 16GB<br/>Kimi: 4GB"| J[75% reduction 💾]
K[Memory @ 1M] -->|"MLA: 128GB<br/>Kimi: 32GB"| L[75% reduction 💾]
style A fill:#2c3e50,stroke:#3498db,stroke-width:2px
style C fill:#2c3e50,stroke:#3498db,stroke-width:2px
style E fill:#2c3e50,stroke:#3498db,stroke-width:2px
style G fill:#2c3e50,stroke:#3498db,stroke-width:2px
style I fill:#2c3e50,stroke:#9b59b6,stroke-width:2px
style K fill:#2c3e50,stroke:#9b59b6,stroke-width:2px
style B fill:#34495e,stroke:#27ae60,stroke-width:2px
style D fill:#34495e,stroke:#27ae60,stroke-width:3px
style F fill:#34495e,stroke:#27ae60,stroke-width:2px
style H fill:#34495e,stroke:#27ae60,stroke-width:2px
style J fill:#34495e,stroke:#e74c3c,stroke-width:2px
style L fill:#34495e,stroke:#e74c3c,stroke-width:2px
| Context Length | MLA (ms) | GDN-H (ms) | Kimi Linear (ms) | Speedup vs MLA | Winner |
|---|---|---|---|---|---|
| 4K | 2.1 | 2.0 | 2.0 | 1.05× | 🟰 Tie |
| 128K | 45.2 | 18.3 | 11.4 | 3.98× | ⚡ Kimi |
| 512K | 182.7 | 76.1 | 79.4 | 2.30× | ⚡ Kimi |
| 1M | 365.4 | 150.2 | 125.8 | 2.90× | ⚡ Kimi |
| Context Length | MLA TPOT | Kimi TPOT | Speedup | Insight |
|---|---|---|---|---|
| 4K | 1.85 ms | 1.84 ms | 1.01× | Minimal difference at short context |
| 128K | 4.28 ms | 1.91 ms | 2.24× ⚡ | Linear KV cache starts to dominate |
| 512K | 9.16 ms | 1.87 ms | 4.90× ⚡⚡ | Massive savings from O(1) state |
| 1M | 11.48 ms | 1.84 ms | 6.24× ⚡⚡⚡ | 6× faster decoding! |
Key Insight: Kimi Linear maintains constant TPOT (~1.84ms) regardless of context length, while MLA's TPOT grows linearly. This enables sub-2ms per-token generation even at 1M context!
| Metric | Full Attention (MLA) | Kimi Linear | Reduction | Impact |
|---|---|---|---|---|
| KV Cache @ 4K | 512 MB | 512 MB | 0% | No advantage at short context |
| KV Cache @ 128K | 16.0 GB | 4.0 GB | 75% 💾 | 4× larger batch size possible |
| Peak Memory @ 512K | 64.0 GB | 16.0 GB | 75% 💾 | Fits on single A100 40GB |
| Peak Memory @ 1M | 128.0 GB | 32.0 GB | 75% 💾 | Practical million-token inference |
| State Growth | O(n) per head | O(1) per head | N/A | Bounded memory even at ∞ context |
| Batch Throughput | Limited by KV cache | 4× higher @ 128K | 4× 📊 | Better hardware utilization |
%%{init: {'theme':'dark', 'themeVariables': { 'primaryColor':'#1e1e1e','mainBkg':'#2c3e50','secondBkg':'#34495e','textColor':'#ecf0f1','fontSize':'13px'}}}%%
graph LR
A[Full Attention<br/>O n² complexity] -->|"Short Context<br/>< 4K"| B[✅ Best Accuracy<br/>❌ Slow scaling]
C[Linear Attention<br/>O n complexity] -->|"Medium Context<br/>4K-128K"| D[✅ Fast<br/>❌ Accuracy loss]
E[Kimi Hybrid<br/>O n with sparse O n²] -->|"Long Context<br/>128K-1M"| F[✅ Fast + Accurate<br/>✅ Constant memory]
style A fill:#2c3e50,stroke:#e74c3c,stroke-width:2px
style C fill:#2c3e50,stroke:#f39c12,stroke-width:2px
style E fill:#2c3e50,stroke:#27ae60,stroke-width:3px
style B fill:#34495e,stroke:#3498db,stroke-width:2px
style D fill:#34495e,stroke:#3498db,stroke-width:2px
style F fill:#34495e,stroke:#27ae60,stroke-width:3px
| Task | Context | MLA (Full Attn) | GDN-H (Linear) | Kimi Linear | Winner |
|---|---|---|---|---|---|
| MMLU-Pro | 4K | 47.2 | 47.9 | 51.0 | ✅ Kimi (+3.8) |
| RULER | 128K | 81.3 | 80.5 | 84.3 | ✅ Kimi (+3.0) |
| MATH500 | 4K | 80.8 | 83.0 | 81.2 | 🥈 Kimi (+0.4) |
| AIME 2025 | 4K | 20.6 | 21.1 | 21.3 | ✅ Kimi (+0.7) |
| HumanEval | 4K | 71.3 | 72.0 | 73.2 | ✅ Kimi (+1.9) |
| GPQA | 4K | 44.2 | 43.1 | 43.8 | 🥈 Kimi (-0.4) |
Summary: Kimi Linear achieves better or comparable accuracy to full attention while being 2-6× faster at long context. The hybrid approach avoids the accuracy degradation typical of pure linear attention.
| Batch Size | Context | MLA Tokens/sec | Kimi Tokens/sec | Throughput Gain |
|---|---|---|---|---|
| 1 | 128K | 234 | 524 | 2.24× ⚡ |
| 4 | 128K | 890 | 1987 | 2.23× ⚡ |
| 8 | 128K | OOM | 3840 | ∞ 💥 |
| 1 | 1M | 87 | 543 | 6.24× ⚡⚡⚡ |
| 4 | 1M | OOM | 2048 | ∞ 💥 |
Hardware: A100 80GB, BF16, DeepSpeed ZeRO-3
Key Takeaway: At 1M context, Kimi Linear enables 4× batch size that causes OOM in MLA, unlocking previously impossible workloads.
- Python >= 3.10
- PyTorch >= 2.6
- CUDA >= 12.0 (for GPU acceleration)
- fla-core >= 0.4.0
# Clone the repository
git clone https://github.com/YOUR_USERNAME/kimi-linear.git
cd kimi-linear
# Install dependencies
pip install -r requirements.txt
# Install in development mode
pip install -e .
**Decoding TPOT (Time Per Output Token)**:
- 4K: 1.84ms (Kimi Linear) vs 1.85ms (MLA) = 1.01× speedup
- 1M: 1.84ms (Kimi Linear) vs 11.48ms (MLA) = **6.3× speedup** ⚡
### Memory Efficiency
| Metric | Full Attention (MLA) | Kimi Linear | Reduction |
|--------|----------------------|-------------|-----------|
| KV Cache @ 128K | 16.0 GB | 4.0 GB | **75%** |
| Peak Memory @ 1M | 128.0 GB | 32.0 GB | **75%** |
| State Size per Head | Linear (O(n)) | Constant (dk × dv) | N/A |
### Accuracy Benchmarks
| Task | Context | MLA | GDN-H | Kimi Linear |
|------|---------|-----|-------|-------------|
| MMLU-Pro | 4K | 47.2 | 47.9 | **51.0** ✅ |
| RULER | 128K | 81.3 | 80.5 | **84.3** ✅ |
| MATH500 | 4K | 80.8 | 83.0 | **81.2** |
| AIME 2025 | 4K | 20.6 | 21.1 | **21.3** ✅ |
---
## 🚀 Installation
### Prerequisites
- Python >= 3.10
- PyTorch >= 2.6
- CUDA >= 12.0 (for GPU acceleration)
- fla-core >= 0.4.0
### Option 1: From Source (Recommended for Development)
```bash
# Clone the repository
git clone https://github.com/YOUR_USERNAME/kimi-linear.git
cd kimi-linear
# Install dependencies
pip install -r requirements.txt
# Install in development mode
pip install -e .# Build the Docker image
docker build -t kimi-linear:latest -f docker/Dockerfile .
# Run the container
docker run --gpus all -it kimi-linear:latest# Once published to PyPI
pip install kimi-linearimport torch
from kimi_linear import KimiLinearAttention
# Initialize model
model = KimiLinearAttention(
dim=1024,
num_heads=16,
head_dim=128,
hybrid_ratio=3, # 3 KDA layers per 1 MLA layer
)
# Forward pass
x = torch.randn(1, 4096, 1024) # (batch, seq_len, dim)
output = model(x)
print(f"Output shape: {output.shape}") # (1, 4096, 1024)from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain Kimi Linear in simple terms."}
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
generated_ids = model.generate(inputs=input_ids, max_new_tokens=500)
response = tokenizer.batch_decode(generated_ids)[0]
print(response)# Run all benchmarks
python scripts/benchmark/run_benchmarks.py --model kimi-linear --baseline mla
# Run specific benchmark
python scripts/benchmark/run_benchmarks.py --task mmlu-pro --context-length 4096
# Profile performance
python scripts/profiling/profile_attention.py --kernel kda --chunk-size 64# Run all tests
pytest tests/
# Run unit tests only
pytest tests/unit/
# Run with coverage
pytest --cov=src --cov-report=html tests/
# Run synthetic tasks
python tests/synthetic/test_palindrome.py
python tests/synthetic/test_mqar.py
python tests/synthetic/test_stack.pykimi-linear/
├── src/ # Source code
│ ├── kda/ # Kimi Delta Attention implementation
│ │ ├── gating.py # Fine-grained gating mechanism
│ │ ├── state_manager.py # State tracking and updates
│ │ ├── wy_representation.py # WY representation for rank-1 updates
│ │ ├── ut_transform.py # UT transform
│ │ ├── chunk_update.py # Chunkwise state updates
│ │ └── dplr.py # DPLR variant implementation
│ ├── attention/ # Attention mechanisms
│ │ ├── linear_attention.py # Base linear attention
│ │ ├── delta_rule.py # Delta rule learning
│ │ └── mla.py # Multi-Head Latent Attention
│ ├── models/ # Model architectures
│ │ ├── kimi_linear.py # Hybrid Kimi Linear model
│ │ ├── projections.py # Input projections
│ │ ├── conv_layer.py # Short convolution
│ │ └── gating.py # Output/forget gates
│ ├── kernels/ # Optimized kernels
│ │ ├── kda_fused_kernel.cu # CUDA implementation
│ │ └── kda_triton.py # Triton implementation
│ ├── utils/ # Utility functions
│ │ ├── performance_logger.py
│ │ └── memory_monitor.py
│ └── benchmarks/ # Benchmark utilities
├── tests/ # Test suite
│ ├── unit/ # Unit tests
│ ├── integration/ # Integration tests
│ └── synthetic/ # Synthetic task tests
├── scripts/ # Scripts
│ ├── setup/ # Setup scripts
│ ├── benchmark/ # Benchmarking scripts
│ └── profiling/ # Profiling tools
├── docs/ # Documentation
│ ├── api/ # API documentation
│ ├── tutorials/ # Tutorials and guides
│ ├── architecture/ # Architecture docs
│ └── project-plan.md # Comprehensive project plan
├── data/ # Data directory
│ ├── synthetic/ # Synthetic test data
│ ├── benchmarks/ # Benchmark results
│ └── results/ # Experimental results
├── assets/ # Assets (figures, diagrams)
├── docker/ # Docker configurations
├── .github/ # GitHub-specific files
│ └── workflows/ # CI/CD workflows
├── .copilot/ # Copilot configurations
├── .vscode/ # VS Code settings
├── memory-bank/ # Memory bank system
│ ├── app-description.md # Project description
│ ├── change-log.md # Change log
│ └── implementation-plans/ # Implementation plans
├── configs/ # Configuration files
├── requirements.txt # Python dependencies
├── setup.py # Package setup
├── pyproject.toml # Project metadata
├── .gitignore # Git ignore rules
├── .editorconfig # Editor configuration
├── LICENSE # MIT License
└── README.md # This file
# Install development dependencies
pip install -r requirements-dev.txt
# Install pre-commit hooks
pre-commit install
# Run code formatting
black src/ tests/
isort src/ tests/
# Run linting
pylint src/
flake8 src/
# Run type checking
mypy src/cd docs/
make html
# Documentation will be in docs/_build/html/# Build development image
docker build -t kimi-linear:dev -f docker/Dockerfile.dev .
# Run with GPU and mounted source
docker run --gpus all -v $(pwd):/workspace -it kimi-linear:dev bash- Python: Follow PEP 8, use Black formatter (88 char line length)
- C++: Follow Google C++ Style Guide (100 char line length)
- Java: Follow Google Java Style Guide
- Naming Conventions:
- Functions/methods:
snake_case - Classes:
PascalCase - Constants:
UPPER_SNAKE_CASE - Private members:
_leading_underscore
- Functions/methods:
# Full benchmark suite (requires GPU with 24GB+ VRAM)
python scripts/benchmark/run_benchmarks.py \
--models kimi-linear mla gdn-h \
--context-lengths 4096 32768 131072 524288 1048576 \
--tasks all \
--output-dir data/benchmarks/results
# Quick benchmark (lighter tests)
python scripts/benchmark/run_benchmarks.py \
--models kimi-linear mla \
--context-lengths 4096 32768 \
--tasks mmlu-pro ruler \
--quick# Palindrome test (sequence reversal)
python tests/synthetic/test_palindrome.py --lengths 256 512 1024 2048
# MQAR test (associative recall)
python tests/synthetic/test_mqar.py --num-queries 5 10 20
# Stack test (state tracking)
python tests/synthetic/test_stack.py --num-stacks 64 --sequence-length 1024# Kernel profiling with Nsight Compute
ncu --set full python scripts/profiling/profile_kernels.py
# System profiling with Nsight Systems
nsys profile -o profile.nsys-rep python scripts/profiling/profile_system.py
# Memory profiling
python scripts/profiling/profile_memory.py --max-context 1048576Comprehensive documentation is available in the docs/ directory:
- Quick Start Guide: Get started in 5 minutes
- API Reference: Complete API documentation
- Architecture Guide: Deep dive into the architecture
- Training Guide: Training your own models
- Advanced Usage: Custom kernels and optimizations
- Project Plan: Comprehensive development roadmap
- Research Paper: Kimi Linear Technical Report
- Original Implementation: MoonshotAI/Kimi-Linear
- FLA Kernels: fla-org/flash-linear-attention
- Pre-trained Models: HuggingFace Hub
We welcome contributions! Please see our Contributing Guidelines for details.
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Make your changes
- Run tests (
pytest tests/) - Commit your changes (
git commit -m 'Add amazing feature') - Push to branch (
git push origin feature/amazing-feature) - Open a Pull Request
- 🔴 Critical: Core functionality, bug fixes, performance regressions
- 🟠 High: New features, optimizations
- 🟡 Medium: Documentation improvements, refactoring
- 🟢 Low: Code style, minor enhancements
If you use Kimi Linear in your research, please cite:
@misc{team2025kimi,
title = {Kimi Linear: An Expressive, Efficient Attention Architecture},
author = {Zhang, Yu and Lin, Zongyu and Yao, Xingcheng and Hu, Jiaxi and others},
year = {2025},
eprint = {2510.26692},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Moonshot AI for the original Kimi Linear research and implementation
- FLA Team for the flash linear attention kernels
- DeepSeek for MLA architecture insights
- Community contributors for feedback and improvements
- Issues: GitHub Issues
- Discussions: GitHub Discussions
⭐ Star this repository if you find it useful!
Made with ❤️ by the Kimi Linear Optimization Team