Flydsl mxfp8 quantize#4357
Conversation
Move FLOOR scale derivation, FP8 clamp constants, and chunk-quantize-pack into flydsl_utils.py as module-level helpers (cutedsl pattern). Each kernel now imports them at module level and calls them inline. Net ~256 lines (-22%) removed across the three kernel files.
…olidate tests - FlyDSL dispatcher functions now accept the same kwargs as their cutedsl counterparts (stage_count, blocked_scale_output, offs). Unsupported values raise NotImplementedError; stage_count is accepted but ignored (no TMA on CDNA3). Allows callers to swap backends without changing call signatures. - Move FlyDSL tests from 4 separate test_flydsl_mxfp8_*.py files into test_kernels.py with @pytest.mark.skipif gates, matching the cutedsl test layout convention. Remove the _flydsl_test_utils.py helper module.
Add MXFP8Dim1CastKernelChoice.FLYDSL and dispatch branch in mx_formats utils. Mirrors the CUTEDSL branch for AMD: callers can now pick FlyDSL as the dim1 cast backend just like cutedsl on NVIDIA. The FlyDSL backend currently supports FLOOR scale only; the dispatch asserts on RCEIL with a clear message. All existing MXFP8TrainingRecipe entries hardcode RCEIL, so end-to-end training-recipe coverage waits on either a FLOOR recipe or a FlyDSL RCEIL implementation. Direct MXFP8TrainingOpConfig with scale_calculation_mode=FLOOR exercises the backend today.
Each workgroup now runs up to 4 waves, each handling its own K-tile, mirroring triton_to_mxfp8_dim1's num_warps=4 strategy. The SIMD scheduler can overlap memory latency across waves within a CU. waves_per_block is picked at launch time based on K (largest power of two such that K % (waves * K_TILE) == 0) so small-K shapes still work via the original 1-wave path.
…er launch 32x1 Phase-1 loads now issue dwordx2 (vec_width=VEC=4 bf16 per lane); K_TILE grows to AMD_WAVE_SIZE * VEC = 256 and per-WG LDS is capped at the 64 KB budget. Test K list updated. All three FlyDSL wrappers pass `x` as a raw torch tensor so the JIT runtime takes its bare-pointer fast path (~46 us/launch saved vs the from_dlpack adapter). Outputs use as_strided over flat fp8 storage.
…lookup 32x1 now stacks up to 4 MX blocks (M-direction) per workgroup, with 4 waves cooperating on a (128, 256) bf16 tile sharing one 64 KB WG-level LDS region; each wave owns one of the 4 stacked 32-row blocks. Phase-1 keeps the multi-wave HBM hide; Phase-4 has 4 waves write 4 disjoint 32-byte fragments of the same 128 B HBM cache line concurrently from 4 SIMDs of one CU, letting the L2 controller coalesce into a single line fill (no eviction, no RMW). Closes the 16384^2 0.75x regression: PMC WriteSize drops from 396 MB (+43% over ideal) to 338 MB (+22%, matching Triton exactly). MI355X bf16 end-to-end: 16384^2 0.75x -> 1.19x; 8192^2 1.33x -> 1.50x. Layout is adaptive via _pick_layout(M, K, ...), so M=32/64/128/... all still compile via 1/2/4-wave configurations and existing small-shape numerics tests pass unchanged. Stream fast-path: replace torch.cuda.current_stream() (2.6 us/call) with fx.Stream(torch._C._cuda_getCurrentStream(idx)[0]) (0.2 us) in all three wrappers via a shared current_stream_fast helper in flydsl_utils. Saves ~2.7 us/call; brings 1024^2 and 4096^2 from ~0.89x to within 1% of Triton parity.
Mirrors the cutedsl MXFP8 quantize surface so the FlyDSL backend can be
swapped in without changing call sites. No new perf work — existing FlyDSL
optimizations (multi-wave, dwordx2, M-stacked WG layout) are preserved.
- Symmetrize all three FlyDSL kernel-module signatures (1x32, 32x1, 3D) to
accept the same kwargs as their cutedsl peers; raise NotImplementedError
at the kernel-module boundary (belt-and-suspenders with the dispatcher).
- Add 3D blocked_scale_output + scale_block_k=32 support so the 3D FlyDSL
kernel matches cutedsl_quantize_3d's option surface (2D blocked output
remains follow-up).
- Add bench_flydsl_quantize_2d_{1x32,32x1}.py mirroring the cutedsl bench
files (FLOOR-only; baseline = triton_to_mxfp8_dim{0,1}).
- Extend bench_quantize_3d.py with a flydsl column and gate cuda/cutedsl
baselines on _mxfp8_cuda_kernels_available so the script runs on either
backend.
- Widen the FlyDSL 2D test param grids to mirror cutedsl coverage where
divisibility allows; add test_flydsl_kernel_module_rejects_unsupported_options
asserting NotImplementedError on direct kernel-module imports.
Known limitations (follow-up PRs):
- 1x32 requires K % 2048 == 0 (one wg-iter consumes wave_size * block_size);
K=7168 needs tail-handling.
- scaling_mode="rceil" not implemented (needs sw cvt.rp.satfinite.ue8m0x2.f32).
- 2D blocked_scale_output not implemented (tcgen05-specific).
- offs (token-group offsets) not yet wired into the 2D kernels.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4357
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New FailuresAs of commit 956f102 with merge base 9472d7d ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@zstreet87 could you add the performance achieved (GB/s in both absolute units and % of peak) to the PR summary? |
|
@claude review this PR |
|
Claude finished @vkuzo's task in 3m 31s —— View job PR Review: FlyDSL MXFP8 Quantize (#4357)
Overall AssessmentThis is a well-structured PR that adds the FlyDSL backend for MXFP8 quantize kernels on AMD CDNA. The code is high quality — clean API parity with the existing cutedsl surface, thorough error handling for unsupported options, and solid test coverage. The architecture mirrors the cutedsl pattern exactly, making it easy to follow. Strengths
Issues & SuggestionsMedium Priority
Low Priority / Nits
Questions for Author
SummaryThe PR is well-executed and ready for merge from a code quality perspective. The major limitations (RCEIL, K=7168 tail handling, 2D blocked scale output) are clearly documented and gated with |
| # 1x32 (K-direction): K % 2048 (no tail handling yet). | ||
| # 32x1 (M-direction): M % 32, K % 256 (lane × VEC bf16 dwordx2 loads). | ||
| # 3D (per-expert N-direction): N % 32, K % 256 (lane × VEC bf16 dwordx2 loads). | ||
| # Shape grids mirror cutedsl test coverage (`(128, 8192)` × cutedsl K set | ||
| # for 1x32; `(128, 1024)` × cutedsl K set for 32x1) wherever the FlyDSL | ||
| # divisibility constraints above allow. | ||
| _FLYDSL_1X32_K = (2048, 4096, 8192) | ||
| _FLYDSL_2D_M = (1, 32, 128, 1024, 8192) | ||
| _FLYDSL_32X1_K = (256, 512, 1536, 2048, 4096, 5120, 7168, 8192) | ||
| _FLYDSL_3D_E = (1, 2, 4, 8) | ||
| _FLYDSL_3D_N = (32, 64, 256) | ||
| _FLYDSL_3D_K = (256, 1024, 4096) |
There was a problem hiding this comment.
+1 to claude that this is a lot of cases, can we make sure this is quick to run on supported hardware
There was a problem hiding this comment.
this is still a lot of cases, can we cut down?
|
thanks for adding the benchmarks! I had one nit inline, otherwise lgtm. Can you fix the |
PR review feedback (pytorch#4357): - Fix ruff F401: drop unused E8M0_EXPONENT_BIAS import in 1x32. - Trim test_amd_mx_3d_flydsl_numerics from 288 -> 32 cases (the default-config test above already sweeps full (E,N,K,dtype); this one only needs the (scale_block_k x blocked_scale_output) cross). - Add test_flydsl_2d_32x1_rejects_misaligned_M failure-path test (M in (1, 16, 33)) to cover the M%32==0 contract. - Document M%32==0 + K%256==0 shape requirements on the 32x1 entry. - Comment alloc.finalized=False reset in 32x1 and 3D launch helpers (functools.cache reuses the SmemAllocator across launches). - Document the FLYDSL_3D_* env vars as diagnostic-only knobs. - Harden current_stream_fast: fall back to torch.cuda.current_stream if torch._C._cuda_getCurrentStream is removed in a future release. Diagnostic env vars (mirror the existing FLYDSL_3D_* pattern): - Add FLYDSL_1X32_WAVES_PER_EU and FLYDSL_32X1_WAVES_PER_EU CompilationContext.compile_hints overrides on the launch path. - Add FLYDSL_3D_WAVES_PER_EU and FLYDSL_3D_NT_STORES (the latter threads a cache_modifier through buffer_store on the fp8 data writes; folded into the JIT cache key).
|
hey @vkuzo any movements on my PR? |
| # AMD CDNA3+ counterpart of CUTEDSL via FlyDSL. FLOOR-mode scale only; | ||
| # all existing MXFP8TrainingRecipe entries force RCEIL today, so callers | ||
| # must construct an MXFP8TrainingOpConfig with scale_calculation_mode=FLOOR | ||
| # explicitly to actually exercise this backend until FlyDSL gains RCEIL. |
There was a problem hiding this comment.
do you want to just do RCEIL in software? Here is the code: https://github.com/pytorch/ao/blame/13cd013d65769db5f67cda2c3ea311cb5b70a665/torchao/prototype/mx_formats/mx_tensor.py#L168
|
overall it looks reasonable to me. I would recommend just doing RCEIL in software in your path, the FLOOR method is not widely used anymore and will limit the practical usability of your kernels. I think @danielvegamyhre should accept this one. looks like there are also some ruff errors, could you rebase and make sure ruff passes |
Summary
Adds the FlyDSL backend for MXFP8 quantize kernels on AMD CDNA, mirroring the existing
cuTeDSL surface so the dispatcher can route to either backend by hardware. Three new kernels
(2d_1x32, 2d_32x1, 3d) ship with matching dispatcher entrypoints, custom-op registration,
tests, and benchmarks.
What's in this PR
(torchao/prototype/moe_training/kernels/mxfp8/flydsl_quantize_{2d_1x32,2d_32x1,3d}.py) —
same I/O contract as the cutedsl peers, with FlyDSL-specific perf work already applied
(multi-wave workgroups, buffer_load_dwordx2 widening, M-stacked WG layout).
torch.library custom ops, gated on a _mxfp8_flydsl_kernels_available flag and wired into
MXFP8Dim1CastKernelChoice.
wrappers (stage_count, blocked_scale_output, offs, scale_block_n); unsupported values raise
NotImplementedError at both the dispatcher and kernel-module boundaries.
set).
allows; dispatcher and kernel-module raise contracts both asserted; custom-op registration
checked.
(FLOOR-only, baseline = triton_to_mxfp8_dim{0,1}); bench_quantize_3d.py extended with a
flydsl column and gated so the script runs on either backend.
Known limitations (follow-up PRs)
Supporting K=7168 needs tail-handling in the kernel — see ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
compute_compile_constants + make_pingpong_kloop for the compile-time-loop +
runtime-tail-guard pattern.
to port from ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
tests/kernels/blockscale_gemm_test_utils.py:fp32_to_e8m0: extract exponent byte, add 1 iff
mantissa is nonzero, clamp to [1, 254]. Required for correctness — FLOOR systematically
biases scales down and causes FP8 saturation.
All raise NotImplementedError with descriptive messages.
Test plan
passing on gfx950
rted_options — contract test passes
pported_options — direct-import contract test passes
ops registered
end-to-end on AMD
end-to-end on AMD
backend; columns NaN'd on the absent one
@vkuzo — measured on AMD Instinct MI355X (gfx950), HBM3E spec peak = 8.0 TB/s. Comparison is FlyDSL
vs the next-best AMD path (triton: triton_to_mxfp8_dim0 for 2D 1x32, triton_to_mxfp8_dim1 for 2D
32x1, per-expert triton_to_mxfp8_dim1 loop for 3D — there is no native 3D triton MXFP8 kernel on
AMD, so the loop is the honest fallback). Methodology unchanged from the bench scripts in this PR;
FLOOR scaling; bytes = read(bf16) + write(fp8) + write(scale_e8m0). % of peak is vs spec-sheet 8
TB/s (not measured achievable peak, which is typically ~80–90% of spec — so % of achievable would be
~10–20 pts higher).
2D 1x32 (bench_flydsl_quantize_2d_1x32.py)
2D 32x1 (bench_flydsl_quantize_2d_32x1.py)
3D FLOOR, scale_block_k=1 (bench_quantize_3d.py; triton baseline
is a per-expert triton_to_mxfp8_dim1 loop)
Note: scale_block_k=32 rows are omitted from the 3D comparison — triton_to_mxfp8_dim1 is a 32×1
kernel, not the 32×32 tile, so no apples-to-apples triton baseline exists for the sbk=32 path.
Headline
there).
per-expert triton loop pays heavy launch/rearrange overhead.
this shape).
Numbers are single-run; happy to rerun with N=10 + min/median if that's useful. FlyDSL → cuTeDSL
parity is follow-up work; this PR establishes that the AMD path beats the triton fallback.