Skip to content

[Kernel Gap Analysis] DeepSeek-V4 AMD Porting — Block-by-Block Kernel Inventory #807

@ChuanLi1101

Description

@ChuanLi1101

Cross-posting from internal kernel-team analysis by @carlushuang for broader visibility. The source thread tracks per-kernel ownership and ETAs internally; this issue is the public-facing technical inventory only.


Executive Summary

This issue inventories every kernel/operation needed to run DeepSeek-V4 (Flash + Pro) on AMD MI3xx/MI350X/MI355X, cross-referenced against:

  • HF reference (deepseek-ai/DeepSeek-V4-Flash) — TileLang kernels
  • vLLM implementation — FlashMLA/FlashInfer + custom fused kernels + multi-stream
  • Current AMD stack — ATOM (model), aiter (kernels), CK (GEMM)

Bottom line: 5 kernels are functionally missing on AMD (sparse attention, compressor, Hadamard, Sinkhorn, FP4 upcast-at-load). 3 fusion opportunities are proven by vLLM but not yet implemented in aiter. FP4 expert GEMM works via FusedMoE but needs a native/upcast path for peak throughput.

Model Config Summary

Parameter V4-Flash V4-Pro
Total / Active params 284B / 13B 1.6T / 49B
hidden_size (dim) 4096 7168
num_hidden_layers 43 61
num_attention_heads 64 128
head_dim 512 512
q_lora_rank 1024 1536
o_lora_rank / o_groups 1024 / 8 1024 / 16
qk_rope_head_dim 64 64
sliding_window 128 128
n_routed_experts / n_shared 256 / 1 384 / 1
num_experts_per_tok 6 6
moe_intermediate_size 2048 3072
expert_dtype FP4 (e2m1) FP4 (e2m1)
Attention/dense dtype FP8 (e4m3, block 128×128, E8M0 scale) FP8 (e4m3, block 128×128, E8M0 scale)
scoring_func sqrtsoftplus sqrtsoftplus
topk_method noaux_tc noaux_tc
swiglu_limit 10.0 10.0
hc_mult / hc_sinkhorn_iters 4 / 20 4 / 20
index_n_heads / index_head_dim 64 / 128 64 / 128
index_topk 512 1024
compress_ratios [0,0,4,128,...,4,128,4,0] (44 entries) [128,128,4,128,...,4,128,4,0] (62 entries)
routed_scaling_factor 1.5 2.5
num_hash_layers 3 3
max_position_embeddings 1,048,576 (1M) 1,048,576 (1M)
RoPE theta=10K (SWA), 160K+YaRN (compressed) theta=10K (SWA), 160K+YaRN (compressed)
MTP layers 1 1
Min hardware (vLLM) 4× B200/B300 8× B200/B300

Notation: In all kernel tables below:
- T = number of input tokens (sequence length)
- H = number of attention heads (H=64 for Flash, H=128 for Pro)
- G = number of O-projection groups (G=8 for Flash, G=16 for Pro)
- SK = number of gathered KV tokens per query (see sparse_attn row for case-by-case formulas)
- max_B = maximum batch size (number of concurrent sequences the pre-allocated buffers can hold)
- max_S = maximum sequence length (up to 1,048,576 for 1M context; determines KV cache and compressor buffer sizes at init)
Where Flash and Pro differ, both shapes are shown. Where only one shape is shown, both variants use the same dimensions.
Note: max_B and max_S appear in statically pre-allocated buffers (HF/ATOM eager mode). vLLM/SGLang use paged KV cache with dynamic block allocation instead.

For per-query KV token counts, see ATOM#664 Section 7.


1. MLA / Attention Block

# Operation Flash Shapes Pro Shapes Comments
1 wq_a GEMM (FP8 weight) [T,4096] @ [4096,1024][T,1024] [T,7168] @ [7168,1536][T,1536] NV: FlashMLA fused pre-attn
AMD: ✅ CK GemmMXFP8
2 Q RMSNorm (BF16) [T,1024] [T,1536] NV: Fused in pre-attn
AMD: ✅ aiter RMSNorm
3 wq_b GEMM (FP8 weight) [T,1024] @ [1024,32768][T,32768] [T,1536] @ [1536,65536][T,65536] NV: FlashMLA fused pre-attn
AMD: ✅ CK GemmMXFP8
4 Per-head Q Norm (BF16) [T,H,512] H=64 [T,H,512] H=128 NV: Fused in pre-attn
AMD: ⚠️ PyTorch fallback
5 Q RoPE (BF16) [T,H,64] H=64 [T,H,64] H=128 NV: FlashMLA
AMD: ✅ Reuse existing RoPE
6 wkv GEMM (FP8 weight) [T,4096] @ [4096,512][T,512] [T,7168] @ [7168,512][T,512] NV: Fused Q/KV pre-attn
AMD: ✅ CK GemmMXFP8
7 KV RMSNorm (BF16) [T,512] [T,512] NV: Fused in pre-attn
AMD: ✅ aiter RMSNorm
8 KV RoPE (BF16) [T,64] (last dims) [T,64] (last dims) NV: Fused Q norm + KV RoPE + K insert (~10–20× speedup)
AMD: ✅ Reuse existing
9 KV act_quant (BF16→FP8) [T,448] → FP8 e4m3 [T,448] → FP8 e4m3 NV: Fused in pre-attn
AMD: ✅ aiter quant
10 sparse_attn (Q: BF16, KV: mixed BF16/FP8) Q [T,H,512] x KV [SK,512][T,H,512]
CSA: SK = 128 + min(512, T/4)
HCA: SK = 128 + T/128
(H=64)
Q [T,H,512] x KV [SK,512][T,H,512]
CSA: SK = 128 + min(1024, T/4)
HCA: SK = 128 + T/128
(H=128)
NV: FlashMLA/FlashInfer custom sparse backend
AMD: ❌ Torch fallback (atom/model_ops/sparse_attn_v4.py)
11 attn_sink bias (FP32) [H] H=64 [H] H=128 NV: Part of sparse attn kernel
AMD: ❌ Part of torch fallback
12 Inverse RoPE (BF16) [T,H,64] H=64 [T,H,64] H=128 NV: Fused inverse RoPE + FP8 quant (~2–3×)
AMD: ⚠️ PyTorch fallback
13 wo_a grouped LoRA (BF16) einsum [T,G,4096] × [G,1024,4096][T,G,1024] G=8 einsum [T,G,4096] × [G,1024,4096][T,G,1024] G=16 NV: Custom grouped matmul
AMD: ⚠️ PyTorch einsum
Requires view() before: attn output [T,H,512][T,G,H/G×512] = [T,G,4096] (H/G = 8 heads/group for both Flash and Pro)
14 wo_b GEMM (FP8 weight) [T,8192] @ [8192,4096][T,4096] [T,16384] @ [16384,7168][T,7168] NV: FlashMLA/standard GEMM
AMD: ✅ CK GemmMXFP8
Requires flatten(2) before: wo_a output [T,G,1024][T,G×1024] = [T,8192] (Flash) / [T,16384] (Pro)
15 Input/output RMSNorm (BF16) [T,4096] [T,7168] NV: Standard
AMD: ✅ aiter RMSNorm

1.1 sparse_attn Deep Dive (P0 — Biggest Gap)

What it does: Index-gather attention — NOT standard FlashAttention.

# For each (batch, seq_pos):
#   1. Gather KV[topk_idxs] by index (topk = 640 Flash CSA, 160 Flash HCA at 4K)
#   2. Q @ KV_gathered^T → scores
#   3. Online softmax (running max + sum exp)
#   4. Add attn_sink bias AFTER normalization
#   5. Softmax scores @ KV_gathered → output

Why generic FA won't work: Dense FA mask is O(S²) memory — prohibitive for 1M context. The gather indices are sparse and query-dependent.

vLLM approach: Integrates FlashMLA/FlashInfer backends but still requires custom sparse-gather logic.

AMD need: Custom aiter kernel with:

  • Index gather by topk_idxs (can be -1 for invalid)
  • Online softmax over gathered KV
  • attn_sink as pseudo-token added to denominator
  • Support both Flash (topk=512) and Pro (topk=1024) for CSA; HCA uses topk=8192 (effectively all entries for ≤1M)

Perf target: Memory-bound on KV gather. Optimize gather coalescing above compute.


2. Indexer Block

The Indexer runs on a separate CUDA stream in vLLM, overlapped with main attention prep.

# Operation Flash Shapes Pro Shapes Comments
1 Indexer Compressor (FP4) See Compressor block (head_dim=128, rotate=True, FP4 output) See Compressor block (head_dim=128, rotate=True, FP4 output) NV: Fused Compressor+Norm+RoPE+insert
AMD: ❌ PyTorch impl
2 wq_b GEMM (FP8) : [T,1024] @ [1024,8192][T,8192] [T,1536] @ [1536,8192][T,8192] NV: On separate stream
AMD: ✅ CK GemmMXFP8
3 Unflatten + RoPE (BF16) [4096,64,128] → RoPE on last 64 dims [4096,64,128] → RoPE on last 64 dims NV: On separate stream
AMD: ✅ Reuse
4 rotate_activation (Hadamard) (BF16) [T,64,128] [T,64,128] NV: On separate stream
AMD: ❌ Torch fallback
5 fp4_act_quant (BF16→FP4) [T,64,128] → FP4 [T,64,128] → FP4 NV: On separate stream
AMD: ❌ Torch fallback
6 weights_proj GEMM (BF16) : [T,4096] @ [4096,64][T,64] [T,7168] @ [7168,64][T,64] NV: On separate stream
AMD: ✅ CK GEMM
7 Einsum scoring (FP4 Q, BF16 KV) [T,64,128] @ [T/4,128]^T[T,64,T/4] [T,64,128] @ [T/4,128]^T[T,64,T/4] NV: On separate stream
AMD: ⚠️ PyTorch matmul (FP4 upcast first)
8 ReLU + weighted sum (FP32) [4096,64,1024] × [4096,64,1] sum(dim=2) → [4096,1024] [4096,64,1024] × [4096,64,1] sum(dim=2) → [4096,1024] NV: On separate stream
AMD: ⚠️ PyTorch
9 Causal mask + top-k (FP32) [T,T/4][T,512] indices [T,T/4][T,512] indices NV: On separate stream
AMD: ⚠️ PyTorch topk
10 Top-k offset + mask (int32) [T,512] → offset, invalidate lookahead [T,512] → offset, invalidate lookahead NV: On separate stream
AMD: ⚠️ PyTorch

2.1 Indexer Memory at 1M Context

Important: At full 1M prefill, index_score = [1, 1M, 250K] at FP32 ≈ 1 TB — infeasible in a single pass. Chunked prefill is mandatory. The table below shows per-chunk shapes at a typical chunk size of T=4K (matching ATOM#664 examples), plus the full-context column for reference.

Tensor Shape (T=4K chunk) Bytes (chunk) Shape (T=1M, full) Bytes (full, theoretical)
Q (FP4) [1,4096,64,128] ~16 MB [1,1M,64,128] ~4 GB
index_score [1,4096,1024] ~16 MB (FP32) [1,1M,250K] ~1 TB (FP32)
topk_idxs [1,4096,512] ~8 MB (int32) [1,1M,512] ~2 GB (int32)

index_score is the largest intermediate. FP4 precision for Q/KV is essential, and chunked prefill caps the per-chunk index_score to manageable sizes.


3. Compressor Block

There are two Compressor instances per CSA layer (Attention + Indexer) and one per HCA layer.

Three compressor variants share the same pipeline (project → ape → reshape → pool → norm → RoPE → quant → cache write) but differ in head_dim, compress_ratio, overlap, rotation, and output quantization. All shapes are shown in one table below.

# Operation CSA Attention Compressor
(ratio=4, overlap=True, head_dim=512)
CSA Indexer Compressor
(ratio=4, overlap=True, head_dim=128)
HCA Compressor
(ratio=128, overlap=False, head_dim=512)
Config coff=2, compress_ratio=4
rotate=False
coff=2, compress_ratio=4
rotate=True
coff=1, compress_ratio=128
rotate=False
1 wkv GEMM (FP32) [T,dim] @ [dim, coff×hd][T,1024]
F:[T,4096]@[4096,1024]
P:[T,7168]@[7168,1024]
[T,dim] @ [dim, coff×hd][T,256]
F:[T,4096]@[4096,256]
P:[T,7168]@[7168,256]
[T,dim] @ [dim,hd][T,512]
F:[T,4096]@[4096,512]
P:[T,7168]@[7168,512]
2 wgate GEMM (FP32) [T,dim] @ [dim,1024][T,1024] [T,dim] @ [dim,256][T,256] [T,dim] @ [dim,512][T,512]
3 + ape (FP32) [T,1024] += broadcast [4,1024] [T,256] += broadcast [4,256] [T,512] += broadcast [128,512]
4 Unflatten [T/4, 4, 1024] [T/4, 4, 256] [T/128, 128, 512]
5 Overlap transform (FP32) [T/4,4,1024][T/4,8,512]
(4 prev first-halves + 4 curr second-halves)
[T/4,4,256][T/4,8,128]
(4 prev first-halves + 4 curr second-halves)
N/A (no overlap, coff=1)
6 Gated pooling (FP32) score.softmax(dim=-2) over 8 positions
(kv * score).sum(dim=-2)[T/4,512]
score.softmax(dim=-2) over 8 positions
(kv * score).sum(dim=-2)[T/4,128]
score.softmax(dim=-2) over 128 tokens
(kv * score).sum(dim=-2)[T/128,512]
7 RMSNorm (BF16) [T/4,512] [T/4,128] [T/128,512]
8 RoPE (last 64) (BF16) [T/4,64] at position 4j [T/4,64] at position 4j [T/128,64] at position 128j
9a Hadamard rotation (BF16) N/A (rotate=False) [T/4,128] → Hadamard → [T/4,128]
(full-vector before FP4 quant)
N/A (rotate=False)
9b Output quantization act_quant on non-RoPE dims:
[T/4,448]FP8 e4m3
(RoPE 64 dims stay BF16)
fp4_act_quant on full vector:
[T/4,128]FP4 e2m1
(no RoPE/non-RoPE split)
act_quant on non-RoPE dims:
[T/128,448]FP8 e4m3
(RoPE 64 dims stay BF16)
10 Write to kv_cache [T/4,512] → Attn kv_cache [T/4,128] → Indexer kv_cache [T/128,512] → Attn kv_cache

3.1 KV Cache Shapes

Compressor kv_cache Shape Entry Size Entries at 1M Context
CSA Attention [max_B, max_S//4, 512] 512 (448 FP8 + 64 BF16) 262,144
CSA Indexer [max_B, max_S//4, 128] 128 (FP4 packed) 262,144
HCA Attention [max_B, max_S//128, 512] 512 (448 FP8 + 64 BF16) 8,192

3.2 Persistent Buffer (kv_state / score_state) — Decode Only

Compressor Buffer Shape sliding_window Purpose
CSA Attention [max_B, 8, 1024] 8 (= coff × ratio = 2×4) Double-buffer: slots 0-3 = previous window, 4-7 = current window
CSA Indexer [max_B, 8, 256] 8 (= coff × ratio = 2×4) Same double-buffer, smaller head_dim
HCA [max_B, 128, 512] 128 (= coff × ratio = 1×128) Rolling 128-token window for non-overlapping compression

vLLM treats compressor state as sliding-window KV with sliding_window = coff * compress_ratio, registered under the same hybrid KV cache manager. Shift operation after each compression: kv_state[:, :ratio] = kv_state[:, ratio:] (current becomes previous).


4. MoE Block

# Operation Flash Shapes Pro Shapes Comments
1 Gate scoring (FP8) : [B,T,4096] @ [4096,256][B,T,256] [B,T,7168] @ [7168,384][B,T,384] NV: FlashInfer fused dispatch
AMD: ✅ aiter select_experts
2 sqrtsoftplus (FP32) : [B,T,256] [B,T,384] NV: Fused in dispatch
AMD: ✅ aiter
3 e_score_correction_bias (FP32) : [256] [384] added NV: Fused in dispatch
AMD: ✅ aiter
4 Top-k select (FP32) [1,256] → top-6 [1,384] → top-6 NV: Fused in dispatch
AMD: ✅ aiter
5 Hash routing (layers 0-2) (int32) tid2eid[input_ids] lookup tid2eid[input_ids] lookup NV: Embedding gather
AMD: ✅ PyTorch gather
6 noaux_tc (—) Load-balancing constraint Load-balancing constraint NV: Fused in dispatch
AMD: ✅ aiter
7 Expert dispatch (FP4) Gather 6 experts' weights Gather 6 experts' weights NV: FlashInfer
AMD: ✅ aiter FusedMoE
8 FP4 GEMM (w1/w3) (FP8 act × FP4 weight) [6,4096] @ [4096,2048] (per expert) [6,7168] @ [7168,3072] (per expert) NV: Custom FP4 backend
AMD: ⚠️ aiter FusedMoE (triton/CK path)
9 FP4 GEMM (w2) (FP8 act × FP4 weight) [6,2048] @ [2048,4096] (per expert) [6,3072] @ [3072,7168] (per expert) NV: Custom FP4 backend
AMD: ⚠️ aiter FusedMoE (triton/CK path)
10 SwiGLU + clamp (BF16) silu(clamp(gate, max=10)) * clamp(up, -10, 10) silu(clamp(gate, max=10)) * clamp(up, -10, 10) NV: Fused in expert FFN
AMD: ⚠️ triton post-kernel clamp (gfx950 workaround)
Note: silu(x) = x * sigmoid(x), gate clamp is one-sided (max only)
11 Shared expert (FP4) Always-active expert Always-active expert NV: Same as routed
AMD: ✅ aiter
12 Accumulator reduce (BF16) Sum 6 expert outputs + shared expert Sum 6 expert outputs + shared expert NV: Fused in dispatch
AMD: ✅ aiter

4.1 FP4 GEMM Deep Dive (P0)

HF reference:

# C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T
# Act scale:  per-1×128 on K-dim, E8M0/FP32
# Weight scale: per-1×32 on K-dim, E8M0
# B stored as [N, K//2] — 2 FP4 values packed per byte along K

vLLM: Ships with native FP4 MoE weights. Uses custom FP4 backend (likely upcast-at-load).

AMD current: aiter FusedMoE works via:

  • gfx942 (MI300X): ASM/CK path
  • gfx950 (MI355X): Triton path with CDNA4MXScaleLayout
  • swiglu_limit applied in triton post-kernel (not in fused SwiGLU — 9× amplitude loss bug)

AMD gap: No native FP4 tensor core on MI3xx confirmed. Options:

  1. Upcast-at-load (like TileLang): Load FP4 sub-blocks → upcast to FP8 in shared mem → run CK GemmMXFP8. Recommended for bring-up.
  2. Pre-upcast weights to FP8: ~2× weight memory overhead.
  3. Native FP4: Verify MI350+ sub-byte support first.

Weight shapes (V4-Flash):

w13_weight:       (257, 4096, 2048)   torch.float4_e2m1fn_x2   # 256 routed + 1 shared
w13_weight_scale: (257, 4096, 128)    torch.float8_e8m0fnu     # per-1×32 along K
w2_weight:        (257, 4096, 1024)   torch.float4_e2m1fn_x2
w2_weight_scale:  (257, 4096, 64)     torch.float8_e8m0fnu

Weight shapes (V4-Pro):

w13_weight:       (385, 6144, 3584)   torch.float4_e2m1fn_x2   # 384 routed + 1 shared
w13_weight_scale: (385, 6144, 224)    torch.float8_e8m0fnu     # per-1×32 along K
w2_weight:        (385, 7168, 1536)   torch.float4_e2m1fn_x2
w2_weight_scale:  (385, 7168, 96)     torch.float8_e8m0fnu

5. mHC (Manifold-Constrained Hyper-Connections) Block

# Operation Flash Shapes Pro Shapes Comments
1 hc_pre inline norm (FP32) rsqrt(mean(x²) + eps) on [B,T,hc×dim] = [B,T,16384] [B,T,28672] NV: Standard
AMD: ⚠️ PyTorch inline (not a separate RMSNorm module — operates on flattened hc_mult × dim vector)
2 hc_fn Linear (FP32) [B,T,16384] @ [16384,24][B,T,24] [B,T,28672] @ [28672,24][B,T,24] NV: Standard
AMD: ⚠️ PyTorch (input is [B,T,hc_mult,dim] flattened to [B,T,hc_mult×dim]; hc_fn is nn.Parameter([24, hc_mult×dim]), not a CK GEMM candidate due to tiny output dim)
3 hc_split_sinkhorn (FP32) 24-element mixes → pre[4], post[4], comb[4,4] 24-element mixes → pre[4], post[4], comb[4,4] NV: Torch fallback
AMD: ⚠️ aiter mhc_pre (needs #2916)
4 20-iter Sinkhorn (FP32) Row/col normalize 4×4 matrix, 20× Row/col normalize 4×4 matrix, 20× NV: Inside mhc_pre
AMD: ✅ aiter mhc_pre (HIP kernel)
5 hc_post combine (BF16) post * x + comb * residual (4 copies) post * x + comb * residual (4 copies) NV: Standard
AMD: ✅ aiter mhc_post

5.1 Sinkhorn Bottleneck

Cost per forward: 20 iterations × 43 layers (Flash) = 860 loops.

  • PyTorch fallback: ~100–500µs per call → 86–430ms total
  • aiter mhc_pre kernel: ~1–5µs per call → 0.9–4.3ms total

aiter's mhc_pre is already kernelized. The main blocker was device-mismatch bug (#2916, now fixed). No new kernel needed here — just the bugfix merge.

5.2 Fused Block Kernel (Future Optimization)

vLLM has not yet fused the full block, but it's the next logical step:

__global__ void block_forward(...) {
    // 1. hc_pre: Sinkhorn → 1 copy
    // 2. RMSNorm
    // 3. Attention or MoE
    // 4. hc_post: combine with 4-copy residual
}

This would eliminate 4 kernel launches per layer.


6. Quantization & Utility Kernels

# Kernel Input Output Block Size Comments
1 act_quant BF16 FP8 e4m3 per-1×128 NV: Fused in pre-attn
AMD: ✅ aiter quant
2 fp4_act_quant BF16 FP4 e2m1 per-1×32 NV: Fused in indexer stream
AMD: ❌ Torch fallback
3 rotate_activation (Hadamard) BF16 BF16 NV: Fused in indexer stream
AMD: ❌ Torch fallback
4 dequant_fp4_e2m1 uint8 packed + E8M0 scale BF16 per-1×32 NV: ATOM utility
AMD: ✅ ATOM dequant_fp4_e2m1()
5 Inverse RoPE + FP8 quant BF16 FP8 NV: Fused (~2–3× speedup)
AMD: ⚠️ Separate ops
6 YaRN RoPE scaling NV: For 1M context
AMD: ✅ Reuse existing YaRN

7. vLLM-Proven Fusions (Not Yet in aiter)

vLLM has deployed these fusions with measured speedups. They are targets for aiter/CK on AMD:

Fusion Stages vLLM Speedup AMD Status
Compressor fusion Compressor + RMSNorm + RoPE + cache insert ~1.4–3× ❌ Not in aiter
Inverse RoPE + FP8 quant De-rotate + quant before o_lora ~2–3× ❌ Not in aiter
Q/KV pre-attn fusion Q norm + KV RoPE + K insert (warp-dispatched) ~10–20× ❌ Not in aiter
Indexer multi-stream Indexer pipeline main KV compression + SWA insert

8. AMD Gap Summary & Priority Matrix

Missing Kernels (Need New Implementation)

Kernel Current Fallback Perf Impact Effort Owner Suggestion
sparse_attn Torch index-gather + online softmax 🔴 Major — defeats V4's purpose at 1M context 3–4 weeks AITER team
Compressor (CSA + Indexer) PyTorch gated pooling + overlap 🟡 Medium — bandwidth bound, but layers add up 2–3 weeks AMD kernel + aiter
fp4_act_quant PyTorch rounding + pack 🟡 Medium — on indexer critical path 1 week AMD kernel
rotate_activation (Hadamard) PyTorch matmul with Hadamard matrix 🟢 Low — not on hottest path 3–5 days AMD kernel
FP4 upcast-at-load GEMM FusedMoE triton/CK (pre-upcast) 🟡 Medium — 2× weight bandwidth vs native 2–3 weeks AMD kernel

Working but Need Hardening

Component Status Notes
CK GemmMXFP8 ✅ Working FP8 act × FP8 weight with E8M0 scales
aiter FusedMoE ✅ Working FP4 experts via triton (gfx950) or CK (gfx942)
aiter mhc_pre/post ✅ Working Needs #2916 bugfix merged
aiter RMSNorm ✅ Working Standard path
ATOM model ✅ Working Single-sequence eager mode, 512-token coherent gen

9. Kernel Implementation Inventory

Implementation approach legend:

  • ASM — Pure GPU assembly (hand-optimized microkernels) - contact: @niels-zhang
  • flyDSL — AMD DSL for declarative patterns (elementwise, reduction, transforms) - contact: @coderfeli
  • Triton/Gluon — Triton-based kernels (aiter’s fused-op path) - contact: @vgokhale
  • HIP/C++ — Raw HIP kernel, Opus C++ kernel, or Composable Kernel (CK) - contact: @valarLip @carlushuang

(1) Attention

  • sparse_attn — ❌ Torch fallback (P0 gap)
    • Index-gather + online softmax + attn_sink bias over variable topk gathered KV entries
    • NOT FlashAttention — sparse, query-dependent gather indices
    • Must handle -1 invalid indices, attn_sink as pseudo-token in softmax denominator
    • See §1.1 for full spec
    • ☑️ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++

(2) Compression

  • Indexer/Compressor kernel

    • CSA Flash: Q[T,64,512] × KV[640,512], CSA Pro: Q[T,128,512] × KV[1152,512]
    • HCA: Q[T,H,512] × KV[128+T/128, 512] (grows with context)
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • Indexer/Compressor misc (linear, elem-wise, fusions) — ❌ Torch fallback

    • CSA (ratio=4, overlap=True): wkv/wgate FP32 GEMM → +ape broadcast → overlap transform [T/4,4,1024]→[T/4,8,512] → softmax gated pooling → RMSNorm → RoPE → act_quant FP8 → cache write (§3.1)
    • HCA (ratio=128, overlap=False): same but no overlap transform, pool over 128 tokens (§3.3)
    • Indexer compressor (head_dim=128, rotate=True): same pipeline but Hadamard → FP4 quant instead of split FP8 (§3.2)
    • vLLM fuses entire pipeline into one kernel (~1.4–3× speedup)
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++

(3) KV Cache Codesign

DeepSeek-V4 uses a hybrid KV cache with three storage paths:

Path Compression Token Sparsity Cache Shape (Flash)
CSA Attention 4:1 (overlap) topk = 512–640 [max_B, max_S//4, 512]
CSA Indexer 4:1 (overlap) topk = 512 [max_B, max_S//4, 128]
HCA 128:1 (no overlap) topk = 8192 (≈all) [max_B, max_S//128, 512]

Buffer mechanics:

  • kv_state / score_state double-buffer for CSA: [max_B, 8, 256] or [max_B, 8, 1024]
  • HCA rolling window: [max_B, 128, 512]
  • vLLM integrates both under a single hybrid KV cache manager with sliding_window = coff * compress_ratio
  • ⬜ ASM    ⬜ flyDSL    ⬜ Triton/Gluon    ☑️ HIP/C++

Source: §3.3–3.4 for detailed shape traces and vLLM integration

(4) MoE

  • A8W4 (FP8 act × FP4 weight) expert GEMM — includes SwiGLU + clamp⚠️ aiter FusedMoE

    • Expert w1/w3/w2 GEMM + fused SwiGLU + clamp (§4 rows 8-10)
    • Flash: [6,4096]@[4096,2048] / [6,2048]@[2048,4096] per expert
    • Pro: [6,7168]@[7168,3072] / [6,3072]@[3072,7168] per expert
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ⬜ HIP/C++
  • sqrtsoftplus scoring — ✅ aiter

    • sqrt(softplus(x)) on [B,T,256/384] (§4 row 2)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • Hash routing (layers 0–2) — ✅ PyTorch gather

    • tid2eid[input_ids] lookup (§4 row 5)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • MegaMoE (Expert-Parallel All-to-All + FP4 GEMM) — ❌ Not available on AMD (mid-term target)

    • vLLM's DeepseekV4MegaMoEExperts: fused all-to-all dispatch + fp8_fp4_mega_moe() via DeepGEMM, symmetric NCCL buffer management
    • Combines inter-GPU expert routing (NCCL-based) with per-expert FP8×FP4 GEMM in a single fused op
    • NV: DeepGEMM fp8_fp4_mega_moe() + get_symm_buffer_for_mega_moe() (requires SM100/Blackwell)
    • AMD dependency: mori's NCCL-GIN equivalent for the symmetric communication buffer / all-to-all overlap
    • AMD GEMM dependency: aiter/CK A8W4 expert GEMM (see above)
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++

(5) GEMM

  • A8W8 (FP8 act × FP8 weight) — ✅ CK GemmMXFP8

    • Attention: wq_a, wq_b, wkv, wo_b (§1 rows 1,3,6,14)
    • Indexer: wq_b (§2 row 2)
    • MoE: gate scoring projection (§4 row 1)
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • Grouped matmul (BF16)⚠️ PyTorch einsum

    • Attention: wo_a grouped LoRA — einsum("bsgd,grd->bsgr") (§1 row 13)
    • [T,G,4096] × [G,1024,4096][T,G,1024], G=8 (Flash) / G=16 (Pro)
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • Small BF16 matmul⚠️ PyTorch

    • Indexer: weights_proj [T,dim]@[dim,64] (§2 row 6) — tiny output dim
    • mHC: hc_fn [T,hc×dim]@[hc×dim,24] (§5 row 2) — tiny output dim
    • ⬜ ASM    ☑️ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++

(6) Norm

  • RMSNorm — ✅ aiter

    • Attention: Q norm [T,1024/1536], KV norm [T,512], input/output norms [T,dim] (§1 rows 2,7,15)
    • Compressor: post-pool norm [T/4,512] (§3.1 row 6)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • Per-head Q Norm (inline) — ⚠️ PyTorch fallback

    • Attention: q *= rsqrt(mean(q², dim=-1) + eps) on [T,H,512] (§1 row 4)
    • vLLM fuses into Q/KV pre-attn kernel (~10–20× speedup)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • mHC inline norm (FP32) — ⚠️ PyTorch

    • rsqrt(mean(x²) + eps) on flattened [T, hc×dim] = [T,16384/28672] (§5 row 1)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++

(7) mHC

  • hc_split_sinkhorn — ✅ aiter mhc_pre (needs #2916 bugfix merge)

    • 20-iter Sinkhorn on 4×4 matrix, 860 calls per forward (20 iters × 43 layers)
    • ⬜ ASM    ⬜ flyDSL    ⬜ Triton/Gluon    ☑️ HIP/C++
  • hc_post combine — ✅ aiter mhc_post

    • post * x + comb * residual with 4 HC copies
    • ⬜ ASM    ⬜ flyDSL    ⬜ Triton/Gluon    ☑️ HIP/C++

(8) Quantization & Transform

  • act_quant (BF16→FP8 e4m3, per-1×128) — ✅ aiter

    • Attention: KV non-RoPE dims [T,448] (§1 row 9)
    • Compressor: compressed KV non-RoPE dims [T/4,448] (§3.1 row 8)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • fp4_act_quant (BF16→FP4 e2m1, per-1×32) — ❌ Torch fallback

    • Indexer: Q after Hadamard [T,64,128] (§2 row 5)
    • Indexer compressor: compressed KV [T/4,128] (§3.2)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • rotate_activation (Hadamard transform) — ❌ Torch fallback

    • Indexer: Q before FP4 quant [T,64,128] (§2 row 4)
    • Indexer compressor: KV before FP4 quant [T/4,128] (§3.2)
    • Uses fast_hadamard_transform library, scale = dim^{-0.5}
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++
  • Inverse RoPE + FP8 quant (fused) — ⚠️ Separate ops on AMD

    • Attention: de-rotate o[..., -64:] then FP8-quant before wo_a (§1 row 12)
    • vLLM fuses these (~2–3× speedup)
    • ⬜ ASM    ⬜ flyDSL    ☑️ Triton/Gluon    ☑️ HIP/C++

10. Reference Links

Resource URL Relevance
HF V4-Flash https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash Config, weights, TileLang kernels
HF V4-Pro https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro Pro config (384 experts, 7168 hidden)
vLLM V4 Blog https://blog.vllm.ai/2026/04/24/deepseek-v4.html Implementation details, fusions, block design
ATOM PR#650 https://github.com/ROCm/ATOM/pull/650 AMD eager-mode model (13 commits, single-seq)
AITER PR#2916 https://github.com/ROCm/aiter/pull/2916 mhc_pre device bugfix
ATOM Issue #664 https://github.com/ROCm/ATOM/issues/664 Full tensor shape traces for attention pipeline

Compiled from: HF reference (inference/kernel.py, inference/model.py), vLLM blog (Apr 24 2026), ATOM PR#650, AITER PR#2916, and AMD kernel team analysis.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions