Skip to content

Conversation

juuso-oskari
Copy link
Contributor

@juuso-oskari juuso-oskari commented Oct 2, 2025

Authors: @juuso-oskari @Chi-Chu319

e2e fused MoE for small N

This PR adds end to end implementation of MoE optimized for short intermediate representation (N <= 1024). The idea is that because the intermediate dimension is so small, we can fit the whole intermediate token representation to shared memory and fuse the two gemms of MoE efficiently as two Triton tl.dot operations.

We provide the default fp16 and the blockscaled fp8 version. Some implementation details for the blockscaled fp8 version:

  • we don't do the second gemm in fp8, but rather descale the second weight matrix. We are forced to do this as BLOCK_SIZE_N // 2 > group_n (block size is larger than the quantization group size along n) and N is the inner dimension for the second tl.dot. This however allows us to skip the quantization / dequantization of the intermediate.
  • we load only the unique scaling factors and then broadcast to the block dimension. This differs from other quantized Triton kernels, where we usually just do a block sized load along the outer block dimensions for scaling factors. We found that this offers a perf boost, especially for the loading of the second weight matrix scaling factors.
# num_scales_along_n = BLOCK_SIZE_N // group_n
# num_scales_along_k2 = BLOCK_SIZE_K2 // group_k
w2_scale = tl.load(w2_scale_ptrs + k2 * BLOCK_SIZE_K2 // group_k * stride_w2sk) 
# w2_scale is of size (num_scales_along_n, num_scales_along_k2) compared to (BLOCK_SIZE_N, BLOCK_SIZE_K2)
# do the broadcasts
w2_scale = group_broadcast(w2_scale, num_scales_along_n, num_scales_along_k2, group_n, 0)  
w2_scale = group_broadcast(w2_scale, num_scales_along_n * group_n, num_scales_along_k2, group_k, 1)
# (BLOCK_SIZE_N, BLOCK_SIZE_K2)

We end up only parallelizing along sorted_token_ids length which is max_num_tokens_padded= topk * M + E * (BLOCK_SIZE_M – 1), but in the newer Qwen3 models this is sufficiently large due to large topk and E.

fp8 blockscale MoE benching

This PR also modifies the benching script of MoE to enable benching of blockscaled fp8 MoE.

Bench results

Initial performance results for small M (M=32).

FP8 blockscale

baseline (aka the two kernels launched for the two gemms):

~/aiter$ python op_tests/op_benchmarks/triton/bench_moe.py --model qwen3next -M 32 -block_shape 128 128 -fp8_w8a8
           model   M     N     K    E  top_k  Time_(ms)    TFLOPS  Bandwidth_(GB/s)  Arithmetic_Intensity_(Flops/Byte)
0  qwen3next-80B  32   512  2048  512     10   0.090924  7.334154       2717.911854                           2.662507
1  qwen3next-80B  32  2048   256  512     10   0.068197  5.255668       1891.087017                           2.554731

e2e fused moe:

~/aiter$ python op_tests/op_benchmarks/triton/bench_moe.py --model qwen3next -e2e_fused -M 32 -block_shape 128 128 -fp8_w8a8
           model   M    N     K    E  top_k  Time_(ms)    TFLOPS  Bandwidth_(GB/s)  Arithmetic_Intensity_(Flops/Byte)
0  qwen3next-80B  32  512  2048  512     10     0.1156  9.043616       3271.196095                           2.657412

So we see roughly 0.1156 / (0.090924 + 0.068197) = 0.72 decrease in runtime.

BF16

baseline (aka the two kernels launched for the two gemms):

~/aiter$ python op_tests/op_benchmarks/triton/bench_moe.py --model qwen3next -M 32
           model   M     N     K    E  top_k  Time_(ms)    TFLOPS  Bandwidth_(GB/s)  Arithmetic_Intensity_(Flops/Byte)
0  qwen3next-80B  32   512  2048  512     10   0.246026  2.762488       2075.698561                           1.326597
1  qwen3next-80B  32  2048   256  512     10   0.163532  2.036682       1567.156753                           1.315435

e2e fused moe:

~/aiter$ python op_tests/op_benchmarks/triton/bench_moe.py --model qwen3next -M 32 -e2e_fused
           model   M    N     K    E  top_k  Time_(ms)    TFLOPS  Bandwidth_(GB/s)  Arithmetic_Intensity_(Flops/Byte)
0  qwen3next-80B  32  512  2048  512     10    0.17934  5.528023       4333.753851                            1.30923

So we see roughly 0.17934 / (0.246026 + 0.163532) = 0.44 decrease in runtime.

To be noted, we expect that the baseline is underperforming because it lacks tuning for Qwen3 shapes. Doing quick and dirty tuning for baseline BF16 brings the perf to:

~/aiter$ python op_tests/op_benchmarks/triton/bench_moe.py --model qwen3next -M 32
           model   M     N     K    E  top_k  Time_(ms)    TFLOPS  Bandwidth_(GB/s)  Arithmetic_Intensity_(Flops/Byte)
0  qwen3next-80B  32   512  2048  512     10   0.120142  5.519248       4044.894499                           1.348966
1  qwen3next-80B  32  2048   256  512     10   0.086003  3.770810       2826.478909                           1.365971

But we still see that the e2e fused moe is faster.

@juuso-oskari juuso-oskari changed the title [Triton] End-to-end Mixture of Experts for QWEN 3 [Triton] e2e fused MoE for small N and fp8 blockscale MoE benching Oct 2, 2025
@juuso-oskari juuso-oskari marked this pull request as ready for review October 8, 2025 10:16
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements end-to-end (e2e) fused MoE (Mixture of Experts) kernels optimized for small intermediate dimensions (N ≤ 1024) with support for both FP16 and blockscaled FP8 quantization. The key innovation is fitting the entire intermediate token representation in shared memory to fuse two matrix multiplications efficiently.

  • End-to-end fused MoE kernel implementation that combines two GEMMs and gated activation in a single kernel
  • Blockscaled FP8 quantization support with optimized scaling factor loading and broadcasting
  • Enhanced benchmarking capabilities with new model configurations and e2e benchmarking support

Reviewed Changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/triton_tests/test_moe.py Implements test infrastructure for e2e MoE with blockscale quantization support
op_tests/op_benchmarks/triton/utils/model_configs.json Adds MoE-specific configurations for qwen3/qwen3next models
op_tests/op_benchmarks/triton/bench_moe.py Extends benchmarking to support e2e fused MoE with improved memory bandwidth calculations
aiter/ops/triton/utils/moe_config_utils.py Adds configuration management for e2e MoE kernels
aiter/ops/triton/moe_op_e2e.py Core e2e MoE kernel interface with blockscale FP8 support
aiter/ops/triton/configs/moe/*.json Configuration files for optimized e2e MoE kernel parameters
aiter/ops/triton/_triton_kernels/moe_op_e2e.py Low-level Triton kernel implementation for e2e fused MoE

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

juuso-oskari and others added 3 commits October 8, 2025 13:19
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants