Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

  1. fix block quantization op without global scaling factor;
  2. python example with block quantization op followed by layout and grouped_mm.

Note that: I'm pretty nervous about the numeric of block quantization op. I left the check here as relaxed as the existing python ops, as I can't get it to perform close to what decomposed version does.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 requested a review from protonu December 16, 2025 18:03
@github-actions
Copy link

github-actions bot commented Dec 16, 2025

Review updated until commit 090e395

Auto-merge Status

✅ Internal CI is finished
❌ No failed checks (nvfuser-ci/jit_binary_distributed_tests_20_GB200, nvfuser-ci/jit_python_distributed_tests_20_H100)
✅ PR is mergeable
ℹ️ PR mergeable_state: unstable

Description

  • Replace manual quantization logic with built-in nv_block_quantize operation

  • Add comprehensive test for block quantization + layout + grouped_mm pipeline

  • Fix block quantization op without global scaling factor

  • Validate numerical accuracy against decomposed reference implementation

Changes walkthrough

Relevant files
Enhancement
benchmark_inference.py
Replace manual quantization with built-in operation           

benchmarks/python/benchmark_inference.py

  • Remove unused constants FLOAT4_E2M1_MAX, FLOAT8_E4M3_EPS,
    FLOAT8_E4M3_MAX
  • Replace manual quantization logic with single fd.ops.nv_block_quantize
    call
  • Simplify activation preprocessing from 8 lines to 2 lines
  • +1/-21   
    Tests
    test_narrow_precision.py
    Add comprehensive block quantization test                               

    tests/python/direct/test_narrow_precision.py

  • Add test_block_quantize_op_and_layout_op function with comprehensive
    validation
  • Test block quantization followed by layout and grouped_mm operations
  • Compare nvfuser implementation against decomposed reference
  • Include relaxed error checking for numerical stability validation
  • +174/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Numerical Accuracy Concerns

    The test author expresses nervousness about block quantization op numerics and uses very relaxed tolerances (rtol=1e-1, atol=1e-2). The test includes additional validation for large differences (>5.0) but the overall numerical accuracy appears questionable. The PR description mentions being unable to get close to decomposed version performance. This needs thorough validation to ensure the new nv_block_quantize op produces acceptable results for production use.

    def test_block_quantize_op_and_layout_op(
        nvfuser_direct_test,
        config,
        tokens_per_expert_neg_one,
        out_dtype,
    ):
        BLOCK_SIZE = 16
    
        # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor
        m, n, k = config
        assert k % 64 == 0
        tokens_per_expert = list(tokens_per_expert_neg_one)
        tokens_per_expert.append(m - sum(tokens_per_expert))
        g = len(tokens_per_expert)
    
        mat1 = torch.randn((m, k), dtype=torch.float32, device="cuda:0")
        # format is g, n, k instead of g, k, n
        mat2 = torch.randn((g, n, k), dtype=torch.float32, device="cuda:0")
    
        offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0")
    
        # prepare quantization for mat2
        mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0")
        scale2 = torch.empty(
            (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0"
        )
    
        acc_tokens = 0
        rounded_acc_tokens = 0
        mat2_scaled = torch.empty(
            (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0"
        )
    
        for i in range(g):
            global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max()
            offsets[i] = acc_tokens
            blockscale_offsets[i] = rounded_acc_tokens
            acc_tokens += tokens_per_expert[i]
            # Note: we technically don't need to round up, since k is perfectly sized.
            rounded_acc_tokens += round_up(tokens_per_expert[i], 128)
    
            problem_sizes[i][0] = tokens_per_expert[i]
            problem_sizes[i][1] = n
            problem_sizes[i][2] = k
    
            scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf)
            mat2_gs[i] = 1.0 / global_sf
            mat2_scaled[i] = scaled_mat2_i
            scale2[i] = linear_to_swizzled_128_4(bs_mat2_i)
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            mat1 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float,
                is_cpu=False,
            )
            mat2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[2, 0, 1],
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float8_e4m3fn,
                is_cpu=False,
            )
            alpha = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            problem_sizes = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            blockscale_offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
    
            # Note: the decomposed quantization seems to give much better numerics.
            # quantization math with nv_block_quantize op
            fp4_mat1, fp8_scale1 = fd.ops.nv_block_quantize(mat1)
    
            # swizzle & pad block sf
            layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(
                fp8_scale1, offsets, blockscale_offsets
            )
            out = fd.ops.cutlass_nvfp4_grouped_mm(
                fp4_mat1,
                mat2,
                layout_fp8_scale1,
                scale2,
                alpha,
                problem_sizes,
                offsets,
                blockscale_offsets,
                DataType.BFloat16,
            )
            fd.add_output(out)
    
        inputs = [
            mat1,
            mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2),
            scale2,
            mat2_gs,
            problem_sizes,
            offsets,
            blockscale_offsets,
        ]
    
        o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
        # quantization for activation is needed for reference.
        # note: following sglang implementation, not computing global scaling factor for mat1
        #       similarly, we don't need to apply mat1_gs to alpha
        mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0")
        mat1_fp4, scale1 = activation_scale_to_nvfp4(
            mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE
        )
        o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0")
        for i in range(g):
            l = offsets[i]
            l_sf = blockscale_offsets[i]
            if i == g - 1:
                r = m
            else:
                r = offsets[i + 1]
            r_sf = round_up(tokens_per_expert[i], 128) + l_sf
            # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel.
            # This triggers a cublas invalid value error.
            o_decomposed_ref[l:r] = (
                torch._scaled_mm(
                    mat1_fp4[l:r],
                    mat2_scaled[i].transpose(-1, -2),
                    scale1[l_sf:r_sf],
                    scale2[i],
                    None,
                    None,
                    torch.bfloat16,
                )
                * mat2_gs[i]
            )
    
        # Validate: nvfuser quantization should match baseline
        abs_diff = torch.abs(o[0] - o_decomposed_ref)
        max_diff = torch.max(abs_diff)
        assert max_diff <= 10.0, f"Max difference {max_diff:.4f} exceeds threshold of 10.0"
    Test Data Generation

    The test changed from using torch.testing.make_tensor to torch.randn for input data generation. This change might affect test reproducibility and consistency. The reason for this change should be validated to ensure it doesn't introduce flakiness or mask numerical issues.

    mat1 = torch.randn((m, k), dtype=torch.float32, device="cuda:0")
    # format is g, n, k instead of g, k, n
    mat2 = torch.randn((g, n, k), dtype=torch.float32, device="cuda:0")

    Test failures

    • (High, 44) NCCL NVLink SHARP (NVLS) multicast memory binding errors in multidevice distributed / nvFuser test suites on dlcluster_viking_ci runner

      Test Name H100 (dist.) Source
      tests.python.multidevice.test_communication.test_allgather
      tests.python.multidevice.test_communication.test_allgather_expanded_broadcast
      tests.python.multidevice.test_communication.test_allreduce
      tests.python.multidevice.test_communication.test_reduce_scatter
      tests.python.multidevice.test_communication.test_reduce_scatter_noncontiguous
      tests.python.multidevice.test_dtensor.test_column_parallel_linear
      tests.python.multidevice.test_dtensor.test_plus_one
      tests.python.multidevice.test_dtensor.test_row_parallel_linear
      tests.python.multidevice.test_expert_parallel.test_dispatch_and_combine
      tests.python.multidevice.test_matmul.test_column_parallel_grouped_mm
      ... with 34 more test failures omitted. Check internal logs.
    • (Medium, 1) NCCL invalid usage in multidevice overlap allgather test (tests/python/multidevice)

      Test Name H100 (dist.) Source
      tests.python.multidevice.test_overlap.test_overlap_allgather_matmul_shard_outermost[backend_type=CommunicatorBackend.cuda]

    if constexpr (USE_GLOBAL_SCALE) {
    scaled_max = global_scale[0] / scaled_max;
    } else {
    scaled_max = 1.0 / scaled_max;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This should be the right math, but looks like it's causing cpp test errors. I'll double check that.

    @jjsjann123 jjsjann123 removed the request for review from protonu December 16, 2025 20:10
    @jjsjann123
    Copy link
    Collaborator Author

    I should re-evaluate the test after #5696

    @jjsjann123 jjsjann123 requested review from protonu and tbqh December 18, 2025 23:21
    @jjsjann123 jjsjann123 marked this pull request as ready for review December 18, 2025 23:21
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 18, 2025

    Greptile Summary

    This PR replaces the decomposed block quantization implementation with the single nv_block_quantize op in both the test suite and benchmark code. The change simplifies the code by consolidating manual reshape/scale/clamp operations into a single operation.

    Key Changes:

    • benchmark_inference.py: Removes ~20 lines of manual quantization math (reshape, abs, max, div, clamp operations) and replaces with single fd.ops.nv_block_quantize() call
    • test_narrow_precision.py: Adds new test test_block_quantize_op_and_layout_op that validates the op with grouped matmul operations

    Critical Issues:

    • The test uses extremely relaxed error thresholds (max diff 10.0, large diff ratio 10%) that may hide significant accuracy problems
    • The author explicitly acknowledges being "nervous about the numeric of block quantization op" and notes it "can't get it to perform close to what decomposed version does"
    • Test switched from torch.testing.make_tensor to torch.randn, reducing coverage of edge cases (zeros, infinities, NaNs)
    • Code comment on line 757 states "decomposed quantization seems to give much better numerics"

    Concerns:
    The numerical accuracy discrepancy between nv_block_quantize and the decomposed version is concerning. While the PR description mentions this will be "re-evaluated after PR #5696", the current thresholds are too permissive for production use. Consider whether the single-op convenience is worth the accuracy trade-off, or if the op implementation needs fixes before wider adoption.

    Confidence Score: 2/5

    • This PR has significant numerical accuracy concerns that need investigation before merging
    • Score reflects the acknowledged numerical accuracy issues with nv_block_quantize compared to decomposed implementation. The extremely relaxed test thresholds (max diff 10.0), reduced test coverage (switched to torch.randn), and author's explicit nervousness about numerics indicate this change may introduce regressions. While the code simplification is valuable, merging without understanding the accuracy gap risks production issues.
    • Pay close attention to tests/python/direct/test_narrow_precision.py - the test thresholds mask potential accuracy issues that need investigation

    Important Files Changed

    Filename Overview
    tests/python/direct/test_narrow_precision.py Adds test for nv_block_quantize op with grouped matmul, but with extremely relaxed error thresholds (max diff 10.0) that mask potential accuracy issues
    benchmarks/python/benchmark_inference.py Replaces decomposed quantization implementation with single nv_block_quantize op call, simplifying code and removing manual constant imports

    Sequence Diagram

    sequenceDiagram
        participant Test as Test Code
        participant NVFuser as nvFuser Fusion
        participant BlockQuant as nv_block_quantize Op
        participant Layout as preprocess_grouped_matmul_input_sf
        participant GroupedMM as cutlass_nvfp4_grouped_mm
        participant Reference as Decomposed Reference
    
        Note over Test: Prepare input matrices
        Test->>Test: Create mat1 (m, k) float32
        Test->>Test: Create mat2 (g, n, k) float32
        Test->>Test: Quantize mat2 to FP4 with block scales
    
        Note over NVFuser: nvFuser path (new)
        Test->>NVFuser: Execute fusion with mat1
        NVFuser->>BlockQuant: nv_block_quantize(mat1)
        BlockQuant-->>NVFuser: (fp4_mat1, fp8_scale1)
        NVFuser->>Layout: preprocess_grouped_matmul_input_sf(fp8_scale1)
        Layout-->>NVFuser: layout_fp8_scale1
        NVFuser->>GroupedMM: cutlass_nvfp4_grouped_mm(fp4_mat1, mat2, scales...)
        GroupedMM-->>NVFuser: output (bfloat16)
        NVFuser-->>Test: nvfuser_output
    
        Note over Reference: Reference path (decomposed)
        Test->>Reference: activation_scale_to_nvfp4(mat1)
        Reference-->>Test: (mat1_fp4, scale1)
        loop For each expert group
            Test->>Reference: torch._scaled_mm(mat1_fp4[slice], mat2[i], scales...)
            Reference-->>Test: output[slice]
            Test->>Test: Multiply by mat2_gs[i]
        end
    
        Note over Test: Validation
        Test->>Test: Compare abs_diff = |nvfuser_output - reference|
        Test->>Test: Assert max_diff <= 10.0
        Test->>Test: Assert large_diff_ratio < 10%
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @protonu
    Copy link
    Collaborator

    protonu commented Dec 22, 2025

    This looks good to me - do you want to rebase on #5696 and test out the accuracy and we can land both these PRs.

    @jjsjann123
    Copy link
    Collaborator Author

    This looks good to me - do you want to rebase on #5696 and test out the accuracy and we can land both these PRs.

    err. I actually still won't be able to improve the reference result from this, since the TE quantization doesn't allow me to skip global scaling factor (for nvfp4_quantize_with_te). I don't see an example how to do that for activation in grouped mm.

    I might have to stick with the validation. 😢

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Collaborator

    @protonu protonu left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM!

    @jjsjann123
    Copy link
    Collaborator Author

    failing nccl error seems to be unrelated. merging as-is.

    @jjsjann123
    Copy link
    Collaborator Author

    @protonu I can use an approving stamp to merge.

    @jjsjann123 jjsjann123 added the enable-auto-merge Auto-merge a PR when: 1) PR mergeable 2) Internal CI complete 3) No failures label Dec 23, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    enable-auto-merge Auto-merge a PR when: 1) PR mergeable 2) Internal CI complete 3) No failures

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants