Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Dec 22, 2025

When input tv has broadcast dimensions, it may cause one of the following issues:

  1. Merge of iteration domain with broadcast dimension. This is not supported by TMA and will trigger tma lowering validation error.
  2. For 2D scheduler with break point, right side or left side contains a single broadcast domain, which is against our 2D tile assumption. This restriction can be lifted if we further revise the scheduler.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Dec 22, 2025

Review updated until commit 5bc0cb8

Description

  • Prevent TMA usage for tensors with broadcast dimensions by adding validation check

  • Added comprehensive test cases covering outer, inner, and middle broadcast dimensions

  • Moved bits_per_element validation earlier with improved documentation

  • Ensures TMA scheduler avoids invalid tensor configurations that cause lowering errors

Changes walkthrough

Relevant files
Bug fix
pointwise_tma.cpp
Add broadcast dimension validation for TMA                             

csrc/scheduler/pointwise_tma.cpp

  • Added broadcast dimension check in isTvSuitableForTma function
  • Added detailed comments explaining TMA broadcast dimension
    restrictions
  • Moved bits_per_element validation earlier with improved documentation
  • Removed duplicate bits_per_element check later in function
  • +25/-1   
    Tests
    test_pointwise.cpp
    Add comprehensive TMA broadcast dimension tests                   

    tests/cpp/test_pointwise.cpp

  • Added OuterDimOne test for outer dimension broadcast validation
  • Added InnerDimOne test for inner dimension broadcast validation
  • Added MiddleDimOne test for middle dimension broadcast validation
  • Added OneBcastOneNonBcast test for mixed broadcast/non-broadcast
    inputs
  • +101/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Early return logic

    The early return check for bits_per_element == 0 is moved earlier in the function, which is good for avoiding unnecessary computation. However, ensure this doesn't affect the logic flow or introduce any edge cases where the function might return nullptr when it should continue processing.

    const int64_t bits_per_element = getInputBitsPerElement(prop);
    if (bits_per_element == 0) {
      return nullptr;
    }
    Test coverage completeness

    The tests cover various broadcast dimension scenarios (outer, inner, middle, mixed), which is comprehensive. Consider adding a test case with multiple broadcast dimensions to ensure the logic handles more complex scenarios correctly.

    // input tvs have broadcast dimension, they are not suitable for TMA load.
    // outer dimension is the broadcast dimension.
    TEST_F(TmaPointwiseTestF, OuterDimOne) {
      int64_t dim1 = 8192;
      DataType dtype = DataType::Float;
      auto fusion_ptr = std::make_unique<Fusion>();
      auto fusion = fusion_ptr.get();
      FusionGuard fg(fusion);
      auto tv0 = makeContigConcreteTensor({1, dim1}, dtype);
      auto tv1 = makeContigConcreteTensor({1, dim1}, dtype);
      fusion->addInput(tv0);
      fusion->addInput(tv1);
      auto tv2 = add(tv0, tv1);
      fusion->addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn({1, dim1}, options);
      auto t1 = at::randn({1, dim1}, options);
    
      auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
      auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
      EXPECT_FALSE(pparams->use_tma_load);
      testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
    }
    
    // input tvs have broadcast dimension, they are not suitable for TMA load.
    // inner dimension is the broadcast dimension.
    TEST_F(TmaPointwiseTestF, InnerDimOne) {
      int64_t dim0 = 8192;
      DataType dtype = DataType::Float;
      auto fusion_ptr = std::make_unique<Fusion>();
      auto fusion = fusion_ptr.get();
      FusionGuard fg(fusion);
      auto tv0 = makeContigConcreteTensor({dim0, 1}, dtype);
      auto tv1 = makeContigConcreteTensor({dim0, 1}, dtype);
      fusion->addInput(tv0);
      fusion->addInput(tv1);
      auto tv2 = add(tv0, tv1);
      fusion->addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn({dim0, 1}, options);
      auto t1 = at::randn({dim0, 1}, options);
    
      auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
      auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
      EXPECT_FALSE(pparams->use_tma_load);
      testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
    }
    
    // input tvs have broadcast dimension, they are not suitable for TMA load.
    // midddle dimension is the broadcast dimension.
    TEST_F(TmaPointwiseTestF, MiddleDimOne) {
      int64_t dim0 = 8192;
      int64_t dim2 = 1024;
      DataType dtype = DataType::Float;
      auto fusion_ptr = std::make_unique<Fusion>();
      auto fusion = fusion_ptr.get();
      FusionGuard fg(fusion);
      auto tv0 = makeContigConcreteTensor({dim0, 1, dim2}, dtype);
      auto tv1 = makeContigConcreteTensor({dim0, 1, dim2}, dtype);
      fusion->addInput(tv0);
      fusion->addInput(tv1);
      auto tv2 = add(tv0, tv1);
      fusion->addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn({dim0, 1, dim2}, options);
      auto t1 = at::randn({dim0, 1, dim2}, options);
    
      auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
      auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
      EXPECT_FALSE(pparams->use_tma_load);
      testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
    }
    
    // tv0 has broadcast dimension, not suitable for TMA load
    // tv1 doesn't have broadcast dimension, it is suitable for TMA load.
    TEST_F(TmaPointwiseTestF, OneBcastOneNonBcast) {
      int64_t dim0 = 8192;
      int64_t dim2 = 1024;
      DataType dtype = DataType::Float;
      auto fusion_ptr = std::make_unique<Fusion>();
      auto fusion = fusion_ptr.get();
      FusionGuard fg(fusion);
      auto tv0 = makeContigConcreteTensor({dim0, 1, dim2}, dtype);
      auto tv1 = makeContigConcreteTensor({dim0, 2, dim2}, dtype);
      fusion->addInput(tv0);
      fusion->addInput(tv1);
      auto tv2 = add(tv0, tv1);
      fusion->addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn({dim0, 1, dim2}, options);
      auto t1 = at::randn({dim0, 2, dim2}, options);
    
      auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, {t0, t1});
      auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
      EXPECT_TRUE(pparams->use_tma_load);
      testValidate(fusion, cg_results.outputs, {t0, t1}, __LINE__, __FILE__);
    }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 22, 2025

    Greptile Summary

    Prevents TMA (Tensor Memory Accelerator) usage for input tensors with broadcast dimensions in their logical domain by adding an early validation check in isTvSuitableForTma.

    Key changes:

    • Added broadcast dimension check in isTvSuitableForTma (lines 59-72) that rejects any TensorView with broadcast IterDomains in its logical domain
    • Moved bits_per_element calculation earlier in the heuristics flow (lines 118-127) to enable early bailout when no suitable inputs exist
    • Added comprehensive test coverage for outer, inner, middle broadcast dimensions, and mixed scenarios

    Rationale:
    The PR addresses two specific TMA incompatibilities: (1) TMA doesn't support merging iteration domains with broadcast dimensions, and (2) the 2D scheduler's tile assumptions break when left/right sides contain single broadcast domains. The TODO comment at line 56-57 notes this restriction could be relaxed for specific patterns like outer broadcast inputs [B, I] in future work.

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The changes are well-contained and defensive in nature - they prevent TMA usage in problematic cases rather than changing behavior of existing working code. The broadcast check is straightforward and clearly documented. The moved bits_per_element check provides proper early validation. Four comprehensive test cases validate the fix across different broadcast positions and mixed scenarios.
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/scheduler/pointwise_tma.cpp Added broadcast dimension check in isTvSuitableForTma and moved bits_per_element computation earlier for proper validation
    tests/cpp/test_pointwise.cpp Added four comprehensive test cases covering broadcast dimensions in outer, inner, middle positions, and mixed broadcast/non-broadcast inputs

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Pointwise Scheduler
        participant Heuristics as getPointwiseHeuristics
        participant Check as isTvSuitableForTma
        participant Bits as getInputBitsPerElement
        participant Domain as TensorView Logical Domain
    
        Scheduler->>Heuristics: Request TMA scheduling
        Heuristics->>Heuristics: Determine break point
        Note over Heuristics: Step 0: break point computed
        
        Heuristics->>Bits: Calculate bits_per_element
        Bits->>Check: Check each input TV
        Check->>Domain: Get logical domain
        Domain-->>Check: Return IterDomains
        Check->>Check: Check for broadcast dimensions
        alt Has broadcast dimension
            Check-->>Bits: Return false (not suitable)
            Note over Check: NEW: Reject TVs with<br/>broadcast in logical domain
        else No broadcast dimension
            Check-->>Bits: Return true (suitable)
        end
        Bits-->>Heuristics: Return total bits_per_element
        
        alt bits_per_element == 0
            Heuristics-->>Scheduler: Return nullptr (no TMA)
            Note over Heuristics: NEW: Early validation<br/>before Step 1
        else bits_per_element > 0
            Heuristics->>Heuristics: Step 1: Compute TMA domain
            Heuristics->>Heuristics: Step 2: Determine elements per CTA
            Heuristics->>Heuristics: Step 3+: Configure tiles
            Heuristics-->>Scheduler: Return TMA params
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 187 to 189
    if (bits_per_element == 0) {
    return nullptr;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: redundant check - bits_per_element == 0 is already checked at line 125-127, so this condition will never be true

    Suggested change
    if (bits_per_element == 0) {
    return nullptr;
    }
    // bits_per_element already validated at line 125-127

    @liqiangxl liqiangxl changed the title don't use tma when iter and bcast domains are merged Prevent TMA usage for tensors with broadcast dimensions Jan 6, 2026
    @liqiangxl liqiangxl requested a review from naoyam January 16, 2026 16:19
    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    When a fusion has a tensor with a broadcast, does that cause the fusion to not use TMA at all? Or is it just the broadcast tensor that is affected?

    What about if an input tensor doesn't have a broadcast ID but is broadcast inside the fusion?

    executor_cache.fusion(), out_tensors, {t0, t1}, __LINE__, __FILE__);
    }

    TEST_F(TmaPointwiseTestF, OuterDimOne) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Can you just leave a quick comment what each of the tests is meant to validate?

    @liqiangxl
    Copy link
    Collaborator Author

    liqiangxl commented Jan 20, 2026

    When a fusion has a tensor with a broadcast, does that cause the fusion to not use TMA at all? Or is it just the broadcast tensor that is affected?

    Just the broadcast tv, added a new test TmaPointwiseTestF.OneBcastOneNonBcast where one input is suitable for TMA, the other input is not due to braodcast.

    What about if an input tensor doesn't have a broadcast ID but is broadcast inside the fusion?

    It's already addressed. Current scheduler checks number of non-reduction/non-broadcas/non-device dims in logical domain, if it is not equal to the reference tv, TMA is not used.

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @liqiangxl liqiangxl merged commit ae84859 into main Jan 21, 2026
    63 checks passed
    @liqiangxl liqiangxl deleted the llu/pt_contig branch January 21, 2026 13:02
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants