Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

  1. Add flashinfer in benchmarks.

  2. Two optimizations for inner persistent scheduler:
    (1) Given a LLM model, the hidden dimension is fixed, we can use static shape in that dimension. Change bdimx to static.
    (2) For input with small batch size, not all SMs are used, pick a large bdimx to allow more warps.
    After these two optimizations, the performance of nvFuser is better than flashinfer.

image

raw data: https://docs.google.com/spreadsheets/d/1JMea0s_Z2mYDgbdNSI7-pS4MtdUBrIONl5EqiezFrDo/edit?usp=sharing

@github-actions
Copy link

github-actions bot commented Oct 1, 2025

Review updated until commit 5bb64ee

Description

  • Enable static block dimension for inner persistent scheduler

  • Optimize warp utilization for small batch sizes

  • Improve unrolling logic for persistent batch domains

  • Add FlashInfer benchmark for RMSNormAdd


Changes walkthrough 📝

Relevant files
Enhancement
normalization_inner.cpp
Enable static bdimx for persistent scheduling                       

csrc/scheduler/normalization_inner.cpp

  • Set static_bdimx based on static reduction size
  • Enable compute with first consumer for static case
  • Round up bdimx to multiple of warp size
  • Conditionally disable static bdimx via env var
  • +22/-1   
    normalization_utils.cpp
    Refine persistent batch unrolling logic                                   

    csrc/scheduler/normalization_utils.cpp

  • Pass is_static_reduction_size to kernel properties
  • Simplify unroll condition by removing batch count check
  • Identify persistent batch domain by split and serial parallel type
  • Ensure single domain matches criteria
  • +7/-13   
    utils.cpp
    Detect static reduction size in properties                             

    csrc/scheduler/utils.cpp

  • Add is_static_reduction_size detection
  • Check all reduction domains for const extent
  • Set property in ReductionTvProperties
  • Ignore non-reduction domains in check
  • +8/-0     
    normalization_utils.h
    Expose static reduction size in properties                             

    csrc/scheduler/normalization_utils.h

  • Add is_static_reduction_size to PersistentKernelProperties
  • Include in toString() output for debugging
  • Document new field in struct
  • +2/-0     
    utils.h
    Declare static reduction size property                                     

    csrc/scheduler/utils.h

  • Add is_static_reduction_size field to ReductionTvProperties
  • Initialize to false by default
  • Document meaning in comment
  • +3/-0     
    Bug fix
    reduction_utils.cpp
    Adapt reduction scheduling for static bdimx                           

    csrc/scheduler/reduction_utils.cpp

  • Use static bdimx for splitting when enabled
  • Apply padToMultipleOfWarp only in non-static case
  • Adjust parallelization logic based on static_bdimx
  • Fix conditional placement of padding
  • +16/-8   
    Tests
    test_rmsnorm_add_fwd.py
    Add FlashInfer RMSNormAdd benchmark                                           

    benchmarks/python/test_rmsnorm_add_fwd.py

  • Import fused_add_rmsnorm inside wrapper
  • Enable FlashInfer benchmark comparison
  • Support in-place update semantics
  • Return updated tensors for IO measurement
  • +2/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Conditional Override

    The use of environment variables to override static_bdimx and compute_persistent_buffer_with_first_consumer may lead to inconsistent behavior between development and production environments. This should be validated for robustness and documented clearly.

    if (std::getenv("USE_MAIN") != nullptr) {
      rparams->static_bdimx = false;
    }
    
    if (rparams->static_bdimx) {
      rparams->compute_persistent_buffer_with_first_consumer = true;
      if (std::getenv("NO_COMPUTE_WITH") != nullptr) {
        rparams->compute_persistent_buffer_with_first_consumer = false;
      }
    Split Logic Change

    The logic for splitting the reduction axis has been altered with the introduction of static_bdimx. This change may affect how warps are utilized and should be verified for correctness across different input sizes and configurations.

    if (rparams->static_bdimx) {
      // [R, TIDx, Vect]
      reduction_tv->split(inner_reduce_axis, rparams->lparams.bdimx());
      // [R, TIDx, Vect]
      reduction_tv->axis(inner_reduce_axis + 1)
          ->parallelize(rparams->block_dim_inner_reduction);
    } else {
      reduction_tv->split(
          outer_i++, rparams->batches_per_block_inner_reduction, false);
    }
    Unroll Condition Update

    The condition for unrolling persistent cached inputs has been modified to include compute_persistent_buffer_with_first_consumer. This change could impact performance and should be tested thoroughly to ensure it does not introduce regressions.

    bool unroll_persistent_cached_inputs = rparams->vectorize_inner_reduction &&
        rparams->fastest_dim && !rparams->schedule_3D &&
        !rparams->compute_persistent_buffer_with_first_consumer;

    @liqiangxl liqiangxl changed the title perf optimization for rmsnorm add use static bdimx in inner persistent scheduler for static fusion Oct 13, 2025
    @liqiangxl liqiangxl changed the title use static bdimx in inner persistent scheduler for static fusion optimize inner persistent scheduler for static fusion Oct 13, 2025
    @liqiangxl liqiangxl force-pushed the llu/perf_rmsnorm_add branch from 214d963 to 2f03db7 Compare October 14, 2025 14:36
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants