Skip to content

[None][feat] Fuse GDN elementwise ops and split/transpose kernels#12966

Merged
nv-guomingz merged 4 commits into
NVIDIA:mainfrom
Wong4j:qwen35-gdn-kernel-opt
May 6, 2026
Merged

[None][feat] Fuse GDN elementwise ops and split/transpose kernels#12966
nv-guomingz merged 4 commits into
NVIDIA:mainfrom
Wong4j:qwen35-gdn-kernel-opt

Conversation

@Wong4j

@Wong4j Wong4j commented Apr 12, 2026

Copy link
Copy Markdown
Collaborator

Add 3 Triton kernel optimizations for Qwen3.5 GDN layers:

  1. fused_gdn_gating_with_sigmoid: fuse sigmoid(b) into the gating kernel, eliminating a separate elementwise kernel per GDN layer.

  2. split_qkv_contiguous: replace torch.split + implicit .contiguous() with a single Triton kernel that writes directly to 3 contiguous output tensors.

  3. transpose_and_split_qkv: fuse transpose + split for mixed prefill+decode batches, removing a full data copy pass per layer.

Also add unit tests for all 3 kernels validating against reference PyTorch implementations.

Summary by CodeRabbit

  • Performance Improvements

    • Optimized Mamba GDN mixer with fused kernels for improved Q/K/V processing in mixed prefill+decode batches.
    • Enhanced GDN gating computation efficiency by combining sigmoid operations into a single fused kernel.
  • Tests

    • Added comprehensive unit tests for GDN kernel optimizations on CUDA devices, covering gating operations and tensor splitting scenarios.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@Wong4j Wong4j requested review from a team as code owners April 12, 2026 11:04
@Wong4j Wong4j force-pushed the qwen35-gdn-kernel-opt branch from 35a1b44 to abbd450 Compare April 12, 2026 11:04
@Wong4j Wong4j changed the title [None][perf] Fuse GDN elementwise ops and split/transpose kernels [None][feat] Fuse GDN elementwise ops and split/transpose kernels Apr 12, 2026
@coderabbitai

coderabbitai Bot commented Apr 12, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

The PR introduces Triton kernel optimizations for Mamba/GDN Q/K/V materialization and gating operations. New kernels transpose_and_split_qkv and split_qkv_contiguous efficiently fuse tensor transposition and splitting for mixed prefill+decode batches. A fused_gdn_gating_with_sigmoid kernel combines gating and sigmoid computations. These are integrated into the GDN mixer forward path and validated with parameterized CUDA unit tests.

Changes

Cohort / File(s) Summary
QKV Fusion Kernels
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
Added Triton kernels and Python wrappers for transpose_and_split_qkv (fuses prefill transposition and decode reading into contiguous Q/K/V buffers) and split_qkv_contiguous (splits combined QKV into separate contiguous tensors). Both kernels optimize mixed batch processing with 3D grid launches.
GDN Mixer Integration
tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Extended imports to include new QKV fusion helpers. Added fused_gdn_gating_with_sigmoid kernel that computes gating and sigmoid in a single kernel. Refactored Qwen3NextGatedDeltaNet.forward_extend to replace intermediate tensor operations (copy_, cat, split) with direct fused kernel calls for both mixed and pure prefill paths.
Kernel Validation Tests
tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py
New CUDA-gated test module with three parameterized tests validating fused_gdn_gating_with_sigmoid, split_qkv_contiguous, and transpose_and_split_qkv against reference PyTorch implementations using torch.testing.assert_close.

Sequence Diagram(s)

sequenceDiagram
    participant User as Application
    participant Wrapper as Python Wrapper<br/>(transpose_and_split_qkv)
    participant Kernel as Triton Kernel<br/>(_transpose_and_split_qkv_kernel)
    participant Memory as GPU Memory
    
    User->>Wrapper: Call transpose_and_split_qkv(prefill_t, decode, ...)
    Wrapper->>Memory: Allocate q_flat, k_flat, v_flat<br/>(size: total_seq × dim)
    Wrapper->>Kernel: Launch 3D grid<br/>(seq_blocks, dim_blocks, qkv_selector)
    Kernel->>Memory: Read prefill from [D, T_p] layout
    Kernel->>Memory: Read decode from [T_d, D] layout
    Kernel->>Memory: Write contiguous Q/K/V<br/>to output buffers
    Kernel-->>Wrapper: Kernel complete
    Wrapper->>Memory: Reshape outputs to<br/>[1, total_seq, heads, head_dim]
    Wrapper-->>User: Return (query, key, value)
Loading
sequenceDiagram
    participant User as Application
    participant Wrapper as Python Wrapper<br/>(split_qkv_contiguous)
    participant Kernel as Triton Kernel<br/>(_split_qkv_contiguous_kernel)
    participant Memory as GPU Memory
    
    User->>Wrapper: Call split_qkv_contiguous(mixed_qkv, ...)
    Wrapper->>Memory: Allocate q_flat, k_flat, v_flat<br/>(contiguous)
    Wrapper->>Kernel: Launch 3D grid
    Kernel->>Memory: Read mixed_qkv [T, qkv_dim]<br/>(any stride)
    Kernel->>Memory: Write contiguous Q, K, V
    Kernel-->>Wrapper: Kernel complete
    Wrapper->>Memory: Reshape to [1, T, heads, head_dim]
    Wrapper-->>User: Return (query, key, value)
Loading
sequenceDiagram
    participant User as Application
    participant Wrapper as Python Wrapper<br/>(fused_gdn_gating_with_sigmoid)
    participant Kernel as Triton Kernel<br/>(fused_gdn_gating_with_sigmoid_kernel)
    participant Memory as GPU Memory
    
    User->>Wrapper: Call fused_gdn_gating_with_sigmoid<br/>(A_log, a, dt_bias, b, ...)
    Wrapper->>Memory: Allocate g and beta_out buffers
    Wrapper->>Kernel: Launch kernel
    Kernel->>Kernel: Compute -exp(A_log) × softplus(a + dt_bias)<br/>→ g
    Kernel->>Kernel: Compute sigmoid(b)<br/>→ beta_out
    Kernel->>Memory: Write g and beta_out
    Kernel-->>Wrapper: Kernel complete
    Wrapper-->>User: Return (g, beta_out)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely describes the main changes: adding three Triton kernel optimizations to fuse GDN elementwise operations and QKV split/transpose kernels.
Description check ✅ Passed The PR description adequately explains the three kernel optimizations (fused_gdn_gating_with_sigmoid, split_qkv_contiguous, transpose_and_split_qkv) and mentions unit tests, but the 'Description' and 'Test Coverage' sections remain unfilled template placeholders.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@Wong4j Wong4j force-pushed the qwen35-gdn-kernel-opt branch 2 times, most recently from f518d38 to 12c72c4 Compare April 12, 2026 11:11

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py (1)

269-312: Fail fast on inconsistent QKV widths before launching these kernels.

Both wrappers assume the source width is exactly q_dim + k_dim + v_dim and that the head-count arguments multiply back to those dims. If a caller ever passes mismatched values, this becomes raw Triton pointer arithmetic instead of a clean Python exception.

