Add _grouped_mm dispatch handler to MXTensor#4419
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4419
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1c5bbe6 to
a7a74c8
Compare
|
@vkuzo This PR incorporates the suggestions from the MoE Float8 PR, and the CI tests are already green. Please help review it. |
| f"Only supporting bf16 out dtype for now, got {weight.dtype}" | ||
| ) | ||
| # Swizzled scales are a CUDA SM100+ optimization; disable on XPU | ||
| use_swizzled = not weight.is_xpu |
There was a problem hiding this comment.
this should be set deterministically in MXDynamicActivationMXWeightConfig, you can look at https://github.com/pytorch/ao/pull/4290/changes (not landed yet) as an example
|
@claude review |
|
Claude finished @vkuzo's task in 2m 57s —— View job Code Review: Add _grouped_mm dispatch handler to MXTensor
Overall this is a solid PR that follows the established patterns from NVFP4's 1.
|
|
overall this looks good, once all the comments are implemented we should land this |
|
1, 2, 3, 4, 5 from claude review are real and should be handled! |
|
@claude Please explain the second point of the review. |
|
Claude finished @xiaowangintel's task in 1m 4s —— View job Explanation of Review Point #2
Review point #2 concerns a potential alignment issue with how the activation tensor is quantized for The setup: In What The concern: Since Contrast with NVFP4: The NVFP4 handler ( Practical impact: For MX formats with TL;DR: The review point flags that unlike NVFP4 (which quantizes per-group), this code quantizes the full activation at once. For |
|
@claude Based on an analysis of https://github.com/pytorch/pytorch/blob/9e7520818c7d72bcbb813fce421628ba3bf683ae/aten/src/ATen/native/cuda/GroupedBlas.cpp, what kind of checks should be added to address the second point raised in the review? |
|
Claude finished @xiaowangintel's task in 9m 45s —— View job Analysis of
|
| Constraint | FP8 Rowwise | MXFP8 (BlockWise1x32) | Source |
|---|---|---|---|
| K (inner dim) | K % 16 == 0 | K % 32 == 0 | fp8_grouped_mm.py:86, mxfp8_extension.cpp:132 |
| N (output dim) | N % 16 == 0 | N % 32 == 0 | fp8_grouped_mm.py:88, mxfp8_extension.cpp:131 |
| Per-group token count (M_i) | M_i % 16 == 0 | M_i % 32 == 0 | fp8_grouped_mm.py:45, mxfp8_grouped_mm.py:110 |
| Offsets dtype | int32 | int32 | fp8_grouped_mm.py:99 |
| A memory layout | Row-major | Row-major | fp8_grouped_mm.py:109 |
| B memory layout | Column-major | Column-major (per-group) | fp8_grouped_mm.py:112 |
The critical constraint is that each group's token count must be a multiple of 32 for MXFP8/MX BlockWise1x32 scaled_grouped_mm. This is because:
- The SM100 CUTLASS kernel processes tiles of 128 rows × 4 scale columns (from
mxfp8_extension.cpp:215-216). - The
fused_pad_token_groups_cudakernel enforcesalignment_size == 32(line 298 ofmxfp8_extension.cpp). - The blocked scale rearrangement (
mx_block_rearrange_2d_M_groups) pads each group to 128-row boundaries.
Recommended checks to add to mx_grouped_mm
Here are the specific validations that should be added to the AUTO path in mx_grouped_mm (at torchao/prototype/mx_formats/mx_tensor.py:1098):
@implements([aten._grouped_mm.default])
def mx_grouped_mm(func, types, args, kwargs):
"""Handles torch._grouped_mm when weight (mat_b) is an MXTensor."""
mat_a, mat_b = args[0], args[1]
offs = args[2] if len(args) > 2 else kwargs.get("offs", None)
assert isinstance(mat_b, MXTensor)
assert offs is not None, "offs is required for MXTensor grouped_mm"
act_quant_kwargs = mat_b.act_quant_kwargs
if act_quant_kwargs is None:
return torch._grouped_mm(mat_a, mat_b.dequantize(mat_b.orig_dtype), offs=offs)
k = act_quant_kwargs
if k.kernel_preference == KernelPreference.EMULATED:
return torch._grouped_mm(mat_a, mat_b.dequantize(mat_b.orig_dtype), offs=offs)
# === ADD THESE CHECKS for the AUTO (hardware) path ===
block_size = k.block_size # typically 32
# 1. K dimension must be divisible by block_size (required for 1x32 block scaling)
K = mat_a.shape[-1]
assert K % block_size == 0, (
f"K dimension ({K}) must be divisible by block_size ({block_size}) "
f"for MX scaled_grouped_mm"
)
# 2. N dimension (weight output dim) must be divisible by block_size
# mat_b is transposed (E, K, N) so the last dim is N
N = mat_b.shape[-1]
assert N % block_size == 0, (
f"N dimension ({N}) must be divisible by block_size ({block_size}) "
f"for MX scaled_grouped_mm"
)
# 3. Per-group token counts must be divisible by block_size
# offs contains cumulative end indices; group sizes are the diffs
group_sizes = torch.diff(offs, prepend=offs.new_zeros(1))
misaligned = group_sizes % block_size
assert misaligned.sum() == 0, (
f"All per-group token counts must be divisible by block_size ({block_size}) "
f"for MX scaled_grouped_mm. Got group sizes: {group_sizes.tolist()}"
)
# 4. Offsets must be int32
assert offs.dtype == torch.int32, (
f"offs must be int32, got {offs.dtype}"
)
# ... rest of the handler (quantize activation, call scaled_grouped_mm)Why these specific checks?
-
K % block_size == 0: The
MXTensor.to_mx()call reshapes the last dimension into blocks ofblock_size. If K is not divisible, quantization will fail or produce wrong results. This is already checked insideto_mx(line 227), but checking it early gives a clearer error message in the grouped_mm context. -
N % block_size == 0: Required by
mxfp8_extension.cpp:131— "rows must be a multiple of 32". Since the weight is stored as(E, N, K)and transposed, N becomes the "rows" dimension of the weight scales. -
Per-group M_i % block_size == 0: This is the most important check and the core of review point Headers and gitignore #2. The SM100 hardware kernel requires each group's row count to be aligned. Without this check, you'll get silent incorrect results. The MXFP8 training handler addresses this via
pad_token_groups_for_grouped_mm=True(which pads to multiples of 32), and FP8 pads to multiples of 16. You should either:- Add an assertion requiring pre-aligned group sizes (simplest, documented requirement), or
- Add automatic padding via
pad_token_groups/unpad_token_groups(more user-friendly, matches MXFP8 training handler pattern)
-
offs.dtype == int32: Required by
GroupedBlas.cpp— the CUDA kernel expects int32 offsets.
Which approach: assert or pad?
Looking at the two patterns in the repo:
- NVFP4 (
nvfp4_tensor.py:707): Does NOT check alignment — relies onmslk_quantize_nvfp4_stackedhandling padding internally (each segment can add up to 127 padding rows). - MXFP8 training (
mxfp8_grouped_mm.py:205): Providespad_token_groups_for_grouped_mmflag — if True, pads automatically; if False, assumes pre-aligned input.
For the inference handler, I'd recommend the assertion approach (option 1) since:
- It's simpler and makes the requirement explicit
- Users can pad their inputs before calling the model
- The training handler's padding utilities depend on SM100 CUDA kernels that may not be available in all inference environments
Summary
This PR adds
aten._grouped_mm.defaultandaten.transpose.intdispatch handlers toMXTensor, enabling MX-format (FP8/FP4) inference for MoE (Mixture of Experts) models that usetorch._grouped_mmfor expert computation. This follows the same pattern established by NVFP4 in #4316 and Float8 in #4390.Motivation
MoE architectures (e.g., DeepSeek, Qwen3-MoE) store expert weights as 3D tensors
(E, N, K)and usetorch._grouped_mmin the forward pass. With this PR, users can quantize expert weights to MX format (FP8 or FP4) viaquantize_()usingMXDynamicActivationMXWeightConfigand get automatic dispatch through the_grouped_mmhandler — no model code changes needed.Related: RFC #4355, NVFP4 reference #4316, Float8 reference #4390
Design
The handler supports three modes based on
act_quant_kwargsandkernel_preference:Weight-only (
act_quant_kwargs is None):_grouped_mmwith bf16EMULATED (
kernel_preference == EMULATED):_grouped_mmwith bf16 (no activation quantization)AUTO (preferred path):
MXTensor.to_mx()uint8qdata asfloat4_e2m1fn_x2F.scaled_grouped_mmwith BlockWise1x32 scaling and optional swizzle