Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Nov 18, 2025

Following #5553, part of #5366

TMA Pointwise Scheduler: Broadcast Domain Handling

(1) TMA Load vs. General Load (ldg/ld.global)

The current TMA pointwise scheduler does not use TMA load for inputs with concretized broadcast domains.

Example:
Given three inputs: tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], where B1 and B2 are broadcast domains concretized to I1 and I2:

  • tv0 and tv1 will NOT be loaded with TMA

Note: This is a performance optimization. These inputs can be loaded with TMA, but only using a one-dimensional tile, as demonstrated in the newly added test.

(2) Break Point Selection

When broadcasts are present, the loop domain of the reference tv is merged to [lhs, rhs] instead of flattening to a single dimension.

Example:
Given tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], the break point is at pos-1, which separates [I1, I2] into [lhs, rhs].

Break point selection differs between TMA and non-TMA versions:

  • TMA version: Break point is selected whenever broadcast domains are present
  • Non-TMA version: Break point is selected only when at least 10% of transferred data can be saved

Rationale: In the TMA version, we cannot safely merge broadcast and non-broadcast domains when creating 2D TMA domains and schedules, so we always break when broadcasts are present. See restrictions at #5556

@github-actions
Copy link

github-actions bot commented Nov 18, 2025

Review updated until commit ab08df1

Description

  • Add TMA-specific break point selection logic for broadcast domain handling

  • Enable TMA load for broadcast tensors with proper 2D domain scheduling

  • Refactor pointwise scheduler to use configurable break point selection

  • Add comprehensive tests for TMA scheduling with inner/outer broadcast dimensions

Changes walkthrough

Relevant files
Enhancement
pointwise_tma.cpp
TMA scheduler break point integration                                       

csrc/scheduler/pointwise_tma.cpp

  • Add TMA-specific break point calculation using getBreakPoint with
    is_tma=true
  • Use configurable break_point from params instead of hardcoded 0
  • Add conditional TMA domain splitting logic based on break_point value
  • Include break_point in debug output for better visibility
  • +14/-2   
    pointwise_utils.cpp
    Break point calculation with TMA support                                 

    csrc/scheduler/pointwise_utils.cpp

  • Add is_tma parameter to getBreakPoint function signature
  • Implement TMA-specific logic: allow break at pos-0, prioritize
    transfer size savings
  • Keep non-TMA logic unchanged: maintain 10% savings requirement and
    parallelization constraints
  • Add TMA-specific grid binding flip logic for broadcast scenarios
  • +33/-21 
    pointwise_utils.h
    Function signature update for TMA support                               

    csrc/scheduler/pointwise_utils.h

  • Update getBreakPoint function signature to include bool is_tma
    parameter
  • Maintain backward compatibility with default parameter values
  • +1/-0     
    pointwise_non_tma.cpp
    Non-TMA scheduler parameter update                                             

    csrc/scheduler/pointwise_non_tma.cpp

  • Update getBreakPoint call to explicitly pass is_tma=false parameter
  • Ensure non-TMA scheduler uses original break point selection logic
  • +1/-1     
    Tests
    test_pointwise.cpp
    TMA broadcast scheduling tests                                                     

    tests/cpp/test_pointwise.cpp

  • Refactor Tma2dTileTest to TmaPointwiseTest with improved naming
  • Create TmaTestBase base class for common TMA test infrastructure
  • Add TmaPointwiseBcastTest for comprehensive broadcast scenario testing
  • Implement manual TMA scheduling examples with inner/outer broadcast
    handling
  • Test both auto-scheduler and manual TMA load/store configurations
  • +215/-17

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Break point validation

    The new code adds break point selection for TMA scheduling, but there's a potential issue with the validation logic. The code checks n_valid_dims >= 2 when break_point != 0, but this validation might be too strict or might not cover all valid cases. Need to verify this validation logic handles all expected scenarios correctly, especially for edge cases with broadcast dimensions.

    NVF_ERROR(
        n_valid_dims >= 2,
        "Required at least 2 valid dimensions for Tma scheduling, but got ",
        n_valid_dims);
    TMA break point logic

    The TMA break point selection logic always breaks when broadcasts are present, but the condition if (cur_transfer_size_bit >= min_total_transfer_bit) continue; might be redundant or overly restrictive for TMA cases. The comment suggests prioritizing break at broadcast dimensions, but the current implementation might skip valid break points. Need to verify this logic achieves the intended behavior.

    if (cur_transfer_size_bit >= min_total_transfer_bit) {
      continue;
    }
    Test coverage completeness

    The new broadcast test (TmaPointwiseBcastTest) covers various combinations of TMA vs regular loads for broadcast tensors, but the test dimensions are fixed (1024x2048). Consider adding tests with different tensor sizes to ensure the break point selection logic works correctly across various problem sizes, especially edge cases near the 10% threshold boundary for non-TMA scheduling.

    int64_t dim0 = 1024;
    int64_t dim1 = 2048;

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl marked this pull request as ready for review November 20, 2025 15:09
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Nov 20, 2025

    Greptile Overview

    Greptile Summary

    This PR extends the TMA pointwise scheduler to handle broadcast domains by implementing a break point selection mechanism.

    • Break Point Selection for TMA: When broadcasts are present, the loop domain is merged to [lhs, rhs] instead of flattening to a single dimension. For TMA, break points are always used when broadcasts exist (unlike non-TMA which requires 10% data transfer savings).

    • TMA Load Restriction: Inputs with concretized broadcast domains won't use TMA load, falling back to regular global loads (ldg). This is a performance optimization as demonstrated in the new test.

    • Refactored Test Infrastructure: Extracted common TMA test setup into TmaTestBase template class and added comprehensive TmaPointwiseBcastTest to verify broadcast handling with various TMA load configurations.

    Confidence Score: 4/5

    • This PR is safe to merge - the logic is sound and new tests cover the broadcast scenarios
    • The changes correctly implement break point selection for TMA scheduling with broadcasts. The logic is well-structured and follows established patterns. Minor concern about tile size calculation when break_point > 0 using flattened element counts, but this appears intentional for heuristics.
    • csrc/scheduler/pointwise_tma.cpp - verify tile size calculation is appropriate when break_point > 0

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/pointwise_non_tma.cpp 5/5 Minor change - added is_tma parameter to getBreakPoint call (false for non-TMA scheduler)
    csrc/scheduler/pointwise_utils.h 5/5 Added is_tma parameter to getBreakPoint function declaration
    csrc/scheduler/pointwise_utils.cpp 4/5 Extended getBreakPoint to support TMA mode - allows break at pos-0 for TMA and removes 10% savings threshold requirement when TMA is used. Contains minor typos in comments.
    csrc/scheduler/pointwise_tma.cpp 4/5 Added break point support to TMA scheduler - computes break point and passes to commonPointwiseSchedule. When break_point > 0, skips flattening split and uses the [lhs, rhs] structure from common schedule for TMA tile splitting.
    tests/cpp/test_pointwise.cpp 5/5 Added new TmaPointwiseBcastTest test class with InnerOuterBcast test for TMA scheduling with broadcast tensors. Refactored test base class to share CUDA arch guard and TMA options setup.

    Sequence Diagram

    sequenceDiagram
        participant H as getPointwiseHeuristics
        participant BP as getBreakPoint
        participant S as schedulePointwise
        participant C as commonPointwiseSchedule
    
        H->>BP: getBreakPoint(fusion, prop, is_tma=true)
        BP-->>H: BreakPointInfo (break_point, flip_grid_binding)
        H->>H: Compute tma_domain_inner/outer
        H->>H: Compute tma_tile_inner/outer
        H-->>S: PointwiseParams (break_point, tma_tile_*)
        
        S->>C: commonPointwiseSchedule(fusion, break_point)
        alt break_point == 0
            C->>C: Flatten all dims to [I]
        else break_point > 0
            C->>C: Merge to [lhs, rhs] structure
        end
        C-->>S: CommonScheduleInfo (reference_tv)
        
        alt break_point == 0
            S->>S: reference_tv->split(0, tma_domain_inner)
        else break_point > 0
            S->>S: Validate n_valid_dims >= 2
        end
        S->>S: Split for TMA tiles
        S->>S: Parallelize TMA tensors (Bulk)
        S->>S: Schedule non-TMA tensors
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 82 to 83
    // dimension. outer tile size: don't exceed the outer TMA dimension size Both
    // Both are subject to hardware constraints of 256 elements per dimension.
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: comment has formatting issue - appears to be an incomplete edit

    Suggested change
    // dimension. outer tile size: don't exceed the outer TMA dimension size Both
    // Both are subject to hardware constraints of 256 elements per dimension.
    // - Inner tile size: ensure at least 2 tiles in the inner TMA dimension
    // - Outer tile size: don't exceed the outer TMA dimension size

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +334 to +341
    if (pparams->break_point == 0) {
    reference_tv->split(0, pparams->tma_domain_inner);
    } else {
    NVF_ERROR(
    n_valid_dims >= 2,
    "Required at least 2 valid dimensions for Tma scheduling, but got ",
    n_valid_dims);
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: when break_point > 0, only dimension validation is performed but no actual transformation. Verify the subsequent splits at lines 349-350 work correctly with the [lhs, rhs] structure from commonPointwiseSchedule

    Comment on lines 342 to 343


    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: extra blank line

    Suggested change

    result.flip_grid_binding = false;
    }
    } else {
    // If TMA is used, priorize break if it saves transfered size
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: 'priorize' should be 'prioritize'

    Suggested change
    // If TMA is used, priorize break if it saves transfered size
    // If TMA is used, prioritize break if it saves transfered size

    }
    } else {
    // If TMA is used, priorize break if it saves transfered size
    // This ensures we break at broadcast dimensions, then we can optionally
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: 'transfered' should be 'transferred'

    Suggested change
    // This ensures we break at broadcast dimensions, then we can optionally
    // This ensures we break at broadcast dimensions, then we can optionally

    Base automatically changed from llu/pt3_auto1 to main December 2, 2025 01:15
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @liqiangxl liqiangxl force-pushed the llu/pt3_auto2_bcast branch from 22a4ced to 350b75a Compare December 2, 2025 01:42
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Move break point calculation logic into pointwise_utils:
    - Add pointwise_utils::getBreakPoint() to calculate optimal 2D break point
    - Add pointwise_utils::getBlockGridConfig() to compute block/grid dimensions
    - Refactor pointwise_non_tma.cpp to use new utility functions
    - Add BreakPointInfo and BlockGridConfig structs for return values
    
    This is a mechanical refactoring with no functional changes.
    @liqiangxl liqiangxl changed the base branch from main to llu/pt3_auto2_bcast_utils_mechanical December 2, 2025 01:51
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @liqiangxl liqiangxl force-pushed the llu/pt3_auto2_bcast_utils_mechanical branch from 61412c0 to 01924fa Compare December 2, 2025 02:04
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants