Tags: pytorch/ao
Tags
[ROCm] MXFP8 MoE: persistent grouped kernel, F.scaled_mm dense, corre…
…ctness + tests
- Persistent grouped-MM kernel (grid = num_CUs * ctas_per_cu); walks experts
in-kernel with a global tile counter. Avoids silent row-dropping under
(M+E-1)//E bounds and keeps the dispatcher torch.compile-clean.
- Dense MXFP8 path: dispatch to F.scaled_mm with BlockWise1x32.
- Wgrad: retune default tile to (BN=256, BK=256, BM=64, nw=8).
- K-tail and scale-tail masking; m_mask bounded by group_end and global M.
- torch.compile: register pad/unpad helpers as torch.library.custom_op;
skip nonstrict_trace on ROCm.
- mx_linear / MXFP8TrainingOpConfig: drop is_ROCM() auto-switch; expose
mxfp8_dim1_cast_kernel_choice as explicit arg (CUDA default).
- bench_2d_3d_grouped_gemm.py: run on MI350+ via bench_mxfp8_grouped_mm_rocm;
fix flops formula = 2 * M * N * K.
Tested on MI355X / gfx950 / ROCm 7.1 / Triton 3.7:
Accuracy: test/prototype/moe_training/test_mxfp8_grouped_mm.py
-> 129 passed, 16 skipped.
SQNR margins: out >= 27.6 (>= 27), in_grad >= 25.2 (>= 25),
w_grad >= 25.5 (>= 24).
Perf: benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py
Remove unnecessary Torch dependency from ExecuTorch ops build The ExecuTorch ops target defines TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH=1, which means shared kernel headers use ExecuTorch includes, not Torch includes. The find_package(Torch) and TORCH_INCLUDE_DIRS were therefore unnecessary and caused CMake configure failures in standalone ExecuTorch builds where PyTorch is not discoverable via find_package.
Remove unnecessary Torch dependency from ExecuTorch ops build The ExecuTorch ops target defines TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH=1, which means shared kernel headers use ExecuTorch includes, not Torch includes. The find_package(Torch) and TORCH_INCLUDE_DIRS were therefore unnecessary and caused CMake configure failures in standalone ExecuTorch builds where PyTorch is not discoverable via find_package.
Remove unnecessary Torch dependency from ExecuTorch ops build The ExecuTorch ops target defines TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH=1, which means shared kernel headers use ExecuTorch includes, not Torch includes. The find_package(Torch) and TORCH_INCLUDE_DIRS were therefore unnecessary and caused CMake configure failures in standalone ExecuTorch builds where PyTorch is not discoverable via find_package.
Remove unnecessary Torch dependency from ExecuTorch ops build The ExecuTorch ops target defines TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH=1, which means shared kernel headers use ExecuTorch includes, not Torch includes. The find_package(Torch) and TORCH_INCLUDE_DIRS were therefore unnecessary and caused CMake configure failures in standalone ExecuTorch builds where PyTorch is not discoverable via find_package.
PreviousNext