Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

WAR fix of a bug in compute map based loop ID retrieval

Problem

See issue-#5326
The NVFuser kernel generator was creating loops with incorrect extents, causing index computation failures:

INTERNAL ASSERT FAILED at index_compute.cpp:1987
Couldn't find allocation mapping for T18_l_float[iblockIdx.x137{8}, ...]

The root cause: FOR blockIdx.x in iblockIdx.x202{1}: was generated when it should have been FOR blockIdx.x in iblockIdx.x202{8}:

Root Cause

In getConcreteLoopID(), when using the ComputeAtMap path:

  • Multiple IterDomains with different extents (1 and 8) are mapped into the same LOOP group
  • ComputeAtMap::computeConcreteId() picks a representative based on graph topology

The WAR Fix

The IdModel path correctly generates loop IDs, while the ComputeAtMap path is scheduled for deprecation. A temporary WAR fix is introduced here to unblock execution.

What was changed:

Added extent validation logic in the CA map path (non-IdModel):

  1. Broadcast Promotion: Prefer non-broadcast IDs over broadcast/size-1 IDs
  2. Maximum Extent: Choose the ID with the largest extent in the group
  3. Extent Validation: Actively check for extent compatibility

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Oct 7, 2025

Review updated until commit f16107e

Description

  • Fix loop ID extent mismatch in CA map path

  • Prefer non-broadcast IDs with maximum extent

  • Add test case reproducing loop ID issue

  • Validate concrete loop ID extent compatibility


Changes walkthrough 📝

Relevant files
Bug fix
utils.cpp
Fix concrete loop ID selection in CA map                                 

csrc/device_lower/utils.cpp

  • Added extent validation in getConcreteLoopID
  • Prefer non-broadcast IDs in loop group
  • Select ID with largest constant extent
  • WAR fix for incorrect loop extent generation
  • +40/-1   
    Tests
    test_repro.py
    Add repro test for loop ID bug                                                     

    tests/python/direct/test_repro.py

  • Add repro test for CA map loop ID issue
  • Define fusion with broadcast and reshape ops
  • Validate inputs triggering the bug
  • Test case covers extent mismatch scenario
  • +47/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Extent Validation

    The WAR fix introduces extent validation logic to prefer non-broadcast IDs with larger extents, but it only checks constant integer extents. This may miss cases where extents are symbolic or dynamic, leading to potential incorrect loop ID selection.

    bool concrete_is_broadcast_or_one = concrete->isBroadcast() ||
        (concrete->extent()->isConstInt() &&
         concrete->extent()->evaluate().as<int64_t>() == 1);
    
    if (concrete_is_broadcast_or_one && disjoint_set->vector().size() > 1) {
      // Look for a non-broadcast ID with a larger extent in the same loop group
      IterDomain* better_concrete = nullptr;
      int64_t max_extent = 1;
    
      for (auto loop_id : disjoint_set->vector()) {
        if (loop_id->isBroadcast()) {
          continue;
        }
    
        if (loop_id->extent()->isConstInt()) {
          auto extent_val = loop_id->extent()->evaluate().as<int64_t>();
          if (extent_val > max_extent) {
            max_extent = extent_val;
            better_concrete = loop_id;
          }
        }
      }
    
      // If we found a better candidate, use it instead
      if (better_concrete != nullptr && better_concrete != concrete) {
        concrete = better_concrete;
      }
    }
    Loop Group Iteration

    The loop over disjoint_set->vector() may not consider all relevant IterDomains if the set contains non-constant extents that could be larger than 1. This could result in suboptimal or incorrect concrete ID selection when non-constant extents are involved.

    for (auto loop_id : disjoint_set->vector()) {
      if (loop_id->isBroadcast()) {
        continue;
      }
    
      if (loop_id->extent()->isConstInt()) {
        auto extent_val = loop_id->extent()->evaluate().as<int64_t>();
        if (extent_val > max_extent) {
          max_extent = extent_val;
          better_concrete = loop_id;
        }
      }
    }

    @naoyam
    Copy link
    Collaborator

    naoyam commented Oct 7, 2025

    !test --diff

    auto disjoint_set = ca_map.disjointSetOf(id, IdMappingMode::LOOP);
    auto concrete = ca_map.getConcreteMappedID(id, IdMappingMode::LOOP);

    // The CA map's concrete ID may have an incompatible extent.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Please link the issue number.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    added

    @naoyam
    Copy link
    Collaborator

    naoyam commented Oct 7, 2025

    The fix looks good. I agree just a war would be sufficient.

    Started the diff check to see if there's any unexpected side effect.

    @liqiangxl
    Copy link
    Collaborator Author

    The fix looks good. I agree just a war would be sufficient.

    Started the diff check to see if there's any unexpected side effect.

    The codediff is ptx diff.

    @liqiangxl liqiangxl marked this pull request as ready for review October 8, 2025 12:28
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM. Thanks for the fix.

    @naoyam naoyam merged commit 6e8136a into main Oct 8, 2025
    56 of 57 checks passed
    @naoyam naoyam deleted the llu/alloc_mapping_5326 branch October 8, 2025 15:07
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants