Skip to content

Flydsl mxfp8 quantize#4357

Open
zstreet87 wants to merge 12 commits into
pytorch:mainfrom
zstreet87:flydsl-mxfp8-quantize
Open

Flydsl mxfp8 quantize#4357
zstreet87 wants to merge 12 commits into
pytorch:mainfrom
zstreet87:flydsl-mxfp8-quantize

Conversation

@zstreet87
Copy link
Copy Markdown

@zstreet87 zstreet87 commented Apr 30, 2026

Summary

Adds the FlyDSL backend for MXFP8 quantize kernels on AMD CDNA, mirroring the existing
cuTeDSL surface so the dispatcher can route to either backend by hardware. Three new kernels
(2d_1x32, 2d_32x1, 3d) ship with matching dispatcher entrypoints, custom-op registration,
tests, and benchmarks.

What's in this PR

  • 3 FlyDSL kernels
    (torchao/prototype/moe_training/kernels/mxfp8/flydsl_quantize_{2d_1x32,2d_32x1,3d}.py) —
    same I/O contract as the cutedsl peers, with FlyDSL-specific perf work already applied
    (multi-wave workgroups, buffer_load_dwordx2 widening, M-stacked WG layout).
  • Dispatcher integration — mxfp8_quantize_*_flydsl entrypoints in quant.py, registered as
    torch.library custom ops, gated on a _mxfp8_flydsl_kernels_available flag and wired into
    MXFP8Dim1CastKernelChoice.
  • API surface parity — kernel-module signatures accept the same kwargs as the cutedsl
    wrappers (stage_count, blocked_scale_output, offs, scale_block_n); unsupported values raise
    NotImplementedError at both the dispatcher and kernel-module boundaries.
  • 3D blocked scale output + scale_block_k=32 implemented (matches cutedsl 3D's full option
    set).
  • Tests — numerics tests for all 3 kernels mirror cutedsl coverage where divisibility
    allows; dispatcher and kernel-module raise contracts both asserted; custom-op registration
    checked.
  • Benchmarks — new bench_flydsl_quantize_2d_{1x32,32x1}.py mirror the cutedsl bench files
    (FLOOR-only, baseline = triton_to_mxfp8_dim{0,1}); bench_quantize_3d.py extended with a
    flydsl column and gated so the script runs on either backend.

Known limitations (follow-up PRs)

  • 1x32 requires K % 2048 == 0 (one workgroup-iteration consumes wave_size × block_size).
    Supporting K=7168 needs tail-handling in the kernel — see ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
    compute_compile_constants + make_pingpong_kloop for the compile-time-loop +
    runtime-tail-guard pattern.
  • scaling_mode="rceil" not implemented. The 3-op exponent-roundup pattern is straightforward
    to port from ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
    tests/kernels/blockscale_gemm_test_utils.py:fp32_to_e8m0: extract exponent byte, add 1 iff
    mantissa is nonzero, clamp to [1, 254]. Required for correctness — FLOOR systematically
    biases scales down and causes FP8 saturation.
  • 2D blocked_scale_output not implemented (tcgen05-specific scale layout).
  • offs (token-group offsets) not yet wired into the 2D kernels.
  • 3D scale_block_n != 32 not implemented.

All raise NotImplementedError with descriptive messages.

Test plan

  • pytest test/prototype/moe_training/test_kernels.py -k "flydsl or amd_mx_3d_flydsl" — 458
    passing on gfx950
  • pytest test/prototype/moe_training/test_kernels.py::test_flydsl_dispatcher_rejects_unsuppo
    rted_options — contract test passes
  • pytest test/prototype/moe_training/test_kernels.py::test_flydsl_kernel_module_rejects_unsu
    pported_options — direct-import contract test passes
  • pytest test/prototype/moe_training/test_kernels.py::test_flydsl_custom_ops_registered —
    ops registered
  • python benchmarks/prototype/moe_training/mxfp8/bench_flydsl_quantize_2d_1x32.py — runs
    end-to-end on AMD
  • python benchmarks/prototype/moe_training/mxfp8/bench_flydsl_quantize_2d_32x1.py — runs
    end-to-end on AMD
  • python benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py — runs on either
    backend; columns NaN'd on the absent one

@vkuzo — measured on AMD Instinct MI355X (gfx950), HBM3E spec peak = 8.0 TB/s. Comparison is FlyDSL
vs the next-best AMD path (triton: triton_to_mxfp8_dim0 for 2D 1x32, triton_to_mxfp8_dim1 for 2D
32x1, per-expert triton_to_mxfp8_dim1 loop for 3D — there is no native 3D triton MXFP8 kernel on
AMD, so the loop is the honest fallback). Methodology unchanged from the bench scripts in this PR;
FLOOR scaling; bytes = read(bf16) + write(fp8) + write(scale_e8m0). % of peak is vs spec-sheet 8
TB/s (not measured achievable peak, which is typically ~80–90% of spec — so % of achievable would be
~10–20 pts higher).

2D 1x32 (bench_flydsl_quantize_2d_1x32.py)

shape (M, K) flydsl GB/s (%peak) triton GB/s (%peak) speedup
(8192, 2048) 1746 (21.8%) 497 (6.2%) 3.52×
(32768, 2048) 3204 (40.1%) 1763 (22.0%) 1.82×
(131072, 2048) 4157 (52.0%) 3360 (42.0%) 1.24×

2D 32x1 (bench_flydsl_quantize_2d_32x1.py)

shape (M, K) flydsl GB/s (%peak) triton GB/s (%peak) speedup
(8192, 2048) 1781 (22.3%) 518 (6.5%) 3.44×
(8192, 7168) 2478 (31.0%) 1638 (20.5%) 1.51×
(32768, 2048) 2289 (28.6%) 2171 (27.1%) 1.05×
(32768, 7168) 3328 (41.6%) 2908 (36.4%) 1.14×
(131072, 2048) 2997 (37.5%) 3068 (38.4%) 0.98×
(131072, 7168) 3772 (47.2%) 3083 (38.5%) 1.22×

3D FLOOR, scale_block_k=1 (bench_quantize_3d.py; triton baseline
is a per-expert triton_to_mxfp8_dim1 loop)

shape (E, N, K) flydsl GB/s (%peak) triton GB/s (%peak) speedup
(1, 8192, 5120) 2428 (30.4%) 1352 (16.9%) 1.80×
(1, 7168, 2048) 1344 (16.8%) 461 (5.8%) 2.91×
(8, 8192, 5120) 2868 (35.8%) 1338 (16.7%) 2.14×
(8, 7168, 2048) 2433 (30.4%) 478 (6.0%) 5.09×
(32, 7168, 2048) 3053 (38.2%) 497 (6.2%) 6.15×
(32, 8192, 5120) 2877 (36.0%) 1378 (17.2%) 2.09×

Note: scale_block_k=32 rows are omitted from the 3D comparison — triton_to_mxfp8_dim1 is a 32×1
kernel, not the 32×32 tile, so no apples-to-apples triton baseline exists for the sbk=32 path.

Headline

  • 2D peaks at 47–52% of HBM3E spec peak at the largest DSV3 shapes (1.24× / 1.22× over triton
    there).
  • 2D smaller-M shapes are 1.5–3.5× over triton.
  • 3D is 1.80–6.15× over triton across every shape; biggest wins on small-K shapes where the
    per-expert triton loop pays heavy launch/rearrange overhead.
  • One parity row: 2D 32x1 (131072, 2048) lands at 0.98× (both kernels saturate HBM at large M for
    this shape).

Numbers are single-run; happy to rerun with N=10 + min/median if that's useful. FlyDSL → cuTeDSL
parity is follow-up work; this PR establishes that the AMD path beats the triton fallback.

Move FLOOR scale derivation, FP8 clamp constants, and chunk-quantize-pack
into flydsl_utils.py as module-level helpers (cutedsl pattern). Each
kernel now imports them at module level and calls them inline. Net
~256 lines (-22%) removed across the three kernel files.
…olidate tests

- FlyDSL dispatcher functions now accept the same kwargs as their cutedsl
  counterparts (stage_count, blocked_scale_output, offs). Unsupported values
  raise NotImplementedError; stage_count is accepted but ignored (no TMA on
  CDNA3). Allows callers to swap backends without changing call signatures.
- Move FlyDSL tests from 4 separate test_flydsl_mxfp8_*.py files into
  test_kernels.py with @pytest.mark.skipif gates, matching the cutedsl
  test layout convention. Remove the _flydsl_test_utils.py helper module.
Add MXFP8Dim1CastKernelChoice.FLYDSL and dispatch branch in mx_formats
utils. Mirrors the CUTEDSL branch for AMD: callers can now pick FlyDSL
as the dim1 cast backend just like cutedsl on NVIDIA.

The FlyDSL backend currently supports FLOOR scale only; the dispatch
asserts on RCEIL with a clear message. All existing MXFP8TrainingRecipe
entries hardcode RCEIL, so end-to-end training-recipe coverage waits on
either a FLOOR recipe or a FlyDSL RCEIL implementation. Direct
MXFP8TrainingOpConfig with scale_calculation_mode=FLOOR exercises the
backend today.
Each workgroup now runs up to 4 waves, each handling its own K-tile,
mirroring triton_to_mxfp8_dim1's num_warps=4 strategy. The SIMD
scheduler can overlap memory latency across waves within a CU.

waves_per_block is picked at launch time based on K (largest power of
two such that K % (waves * K_TILE) == 0) so small-K shapes still work
via the original 1-wave path.
…er launch

32x1 Phase-1 loads now issue dwordx2 (vec_width=VEC=4 bf16 per lane);
K_TILE grows to AMD_WAVE_SIZE * VEC = 256 and per-WG LDS is capped at
the 64 KB budget. Test K list updated.

All three FlyDSL wrappers pass `x` as a raw torch tensor so the JIT
runtime takes its bare-pointer fast path (~46 us/launch saved vs the
from_dlpack adapter). Outputs use as_strided over flat fp8 storage.
…lookup

32x1 now stacks up to 4 MX blocks (M-direction) per workgroup, with 4
waves cooperating on a (128, 256) bf16 tile sharing one 64 KB WG-level
LDS region; each wave owns one of the 4 stacked 32-row blocks. Phase-1
keeps the multi-wave HBM hide; Phase-4 has 4 waves write 4 disjoint
32-byte fragments of the same 128 B HBM cache line concurrently from
4 SIMDs of one CU, letting the L2 controller coalesce into a single
line fill (no eviction, no RMW).

Closes the 16384^2 0.75x regression: PMC WriteSize drops from 396 MB
(+43% over ideal) to 338 MB (+22%, matching Triton exactly). MI355X
bf16 end-to-end: 16384^2 0.75x -> 1.19x; 8192^2 1.33x -> 1.50x. Layout
is adaptive via _pick_layout(M, K, ...), so M=32/64/128/... all still
compile via 1/2/4-wave configurations and existing small-shape numerics
tests pass unchanged.

Stream fast-path: replace torch.cuda.current_stream() (2.6 us/call)
with fx.Stream(torch._C._cuda_getCurrentStream(idx)[0]) (0.2 us) in all
three wrappers via a shared current_stream_fast helper in flydsl_utils.
Saves ~2.7 us/call; brings 1024^2 and 4096^2 from ~0.89x to within 1%
of Triton parity.
Mirrors the cutedsl MXFP8 quantize surface so the FlyDSL backend can be
swapped in without changing call sites. No new perf work — existing FlyDSL
optimizations (multi-wave, dwordx2, M-stacked WG layout) are preserved.

- Symmetrize all three FlyDSL kernel-module signatures (1x32, 32x1, 3D) to
  accept the same kwargs as their cutedsl peers; raise NotImplementedError
  at the kernel-module boundary (belt-and-suspenders with the dispatcher).
- Add 3D blocked_scale_output + scale_block_k=32 support so the 3D FlyDSL
  kernel matches cutedsl_quantize_3d's option surface (2D blocked output
  remains follow-up).
- Add bench_flydsl_quantize_2d_{1x32,32x1}.py mirroring the cutedsl bench
  files (FLOOR-only; baseline = triton_to_mxfp8_dim{0,1}).
- Extend bench_quantize_3d.py with a flydsl column and gate cuda/cutedsl
  baselines on _mxfp8_cuda_kernels_available so the script runs on either
  backend.
- Widen the FlyDSL 2D test param grids to mirror cutedsl coverage where
  divisibility allows; add test_flydsl_kernel_module_rejects_unsupported_options
  asserting NotImplementedError on direct kernel-module imports.

Known limitations (follow-up PRs):
- 1x32 requires K % 2048 == 0 (one wg-iter consumes wave_size * block_size);
  K=7168 needs tail-handling.
- scaling_mode="rceil" not implemented (needs sw cvt.rp.satfinite.ue8m0x2.f32).
- 2D blocked_scale_output not implemented (tcgen05-specific).
- offs (token-group offsets) not yet wired into the 2D kernels.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4357

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures

As of commit 956f102 with merge base 9472d7d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2026
@zstreet87 zstreet87 marked this pull request as draft April 30, 2026 18:35
@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented May 4, 2026

@zstreet87 could you add the performance achieved (GB/s in both absolute units and % of peak) to the PR summary?

@zstreet87 zstreet87 marked this pull request as ready for review May 4, 2026 17:37
@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented May 4, 2026

@claude review this PR

@claude
Copy link
Copy Markdown

claude Bot commented May 4, 2026

Claude finished @vkuzo's task in 3m 31s —— View job


PR Review: FlyDSL MXFP8 Quantize (#4357)

  • Read kernel implementations (flydsl_quantize_2d_1x32, 2d_32x1, 3d)
  • Read utility module (flydsl_utils.py)
  • Read dispatcher integration (quant.py changes)
  • Read test file changes
  • Read benchmark files
  • Read config/utils changes
  • Provide review feedback

Overall Assessment

This is a well-structured PR that adds the FlyDSL backend for MXFP8 quantize kernels on AMD CDNA. The code is high quality — clean API parity with the existing cutedsl surface, thorough error handling for unsupported options, and solid test coverage. The architecture mirrors the cutedsl pattern exactly, making it easy to follow.

Strengths

  1. Clean layering: validation at both the dispatcher (quant.py) and kernel-module boundaries, so direct imports also get contract checks.
  2. Shared utility module (flydsl_utils.py): the floor_scale_and_inv_scale, make_fp8_clamp_vectors, and quantize_pack_chunk_to_i32 helpers are well-documented and eliminate duplication across the 3 kernels.
  3. functools.cache on compilation functions: JIT-compiled kernels are cached by (dtype, scaling_mode, shape) — appropriate for production use.
  4. Comprehensive tests: bit-exact numerics checks against to_mx reference, contract tests for unsupported options, input validation tests, and custom-op registration checks.
  5. Performance: the benchmark numbers in the PR description are impressive (1.2–6x over triton fallback).

Issues & Suggestions

Medium Priority

  1. _pick_layout in 32x1 does not validate M % m_tile == 0 at the call siteflydsl_quantize_2d_32x1.py:316 only asserts block_size == BLOCK_SIZE (i.e. M % 32 == 0), but _pick_layout may select m_tile=128 (stack=4). If M=32 and stack tries 4→2→1, it'll fall through to stack=1 which works, but the test grid only goes down to M=32. Consider adding M=1 to _FLYDSL_2D_M or documenting that 32x1 requires M≥32 (the cutedsl counterpart may have the same constraint).

  2. Unbounded functools.cache growth_compile_quantize_2d_1x32 is keyed on (input_dtype_name, scaling_mode, K) and _compile_quantize_3d adds (N, K, waves_per_block, scale_block_k, blocked_scale_output). In long-running training with varying batch sizes, the cache can grow indefinitely. This is unlikely to be a real problem today but worth noting — an lru_cache(maxsize=32) or similar would bound it. Low priority since the cutedsl side does the same thing.

  3. current_stream_fast uses private APItorch._C._cuda_getCurrentStream (flydsl_utils.py:106) is internal and could break across PyTorch versions. The # ~0.2 µs vs 2.6 µs perf rationale is sound for small shapes, but consider adding a version guard or a fallback to torch.cuda.current_stream() if the private API disappears.

  4. os.environ.get("FLYDSL_3D_FORCE_WAVES") in hot pathflydsl_quantize_3d.py:383-391: os.environ.get is called every time mxfp8_quantize_flydsl_3d runs (not just at compile time). This is likely negligible but would be cleaner to read once at module-load or inside the cached _compile_quantize_3d.

Low Priority / Nits

  1. Test parametrization is largetest_amd_mx_3d_flydsl_numerics has 4 × 3 × 3 × 2 × 2 × 2 = 288 parametrized cases. On CI without MI300/MI350 hardware they're all skipped anyway, but if AMD CI runs them this could be slow. This is fine if gating is reliable.

  2. _K_PER_CHUNK comment in 1x32 (flydsl_quantize_2d_1x32.py:42) says "AMD_WAVE_SIZE lanes × BLOCK_SIZE elements/block" = 2048, which is correct. The K divisibility requirement (K % 2048 == 0) is well-documented in the PR description's "Known limitations" section.

  3. The error message in _check_flydsl_unsupported_params (quant.py:1621) says "blocked_scale_output=True is tcgen05-specific to SM 10.x" — this is accurate for the 2D paths, but the 3D FlyDSL kernel does implement it. The separation into _check_flydsl_unsupported_params (2D) vs _check_flydsl_3d_unsupported_params (3D) handles this correctly; just noting the distinction is non-obvious to a reader.

  4. SmemAllocator usage — In both 32x1 and 3D kernels, alloc._align(alloc.ptr, 16) uses a private API of the FlyDSL SmemAllocator. This is fragile to FlyDSL internal changes but may be the only supported way. Worth tracking FlyDSL API stability for this pattern.

  5. alloc.finalized = False in launch functions (flydsl_quantize_2d_32x1.py:251, flydsl_quantize_3d.py:313) — re-setting finalized before alloc.finalize() in the @flyc.jit launcher is a workaround pattern. A brief comment explaining why this reset is needed would help future maintainers (is it because functools.cache reuses the same allocator object across calls?).

Questions for Author

  • Is there a reason M=1 is in _FLYDSL_2D_M for 1x32 tests but not exercised for 32x1? The 32x1 kernel requires M % 32 == 0, so M=1 would fail — just want to confirm the test grids were intentionally filtered.
  • The 3D kernel's _load_chunks_and_amax is called twice in the scale_block_k == 32 path (once for amax accumulation, once for the actual quantization). The comment says "Re-read chunks in the second pass to keep at most CHUNKS_PER_BLOCK live." Is the LDS re-read actually cheaper than keeping 8×VEC=32 f32 vectors live (128 VGPRs per k_local × 4 k_local = 512 VGPRs total)? Presumably yes given gfx950's register file, but curious if this was measured.

Summary

The PR is well-executed and ready for merge from a code quality perspective. The major limitations (RCEIL, K=7168 tail handling, 2D blocked scale output) are clearly documented and gated with NotImplementedError. No blocking issues found.


Comment on lines +936 to +947
# 1x32 (K-direction): K % 2048 (no tail handling yet).
# 32x1 (M-direction): M % 32, K % 256 (lane × VEC bf16 dwordx2 loads).
# 3D (per-expert N-direction): N % 32, K % 256 (lane × VEC bf16 dwordx2 loads).
# Shape grids mirror cutedsl test coverage (`(128, 8192)` × cutedsl K set
# for 1x32; `(128, 1024)` × cutedsl K set for 32x1) wherever the FlyDSL
# divisibility constraints above allow.
_FLYDSL_1X32_K = (2048, 4096, 8192)
_FLYDSL_2D_M = (1, 32, 128, 1024, 8192)
_FLYDSL_32X1_K = (256, 512, 1536, 2048, 4096, 5120, 7168, 8192)
_FLYDSL_3D_E = (1, 2, 4, 8)
_FLYDSL_3D_N = (32, 64, 256)
_FLYDSL_3D_K = (256, 1024, 4096)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to claude that this is a lot of cases, can we make sure this is quick to run on supported hardware

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still a lot of cases, can we cut down?

@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented May 5, 2026

thanks for adding the benchmarks! I had one nit inline, otherwise lgtm. Can you fix the ruff error? also @danielvegamyhre should chime in as a lot of this is going into moe_training.

PR review feedback (pytorch#4357):
- Fix ruff F401: drop unused E8M0_EXPONENT_BIAS import in 1x32.
- Trim test_amd_mx_3d_flydsl_numerics from 288 -> 32 cases (the
  default-config test above already sweeps full (E,N,K,dtype); this
  one only needs the (scale_block_k x blocked_scale_output) cross).
- Add test_flydsl_2d_32x1_rejects_misaligned_M failure-path test
  (M in (1, 16, 33)) to cover the M%32==0 contract.
- Document M%32==0 + K%256==0 shape requirements on the 32x1 entry.
- Comment alloc.finalized=False reset in 32x1 and 3D launch helpers
  (functools.cache reuses the SmemAllocator across launches).
- Document the FLYDSL_3D_* env vars as diagnostic-only knobs.
- Harden current_stream_fast: fall back to torch.cuda.current_stream
  if torch._C._cuda_getCurrentStream is removed in a future release.

Diagnostic env vars (mirror the existing FLYDSL_3D_* pattern):
- Add FLYDSL_1X32_WAVES_PER_EU and FLYDSL_32X1_WAVES_PER_EU
  CompilationContext.compile_hints overrides on the launch path.
- Add FLYDSL_3D_WAVES_PER_EU and FLYDSL_3D_NT_STORES (the latter
  threads a cache_modifier through buffer_store on the fp8 data
  writes; folded into the JIT cache key).
@zstreet87
Copy link
Copy Markdown
Author

hey @vkuzo any movements on my PR?

# AMD CDNA3+ counterpart of CUTEDSL via FlyDSL. FLOOR-mode scale only;
# all existing MXFP8TrainingRecipe entries force RCEIL today, so callers
# must construct an MXFP8TrainingOpConfig with scale_calculation_mode=FLOOR
# explicitly to actually exercise this backend until FlyDSL gains RCEIL.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented May 15, 2026

overall it looks reasonable to me. I would recommend just doing RCEIL in software in your path, the FLOOR method is not widely used anymore and will limit the practical usability of your kernels. I think @danielvegamyhre should accept this one.

looks like there are also some ruff errors, could you rebase and make sure ruff passes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants