Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 8, 2026

Context

The series of PRs is trying to enable a single kernel for quantization and layout handling of block scaling factor on grouped tensors.

Existing solution for nvfp4 quantization of activation Tensor for grouped_mm relies on two operation:
i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
ii. block_scaling_factor needs to be processed by PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout required by grouped_mm kernels

The series of PRs tries to merge the two operation into a single one.

Stacked PRs

#5775 GroupedBlockQuantizationOp PR0: Adding runtime function
#5776 GroupedBlockQuantizationOp PR1: Adding codegen support
#5777 GroupedBlockQuantizationOp PR2: Adding python API and updating llama4 benchmark

What's in this PR

  1. refactor existing runtime function for re-use by the new op;
  2. added runtime function for GroupedBlockQuantizeOp.

1. refactor existing block_layout op and block_quantization_kernel to re-use existing runtime functions;
2. added runtime function for GroupedBlockQuantizeOp
@jjsjann123
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Review updated until commit e66c444

Description

  • Refactored runtime functions by merging block_layout.cu into block_quantization_kernels.cu

  • Moved preprocessGroupedMatmulInputSf function from block_layout namespace to bq namespace

  • Added grouped_block_quantize_to_nvfp4 runtime function for combined quantization and layout handling

  • Updated codegen to use new bq::preprocessGroupedMatmulInputSf function

  • Removed block_layout.cu from build system and resource includes

Changes walkthrough

Relevant files
Enhancement
codegen.cpp
Update function call to use new namespace                               

csrc/codegen.cpp

  • Updated function call from
    block_layout::preprocessGroupedMatmulInputSf to
    bq::preprocessGroupedMatmulInputSf
  • Reflects namespace change for refactored runtime function
  • +1/-1     
    compiled_kernel.cpp
    Update resource includes and compilation logic                     

    csrc/runtime/compiled_kernel.cpp

  • Removed include of removed block_layout.h header
  • Removed conditional inclusion of block_layout_cu resource
  • Updated block_quantization_kernels inclusion to handle both
    block_layout and block_quantize_op cases
  • +1/-5     
    block_layout.cu
    Delete merged file with relocated functions                           

    runtime/block_layout.cu

  • Entire file deleted - functions moved to block_quantization_kernels.cu
  • preprocessGroupedMatmulInputSf and outputOffsetAfterSwizzlePadding
    functions relocated
  • +0/-102 
    block_quantization_kernels.cu
    Consolidate runtime functions and add grouped quantization

    runtime/block_quantization_kernels.cu

  • Added outputOffsetAfterSwizzlePadding function (moved from
    block_layout.cu)
  • Added preprocessGroupedMatmulInputSf function (moved from
    block_layout.cu)
  • Added exp2f_rcp and block_quantize_to_mxfp8 utility functions
  • Refactored block_quantize_to_nvfp4 by extracting common logic to
    block_quantize_to_nvfp4_util
  • Added grouped_block_quantize_to_nvfp4 function for combined
    quantization and layout handling
  • +233/-73
    Configuration changes
    CMakeLists.txt
    Remove deleted file from build system                                       

    CMakeLists.txt

  • Removed runtime/block_layout.cu from NVFUSER_RUNTIME_FILES list
  • Reflects file deletion and function consolidation
  • +0/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Code Quality

    The new grouped_block_quantize_to_nvfp4 function has a linear search loop to find the expert_id which could be inefficient for large group sizes. Consider if a binary search or other optimization is needed.

    // find corresponding expert_id
    int expert_id = group_size - 1;
    for (int i = 1; i < group_size; ++i) {
      if (row_idx < input_offsets[i]) {
        expert_id = i - 1;
        break;
      }
    }
    API Consistency

    The PR refactors block_quantize_to_nvfp4 into block_quantize_to_nvfp4_util and block_quantize_to_nvfp4, but the original function signature is preserved. Verify that all existing callers will continue to work correctly with the refactored interface.

    __device__ void block_quantize_to_nvfp4(
        const Array<T, ITEMS_PER_THREAD, ALIGNMENT_1>& input,
        Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output,
        Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales,
        nvfuser_index_t logical_index,
        Tensor<float, 0, 0> global_scale,
        int64_t fp8_scaling_factors_inner_dim = -1,
        int64_t alloc_dim0 = -1,
        int64_t alloc_dim1 = -1,
        int64_t alloc_dim2 = -1,
        int64_t alloc_dim3 = -1,
        int64_t alloc_dim4 = -1) {
      // Write out the block scaling factor to global memory.
      // This assumes 16 elements in the input were contiguous.
      // Only one block scaling factor is written out per 16(assumed block size)
      // elements.
      int offset = logical_index / 16;
    
      if (fp8_scaling_factors_inner_dim > 0) {
        auto stride_4 = 1;
        auto stride_3 = stride_4 * alloc_dim4;
        auto stride_2 = stride_3 * alloc_dim3;
        auto stride_1 = stride_2 * alloc_dim2;
        auto stride_0 = stride_1 * alloc_dim1;
    
        auto logical_inner = offset % fp8_scaling_factors_inner_dim;
        auto logical_outer = offset / fp8_scaling_factors_inner_dim;
    
        // The allocation domain swizzle logic is:
        // m, k -> m, k/4, 4
        // m, k/4, 4 -> m/128, 128, k/4, 4 ->
        // m/128, 4(m), 32, k/4, 4(k) ->
        // m/128, k/4, 32, 4(m), 4(k)
    
        auto pos_4 = logical_inner % 4;
        auto pos_1 = logical_inner / 4;
        auto pos_t = logical_outer % 128;
        auto pos_0 = logical_outer / 128;
        auto pos_3 = pos_t / 32;
        auto pos_2 = pos_t % 32;
    
        offset = pos_4 * stride_4 + pos_3 * stride_3 + pos_2 * stride_2 +
            pos_1 * stride_1 + pos_0 * stride_0;
      }
      block_quantize_to_nvfp4_util<USE_GLOBAL_SCALE>(
          input, output, block_scales, global_scale, offset);
    }

    @jjsjann123 jjsjann123 changed the title PR0: Adding runtime function for GroupedBlockQuantizeOp GroupedBlockQuantizeOp PR0: Adding runtime function Jan 8, 2026
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 marked this pull request as ready for review January 8, 2026 01:36
    @jjsjann123 jjsjann123 requested a review from protonu January 8, 2026 01:36
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 8, 2026

    Greptile Summary

    Refactored block quantization runtime code by consolidating block_layout.cu functionality into block_quantization_kernels.cu to enable code reuse for the new GroupedBlockQuantizeOp.

    • Moved outputOffsetAfterSwizzlePadding and preprocessGroupedMatmulInputSf from nvf::block_layout namespace to nvf::bq namespace
    • Extracted shared quantization logic into block_quantize_to_nvfp4_util helper function
    • Added new grouped_block_quantize_to_nvfp4 runtime function combining quantization with swizzle layout handling
    • Updated codegen to reference bq::preprocessGroupedMatmulInputSf instead of block_layout::preprocessGroupedMatmulInputSf
    • Consolidated build configuration to include block_quantization_kernels_cu for both block layout and quantization operations

    Confidence Score: 4/5

    • Safe to merge with minimal risk - primarily code consolidation and refactoring
    • This is a well-structured refactoring that consolidates related functionality without changing core logic. The code moves existing functions between files and extracts shared logic into reusable utilities. However, confidence is 4 rather than 5 because the new grouped_block_quantize_to_nvfp4 function has not been tested in this PR (testing comes in subsequent PRs), and there are minor logical concerns about edge case handling in the expert_id lookup.
    • Pay close attention to runtime/block_quantization_kernels.cu for the new grouped quantization logic

    Important Files Changed

    Filename Overview
    csrc/codegen.cpp Updated namespace reference from block_layout:: to bq:: for preprocessGroupedMatmulInputSf function call
    csrc/runtime/compiled_kernel.cpp Removed block_layout.h include and merged conditional logic to include block_quantization_kernels_cu for both layout and quantize ops
    runtime/block_quantization_kernels.cu Added outputOffsetAfterSwizzlePadding, preprocessGroupedMatmulInputSf, and grouped_block_quantize_to_nvfp4 functions; refactored block_quantize_to_nvfp4 to extract shared logic into block_quantize_to_nvfp4_util

    Sequence Diagram

    sequenceDiagram
        participant Kernel as Generated Kernel Code
        participant CG as CodeGen (codegen.cpp)
        participant CK as CompiledKernel (compiled_kernel.cpp)
        participant BQ as block_quantization_kernels.cu
        
        Note over CG,CK: Build Phase
        CK->>CK: Check has_block_layout || has_block_quantize_op
        CK->>BQ: Include block_quantization_kernels_cu
        Note over BQ: Contains merged functionality:<br/>- outputOffsetAfterSwizzlePadding<br/>- preprocessGroupedMatmulInputSf<br/>- block_quantize_to_nvfp4_util<br/>- grouped_block_quantize_to_nvfp4
        
        Note over Kernel,BQ: Runtime Phase
        Kernel->>BQ: Call bq::preprocessGroupedMatmulInputSf
        Note over BQ: Uses outputOffsetAfterSwizzlePadding<br/>for swizzle layout computation
        
        Kernel->>BQ: Call bq::grouped_block_quantize_to_nvfp4
        BQ->>BQ: Find expert_id from input_offsets
        BQ->>BQ: Compute output offset with swizzle
        BQ->>BQ: Call block_quantize_to_nvfp4_util
        Note over BQ: Performs quantization and<br/>writes block scaling factors
    
    Loading

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 8, 2026

    Greptile's behavior is changing!

    From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

    This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR refactors runtime functions to prepare for grouped block quantization operations. The changes consolidate block layout functions from block_layout.cu into block_quantization_kernels.cu and introduce new runtime functions.

    Key changes:

    • Moved outputOffsetAfterSwizzlePadding and preprocessGroupedMatmulInputSf from block_layout.cu to block_quantization_kernels.cu namespace bq::
    • Refactored block_quantize_to_nvfp4 by extracting core logic into block_quantize_to_nvfp4_util for reuse
    • Added new grouped_block_quantize_to_nvfp4 function that combines quantization with grouped layout handling
    • Added block_quantize_to_mxfp8 function for FP8 quantization
    • Updated build system to remove block_layout.cu and consolidated resource loading

    Issue found:

    • Parameter type mismatch in block_quantize_to_nvfp4_util - expects global_scale by reference but callers pass by value

    Confidence Score: 3/5

    • This PR has a parameter type mismatch that needs to be fixed before merging
    • The refactoring is well-structured and consolidates related functionality appropriately. However, there's a critical parameter type inconsistency in block_quantize_to_nvfp4_util where global_scale is declared as a reference but callers pass it by value. This will cause compilation errors or incorrect behavior and must be resolved.
    • Pay close attention to runtime/block_quantization_kernels.cu - fix the parameter type mismatch on line 202

    Important Files Changed

    File Analysis

    Filename Score Overview
    CMakeLists.txt 5/5 removed block_layout.cu from build files as the functions were consolidated into block_quantization_kernels.cu
    csrc/runtime/compiled_kernel.cpp 5/5 removed block_layout.h include and consolidated resource loading - block_quantization_kernels_cu now loaded for both has_block_layout and has_block_quantize_op flags
    runtime/block_quantization_kernels.cu 3/5 major refactoring - added helper functions from block_layout.cu, refactored block_quantize_to_nvfp4 into utility function, added new block_quantize_to_mxfp8 and grouped_block_quantize_to_nvfp4 functions; parameter type mismatch found

    Sequence Diagram

    sequenceDiagram
        participant Caller as Fusion Kernel
        participant GBQ as grouped_block_quantize_to_nvfp4
        participant Util as block_quantize_to_nvfp4_util
        participant Helper as outputOffsetAfterSwizzlePadding
        
        Note over Caller,Helper: New Grouped Block Quantization Flow
        
        Caller->>GBQ: input, row_idx, col_idx, offsets, group_size
        
        Note over GBQ: Find expert_id for current row
        GBQ->>GBQ: Search input_offsets array
        
        Note over GBQ: Calculate group-relative indices
        GBQ->>GBQ: c_row_idx = row_idx - input_offsets[expert_id]
        GBQ->>GBQ: padded_col_size = ceil(col_size/BLOCK_SIZE / BLOCK_COL) * BLOCK_COL
        
        GBQ->>Helper: c_row_idx, col_idx/BLOCK_SIZE, padded_col_size
        Helper-->>GBQ: swizzled index
        
        Note over GBQ: Calculate output offset
        GBQ->>GBQ: offset = output_offsets[expert_id] * padded_col_size + index
        
        GBQ->>Util: input, output, block_scales, global_scale, offset
        
        Note over Util: Quantization Logic
        Util->>Util: Convert to float & compute max
        Util->>Util: Reduce across threads
        Util->>Util: Scale with global_scale (if enabled)
        Util->>Util: Clamp to FP8
        Util->>Util: Write block_scales[offset]
        Util->>Util: Quantize to nvfp4
        
        Util-->>GBQ: quantized output
        GBQ-->>Caller: output with swizzled layout
    
    Loading

    Copy link
    Collaborator

    @protonu protonu left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @jjsjann123 jjsjann123 merged commit be448c4 into main Jan 9, 2026
    63 checks passed
    @jjsjann123 jjsjann123 deleted the jj/grouped_block_quantize_op_0 branch January 9, 2026 19:53
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants