-
Notifications
You must be signed in to change notification settings - Fork 116
[Triton] e2e fused MoE for small N and fp8 blockscale MoE benching #1126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… dimension like in Qwen 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements end-to-end (e2e) fused MoE (Mixture of Experts) kernels optimized for small intermediate dimensions (N ≤ 1024) with support for both FP16 and blockscaled FP8 quantization. The key innovation is fitting the entire intermediate token representation in shared memory to fuse two matrix multiplications efficiently.
- End-to-end fused MoE kernel implementation that combines two GEMMs and gated activation in a single kernel
- Blockscaled FP8 quantization support with optimized scaling factor loading and broadcasting
- Enhanced benchmarking capabilities with new model configurations and e2e benchmarking support
Reviewed Changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
op_tests/triton_tests/test_moe.py | Implements test infrastructure for e2e MoE with blockscale quantization support |
op_tests/op_benchmarks/triton/utils/model_configs.json | Adds MoE-specific configurations for qwen3/qwen3next models |
op_tests/op_benchmarks/triton/bench_moe.py | Extends benchmarking to support e2e fused MoE with improved memory bandwidth calculations |
aiter/ops/triton/utils/moe_config_utils.py | Adds configuration management for e2e MoE kernels |
aiter/ops/triton/moe_op_e2e.py | Core e2e MoE kernel interface with blockscale FP8 support |
aiter/ops/triton/configs/moe/*.json | Configuration files for optimized e2e MoE kernel parameters |
aiter/ops/triton/_triton_kernels/moe_op_e2e.py | Low-level Triton kernel implementation for e2e fused MoE |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Authors: @juuso-oskari @Chi-Chu319
e2e fused MoE for small N
This PR adds end to end implementation of MoE optimized for short intermediate representation (N <= 1024). The idea is that because the intermediate dimension is so small, we can fit the whole intermediate token representation to shared memory and fuse the two gemms of MoE efficiently as two Triton tl.dot operations.
We provide the default fp16 and the blockscaled fp8 version. Some implementation details for the blockscaled fp8 version:
We end up only parallelizing along sorted_token_ids length which is max_num_tokens_padded= topk * M + E * (BLOCK_SIZE_M – 1), but in the newer Qwen3 models this is sufficiently large due to large topk and E.
fp8 blockscale MoE benching
This PR also modifies the benching script of MoE to enable benching of blockscaled fp8 MoE.
Bench results
Initial performance results for small M (M=32).
FP8 blockscale
baseline (aka the two kernels launched for the two gemms):
e2e fused moe:
So we see roughly 0.1156 / (0.090924 + 0.068197) = 0.72 decrease in runtime.
BF16
baseline (aka the two kernels launched for the two gemms):
e2e fused moe:
So we see roughly 0.17934 / (0.246026 + 0.163532) = 0.44 decrease in runtime.
To be noted, we expect that the baseline is underperforming because it lacks tuning for Qwen3 shapes. Doing quick and dirty tuning for baseline BF16 brings the perf to:
But we still see that the e2e fused moe is faster.