Skip to content

Add _grouped_mm dispatch handler to MXTensor#4419

Open
xiaowangintel wants to merge 12 commits into
pytorch:mainfrom
xiaowangintel:xw/mx-grouped-mm
Open

Add _grouped_mm dispatch handler to MXTensor#4419
xiaowangintel wants to merge 12 commits into
pytorch:mainfrom
xiaowangintel:xw/mx-grouped-mm

Conversation

@xiaowangintel

Copy link
Copy Markdown
Collaborator

Summary

This PR adds aten._grouped_mm.default and aten.transpose.int dispatch handlers to MXTensor, enabling MX-format (FP8/FP4) inference for MoE (Mixture of Experts) models that use torch._grouped_mm for expert computation. This follows the same pattern established by NVFP4 in #4316 and Float8 in #4390.

Motivation

MoE architectures (e.g., DeepSeek, Qwen3-MoE) store expert weights as 3D tensors (E, N, K) and use torch._grouped_mm in the forward pass. With this PR, users can quantize expert weights to MX format (FP8 or FP4) via quantize_() using MXDynamicActivationMXWeightConfig and get automatic dispatch through the _grouped_mm handler — no model code changes needed.

Related: RFC #4355, NVFP4 reference #4316, Float8 reference #4390

Design

The handler supports three modes based on act_quant_kwargs and kernel_preference:

Weight-only (act_quant_kwargs is None):

  • Dequantize weight → _grouped_mm with bf16

EMULATED (kernel_preference == EMULATED):

  • Dequantize weight → _grouped_mm with bf16 (no activation quantization)

AUTO (preferred path):

  1. Quantize activation to MX format via MXTensor.to_mx()
  2. For FP4: reinterpret packed uint8 qdata as float4_e2m1fn_x2
  3. Call F.scaled_grouped_mm with BlockWise1x32 scaling and optional swizzle
act_quant_kwargs is None?
├── YES → dequant weight → _grouped_mm (weight-only)
└── NO
    ├── EMULATED → dequant weight → _grouped_mm
    └── AUTO → quantize activation via MXTensor.to_mx
              → scaled_grouped_mm(a_qdata, b_qdata, scales, offs)

@pytorch-bot

pytorch-bot Bot commented May 18, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4419

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 18, 2026
@xiaowangintel xiaowangintel requested a review from liangan1 May 18, 2026 06:23
@xiaowangintel xiaowangintel added mx module: inference quantize_ api inference flow moe topic: new feature Use this tag if this PR adds a new feature labels May 18, 2026
@xiaowangintel

Copy link
Copy Markdown
Collaborator Author

@liangan1 @vkuzo Please help to review, thanks.

@xiaowangintel

Copy link
Copy Markdown
Collaborator Author

@vkuzo This PR incorporates the suggestions from the MoE Float8 PR, and the CI tests are already green. Please help review it.

f"Only supporting bf16 out dtype for now, got {weight.dtype}"
)
# Swizzled scales are a CUDA SM100+ optimization; disable on XPU
use_swizzled = not weight.is_xpu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be set deterministically in MXDynamicActivationMXWeightConfig, you can look at https://github.com/pytorch/ao/pull/4290/changes (not landed yet) as an example

@vkuzo

vkuzo commented May 28, 2026

Copy link
Copy Markdown
Contributor

@claude review

@claude

claude Bot commented May 28, 2026

Copy link
Copy Markdown

Claude finished @vkuzo's task in 2m 57s —— View job


Code Review: Add _grouped_mm dispatch handler to MXTensor

  • Read PR diff and changed files
  • Review implementation in mx_tensor.py
  • Review implementation in inference_workflow.py
  • Review test changes
  • Post review feedback

Overall this is a solid PR that follows the established patterns from NVFP4's _grouped_mm handler. The code is well-structured with clean separation between the three modes (weight-only, emulated, auto). A few issues to address:


1. is_swizzled_scales should be deterministic in the config (as vkuzo noted)

File: torchao/prototype/mx_formats/inference_workflow.py:139

use_swizzled = not weight.is_xpu

This runtime decision is fragile — it couples the quantization format to the device at quantization time, and makes the behavior non-reproducible from the config alone. As vkuzo commented, this should be set deterministically in MXDynamicActivationMXWeightConfig. Consider adding an is_swizzled_scales field (or computing it in __post_init__ based on a device hint), so that the config fully describes the quantization format. See PR #4290 for the suggested pattern.

Fix this →


2. Potential issue: mat_a_mx is 2D but scaled_grouped_mm may expect alignment constraints

File: torchao/prototype/mx_formats/mx_tensor.py:1116-1123

mat_a_mx = MXTensor.to_mx(
    mat_a,
    k.elem_dtype,
    k.block_size,
    k.scaling_mode,
    k.kernel_preference,
    is_swizzled_scales=k.is_swizzled_scales,
)

The activation mat_a is a 2D tensor (total_M, K) where total_M is the sum of all group sizes. MXTensor.to_mx quantizes the full 2D tensor at once. However, for scaled_grouped_mm, the hardware typically requires that each group's token count is aligned (e.g., multiple of 32 for MXFP8). There's no validation or padding here — if individual groups have sizes not aligned to the block requirements, this could produce incorrect results silently at runtime.

The NVFP4 handler (for comparison) uses mslk_quantize_nvfp4_stacked which handles per-group quantization. Consider adding a note or assertion about alignment requirements for the per-group token sizes.


3. b_scale.contiguous() may silently change swizzled scale layout

File: torchao/prototype/mx_formats/mx_tensor.py:1141

b_scale = mat_b.scale.contiguous()

If mat_b.is_swizzled_scales is True, the scale tensor has a specific swizzled memory layout. Calling .contiguous() on a transposed swizzled scale may re-layout it in a way that breaks the swizzle pattern expected by scaled_grouped_mm. The NVFP4 handler explicitly does mat_b.scale.transpose(-2, -1).flatten(1) to manage this. You should verify that .contiguous() preserves the correct byte ordering for the swizzled case.


4. use_swizzled logic uses or — should both sides always agree?

File: torchao/prototype/mx_formats/mx_tensor.py:1138

use_swizzled = mat_a_mx.is_swizzled_scales or mat_b.is_swizzled_scales

Using or means if only one side is swizzled, we still pass SwizzleType.SWIZZLE_32_4_4 for both swizzle_a and swizzle_b. This seems incorrect — if mat_a_mx is swizzled but mat_b is not (or vice versa), passing the same swizzle type for both would produce wrong results. Consider:

  1. Asserting both sides agree: assert mat_a_mx.is_swizzled_scales == mat_b.is_swizzled_scales
  2. Or passing swizzle_a and swizzle_b independently based on each tensor's flag.

In practice they should always match since both come from the same config, but an assertion would make this invariant explicit and catch bugs early.


5. mx_transpose handler: no return_and_correct_aliasing call

File: torchao/prototype/mx_formats/mx_tensor.py:900-916

The mx_transpose handler creates a new MXTensor directly without calling return_and_correct_aliasing, unlike other handlers (mx_slice, mx_select). While aten.transpose.int returns a view (not a copy), the aten.t.default handler (line 884) also doesn't use return_and_correct_aliasing — so this is consistent. Just flagging in case the framework requires it for correctness with FunctionalTensor dispatch.


6. Test: test_grouped_mm_mx_dynamic_activation uses torch.accelerator.current_accelerator()

File: test/prototype/mx_formats/test_inference_workflow.py:498

device = torch.accelerator.current_accelerator()

This differs from the other tests in this file (including test_grouped_mm_nvfp4) which use device = "cuda" directly. While torch.accelerator.current_accelerator() is more portable, it returns a torch.device object rather than a string. This inconsistency is minor but worth noting for maintainability.


7. Minor: transpose handler asserts 3D only — might want to support 2D too

File: torchao/prototype/mx_formats/mx_tensor.py:903

assert len(old.shape) == 3, f"unsupported rank {len(old.shape)}"

The existing mx_t handler covers the 2D case via aten.t.default. However, if someone calls tensor.transpose(0, 1) on a 2D MXTensor, it would hit aten.transpose.int and fail with this assert. This is a design choice — just noting it as a potential user-facing rough edge.


Summary

The core logic is correct and follows established patterns well. The main actionable items are:

  1. Must fix: Move is_swizzled_scales to be config-driven (vkuzo's existing comment)
  2. Should fix: Assert mat_a_mx.is_swizzled_scales == mat_b.is_swizzled_scales instead of using or
  3. Should verify: That .contiguous() on a transposed swizzled scale doesn't break the swizzle pattern

@vkuzo

vkuzo commented May 28, 2026

Copy link
Copy Markdown
Contributor

overall this looks good, once all the comments are implemented we should land this

@vkuzo

vkuzo commented May 28, 2026

Copy link
Copy Markdown
Contributor

1, 2, 3, 4, 5 from claude review are real and should be handled!

@xiaowangintel

Copy link
Copy Markdown
Collaborator Author

@claude Please explain the second point of the review.

@claude

claude Bot commented May 29, 2026

Copy link
Copy Markdown

Claude finished @xiaowangintel's task in 1m 4s —— View job


Explanation of Review Point #2

  • Read relevant code for context
  • Post explanation

Review point #2 concerns a potential alignment issue with how the activation tensor is quantized for scaled_grouped_mm.

The setup: In mx_tensor.py:1116-1123, when the AUTO path is taken, the full activation mat_a — a 2D tensor of shape (total_M, K) where total_M is the sum of all group token counts — is quantized in one shot via MXTensor.to_mx(mat_a, ...).

What to_mx does internally: It reshapes the last dimension into blocks of size block_size (typically 32), computes a per-block scale, and quantizes. Crucially, it operates on the entire 2D tensor at once, treating it as a flat (total_M, K) matrix. The only alignment check is that K % block_size == 0 (line 226-228 of mx_tensor.py).

The concern: scaled_grouped_mm performs separate GEMMs for each group, where each group's tokens are defined by the offs tensor. For example, if offs = [5, 12, 20], group 0 has 5 tokens, group 1 has 7, group 2 has 8. Hardware implementations of scaled_grouped_mm may require each group's token count (the M-dimension per group) to be aligned to some boundary (e.g., multiples of 16 or 32) for the block-scaled quantization to be applied correctly per-group.

Since MXTensor.to_mx quantizes the entire (total_M, K) tensor without awareness of group boundaries, the block scales span across group boundaries in the M-dimension. This is fine for the K-dimension (scales are computed per row along K), but if the hardware expects per-group alignment of the M-dimension, this could produce incorrect results silently (no error, just wrong numbers).

Contrast with NVFP4: The NVFP4 handler (nvfp4_tensor.py:728) uses mslk_quantize_nvfp4_stacked(m_sizes, mat_a, ...) which explicitly receives the per-group sizes (m_sizes) and quantizes each group's tokens separately with proper alignment/padding. This is a group-aware quantization.

Practical impact: For MX formats with BlockWise1x32 scaling (scales computed along K in groups of 32), the scales are row-wise — each row gets its own set of scales along the K dimension. Since rows don't cross group boundaries (each group is a contiguous set of rows), the quantization is actually correct in this specific case. The M-dimension alignment would only matter if the hardware kernel itself requires padded group sizes, which depends on the GPU and driver version.

TL;DR: The review point flags that unlike NVFP4 (which quantizes per-group), this code quantizes the full activation at once. For BlockWise1x32 scaling this is likely correct because scales are per-row-along-K, but it's worth adding a comment or assertion documenting the alignment assumptions, in case future hardware or scaling modes have stricter per-group requirements.


@xiaowangintel

Copy link
Copy Markdown
Collaborator Author

@claude Based on an analysis of https://github.com/pytorch/pytorch/blob/9e7520818c7d72bcbb813fce421628ba3bf683ae/aten/src/ATen/native/cuda/GroupedBlas.cpp, what kind of checks should be added to address the second point raised in the review?

@claude

claude Bot commented May 29, 2026

Copy link
Copy Markdown

Claude finished @xiaowangintel's task in 9m 45s —— View job


Analysis of GroupedBlas.cpp alignment constraints for review point #2

  • Fetch and analyze PyTorch's GroupedBlas.cpp source
  • Identify alignment/validation constraints in scaled_grouped_mm
  • Cross-reference with existing handlers in this repo (FP8, MXFP8, NVFP4)
  • Recommend specific checks to add to the MXTensor handler

While I couldn't directly download the file from GitHub due to network restrictions in this CI environment, I was able to reconstruct the full set of alignment constraints by analyzing the existing code in this repo that calls into torch._scaled_grouped_mm and F.scaled_grouped_mm — specifically:

  • torchao/prototype/moe_training/fp8_grouped_mm.py (FP8 rowwise path, alignment=16)
  • torchao/prototype/moe_training/mxfp8_grouped_mm.py (MXFP8 path, alignment=32)
  • torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp (CUDA kernel checks)
  • torchao/prototype/mx_formats/nvfp4_tensor.py (NVFP4 grouped_mm handler)

Constraints from GroupedBlas.cpp and the SM100 CUTLASS kernels

Based on the validation patterns used across this codebase (which mirror what GroupedBlas.cpp enforces at the CUDA level), the following constraints apply to scaled_grouped_mm with MX block-wise (1x32) scaling:

Constraint FP8 Rowwise MXFP8 (BlockWise1x32) Source
K (inner dim) K % 16 == 0 K % 32 == 0 fp8_grouped_mm.py:86, mxfp8_extension.cpp:132
N (output dim) N % 16 == 0 N % 32 == 0 fp8_grouped_mm.py:88, mxfp8_extension.cpp:131
Per-group token count (M_i) M_i % 16 == 0 M_i % 32 == 0 fp8_grouped_mm.py:45, mxfp8_grouped_mm.py:110
Offsets dtype int32 int32 fp8_grouped_mm.py:99
A memory layout Row-major Row-major fp8_grouped_mm.py:109
B memory layout Column-major Column-major (per-group) fp8_grouped_mm.py:112

The critical constraint is that each group's token count must be a multiple of 32 for MXFP8/MX BlockWise1x32 scaled_grouped_mm. This is because:

  1. The SM100 CUTLASS kernel processes tiles of 128 rows × 4 scale columns (from mxfp8_extension.cpp:215-216).
  2. The fused_pad_token_groups_cuda kernel enforces alignment_size == 32 (line 298 of mxfp8_extension.cpp).
  3. The blocked scale rearrangement (mx_block_rearrange_2d_M_groups) pads each group to 128-row boundaries.

Recommended checks to add to mx_grouped_mm

Here are the specific validations that should be added to the AUTO path in mx_grouped_mm (at torchao/prototype/mx_formats/mx_tensor.py:1098):

@implements([aten._grouped_mm.default])
def mx_grouped_mm(func, types, args, kwargs):
    """Handles torch._grouped_mm when weight (mat_b) is an MXTensor."""
    mat_a, mat_b = args[0], args[1]
    offs = args[2] if len(args) > 2 else kwargs.get("offs", None)
    assert isinstance(mat_b, MXTensor)
    assert offs is not None, "offs is required for MXTensor grouped_mm"

    act_quant_kwargs = mat_b.act_quant_kwargs

    if act_quant_kwargs is None:
        return torch._grouped_mm(mat_a, mat_b.dequantize(mat_b.orig_dtype), offs=offs)

    k = act_quant_kwargs
    if k.kernel_preference == KernelPreference.EMULATED:
        return torch._grouped_mm(mat_a, mat_b.dequantize(mat_b.orig_dtype), offs=offs)

    # === ADD THESE CHECKS for the AUTO (hardware) path ===
    block_size = k.block_size  # typically 32

    # 1. K dimension must be divisible by block_size (required for 1x32 block scaling)
    K = mat_a.shape[-1]
    assert K % block_size == 0, (
        f"K dimension ({K}) must be divisible by block_size ({block_size}) "
        f"for MX scaled_grouped_mm"
    )

    # 2. N dimension (weight output dim) must be divisible by block_size
    #    mat_b is transposed (E, K, N) so the last dim is N
    N = mat_b.shape[-1]
    assert N % block_size == 0, (
        f"N dimension ({N}) must be divisible by block_size ({block_size}) "
        f"for MX scaled_grouped_mm"
    )

    # 3. Per-group token counts must be divisible by block_size
    #    offs contains cumulative end indices; group sizes are the diffs
    group_sizes = torch.diff(offs, prepend=offs.new_zeros(1))
    misaligned = group_sizes % block_size
    assert misaligned.sum() == 0, (
        f"All per-group token counts must be divisible by block_size ({block_size}) "
        f"for MX scaled_grouped_mm. Got group sizes: {group_sizes.tolist()}"
    )

    # 4. Offsets must be int32
    assert offs.dtype == torch.int32, (
        f"offs must be int32, got {offs.dtype}"
    )

    # ... rest of the handler (quantize activation, call scaled_grouped_mm)

Why these specific checks?

  1. K % block_size == 0: The MXTensor.to_mx() call reshapes the last dimension into blocks of block_size. If K is not divisible, quantization will fail or produce wrong results. This is already checked inside to_mx (line 227), but checking it early gives a clearer error message in the grouped_mm context.

  2. N % block_size == 0: Required by mxfp8_extension.cpp:131 — "rows must be a multiple of 32". Since the weight is stored as (E, N, K) and transposed, N becomes the "rows" dimension of the weight scales.

  3. Per-group M_i % block_size == 0: This is the most important check and the core of review point Headers and gitignore #2. The SM100 hardware kernel requires each group's row count to be aligned. Without this check, you'll get silent incorrect results. The MXFP8 training handler addresses this via pad_token_groups_for_grouped_mm=True (which pads to multiples of 32), and FP8 pads to multiples of 16. You should either:

    • Add an assertion requiring pre-aligned group sizes (simplest, documented requirement), or
    • Add automatic padding via pad_token_groups / unpad_token_groups (more user-friendly, matches MXFP8 training handler pattern)
  4. offs.dtype == int32: Required by GroupedBlas.cpp — the CUDA kernel expects int32 offsets.

Which approach: assert or pad?

Looking at the two patterns in the repo:

  • NVFP4 (nvfp4_tensor.py:707): Does NOT check alignment — relies on mslk_quantize_nvfp4_stacked handling padding internally (each segment can add up to 127 padding rows).
  • MXFP8 training (mxfp8_grouped_mm.py:205): Provides pad_token_groups_for_grouped_mm flag — if True, pads automatically; if False, assumes pre-aligned input.

For the inference handler, I'd recommend the assertion approach (option 1) since:

  • It's simpler and makes the requirement explicit
  • Users can pad their inputs before calling the model
  • The training handler's padding utilities depend on SM100 CUDA kernels that may not be available in all inference environments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow moe mx topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants