Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 8, 2026

Context

The series of PRs is trying to enable a single kernel for quantization and layout handling of block scaling factor on grouped tensors.

Existing solution for nvfp4 quantization of activation Tensor for grouped_mm relies on two operation:
i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
ii. block_scaling_factor needs to be processed by PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout required by grouped_mm kernels

The series of PRs tries to merge the two operation into a single one.

Stacked PRs

#5775 GroupedBlockQuantizationOp PR0: Adding runtime function
#5776 GroupedBlockQuantizationOp PR1: Adding codegen support
#5777 GroupedBlockQuantizationOp PR2: Adding python API and updating llama4 benchmark

What's in this PR

  1. Added python API ops.nv_grouped_block_quantize for GroupedBlockQuantizationOp;
  2. Added python translation rule for GroupedBlockQuantizationOp;
  3. Python test for GroupedBlockQuantizationOp;
  4. Switched 2 operation quantization for grouped_mm activation to use GroupedBlockQuanitzationOp instead.

@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Review updated until commit 5c905de

Description

  • Added Python API ops.nv_grouped_block_quantize for GroupedBlockQuantizationOp

  • Implemented Python translation rule for GroupedBlockQuantizationOp

  • Updated llama4 benchmark to use consolidated quantization operation

  • Added comprehensive test for GroupedBlockQuantizationOp functionality

Changes walkthrough

Relevant files
Enhancement
ops.cpp
Add Python API for GroupedBlockQuantizationOp                       

python/python_direct/ops.cpp

  • Added nv_grouped_block_quantize function to Python ops module
  • Function takes input tensor, offsets, global_scale, block_size and
    dtype parameters
  • Returns tuple of quantized tensor and block scaling factors
  • Includes comprehensive docstring with parameter descriptions
  • +49/-0   
    python_translate.cpp
    Implement Python translation rule for GroupedBlockQuantizationOp

    python/python_direct/python_translate.cpp

  • Added handle method for GroupedBlockQuantizationOp in PythonTranslator
  • Generates Python translation for grouped block quantization operation
  • Handles default arguments for global_scale, block_size, and dtype
  • Maps operation inputs and outputs to Python syntax
  • +25/-0   
    benchmark_inference.py
    Update benchmark to use consolidated quantization operation

    benchmarks/python/benchmark_inference.py

  • Updated nvfp4_grouped_mm_translator to use nv_grouped_block_quantize
  • Replaced two-step process (nv_block_quantize +
    preprocess_grouped_matmul_input_sf)
  • Consolidated quantization and layout handling into single operation
  • Simplified benchmark code by removing intermediate variables
  • +1/-2     
    Tests
    test_narrow_precision.py
    Add comprehensive test for GroupedBlockQuantizationOp       

    tests/python/direct/test_narrow_precision.py

  • Added test_grouped_block_quantize_op function with comprehensive
    validation
  • Tests grouped block quantization with multiple configuration
    parameters
  • Validates output against reference PyTorch implementation
  • Checks for quantization accuracy with max difference and large diff
    ratio thresholds
  • +167/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Performance validation

    The benchmark update shows the main optimization (replacing 2 operations with 1), but no performance metrics are provided in the PR. Should validate that the single kernel approach actually provides the expected performance improvement over the previous two-operation approach.

    fp4_mat1, layout_fp8_scale1 = fd.ops.nv_grouped_block_quantize(nv_act, nv_offsets, nv_blocksf_offsets)
    API consistency

    The new nv_grouped_block_quantize API uses different parameter ordering and defaults compared to similar quantization ops. Should verify the API design is consistent with existing patterns and that all parameters are properly validated.

    "nv_grouped_block_quantize",
    [](TensorView* input,
       TensorView* input_offsets,
       TensorView* output_offsets,
       TensorView* global_scale,
       int64_t block_size,
       PrimDataType dtype) -> py::tuple {
      auto output = groupedBlockQuantize(
          input,
          input_offsets,
          output_offsets,
          BlockScalingFactorLayout::Block128x4,
          global_scale,
          block_size,
          dtype);
      return py::make_tuple(output.quantized_tensor, output.block_scales);
    },
    Test coverage

    The test covers the basic functionality but only tests one configuration. Should verify if additional test cases are needed for edge cases, different block sizes, or various tensor dimensions to ensure robustness.

    def test_grouped_block_quantize_op(
        nvfuser_direct_test,
        config,
        tokens_per_expert_neg_one,
        out_dtype,
    ):
        BLOCK_SIZE = 16
    
        # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor
        m, n, k = config
        assert k % 64 == 0
        tokens_per_expert = list(tokens_per_expert_neg_one)
        tokens_per_expert.append(m - sum(tokens_per_expert))
        g = len(tokens_per_expert)
    
        mat1 = torch.randn((m, k), dtype=torch.float32, device="cuda:0")
        # format is g, n, k instead of g, k, n
        mat2 = torch.randn((g, n, k), dtype=torch.float32, device="cuda:0")
    
        offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0")
    
        # prepare quantization for mat2
        mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0")
        scale2 = torch.empty(
            (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0"
        )
    
        acc_tokens = 0
        rounded_acc_tokens = 0
        mat2_scaled = torch.empty(
            (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0"
        )
    
        for i in range(g):
            global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max()
            offsets[i] = acc_tokens
            blockscale_offsets[i] = rounded_acc_tokens
            acc_tokens += tokens_per_expert[i]
            # Note: we technically don't need to round up, since k is perfectly sized.
            rounded_acc_tokens += round_up(tokens_per_expert[i], 128)
    
            problem_sizes[i][0] = tokens_per_expert[i]
            problem_sizes[i][1] = n
            problem_sizes[i][2] = k
    
            scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf)
            mat2_gs[i] = 1.0 / global_sf
            mat2_scaled[i] = scaled_mat2_i
            scale2[i] = linear_to_swizzled_128_4(bs_mat2_i)
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            mat1 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float,
                is_cpu=False,
            )
            mat2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[2, 0, 1],
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float8_e4m3fn,
                is_cpu=False,
            )
            alpha = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            problem_sizes = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            blockscale_offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
    
            fp4_mat1, fp8_scale1 = fd.ops.nv_grouped_block_quantize(
                mat1, offsets, blockscale_offsets
            )
    
            out = fd.ops.cutlass_nvfp4_grouped_mm(
                fp4_mat1,
                mat2,
                fp8_scale1,
                scale2,
                alpha,
                problem_sizes,
                offsets,
                blockscale_offsets,
                DataType.BFloat16,
            )
            fd.add_output(out)
    
        inputs = [
            mat1,
            mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2),
            scale2,
            mat2_gs,
            problem_sizes,
            offsets,
            blockscale_offsets,
        ]
    
        o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
        # quantization for activation is needed for reference.
        # note: following sglang implementation, not computing global scaling factor for mat1
        #       similarly, we don't need to apply mat1_gs to alpha
        mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0")
        mat1_fp4, scale1 = activation_scale_to_nvfp4(
            mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE
        )
        o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0")
        for i in range(g):
            l = offsets[i]
            l_sf = blockscale_offsets[i]
            if i == g - 1:
                r = m
            else:
                r = offsets[i + 1]
            r_sf = round_up(tokens_per_expert[i], 128) + l_sf
            # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel.
            # This triggers a cublas invalid value error.
            o_decomposed_ref[l:r] = (
                torch._scaled_mm(
                    mat1_fp4[l:r],
                    mat2_scaled[i].transpose(-1, -2),
                    scale1[l_sf:r_sf],
                    scale2[i],
                    None,
                    None,
                    torch.bfloat16,
                )
                * mat2_gs[i]
            )
    
        # Validate: nvfuser quantization should match baseline
        abs_diff = torch.abs(o[0] - o_decomposed_ref)
        max_diff = torch.max(abs_diff)
        assert max_diff <= 10.0, f"Max difference {max_diff:.4f} exceeds threshold of 10.0"
    
        # Check that large differences (> 5.0) are rare (< 10% of elements)
        large_diff_count = torch.count_nonzero(torch.gt(abs_diff, 5.0))
        large_diff_ratio = large_diff_count / abs_diff.numel()
        assert (
            large_diff_ratio < 0.1
        ), f"Large diff ratio {large_diff_ratio:.2%} exceeds 10% threshold"

    @jjsjann123 jjsjann123 changed the title PR2: adding python API and updating benchmarks GroupedBlockQuantizeOp PR2: Adding python API and updating llama4 benchmark Jan 8, 2026
    jjsjann123 added a commit that referenced this pull request Jan 9, 2026
    ## Context
    
    The series of PRs is trying to enable a single kernel for quantization
    and layout handling of block scaling factor on grouped tensors.
    
    Existing solution for nvfp4 quantization of activation Tensor for
    grouped_mm relies on two operation:
    i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
    ii. block_scaling_factor needs to be processed by
    PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout
    required by grouped_mm kernels
    
    The series of PRs tries to merge the two operation into a single one.
    
    ### Stacked PRs
    
    #5775 GroupedBlockQuantizationOp PR0: Adding runtime function
    #5776 GroupedBlockQuantizationOp PR1: Adding codegen support
    #5777 GroupedBlockQuantizationOp PR2: Adding python API and updating
    llama4 benchmark
    
    ## What's in this PR
    
    1. refactor existing runtime function for re-use by the new op;
    2. added runtime function for GroupedBlockQuantizeOp.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants