[None][feat] Fuse GDN elementwise ops and split/transpose kernels#12966
Conversation
35a1b44 to
abbd450
Compare
📝 WalkthroughWalkthroughThe PR introduces Triton kernel optimizations for Mamba/GDN Q/K/V materialization and gating operations. New kernels Changes
Sequence Diagram(s)sequenceDiagram
participant User as Application
participant Wrapper as Python Wrapper<br/>(transpose_and_split_qkv)
participant Kernel as Triton Kernel<br/>(_transpose_and_split_qkv_kernel)
participant Memory as GPU Memory
User->>Wrapper: Call transpose_and_split_qkv(prefill_t, decode, ...)
Wrapper->>Memory: Allocate q_flat, k_flat, v_flat<br/>(size: total_seq × dim)
Wrapper->>Kernel: Launch 3D grid<br/>(seq_blocks, dim_blocks, qkv_selector)
Kernel->>Memory: Read prefill from [D, T_p] layout
Kernel->>Memory: Read decode from [T_d, D] layout
Kernel->>Memory: Write contiguous Q/K/V<br/>to output buffers
Kernel-->>Wrapper: Kernel complete
Wrapper->>Memory: Reshape outputs to<br/>[1, total_seq, heads, head_dim]
Wrapper-->>User: Return (query, key, value)
sequenceDiagram
participant User as Application
participant Wrapper as Python Wrapper<br/>(split_qkv_contiguous)
participant Kernel as Triton Kernel<br/>(_split_qkv_contiguous_kernel)
participant Memory as GPU Memory
User->>Wrapper: Call split_qkv_contiguous(mixed_qkv, ...)
Wrapper->>Memory: Allocate q_flat, k_flat, v_flat<br/>(contiguous)
Wrapper->>Kernel: Launch 3D grid
Kernel->>Memory: Read mixed_qkv [T, qkv_dim]<br/>(any stride)
Kernel->>Memory: Write contiguous Q, K, V
Kernel-->>Wrapper: Kernel complete
Wrapper->>Memory: Reshape to [1, T, heads, head_dim]
Wrapper-->>User: Return (query, key, value)
sequenceDiagram
participant User as Application
participant Wrapper as Python Wrapper<br/>(fused_gdn_gating_with_sigmoid)
participant Kernel as Triton Kernel<br/>(fused_gdn_gating_with_sigmoid_kernel)
participant Memory as GPU Memory
User->>Wrapper: Call fused_gdn_gating_with_sigmoid<br/>(A_log, a, dt_bias, b, ...)
Wrapper->>Memory: Allocate g and beta_out buffers
Wrapper->>Kernel: Launch kernel
Kernel->>Kernel: Compute -exp(A_log) × softplus(a + dt_bias)<br/>→ g
Kernel->>Kernel: Compute sigmoid(b)<br/>→ beta_out
Kernel->>Memory: Write g and beta_out
Kernel-->>Wrapper: Kernel complete
Wrapper-->>User: Return (g, beta_out)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
f518d38 to
12c72c4
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py (1)
269-312: Fail fast on inconsistent QKV widths before launching these kernels.Both wrappers assume the source width is exactly
q_dim + k_dim + v_dimand that the head-count arguments multiply back to those dims. If a caller ever passes mismatched values, this becomes raw Triton pointer arithmetic instead of a clean Python exception.🛡️ Suggested guardrails
def transpose_and_split_qkv( prefill_t: torch.Tensor, decode: torch.Tensor, @@ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ num_cols, num_prefill = prefill_t.shape + expected_cols = q_dim + k_dim + v_dim + if decode.ndim != 2 or num_cols != expected_cols or decode.shape[1] != expected_cols: + raise ValueError( + f"Expected prefill/decode widths to equal {expected_cols}, got " + f"{num_cols} and {decode.shape[1]}" + ) + if q_dim != num_q_heads * head_k_dim or k_dim != num_q_heads * head_k_dim: + raise ValueError("Q/K dims must match num_q_heads * head_k_dim") + if v_dim != num_v_heads * head_v_dim: + raise ValueError("V dim must match num_v_heads * head_v_dim") num_decode = decode.shape[0] total_seq = num_prefill + num_decode @@ def split_qkv_contiguous( mixed_qkv: torch.Tensor, @@ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ - seq_len = mixed_qkv.shape[0] + seq_len, num_cols = mixed_qkv.shape + expected_cols = q_dim + k_dim + v_dim + if num_cols != expected_cols: + raise ValueError(f"Expected mixed_qkv width {expected_cols}, got {num_cols}") + if q_dim != num_q_heads * head_k_dim or k_dim != num_q_heads * head_k_dim: + raise ValueError("Q/K dims must match num_q_heads * head_k_dim") + if v_dim != num_v_heads * head_v_dim: + raise ValueError("V dim must match num_v_heads * head_v_dim") src_stride_seq, src_stride_dim = mixed_qkv.stride()Also applies to: 364-417
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py` around lines 269 - 312, Add preflight checks in transpose_and_split_qkv before launching _transpose_and_split_qkv_kernel: verify that num_cols == q_dim + k_dim + v_dim and that q_dim == num_q_heads * head_k_dim and v_dim == num_v_heads * head_v_dim (and similarly k_dim if applicable), and raise a clear ValueError if any mismatch occurs. Do the same guard checks in the analogous function at lines 364-417 so callers get a Python exception instead of invalid Triton pointer arithmetic. Include the unique symbols transpose_and_split_qkv, _transpose_and_split_qkv_kernel, q_dim, k_dim, v_dim, num_q_heads, head_k_dim, num_v_heads, head_v_dim when locating where to add these checks.tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py (1)
91-117: Add a non-contiguous source case forsplit_qkv_contiguous().This test only covers a contiguous
[T, D]input, buttensorrt_llm/_torch/modules/mamba/gdn_mixer.pyLines 686-688 call the helper withmixed_qkv_t.transpose(0, 1). The stride-sensitive path this kernel was added for is therefore untested right now.🧪 Suggested test expansion
`@skip_no_cuda` `@pytest.mark.parametrize`("seq_len", [1, 32, 128, 1024]) +@pytest.mark.parametrize("transpose_input", [False, True]) `@pytest.mark.parametrize`( "q_dim,k_dim,v_dim,num_q_heads,head_k_dim,num_v_heads,head_v_dim", @@ ) `@pytest.mark.parametrize`("dtype", [torch.bfloat16]) -def test_split_qkv_contiguous(seq_len, q_dim, k_dim, v_dim, - num_q_heads, head_k_dim, num_v_heads, head_v_dim, dtype): +def test_split_qkv_contiguous( + seq_len, + transpose_input, + q_dim, + k_dim, + v_dim, + num_q_heads, + head_k_dim, + num_v_heads, + head_v_dim, + dtype, +): @@ - mixed_qkv = torch.randn(seq_len, total_dim, dtype=dtype, device=device) + if transpose_input: + mixed_qkv = torch.randn(total_dim, seq_len, dtype=dtype, device=device).transpose(0, 1) + assert not mixed_qkv.is_contiguous() + else: + mixed_qkv = torch.randn(seq_len, total_dim, dtype=dtype, device=device) @@ assert q_out.shape == (1, seq_len, num_q_heads, head_k_dim) assert k_out.shape == (1, seq_len, num_q_heads, head_k_dim) assert v_out.shape == (1, seq_len, num_v_heads, head_v_dim) + assert q_out.is_contiguous() + assert k_out.is_contiguous() + assert v_out.is_contiguous()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py` around lines 91 - 117, Update test_split_qkv_contiguous to also exercise the non-contiguous input path: create a transposed/non-contiguous source (e.g., mixed_qkv_t = mixed_qkv.transpose(0,1) or mixed_qkv.t()), call split_qkv_contiguous with that tensor (matching the same q_dim/k_dim/v_dim/head args), and then reshape/transpose the kernel outputs back to 2D to compare against _ref_split_qkv_contiguous(mixed_qkv, ...). Ensure you assert the output shapes for the transposed case and use torch.testing.assert_close to compare values (same rtol/atol), so the stride-sensitive branch used by split_qkv_contiguous is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py`:
- Around line 91-92: The wrapped function signature for
test_split_qkv_contiguous is indented in a way that triggers Flake8 E127;
reformat the continuation lines so parameters align with the first character
after the opening parenthesis (or use a consistent hanging indent) instead of
the current misaligned indentation, and apply the same fix to the other nearby
wrapped test signature that has the same continuation-style issue; update the
signature formatting so ruff/ruff-format and flake8 pass.
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py`:
- Around line 269-312: Add preflight checks in transpose_and_split_qkv before
launching _transpose_and_split_qkv_kernel: verify that num_cols == q_dim + k_dim
+ v_dim and that q_dim == num_q_heads * head_k_dim and v_dim == num_v_heads *
head_v_dim (and similarly k_dim if applicable), and raise a clear ValueError if
any mismatch occurs. Do the same guard checks in the analogous function at lines
364-417 so callers get a Python exception instead of invalid Triton pointer
arithmetic. Include the unique symbols transpose_and_split_qkv,
_transpose_and_split_qkv_kernel, q_dim, k_dim, v_dim, num_q_heads, head_k_dim,
num_v_heads, head_v_dim when locating where to add these checks.
In `@tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py`:
- Around line 91-117: Update test_split_qkv_contiguous to also exercise the
non-contiguous input path: create a transposed/non-contiguous source (e.g.,
mixed_qkv_t = mixed_qkv.transpose(0,1) or mixed_qkv.t()), call
split_qkv_contiguous with that tensor (matching the same q_dim/k_dim/v_dim/head
args), and then reshape/transpose the kernel outputs back to 2D to compare
against _ref_split_qkv_contiguous(mixed_qkv, ...). Ensure you assert the output
shapes for the transposed case and use torch.testing.assert_close to compare
values (same rtol/atol), so the stride-sensitive branch used by
split_qkv_contiguous is covered.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 79665f75-dc86-4f88-ac99-3cd1a2a87616
📒 Files selected for processing (3)
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.pytensorrt_llm/_torch/modules/mamba/gdn_mixer.pytests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py
12c72c4 to
40e5121
Compare
yechank-nvidia
left a comment
There was a problem hiding this comment.
Hi, can I know how much of perf improvement we got?
Benchmark: Qwen3.5-35B-A3B BF16, B200 (TP1, GPU clocked at 1830MHz)Tool: trtllm-bench, ISL=1k/OSL=100
|
be7ee36 to
77961b8
Compare
|
/bot run --disable-fail-fast |
3d4227a to
ce4534c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #43611 [ run ] triggered by Bot. Commit: |
|
PR_Github #43613 [ run ] triggered by Bot. Commit: |
|
PR_Github #43611 [ run ] completed with state |
|
PR_Github #43613 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #44144 [ run ] triggered by Bot. Commit: |
|
PR_Github #44144 [ run ] completed with state
|
Add 3 Triton kernel optimizations for Qwen3.5 GDN layers: 1. fused_gdn_gating_with_sigmoid: fuse sigmoid(b) into the gating kernel, eliminating a separate elementwise kernel per GDN layer. 2. split_qkv_contiguous: replace torch.split + implicit .contiguous() with a single Triton kernel that writes directly to 3 contiguous output tensors. 3. transpose_and_split_qkv: fuse transpose + split for mixed prefill+decode batches, removing a full data copy pass per layer. Also add unit tests for all 3 kernels validating against reference PyTorch implementations. Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie <jaywan@nvidia.com>
Enable @triton.autotune for chunk_gated_delta_rule_fwd_kernel_h and chunk_fwd_kernel_o, replacing hardcoded num_warps/num_stages/block sizes. The autotune search space uses the existing NUM_WARPS and BKV_LIST variables which are already conditioned on GPU architecture. This allows Triton to automatically select optimal kernel parameters per GPU, eliminating the need for manual per-GPU tuning. Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Use shorter loop variable names (nw, ns) in FLA chunk kernel autotune configs to keep lines within 80-column limit and improve readability. Signed-off-by: Shijie Wang <jaywan@nvidia.com>
ce4534c to
f23ab97
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #44948 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #45693 [ run ] triggered by Bot. Commit: |
|
PR_Github #45693 [ run ] completed with state |
…IDIA#12966) Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Add 3 Triton kernel optimizations for Qwen3.5 GDN layers:
fused_gdn_gating_with_sigmoid: fuse sigmoid(b) into the gating kernel, eliminating a separate elementwise kernel per GDN layer.
split_qkv_contiguous: replace torch.split + implicit .contiguous() with a single Triton kernel that writes directly to 3 contiguous output tensors.
transpose_and_split_qkv: fuse transpose + split for mixed prefill+decode batches, removing a full data copy pass per layer.
Also add unit tests for all 3 kernels validating against reference PyTorch implementations.
Summary by CodeRabbit
Performance Improvements
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.