🛡️ Suggested guardrails
 def transpose_and_split_qkv(
     prefill_t: torch.Tensor,
     decode: torch.Tensor,
@@
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@
     num_cols, num_prefill = prefill_t.shape
+    expected_cols = q_dim + k_dim + v_dim
+    if decode.ndim != 2 or num_cols != expected_cols or decode.shape[1] != expected_cols:
+        raise ValueError(
+            f"Expected prefill/decode widths to equal {expected_cols}, got "
+            f"{num_cols} and {decode.shape[1]}"
+        )
+    if q_dim != num_q_heads * head_k_dim or k_dim != num_q_heads * head_k_dim:
+        raise ValueError("Q/K dims must match num_q_heads * head_k_dim")
+    if v_dim != num_v_heads * head_v_dim:
+        raise ValueError("V dim must match num_v_heads * head_v_dim")
     num_decode = decode.shape[0]
     total_seq = num_prefill + num_decode
@@
 def split_qkv_contiguous(
     mixed_qkv: torch.Tensor,
@@
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@
-    seq_len = mixed_qkv.shape[0]
+    seq_len, num_cols = mixed_qkv.shape
+    expected_cols = q_dim + k_dim + v_dim
+    if num_cols != expected_cols:
+        raise ValueError(f"Expected mixed_qkv width {expected_cols}, got {num_cols}")
+    if q_dim != num_q_heads * head_k_dim or k_dim != num_q_heads * head_k_dim:
+        raise ValueError("Q/K dims must match num_q_heads * head_k_dim")
+    if v_dim != num_v_heads * head_v_dim:
+        raise ValueError("V dim must match num_v_heads * head_v_dim")
     src_stride_seq, src_stride_dim = mixed_qkv.stride()

Also applies to: 364-417

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py` around lines 269 -
312, Add preflight checks in transpose_and_split_qkv before launching
_transpose_and_split_qkv_kernel: verify that num_cols == q_dim + k_dim + v_dim
and that q_dim == num_q_heads * head_k_dim and v_dim == num_v_heads * head_v_dim
(and similarly k_dim if applicable), and raise a clear ValueError if any
mismatch occurs. Do the same guard checks in the analogous function at lines
364-417 so callers get a Python exception instead of invalid Triton pointer
arithmetic. Include the unique symbols transpose_and_split_qkv,
_transpose_and_split_qkv_kernel, q_dim, k_dim, v_dim, num_q_heads, head_k_dim,
num_v_heads, head_v_dim when locating where to add these checks.
tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py (1)

91-117: Add a non-contiguous source case for split_qkv_contiguous().

This test only covers a contiguous [T, D] input, but tensorrt_llm/_torch/modules/mamba/gdn_mixer.py Lines 686-688 call the helper with mixed_qkv_t.transpose(0, 1). The stride-sensitive path this kernel was added for is therefore untested right now.

🧪 Suggested test expansion
 `@skip_no_cuda`
 `@pytest.mark.parametrize`("seq_len", [1, 32, 128, 1024])
+@pytest.mark.parametrize("transpose_input", [False, True])
 `@pytest.mark.parametrize`(
     "q_dim,k_dim,v_dim,num_q_heads,head_k_dim,num_v_heads,head_v_dim",
@@
 )
 `@pytest.mark.parametrize`("dtype", [torch.bfloat16])
-def test_split_qkv_contiguous(seq_len, q_dim, k_dim, v_dim,
-                               num_q_heads, head_k_dim, num_v_heads, head_v_dim, dtype):
+def test_split_qkv_contiguous(
+    seq_len,
+    transpose_input,
+    q_dim,
+    k_dim,
+    v_dim,
+    num_q_heads,
+    head_k_dim,
+    num_v_heads,
+    head_v_dim,
+    dtype,
+):
@@
-    mixed_qkv = torch.randn(seq_len, total_dim, dtype=dtype, device=device)
+    if transpose_input:
+        mixed_qkv = torch.randn(total_dim, seq_len, dtype=dtype, device=device).transpose(0, 1)
+        assert not mixed_qkv.is_contiguous()
+    else:
+        mixed_qkv = torch.randn(seq_len, total_dim, dtype=dtype, device=device)

@@
     assert q_out.shape == (1, seq_len, num_q_heads, head_k_dim)
     assert k_out.shape == (1, seq_len, num_q_heads, head_k_dim)
     assert v_out.shape == (1, seq_len, num_v_heads, head_v_dim)
+    assert q_out.is_contiguous()
+    assert k_out.is_contiguous()
+    assert v_out.is_contiguous()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py` around
lines 91 - 117, Update test_split_qkv_contiguous to also exercise the
non-contiguous input path: create a transposed/non-contiguous source (e.g.,
mixed_qkv_t = mixed_qkv.transpose(0,1) or mixed_qkv.t()), call
split_qkv_contiguous with that tensor (matching the same q_dim/k_dim/v_dim/head
args), and then reshape/transpose the kernel outputs back to 2D to compare
against _ref_split_qkv_contiguous(mixed_qkv, ...). Ensure you assert the output
shapes for the transposed case and use torch.testing.assert_close to compare
values (same rtol/atol), so the stride-sensitive branch used by
split_qkv_contiguous is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py`:
- Around line 91-92: The wrapped function signature for
test_split_qkv_contiguous is indented in a way that triggers Flake8 E127;
reformat the continuation lines so parameters align with the first character
after the opening parenthesis (or use a consistent hanging indent) instead of
the current misaligned indentation, and apply the same fix to the other nearby
wrapped test signature that has the same continuation-style issue; update the
signature formatting so ruff/ruff-format and flake8 pass.

---

Nitpick comments:
In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py`:
- Around line 269-312: Add preflight checks in transpose_and_split_qkv before
launching _transpose_and_split_qkv_kernel: verify that num_cols == q_dim + k_dim
+ v_dim and that q_dim == num_q_heads * head_k_dim and v_dim == num_v_heads *
head_v_dim (and similarly k_dim if applicable), and raise a clear ValueError if
any mismatch occurs. Do the same guard checks in the analogous function at lines
364-417 so callers get a Python exception instead of invalid Triton pointer
arithmetic. Include the unique symbols transpose_and_split_qkv,
_transpose_and_split_qkv_kernel, q_dim, k_dim, v_dim, num_q_heads, head_k_dim,
num_v_heads, head_v_dim when locating where to add these checks.

In `@tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py`:
- Around line 91-117: Update test_split_qkv_contiguous to also exercise the
non-contiguous input path: create a transposed/non-contiguous source (e.g.,
mixed_qkv_t = mixed_qkv.transpose(0,1) or mixed_qkv.t()), call
split_qkv_contiguous with that tensor (matching the same q_dim/k_dim/v_dim/head
args), and then reshape/transpose the kernel outputs back to 2D to compare
against _ref_split_qkv_contiguous(mixed_qkv, ...). Ensure you assert the output
shapes for the transposed case and use torch.testing.assert_close to compare
values (same rtol/atol), so the stride-sensitive branch used by
split_qkv_contiguous is covered.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 79665f75-dc86-4f88-ac99-3cd1a2a87616

📥 Commits

Reviewing files that changed from the base of the PR and between 5653803 and abbd450.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
  • tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
  • tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py

Comment thread tests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py Outdated
@Wong4j Wong4j force-pushed the qwen35-gdn-kernel-opt branch from 12c72c4 to 40e5121 Compare April 12, 2026 13:37
@Wong4j Wong4j changed the title [None][feat] Fuse GDN elementwise ops and split/transpose kernels [None][feat] Fuse GDN elementwise ops and split/transpose kernels Apr 12, 2026

@yechank-nvidia yechank-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, can I know how much of perf improvement we got?

@Wong4j Wong4j requested review from nv-guomingz and rosenrodt April 13, 2026 01:40
Comment thread tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Comment thread tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py Outdated
@Wong4j

Wong4j commented Apr 14, 2026

Copy link
Copy Markdown
Collaborator Author

Hi, can I know how much of perf improvement we got?

Benchmark: Qwen3.5-35B-A3B BF16, B200 (TP1, GPU clocked at 1830MHz)

Tool: trtllm-bench, max_num_tokens=32768, max_batch_size=512

ISL=1k/OSL=100

Concurrency Baseline (tps) With PR (tps) Speedup
8 1,016.9 1,025.6 +0.9%
32 2,356.4 2,421.6 +2.8%
128 3,983.4 4,083.6 +2.5%
256 4,589.8 4,799.8 +4.6%

@yechank-nvidia

@Wong4j Wong4j requested a review from a team as a code owner April 14, 2026 12:11
@Wong4j Wong4j requested a review from xxi-nv April 14, 2026 12:11
@Wong4j Wong4j force-pushed the qwen35-gdn-kernel-opt branch from be7ee36 to 77961b8 Compare April 14, 2026 12:12
@Wong4j

Wong4j commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@Wong4j Wong4j force-pushed the qwen35-gdn-kernel-opt branch from 3d4227a to ce4534c Compare April 16, 2026 02:20
@Wong4j

Wong4j commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #43611 [ run ] triggered by Bot. Commit: ce4534c Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #43613 [ run ] triggered by Bot. Commit: ce4534c Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #43611 [ run ] completed with state ABORTED. Commit: ce4534c

Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #43613 [ run ] completed with state SUCCESS. Commit: ce4534c
/LLM/main/L0_MergeRequest_PR pipeline #34105 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #44144 [ run ] triggered by Bot. Commit: ce4534c Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #44144 [ run ] completed with state SUCCESS. Commit: ce4534c
/LLM/main/L0_MergeRequest_PR pipeline #34571 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Wong4j added 4 commits April 22, 2026 17:15
Add 3 Triton kernel optimizations for Qwen3.5 GDN layers:

1. fused_gdn_gating_with_sigmoid: fuse sigmoid(b) into the gating
   kernel, eliminating a separate elementwise kernel per GDN layer.

2. split_qkv_contiguous: replace torch.split + implicit .contiguous()
   with a single Triton kernel that writes directly to 3 contiguous
   output tensors.

3. transpose_and_split_qkv: fuse transpose + split for mixed
   prefill+decode batches, removing a full data copy pass per layer.

Also add unit tests for all 3 kernels validating against reference
PyTorch implementations.

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie <jaywan@nvidia.com>
Enable @triton.autotune for chunk_gated_delta_rule_fwd_kernel_h and
chunk_fwd_kernel_o, replacing hardcoded num_warps/num_stages/block
sizes. The autotune search space uses the existing NUM_WARPS and
BKV_LIST variables which are already conditioned on GPU architecture.

This allows Triton to automatically select optimal kernel parameters
per GPU, eliminating the need for manual per-GPU tuning.

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Use shorter loop variable names (nw, ns) in FLA chunk kernel autotune
configs to keep lines within 80-column limit and improve readability.

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
@Wong4j Wong4j force-pushed the qwen35-gdn-kernel-opt branch from ce4534c to f23ab97 Compare April 22, 2026 09:15
@Wong4j

Wong4j commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #44948 [ run ] triggered by Bot. Commit: f23ab97 Link to invocation

@rosenrodt

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #45693 [ run ] triggered by Bot. Commit: f23ab97 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #45693 [ run ] completed with state SUCCESS. Commit: f23ab97
/LLM/main/L0_MergeRequest_PR pipeline #35897 completed with status: 'SUCCESS'

CI Report

Link to invocation

@QiJune QiJune left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nv-guomingz nv-guomingz merged commit d2c8444 into NVIDIA:main May 6, 2026
5 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants