Cross-posting from internal kernel-team analysis by @carlushuang for broader visibility. The source thread tracks per-kernel ownership and ETAs internally; this issue is the public-facing technical inventory only.
Executive Summary
This issue inventories every kernel/operation needed to run DeepSeek-V4 (Flash + Pro) on AMD MI3xx/MI350X/MI355X, cross-referenced against:
- HF reference (
deepseek-ai/DeepSeek-V4-Flash) — TileLang kernels
- vLLM implementation — FlashMLA/FlashInfer + custom fused kernels + multi-stream
- Current AMD stack — ATOM (model), aiter (kernels), CK (GEMM)
Bottom line: 5 kernels are functionally missing on AMD (sparse attention, compressor, Hadamard, Sinkhorn, FP4 upcast-at-load). 3 fusion opportunities are proven by vLLM but not yet implemented in aiter. FP4 expert GEMM works via FusedMoE but needs a native/upcast path for peak throughput.
Model Config Summary
| Parameter |
V4-Flash |
V4-Pro |
| Total / Active params |
284B / 13B |
1.6T / 49B |
hidden_size (dim) |
4096 |
7168 |
num_hidden_layers |
43 |
61 |
num_attention_heads |
64 |
128 |
head_dim |
512 |
512 |
q_lora_rank |
1024 |
1536 |
o_lora_rank / o_groups |
1024 / 8 |
1024 / 16 |
qk_rope_head_dim |
64 |
64 |
sliding_window |
128 |
128 |
n_routed_experts / n_shared |
256 / 1 |
384 / 1 |
num_experts_per_tok |
6 |
6 |
moe_intermediate_size |
2048 |
3072 |
expert_dtype |
FP4 (e2m1) |
FP4 (e2m1) |
| Attention/dense dtype |
FP8 (e4m3, block 128×128, E8M0 scale) |
FP8 (e4m3, block 128×128, E8M0 scale) |
scoring_func |
sqrtsoftplus |
sqrtsoftplus |
topk_method |
noaux_tc |
noaux_tc |
swiglu_limit |
10.0 |
10.0 |
hc_mult / hc_sinkhorn_iters |
4 / 20 |
4 / 20 |
index_n_heads / index_head_dim |
64 / 128 |
64 / 128 |
index_topk |
512 |
1024 |
compress_ratios |
[0,0,4,128,...,4,128,4,0] (44 entries) |
[128,128,4,128,...,4,128,4,0] (62 entries) |
routed_scaling_factor |
1.5 |
2.5 |
num_hash_layers |
3 |
3 |
max_position_embeddings |
1,048,576 (1M) |
1,048,576 (1M) |
| RoPE |
theta=10K (SWA), 160K+YaRN (compressed) |
theta=10K (SWA), 160K+YaRN (compressed) |
| MTP layers |
1 |
1 |
| Min hardware (vLLM) |
4× B200/B300 |
8× B200/B300 |
Notation: In all kernel tables below:
- T = number of input tokens (sequence length)
- H = number of attention heads (H=64 for Flash, H=128 for Pro)
- G = number of O-projection groups (G=8 for Flash, G=16 for Pro)
- SK = number of gathered KV tokens per query (see sparse_attn row for case-by-case formulas)
- max_B = maximum batch size (number of concurrent sequences the pre-allocated buffers can hold)
- max_S = maximum sequence length (up to 1,048,576 for 1M context; determines KV cache and compressor buffer sizes at init)
Where Flash and Pro differ, both shapes are shown. Where only one shape is shown, both variants use the same dimensions.
Note: max_B and max_S appear in statically pre-allocated buffers (HF/ATOM eager mode). vLLM/SGLang use paged KV cache with dynamic block allocation instead.
For per-query KV token counts, see ATOM#664 Section 7.
1. MLA / Attention Block
| # |
Operation |
Flash Shapes |
Pro Shapes |
Comments |
| 1 |
wq_a GEMM (FP8 weight) |
[T,4096] @ [4096,1024] → [T,1024] |
[T,7168] @ [7168,1536] → [T,1536] |
NV: FlashMLA fused pre-attn AMD: ✅ CK GemmMXFP8 |
| 2 |
Q RMSNorm (BF16) |
[T,1024] |
[T,1536] |
NV: Fused in pre-attn AMD: ✅ aiter RMSNorm |
| 3 |
wq_b GEMM (FP8 weight) |
[T,1024] @ [1024,32768] → [T,32768] |
[T,1536] @ [1536,65536] → [T,65536] |
NV: FlashMLA fused pre-attn AMD: ✅ CK GemmMXFP8 |
| 4 |
Per-head Q Norm (BF16) |
[T,H,512] H=64 |
[T,H,512] H=128 |
NV: Fused in pre-attn AMD: ⚠️ PyTorch fallback |
| 5 |
Q RoPE (BF16) |
[T,H,64] H=64 |
[T,H,64] H=128 |
NV: FlashMLA AMD: ✅ Reuse existing RoPE |
| 6 |
wkv GEMM (FP8 weight) |
[T,4096] @ [4096,512] → [T,512] |
[T,7168] @ [7168,512] → [T,512] |
NV: Fused Q/KV pre-attn AMD: ✅ CK GemmMXFP8 |
| 7 |
KV RMSNorm (BF16) |
[T,512] |
[T,512] |
NV: Fused in pre-attn AMD: ✅ aiter RMSNorm |
| 8 |
KV RoPE (BF16) |
[T,64] (last dims) |
[T,64] (last dims) |
NV: Fused Q norm + KV RoPE + K insert (~10–20× speedup) AMD: ✅ Reuse existing |
| 9 |
KV act_quant (BF16→FP8) |
[T,448] → FP8 e4m3 |
[T,448] → FP8 e4m3 |
NV: Fused in pre-attn AMD: ✅ aiter quant |
| 10 |
sparse_attn (Q: BF16, KV: mixed BF16/FP8) |
Q [T,H,512] x KV [SK,512] → [T,H,512] CSA: SK = 128 + min(512, T/4) HCA: SK = 128 + T/128 (H=64) |
Q [T,H,512] x KV [SK,512] → [T,H,512] CSA: SK = 128 + min(1024, T/4) HCA: SK = 128 + T/128 (H=128) |
NV: FlashMLA/FlashInfer custom sparse backend AMD: ❌ Torch fallback (atom/model_ops/sparse_attn_v4.py) |
| 11 |
attn_sink bias (FP32) |
[H] H=64 |
[H] H=128 |
NV: Part of sparse attn kernel AMD: ❌ Part of torch fallback |
| 12 |
Inverse RoPE (BF16) |
[T,H,64] H=64 |
[T,H,64] H=128 |
NV: Fused inverse RoPE + FP8 quant (~2–3×) AMD: ⚠️ PyTorch fallback |
| 13 |
wo_a grouped LoRA (BF16) |
einsum [T,G,4096] × [G,1024,4096] → [T,G,1024] G=8 |
einsum [T,G,4096] × [G,1024,4096] → [T,G,1024] G=16 |
NV: Custom grouped matmul AMD: ⚠️ PyTorch einsum Requires view() before: attn output [T,H,512] → [T,G,H/G×512] = [T,G,4096] (H/G = 8 heads/group for both Flash and Pro) |
| 14 |
wo_b GEMM (FP8 weight) |
[T,8192] @ [8192,4096] → [T,4096] |
[T,16384] @ [16384,7168] → [T,7168] |
NV: FlashMLA/standard GEMM AMD: ✅ CK GemmMXFP8 Requires flatten(2) before: wo_a output [T,G,1024] → [T,G×1024] = [T,8192] (Flash) / [T,16384] (Pro) |
| 15 |
Input/output RMSNorm (BF16) |
[T,4096] |
[T,7168] |
NV: Standard AMD: ✅ aiter RMSNorm |
1.1 sparse_attn Deep Dive (P0 — Biggest Gap)
What it does: Index-gather attention — NOT standard FlashAttention.
# For each (batch, seq_pos):
# 1. Gather KV[topk_idxs] by index (topk = 640 Flash CSA, 160 Flash HCA at 4K)
# 2. Q @ KV_gathered^T → scores
# 3. Online softmax (running max + sum exp)
# 4. Add attn_sink bias AFTER normalization
# 5. Softmax scores @ KV_gathered → output
Why generic FA won't work: Dense FA mask is O(S²) memory — prohibitive for 1M context. The gather indices are sparse and query-dependent.
vLLM approach: Integrates FlashMLA/FlashInfer backends but still requires custom sparse-gather logic.
AMD need: Custom aiter kernel with:
- Index gather by
topk_idxs (can be -1 for invalid)
- Online softmax over gathered KV
attn_sink as pseudo-token added to denominator
- Support both Flash (
topk=512) and Pro (topk=1024) for CSA; HCA uses topk=8192 (effectively all entries for ≤1M)
Perf target: Memory-bound on KV gather. Optimize gather coalescing above compute.
2. Indexer Block
The Indexer runs on a separate CUDA stream in vLLM, overlapped with main attention prep.
| # |
Operation |
Flash Shapes |
Pro Shapes |
Comments |
| 1 |
Indexer Compressor (FP4) |
See Compressor block (head_dim=128, rotate=True, FP4 output) |
See Compressor block (head_dim=128, rotate=True, FP4 output) |
NV: Fused Compressor+Norm+RoPE+insert AMD: ❌ PyTorch impl |
| 2 |
wq_b GEMM (FP8) |
: [T,1024] @ [1024,8192] → [T,8192] |
[T,1536] @ [1536,8192] → [T,8192] |
NV: On separate stream AMD: ✅ CK GemmMXFP8 |
| 3 |
Unflatten + RoPE (BF16) |
[4096,64,128] → RoPE on last 64 dims |
[4096,64,128] → RoPE on last 64 dims |
NV: On separate stream AMD: ✅ Reuse |
| 4 |
rotate_activation (Hadamard) (BF16) |
[T,64,128] |
[T,64,128] |
NV: On separate stream AMD: ❌ Torch fallback |
| 5 |
fp4_act_quant (BF16→FP4) |
[T,64,128] → FP4 |
[T,64,128] → FP4 |
NV: On separate stream AMD: ❌ Torch fallback |
| 6 |
weights_proj GEMM (BF16) |
: [T,4096] @ [4096,64] → [T,64] |
[T,7168] @ [7168,64] → [T,64] |
NV: On separate stream AMD: ✅ CK GEMM |
| 7 |
Einsum scoring (FP4 Q, BF16 KV) |
[T,64,128] @ [T/4,128]^T → [T,64,T/4] |
[T,64,128] @ [T/4,128]^T → [T,64,T/4] |
NV: On separate stream AMD: ⚠️ PyTorch matmul (FP4 upcast first) |
| 8 |
ReLU + weighted sum (FP32) |
[4096,64,1024] × [4096,64,1] sum(dim=2) → [4096,1024] |
[4096,64,1024] × [4096,64,1] sum(dim=2) → [4096,1024] |
NV: On separate stream AMD: ⚠️ PyTorch |
| 9 |
Causal mask + top-k (FP32) |
[T,T/4] → [T,512] indices |
[T,T/4] → [T,512] indices |
NV: On separate stream AMD: ⚠️ PyTorch topk |
| 10 |
Top-k offset + mask (int32) |
[T,512] → offset, invalidate lookahead |
[T,512] → offset, invalidate lookahead |
NV: On separate stream AMD: ⚠️ PyTorch |
2.1 Indexer Memory at 1M Context
Important: At full 1M prefill, index_score = [1, 1M, 250K] at FP32 ≈ 1 TB — infeasible in a single pass. Chunked prefill is mandatory. The table below shows per-chunk shapes at a typical chunk size of T=4K (matching ATOM#664 examples), plus the full-context column for reference.
| Tensor |
Shape (T=4K chunk) |
Bytes (chunk) |
Shape (T=1M, full) |
Bytes (full, theoretical) |
| Q (FP4) |
[1,4096,64,128] |
~16 MB |
[1,1M,64,128] |
~4 GB |
index_score |
[1,4096,1024] |
~16 MB (FP32) |
[1,1M,250K] |
~1 TB (FP32) |
topk_idxs |
[1,4096,512] |
~8 MB (int32) |
[1,1M,512] |
~2 GB (int32) |
index_score is the largest intermediate. FP4 precision for Q/KV is essential, and chunked prefill caps the per-chunk index_score to manageable sizes.
3. Compressor Block
There are two Compressor instances per CSA layer (Attention + Indexer) and one per HCA layer.
Three compressor variants share the same pipeline (project → ape → reshape → pool → norm → RoPE → quant → cache write) but differ in head_dim, compress_ratio, overlap, rotation, and output quantization. All shapes are shown in one table below.
| # |
Operation |
CSA Attention Compressor (ratio=4, overlap=True, head_dim=512) |
CSA Indexer Compressor (ratio=4, overlap=True, head_dim=128) |
HCA Compressor (ratio=128, overlap=False, head_dim=512) |
|
Config |
coff=2, compress_ratio=4
rotate=False |
coff=2, compress_ratio=4
rotate=True |
coff=1, compress_ratio=128
rotate=False |
| 1 |
wkv GEMM (FP32) |
[T,dim] @ [dim, coff×hd] → [T,1024] F:[T,4096]@[4096,1024] P:[T,7168]@[7168,1024] |
[T,dim] @ [dim, coff×hd] → [T,256] F:[T,4096]@[4096,256] P:[T,7168]@[7168,256] |
[T,dim] @ [dim,hd] → [T,512] F:[T,4096]@[4096,512] P:[T,7168]@[7168,512] |
| 2 |
wgate GEMM (FP32) |
[T,dim] @ [dim,1024] → [T,1024] |
[T,dim] @ [dim,256] → [T,256] |
[T,dim] @ [dim,512] → [T,512] |
| 3 |
+ ape (FP32) |
[T,1024] += broadcast [4,1024] |
[T,256] += broadcast [4,256] |
[T,512] += broadcast [128,512] |
| 4 |
Unflatten |
[T/4, 4, 1024] |
[T/4, 4, 256] |
[T/128, 128, 512] |
| 5 |
Overlap transform (FP32) |
[T/4,4,1024] → [T/4,8,512] (4 prev first-halves + 4 curr second-halves) |
[T/4,4,256] → [T/4,8,128] (4 prev first-halves + 4 curr second-halves) |
N/A (no overlap, coff=1) |
| 6 |
Gated pooling (FP32) |
score.softmax(dim=-2) over 8 positions
(kv * score).sum(dim=-2) → [T/4,512] |
score.softmax(dim=-2) over 8 positions
(kv * score).sum(dim=-2) → [T/4,128] |
score.softmax(dim=-2) over 128 tokens
(kv * score).sum(dim=-2) → [T/128,512] |
| 7 |
RMSNorm (BF16) |
[T/4,512] |
[T/4,128] |
[T/128,512] |
| 8 |
RoPE (last 64) (BF16) |
[T/4,64] at position 4j |
[T/4,64] at position 4j |
[T/128,64] at position 128j |
| 9a |
Hadamard rotation (BF16) |
N/A (rotate=False) |
[T/4,128] → Hadamard → [T/4,128] (full-vector before FP4 quant) |
N/A (rotate=False) |
| 9b |
Output quantization |
act_quant on non-RoPE dims:
[T/4,448] → FP8 e4m3 (RoPE 64 dims stay BF16) |
fp4_act_quant on full vector:
[T/4,128] → FP4 e2m1 (no RoPE/non-RoPE split) |
act_quant on non-RoPE dims:
[T/128,448] → FP8 e4m3 (RoPE 64 dims stay BF16) |
| 10 |
Write to kv_cache |
[T/4,512] → Attn kv_cache |
[T/4,128] → Indexer kv_cache |
[T/128,512] → Attn kv_cache |
3.1 KV Cache Shapes
| Compressor |
kv_cache Shape |
Entry Size |
Entries at 1M Context |
| CSA Attention |
[max_B, max_S//4, 512] |
512 (448 FP8 + 64 BF16) |
262,144 |
| CSA Indexer |
[max_B, max_S//4, 128] |
128 (FP4 packed) |
262,144 |
| HCA Attention |
[max_B, max_S//128, 512] |
512 (448 FP8 + 64 BF16) |
8,192 |
3.2 Persistent Buffer (kv_state / score_state) — Decode Only
| Compressor |
Buffer Shape |
sliding_window |
Purpose |
| CSA Attention |
[max_B, 8, 1024] |
8 (= coff × ratio = 2×4) |
Double-buffer: slots 0-3 = previous window, 4-7 = current window |
| CSA Indexer |
[max_B, 8, 256] |
8 (= coff × ratio = 2×4) |
Same double-buffer, smaller head_dim |
| HCA |
[max_B, 128, 512] |
128 (= coff × ratio = 1×128) |
Rolling 128-token window for non-overlapping compression |
vLLM treats compressor state as sliding-window KV with sliding_window = coff * compress_ratio, registered under the same hybrid KV cache manager. Shift operation after each compression: kv_state[:, :ratio] = kv_state[:, ratio:] (current becomes previous).
4. MoE Block
| # |
Operation |
Flash Shapes |
Pro Shapes |
Comments |
| 1 |
Gate scoring (FP8) |
: [B,T,4096] @ [4096,256] → [B,T,256] |
[B,T,7168] @ [7168,384] → [B,T,384] |
NV: FlashInfer fused dispatch AMD: ✅ aiter select_experts |
| 2 |
sqrtsoftplus (FP32) |
: [B,T,256] |
[B,T,384] |
NV: Fused in dispatch AMD: ✅ aiter |
| 3 |
e_score_correction_bias (FP32) |
: [256] |
[384] added |
NV: Fused in dispatch AMD: ✅ aiter |
| 4 |
Top-k select (FP32) |
[1,256] → top-6 |
[1,384] → top-6 |
NV: Fused in dispatch AMD: ✅ aiter |
| 5 |
Hash routing (layers 0-2) (int32) |
tid2eid[input_ids] lookup |
tid2eid[input_ids] lookup |
NV: Embedding gather AMD: ✅ PyTorch gather |
| 6 |
noaux_tc (—) |
Load-balancing constraint |
Load-balancing constraint |
NV: Fused in dispatch AMD: ✅ aiter |
| 7 |
Expert dispatch (FP4) |
Gather 6 experts' weights |
Gather 6 experts' weights |
NV: FlashInfer AMD: ✅ aiter FusedMoE |
| 8 |
FP4 GEMM (w1/w3) (FP8 act × FP4 weight) |
[6,4096] @ [4096,2048] (per expert) |
[6,7168] @ [7168,3072] (per expert) |
NV: Custom FP4 backend AMD: ⚠️ aiter FusedMoE (triton/CK path) |
| 9 |
FP4 GEMM (w2) (FP8 act × FP4 weight) |
[6,2048] @ [2048,4096] (per expert) |
[6,3072] @ [3072,7168] (per expert) |
NV: Custom FP4 backend AMD: ⚠️ aiter FusedMoE (triton/CK path) |
| 10 |
SwiGLU + clamp (BF16) |
silu(clamp(gate, max=10)) * clamp(up, -10, 10) |
silu(clamp(gate, max=10)) * clamp(up, -10, 10) |
NV: Fused in expert FFN AMD: ⚠️ triton post-kernel clamp (gfx950 workaround) Note: silu(x) = x * sigmoid(x), gate clamp is one-sided (max only) |
| 11 |
Shared expert (FP4) |
Always-active expert |
Always-active expert |
NV: Same as routed AMD: ✅ aiter |
| 12 |
Accumulator reduce (BF16) |
Sum 6 expert outputs + shared expert |
Sum 6 expert outputs + shared expert |
NV: Fused in dispatch AMD: ✅ aiter |
4.1 FP4 GEMM Deep Dive (P0)
HF reference:
# C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T
# Act scale: per-1×128 on K-dim, E8M0/FP32
# Weight scale: per-1×32 on K-dim, E8M0
# B stored as [N, K//2] — 2 FP4 values packed per byte along K
vLLM: Ships with native FP4 MoE weights. Uses custom FP4 backend (likely upcast-at-load).
AMD current: aiter FusedMoE works via:
- gfx942 (MI300X): ASM/CK path
- gfx950 (MI355X): Triton path with
CDNA4MXScaleLayout
swiglu_limit applied in triton post-kernel (not in fused SwiGLU — 9× amplitude loss bug)
AMD gap: No native FP4 tensor core on MI3xx confirmed. Options:
- Upcast-at-load (like TileLang): Load FP4 sub-blocks → upcast to FP8 in shared mem → run CK
GemmMXFP8. Recommended for bring-up.
- Pre-upcast weights to FP8: ~2× weight memory overhead.
- Native FP4: Verify MI350+ sub-byte support first.
Weight shapes (V4-Flash):
w13_weight: (257, 4096, 2048) torch.float4_e2m1fn_x2 # 256 routed + 1 shared
w13_weight_scale: (257, 4096, 128) torch.float8_e8m0fnu # per-1×32 along K
w2_weight: (257, 4096, 1024) torch.float4_e2m1fn_x2
w2_weight_scale: (257, 4096, 64) torch.float8_e8m0fnu
Weight shapes (V4-Pro):
w13_weight: (385, 6144, 3584) torch.float4_e2m1fn_x2 # 384 routed + 1 shared
w13_weight_scale: (385, 6144, 224) torch.float8_e8m0fnu # per-1×32 along K
w2_weight: (385, 7168, 1536) torch.float4_e2m1fn_x2
w2_weight_scale: (385, 7168, 96) torch.float8_e8m0fnu
5. mHC (Manifold-Constrained Hyper-Connections) Block
| # |
Operation |
Flash Shapes |
Pro Shapes |
Comments |
| 1 |
hc_pre inline norm (FP32) |
rsqrt(mean(x²) + eps) on [B,T,hc×dim] = [B,T,16384] |
[B,T,28672] |
NV: Standard AMD: ⚠️ PyTorch inline (not a separate RMSNorm module — operates on flattened hc_mult × dim vector) |
| 2 |
hc_fn Linear (FP32) |
[B,T,16384] @ [16384,24] → [B,T,24] |
[B,T,28672] @ [28672,24] → [B,T,24] |
NV: Standard AMD: ⚠️ PyTorch (input is [B,T,hc_mult,dim] flattened to [B,T,hc_mult×dim]; hc_fn is nn.Parameter([24, hc_mult×dim]), not a CK GEMM candidate due to tiny output dim) |
| 3 |
hc_split_sinkhorn (FP32) |
24-element mixes → pre[4], post[4], comb[4,4] |
24-element mixes → pre[4], post[4], comb[4,4] |
NV: Torch fallback AMD: ⚠️ aiter mhc_pre (needs #2916) |
| 4 |
20-iter Sinkhorn (FP32) |
Row/col normalize 4×4 matrix, 20× |
Row/col normalize 4×4 matrix, 20× |
NV: Inside mhc_pre AMD: ✅ aiter mhc_pre (HIP kernel) |
| 5 |
hc_post combine (BF16) |
post * x + comb * residual (4 copies) |
post * x + comb * residual (4 copies) |
NV: Standard AMD: ✅ aiter mhc_post |
5.1 Sinkhorn Bottleneck
Cost per forward: 20 iterations × 43 layers (Flash) = 860 loops.
- PyTorch fallback: ~100–500µs per call → 86–430ms total
- aiter
mhc_pre kernel: ~1–5µs per call → 0.9–4.3ms total
aiter's mhc_pre is already kernelized. The main blocker was device-mismatch bug (#2916, now fixed). No new kernel needed here — just the bugfix merge.
5.2 Fused Block Kernel (Future Optimization)
vLLM has not yet fused the full block, but it's the next logical step:
__global__ void block_forward(...) {
// 1. hc_pre: Sinkhorn → 1 copy
// 2. RMSNorm
// 3. Attention or MoE
// 4. hc_post: combine with 4-copy residual
}
This would eliminate 4 kernel launches per layer.
6. Quantization & Utility Kernels
| # |
Kernel |
Input |
Output |
Block Size |
Comments |
| 1 |
act_quant |
BF16 |
FP8 e4m3 |
per-1×128 |
NV: Fused in pre-attn AMD: ✅ aiter quant |
| 2 |
fp4_act_quant |
BF16 |
FP4 e2m1 |
per-1×32 |
NV: Fused in indexer stream AMD: ❌ Torch fallback |
| 3 |
rotate_activation (Hadamard) |
BF16 |
BF16 |
— |
NV: Fused in indexer stream AMD: ❌ Torch fallback |
| 4 |
dequant_fp4_e2m1 |
uint8 packed + E8M0 scale |
BF16 |
per-1×32 |
NV: ATOM utility AMD: ✅ ATOM dequant_fp4_e2m1() |
| 5 |
Inverse RoPE + FP8 quant |
BF16 |
FP8 |
— |
NV: Fused (~2–3× speedup) AMD: ⚠️ Separate ops |
| 6 |
YaRN RoPE scaling |
— |
— |
— |
NV: For 1M context AMD: ✅ Reuse existing YaRN |
7. vLLM-Proven Fusions (Not Yet in aiter)
vLLM has deployed these fusions with measured speedups. They are targets for aiter/CK on AMD:
| Fusion |
Stages |
vLLM Speedup |
AMD Status |
| Compressor fusion |
Compressor + RMSNorm + RoPE + cache insert |
~1.4–3× |
❌ Not in aiter |
| Inverse RoPE + FP8 quant |
De-rotate + quant before o_lora |
~2–3× |
❌ Not in aiter |
| Q/KV pre-attn fusion |
Q norm + KV RoPE + K insert (warp-dispatched) |
~10–20× |
❌ Not in aiter |
| Indexer multi-stream |
Indexer pipeline |
|
main KV compression + SWA insert |
8. AMD Gap Summary & Priority Matrix
Missing Kernels (Need New Implementation)
| Kernel |
Current Fallback |
Perf Impact |
Effort |
Owner Suggestion |
sparse_attn |
Torch index-gather + online softmax |
🔴 Major — defeats V4's purpose at 1M context |
3–4 weeks |
AITER team |
| Compressor (CSA + Indexer) |
PyTorch gated pooling + overlap |
🟡 Medium — bandwidth bound, but layers add up |
2–3 weeks |
AMD kernel + aiter |
fp4_act_quant |
PyTorch rounding + pack |
🟡 Medium — on indexer critical path |
1 week |
AMD kernel |
rotate_activation (Hadamard) |
PyTorch matmul with Hadamard matrix |
🟢 Low — not on hottest path |
3–5 days |
AMD kernel |
| FP4 upcast-at-load GEMM |
FusedMoE triton/CK (pre-upcast) |
🟡 Medium — 2× weight bandwidth vs native |
2–3 weeks |
AMD kernel |
Working but Need Hardening
| Component |
Status |
Notes |
CK GemmMXFP8 |
✅ Working |
FP8 act × FP8 weight with E8M0 scales |
aiter FusedMoE |
✅ Working |
FP4 experts via triton (gfx950) or CK (gfx942) |
aiter mhc_pre/post |
✅ Working |
Needs #2916 bugfix merged |
| aiter RMSNorm |
✅ Working |
Standard path |
| ATOM model |
✅ Working |
Single-sequence eager mode, 512-token coherent gen |
9. Kernel Implementation Inventory
Implementation approach legend:
- ASM — Pure GPU assembly (hand-optimized microkernels) - contact: @niels-zhang
- flyDSL — AMD DSL for declarative patterns (elementwise, reduction, transforms) - contact: @coderfeli
- Triton/Gluon — Triton-based kernels (aiter’s fused-op path) - contact: @vgokhale
- HIP/C++ — Raw HIP kernel, Opus C++ kernel, or Composable Kernel (CK) - contact: @valarLip @carlushuang
(1) Attention
sparse_attn — ❌ Torch fallback (P0 gap)
- Index-gather + online softmax +
attn_sink bias over variable topk gathered KV entries
- NOT FlashAttention — sparse, query-dependent gather indices
- Must handle
-1 invalid indices, attn_sink as pseudo-token in softmax denominator
- See §1.1 for full spec
- ☑️ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
(2) Compression
-
Indexer/Compressor kernel
- CSA Flash:
Q[T,64,512] × KV[640,512], CSA Pro: Q[T,128,512] × KV[1152,512]
- HCA:
Q[T,H,512] × KV[128+T/128, 512] (grows with context)
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
Indexer/Compressor misc (linear, elem-wise, fusions) — ❌ Torch fallback
- CSA (ratio=4, overlap=True):
wkv/wgate FP32 GEMM → +ape broadcast → overlap transform [T/4,4,1024]→[T/4,8,512] → softmax gated pooling → RMSNorm → RoPE → act_quant FP8 → cache write (§3.1)
- HCA (ratio=128, overlap=False): same but no overlap transform, pool over 128 tokens (§3.3)
- Indexer compressor (head_dim=128, rotate=True): same pipeline but Hadamard → FP4 quant instead of split FP8 (§3.2)
- vLLM fuses entire pipeline into one kernel (~1.4–3× speedup)
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
(3) KV Cache Codesign
DeepSeek-V4 uses a hybrid KV cache with three storage paths:
| Path |
Compression |
Token Sparsity |
Cache Shape (Flash) |
| CSA Attention |
4:1 (overlap) |
topk = 512–640 |
[max_B, max_S//4, 512] |
| CSA Indexer |
4:1 (overlap) |
topk = 512 |
[max_B, max_S//4, 128] |
| HCA |
128:1 (no overlap) |
topk = 8192 (≈all) |
[max_B, max_S//128, 512] |
Buffer mechanics:
kv_state / score_state double-buffer for CSA: [max_B, 8, 256] or [max_B, 8, 1024]
- HCA rolling window:
[max_B, 128, 512]
- vLLM integrates both under a single hybrid KV cache manager with
sliding_window = coff * compress_ratio
- ⬜ ASM ⬜ flyDSL ⬜ Triton/Gluon ☑️ HIP/C++
Source: §3.3–3.4 for detailed shape traces and vLLM integration
(4) MoE
-
A8W4 (FP8 act × FP4 weight) expert GEMM — includes SwiGLU + clamp — ⚠️ aiter FusedMoE
- Expert
w1/w3/w2 GEMM + fused SwiGLU + clamp (§4 rows 8-10)
- Flash:
[6,4096]@[4096,2048] / [6,2048]@[2048,4096] per expert
- Pro:
[6,7168]@[7168,3072] / [6,3072]@[3072,7168] per expert
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ⬜ HIP/C++
-
sqrtsoftplus scoring — ✅ aiter
sqrt(softplus(x)) on [B,T,256/384] (§4 row 2)
- ⬜ ASM ⬜ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
Hash routing (layers 0–2) — ✅ PyTorch gather
tid2eid[input_ids] lookup (§4 row 5)
- ⬜ ASM ⬜ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
MegaMoE (Expert-Parallel All-to-All + FP4 GEMM) — ❌ Not available on AMD (mid-term target)
- vLLM's
DeepseekV4MegaMoEExperts: fused all-to-all dispatch + fp8_fp4_mega_moe() via DeepGEMM, symmetric NCCL buffer management
- Combines inter-GPU expert routing (NCCL-based) with per-expert FP8×FP4 GEMM in a single fused op
- NV: DeepGEMM
fp8_fp4_mega_moe() + get_symm_buffer_for_mega_moe() (requires SM100/Blackwell)
- AMD dependency: mori's NCCL-GIN equivalent for the symmetric communication buffer / all-to-all overlap
- AMD GEMM dependency: aiter/CK A8W4 expert GEMM (see above)
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
(5) GEMM
-
A8W8 (FP8 act × FP8 weight) — ✅ CK GemmMXFP8
- Attention:
wq_a, wq_b, wkv, wo_b (§1 rows 1,3,6,14)
- Indexer:
wq_b (§2 row 2)
- MoE: gate scoring projection (§4 row 1)
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
Grouped matmul (BF16) — ⚠️ PyTorch einsum
- Attention:
wo_a grouped LoRA — einsum("bsgd,grd->bsgr") (§1 row 13)
[T,G,4096] × [G,1024,4096] → [T,G,1024], G=8 (Flash) / G=16 (Pro)
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
Small BF16 matmul — ⚠️ PyTorch
- Indexer:
weights_proj [T,dim]@[dim,64] (§2 row 6) — tiny output dim
- mHC:
hc_fn [T,hc×dim]@[hc×dim,24] (§5 row 2) — tiny output dim
- ⬜ ASM ☑️ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
(6) Norm
(7) mHC
(8) Quantization & Transform
-
act_quant (BF16→FP8 e4m3, per-1×128) — ✅ aiter
- Attention: KV non-RoPE dims
[T,448] (§1 row 9)
- Compressor: compressed KV non-RoPE dims
[T/4,448] (§3.1 row 8)
- ⬜ ASM ⬜ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
fp4_act_quant (BF16→FP4 e2m1, per-1×32) — ❌ Torch fallback
- Indexer: Q after Hadamard
[T,64,128] (§2 row 5)
- Indexer compressor: compressed KV
[T/4,128] (§3.2)
- ⬜ ASM ⬜ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
rotate_activation (Hadamard transform) — ❌ Torch fallback
- Indexer: Q before FP4 quant
[T,64,128] (§2 row 4)
- Indexer compressor: KV before FP4 quant
[T/4,128] (§3.2)
- Uses
fast_hadamard_transform library, scale = dim^{-0.5}
- ⬜ ASM ⬜ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
-
Inverse RoPE + FP8 quant (fused) — ⚠️ Separate ops on AMD
- Attention: de-rotate
o[..., -64:] then FP8-quant before wo_a (§1 row 12)
- vLLM fuses these (~2–3× speedup)
- ⬜ ASM ⬜ flyDSL ☑️ Triton/Gluon ☑️ HIP/C++
10. Reference Links
| Resource |
URL |
Relevance |
| HF V4-Flash |
https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash |
Config, weights, TileLang kernels |
| HF V4-Pro |
https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro |
Pro config (384 experts, 7168 hidden) |
| vLLM V4 Blog |
https://blog.vllm.ai/2026/04/24/deepseek-v4.html |
Implementation details, fusions, block design |
| ATOM PR#650 |
https://github.com/ROCm/ATOM/pull/650 |
AMD eager-mode model (13 commits, single-seq) |
| AITER PR#2916 |
https://github.com/ROCm/aiter/pull/2916 |
mhc_pre device bugfix |
| ATOM Issue #664 |
https://github.com/ROCm/ATOM/issues/664 |
Full tensor shape traces for attention pipeline |
Compiled from: HF reference (inference/kernel.py, inference/model.py), vLLM blog (Apr 24 2026), ATOM PR#650, AITER PR#2916, and AMD kernel team analysis.
Executive Summary
This issue inventories every kernel/operation needed to run DeepSeek-V4 (Flash + Pro) on AMD MI3xx/MI350X/MI355X, cross-referenced against:
deepseek-ai/DeepSeek-V4-Flash) — TileLang kernelsBottom line: 5 kernels are functionally missing on AMD (sparse attention, compressor, Hadamard, Sinkhorn, FP4 upcast-at-load). 3 fusion opportunities are proven by vLLM but not yet implemented in aiter. FP4 expert GEMM works via FusedMoE but needs a native/upcast path for peak throughput.
Model Config Summary
hidden_size(dim)num_hidden_layersnum_attention_headshead_dimq_lora_ranko_lora_rank/o_groupsqk_rope_head_dimsliding_windown_routed_experts/n_sharednum_experts_per_tokmoe_intermediate_sizeexpert_dtypescoring_functopk_methodswiglu_limithc_mult/hc_sinkhorn_itersindex_n_heads/index_head_dimindex_topkcompress_ratiosrouted_scaling_factornum_hash_layersmax_position_embeddings1. MLA / Attention Block
wq_aGEMM (FP8 weight)[T,4096] @ [4096,1024]→[T,1024][T,7168] @ [7168,1536]→[T,1536]AMD: ✅ CK
GemmMXFP8[T,1024][T,1536]AMD: ✅ aiter RMSNorm
wq_bGEMM (FP8 weight)[T,1024] @ [1024,32768]→[T,32768][T,1536] @ [1536,65536]→[T,65536]AMD: ✅ CK
GemmMXFP8[T,H,512]H=64[T,H,512]H=128AMD:
[T,H,64]H=64[T,H,64]H=128AMD: ✅ Reuse existing RoPE
wkvGEMM (FP8 weight)[T,4096] @ [4096,512]→[T,512][T,7168] @ [7168,512]→[T,512]AMD: ✅ CK
GemmMXFP8[T,512][T,512]AMD: ✅ aiter RMSNorm
[T,64](last dims)[T,64](last dims)AMD: ✅ Reuse existing
act_quant(BF16→FP8)[T,448]→ FP8 e4m3[T,448]→ FP8 e4m3AMD: ✅ aiter quant
sparse_attn(Q: BF16, KV: mixed BF16/FP8)[T,H,512]x KV[SK,512]→[T,H,512]CSA:
SK = 128 + min(512, T/4)HCA:
SK = 128 + T/128(H=64)
[T,H,512]x KV[SK,512]→[T,H,512]CSA:
SK = 128 + min(1024, T/4)HCA:
SK = 128 + T/128(H=128)
AMD: ❌ Torch fallback (
atom/model_ops/sparse_attn_v4.py)attn_sinkbias (FP32)[H]H=64[H]H=128AMD: ❌ Part of torch fallback
[T,H,64]H=64[T,H,64]H=128AMD:
wo_agrouped LoRA (BF16)[T,G,4096] × [G,1024,4096]→[T,G,1024]G=8[T,G,4096] × [G,1024,4096]→[T,G,1024]G=16AMD:
Requires
view()before: attn output[T,H,512]→[T,G,H/G×512]=[T,G,4096](H/G = 8 heads/group for both Flash and Pro)wo_bGEMM (FP8 weight)[T,8192] @ [8192,4096]→[T,4096][T,16384] @ [16384,7168]→[T,7168]AMD: ✅ CK
GemmMXFP8Requires
flatten(2)before: wo_a output[T,G,1024]→[T,G×1024]=[T,8192](Flash) /[T,16384](Pro)[T,4096][T,7168]AMD: ✅ aiter RMSNorm
1.1
sparse_attnDeep Dive (P0 — Biggest Gap)What it does: Index-gather attention — NOT standard FlashAttention.
Why generic FA won't work: Dense FA mask is O(S²) memory — prohibitive for 1M context. The gather indices are sparse and query-dependent.
vLLM approach: Integrates FlashMLA/FlashInfer backends but still requires custom sparse-gather logic.
AMD need: Custom aiter kernel with:
topk_idxs(can be-1for invalid)attn_sinkas pseudo-token added to denominatortopk=512) and Pro (topk=1024) for CSA; HCA usestopk=8192(effectively all entries for ≤1M)Perf target: Memory-bound on KV gather. Optimize gather coalescing above compute.
2. Indexer Block
The Indexer runs on a separate CUDA stream in vLLM, overlapped with main attention prep.
AMD: ❌ PyTorch impl
wq_bGEMM (FP8)[T,1024] @ [1024,8192]→[T,8192][T,1536] @ [1536,8192]→[T,8192]AMD: ✅ CK
GemmMXFP8[4096,64,128]→ RoPE on last 64 dims[4096,64,128]→ RoPE on last 64 dimsAMD: ✅ Reuse
rotate_activation(Hadamard) (BF16)[T,64,128][T,64,128]AMD: ❌ Torch fallback
fp4_act_quant(BF16→FP4)[T,64,128]→ FP4[T,64,128]→ FP4AMD: ❌ Torch fallback
weights_projGEMM (BF16)[T,4096] @ [4096,64]→[T,64][T,7168] @ [7168,64]→[T,64]AMD: ✅ CK GEMM
[T,64,128] @ [T/4,128]^T→[T,64,T/4][T,64,128] @ [T/4,128]^T→[T,64,T/4]AMD:
[4096,64,1024] × [4096,64,1]sum(dim=2) →[4096,1024][4096,64,1024] × [4096,64,1]sum(dim=2) →[4096,1024]AMD:
[T,T/4]→[T,512]indices[T,T/4]→[T,512]indicesAMD:
topk[T,512]→ offset, invalidate lookahead[T,512]→ offset, invalidate lookaheadAMD:
2.1 Indexer Memory at 1M Context
[1,4096,64,128][1,1M,64,128]index_score[1,4096,1024][1,1M,250K]topk_idxs[1,4096,512][1,1M,512]3. Compressor Block
There are two Compressor instances per CSA layer (Attention + Indexer) and one per HCA layer.
(ratio=4, overlap=True, head_dim=512)
(ratio=4, overlap=True, head_dim=128)
(ratio=128, overlap=False, head_dim=512)
coff=2,compress_ratio=4rotate=Falsecoff=2,compress_ratio=4rotate=Truecoff=1,compress_ratio=128rotate=FalsewkvGEMM (FP32)[T,dim] @ [dim, coff×hd]→[T,1024]F:
[T,4096]@[4096,1024]P:
[T,7168]@[7168,1024][T,dim] @ [dim, coff×hd]→[T,256]F:
[T,4096]@[4096,256]P:
[T,7168]@[7168,256][T,dim] @ [dim,hd]→[T,512]F:
[T,4096]@[4096,512]P:
[T,7168]@[7168,512]wgateGEMM (FP32)[T,dim] @ [dim,1024]→[T,1024][T,dim] @ [dim,256]→[T,256][T,dim] @ [dim,512]→[T,512]ape(FP32)[T,1024]+= broadcast[4,1024][T,256]+= broadcast[4,256][T,512]+= broadcast[128,512][T/4, 4, 1024][T/4, 4, 256][T/128, 128, 512][T/4,4,1024]→[T/4,8,512](4 prev first-halves + 4 curr second-halves)
[T/4,4,256]→[T/4,8,128](4 prev first-halves + 4 curr second-halves)
coff=1)score.softmax(dim=-2)over 8 positions(kv * score).sum(dim=-2)→[T/4,512]score.softmax(dim=-2)over 8 positions(kv * score).sum(dim=-2)→[T/4,128]score.softmax(dim=-2)over 128 tokens(kv * score).sum(dim=-2)→[T/128,512][T/4,512][T/4,128][T/128,512][T/4,64]at position4j[T/4,64]at position4j[T/128,64]at position128jrotate=False)[T/4,128]→ Hadamard →[T/4,128](full-vector before FP4 quant)
rotate=False)act_quanton non-RoPE dims:[T/4,448]→ FP8 e4m3(RoPE 64 dims stay BF16)
fp4_act_quanton full vector:[T/4,128]→ FP4 e2m1(no RoPE/non-RoPE split)
act_quanton non-RoPE dims:[T/128,448]→ FP8 e4m3(RoPE 64 dims stay BF16)
[T/4,512]→ Attn kv_cache[T/4,128]→ Indexer kv_cache[T/128,512]→ Attn kv_cache3.1 KV Cache Shapes
[max_B, max_S//4, 512][max_B, max_S//4, 128][max_B, max_S//128, 512]3.2 Persistent Buffer (
kv_state/score_state) — Decode Onlysliding_window[max_B, 8, 1024][max_B, 8, 256][max_B, 128, 512]4. MoE Block
[B,T,4096] @ [4096,256]→[B,T,256][B,T,7168] @ [7168,384]→[B,T,384]AMD: ✅ aiter
select_expertssqrtsoftplus(FP32)[B,T,256][B,T,384]AMD: ✅ aiter
e_score_correction_bias(FP32)[256][384]addedAMD: ✅ aiter
[1,256]→ top-6[1,384]→ top-6AMD: ✅ aiter
tid2eid[input_ids]lookuptid2eid[input_ids]lookupAMD: ✅ PyTorch gather
noaux_tc(—)AMD: ✅ aiter
AMD: ✅ aiter
FusedMoEw1/w3) (FP8 act × FP4 weight)[6,4096] @ [4096,2048](per expert)[6,7168] @ [7168,3072](per expert)AMD:
w2) (FP8 act × FP4 weight)[6,2048] @ [2048,4096](per expert)[6,3072] @ [3072,7168](per expert)AMD:
silu(clamp(gate, max=10)) * clamp(up, -10, 10)silu(clamp(gate, max=10)) * clamp(up, -10, 10)AMD:
Note:
silu(x) = x * sigmoid(x), gate clamp is one-sided (max only)AMD: ✅ aiter
AMD: ✅ aiter
4.1 FP4 GEMM Deep Dive (P0)
HF reference:
vLLM: Ships with native FP4 MoE weights. Uses custom FP4 backend (likely upcast-at-load).
AMD current: aiter
FusedMoEworks via:CDNA4MXScaleLayoutswiglu_limitapplied in triton post-kernel (not in fused SwiGLU — 9× amplitude loss bug)AMD gap: No native FP4 tensor core on MI3xx confirmed. Options:
GemmMXFP8. Recommended for bring-up.Weight shapes (V4-Flash):
Weight shapes (V4-Pro):
5. mHC (Manifold-Constrained Hyper-Connections) Block
hc_preinline norm (FP32)rsqrt(mean(x²) + eps)on[B,T,hc×dim]=[B,T,16384][B,T,28672]AMD:
hc_mult × dimvector)hc_fnLinear (FP32)[B,T,16384] @ [16384,24]→[B,T,24][B,T,28672] @ [28672,24]→[B,T,24]AMD:
[B,T,hc_mult,dim]flattened to[B,T,hc_mult×dim];hc_fnisnn.Parameter([24, hc_mult×dim]), not a CK GEMM candidate due to tiny output dim)hc_split_sinkhorn(FP32)pre[4],post[4],comb[4,4]pre[4],post[4],comb[4,4]AMD:
mhc_pre(needs #2916)mhc_preAMD: ✅ aiter
mhc_pre(HIP kernel)hc_postcombine (BF16)post * x + comb * residual(4 copies)post * x + comb * residual(4 copies)AMD: ✅ aiter
mhc_post5.1 Sinkhorn Bottleneck
Cost per forward: 20 iterations × 43 layers (Flash) = 860 loops.
mhc_prekernel: ~1–5µs per call → 0.9–4.3ms total5.2 Fused Block Kernel (Future Optimization)
vLLM has not yet fused the full block, but it's the next logical step:
This would eliminate 4 kernel launches per layer.
6. Quantization & Utility Kernels
act_quantAMD: ✅ aiter quant
fp4_act_quantAMD: ❌ Torch fallback
rotate_activation(Hadamard)AMD: ❌ Torch fallback
dequant_fp4_e2m1AMD: ✅ ATOM
dequant_fp4_e2m1()AMD:
AMD: ✅ Reuse existing YaRN
7. vLLM-Proven Fusions (Not Yet in aiter)
vLLM has deployed these fusions with measured speedups. They are targets for aiter/CK on AMD:
o_lora8. AMD Gap Summary & Priority Matrix
Missing Kernels (Need New Implementation)
sparse_attnfp4_act_quantrotate_activation(Hadamard)Working but Need Hardening
GemmMXFP8FusedMoEmhc_pre/post9. Kernel Implementation Inventory
(1) Attention
sparse_attn— ❌ Torch fallback (P0 gap)attn_sinkbias over variabletopkgathered KV entries-1invalid indices,attn_sinkas pseudo-token in softmax denominator(2) Compression
Indexer/Compressor kernel
Q[T,64,512] × KV[640,512], CSA Pro:Q[T,128,512] × KV[1152,512]Q[T,H,512] × KV[128+T/128, 512](grows with context)Indexer/Compressor misc (linear, elem-wise, fusions) — ❌ Torch fallback
wkv/wgateFP32 GEMM →+apebroadcast → overlap transform[T/4,4,1024]→[T/4,8,512]→ softmax gated pooling → RMSNorm → RoPE →act_quantFP8 → cache write (§3.1)(3) KV Cache Codesign
DeepSeek-V4 uses a hybrid KV cache with three storage paths:
topk = 512–640[max_B, max_S//4, 512]topk = 512[max_B, max_S//4, 128]topk = 8192(≈all)[max_B, max_S//128, 512]Buffer mechanics:
kv_state/score_statedouble-buffer for CSA:[max_B, 8, 256]or[max_B, 8, 1024][max_B, 128, 512]sliding_window = coff * compress_ratio(4) MoE
A8W4 (FP8 act × FP4 weight) expert GEMM — includes SwiGLU + clamp —⚠️ aiter
FusedMoEw1/w3/w2GEMM + fused SwiGLU + clamp (§4 rows 8-10)[6,4096]@[4096,2048]/[6,2048]@[2048,4096]per expert[6,7168]@[7168,3072]/[6,3072]@[3072,7168]per expertsqrtsoftplusscoring — ✅ aitersqrt(softplus(x))on[B,T,256/384](§4 row 2)Hash routing (layers 0–2) — ✅ PyTorch gather
tid2eid[input_ids]lookup (§4 row 5)MegaMoE (Expert-Parallel All-to-All + FP4 GEMM) — ❌ Not available on AMD (mid-term target)
DeepseekV4MegaMoEExperts: fused all-to-all dispatch +fp8_fp4_mega_moe()via DeepGEMM, symmetric NCCL buffer managementfp8_fp4_mega_moe()+get_symm_buffer_for_mega_moe()(requires SM100/Blackwell)(5) GEMM
A8W8 (FP8 act × FP8 weight) — ✅ CK
GemmMXFP8wq_a,wq_b,wkv,wo_b(§1 rows 1,3,6,14)wq_b(§2 row 2)Grouped matmul (BF16) —⚠️ PyTorch einsum
wo_agrouped LoRA —einsum("bsgd,grd->bsgr")(§1 row 13)[T,G,4096] × [G,1024,4096]→[T,G,1024], G=8 (Flash) / G=16 (Pro)Small BF16 matmul —⚠️ PyTorch
weights_proj[T,dim]@[dim,64](§2 row 6) — tiny output dimhc_fn[T,hc×dim]@[hc×dim,24](§5 row 2) — tiny output dim(6) Norm
RMSNorm — ✅ aiter
[T,1024/1536], KV norm[T,512], input/output norms[T,dim](§1 rows 2,7,15)[T/4,512](§3.1 row 6)Per-head Q Norm (inline) —⚠️ PyTorch fallback
q *= rsqrt(mean(q², dim=-1) + eps)on[T,H,512](§1 row 4)mHC inline norm (FP32) —⚠️ PyTorch
rsqrt(mean(x²) + eps)on flattened[T, hc×dim]=[T,16384/28672](§5 row 1)(7) mHC
hc_split_sinkhorn— ✅ aitermhc_pre(needs #2916 bugfix merge)hc_postcombine — ✅ aitermhc_postpost * x + comb * residualwith 4 HC copies(8) Quantization & Transform
act_quant(BF16→FP8 e4m3, per-1×128) — ✅ aiter[T,448](§1 row 9)[T/4,448](§3.1 row 8)fp4_act_quant(BF16→FP4 e2m1, per-1×32) — ❌ Torch fallback[T,64,128](§2 row 5)[T/4,128](§3.2)rotate_activation(Hadamard transform) — ❌ Torch fallback[T,64,128](§2 row 4)[T/4,128](§3.2)fast_hadamard_transformlibrary, scale =dim^{-0.5}Inverse RoPE + FP8 quant (fused) —⚠️ Separate ops on AMD
o[..., -64:]then FP8-quant beforewo_a(§1 row 12)10. Reference Links
https://huggingface.co/deepseek-ai/DeepSeek-V4-Flashhttps://huggingface.co/deepseek-ai/DeepSeek-V4-Prohttps://blog.vllm.ai/2026/04/24/deepseek-v4.htmlhttps://github.com/ROCm/ATOM/pull/650https://github.com/ROCm/aiter/pull/2916https://github.com/ROCm/ATOM/issues/664Compiled from: HF reference (
inference/kernel.py,inference/model.py), vLLM blog (Apr 24 2026), ATOM PR#650, AITER PR#2916, and AMD kernel team analysis.