Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 18, 2025

Stacked PRs

Breaking original PR #5170 into three:
#5186 Fix allocation logic: non-divisible split
#5185 Fix allocation logic: unconnected alloc/logical
#5184 Allow split on logical->allocation <- this one

This PR

Allows split of ID on the path logical->allocation to represent padding logic on allocation. Notably, we no longer require allocation domain on the path between logical->loop

  1. This PR relaxes the assert in vectorization analysis, which used to assert on seeing non-device split;
  2. Updates TensorView::cacheBefore to replay not only logical->loop, but also logical->allocation transforms on output. Because padding related split only appears on logical->allocation. Without the extra replay, TensorView::cacheBefore would alter the semantics by changing allocation domain of outputs.

TODO:

  • add a test case where it could impact vectorization analysis.

@github-actions
Copy link

github-actions bot commented Sep 18, 2025

Review updated until commit 2950054

Description

  • Relax validation of allocation domain splits for padding

  • Preserve allocation domain in cacheBefore for correctness

  • Fix ID model update during matmul scheduling

  • Improve split/merge handling in transform replay


Changes walkthrough 📝

Relevant files
Enhancement
6 files
alias_analysis.cpp
Allow layout relaxation when allocation is logical             
+6/-0     
utils.cpp
Add logic to determine valid preset allocation domains     
+52/-0   
tensor_view.cpp
Redesign cacheBefore to preserve allocation domain             
+99/-9   
transform_replay.cpp
Propagate parallelization and iter types in replay             
+11/-3   
internal_base_nodes.h
Add resetRFactorProduct method to IterDomain                         
+5/-0     
utils.h
Declare canUsePresetAllocationDomain utility function       
+2/-0     
Bug fix
2 files
validation.cpp
Skip vectorization validation without preset allocation   
+3/-0     
matmul.cpp
Update ID model mapping in cacheBefore                                     
+17/-6   
Tests
2 files
test_allocation_domain.cpp
Update cacheBefore tests for allocation domain                     
+0/-2     
test_layout_op.cpp
Add tests for allocation domain split and padding               
+73/-0   
Miscellaneous
1 files
test_multidevice_transformer.cpp
Add debug print for distributed test                                         
+10/-0   
Documentation
1 files
test_transpose.cpp
Clarify reshape-permute-transpose test comment                     
+2/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The handling of scatter operations in cacheBefore appears inconsistent and potentially incorrect. The code comments indicate uncertainty about whether loop domain replay is appropriate, and there's a TODO about special handling that's commented out. This could lead to incorrect domain propagation for scatter operations.

if (definition()->isA<ScatterOp>()) {
  // TODO: is there any way to replay a scatter op?!
  // scatter output's loop is not connected to its root.
  NVF_ERROR(!domain()->hasRoot(), "scatter output's root is not replayed in cacheBefore");

  std::vector<IterDomain*> logical;
  std::vector<IterDomain*> loop;
  std::unordered_map<IterDomain*, IterDomain*> map_cloned_ids;

  std::ranges::transform(domain()->logical(), std::back_inserter(logical), [&](IterDomain* id) {
      IterDomain* cloned_id = IrBuilder::createInContainer<IterDomain>(container(), id);
      map_cloned_ids[id] = cloned_id;
      return cloned_id;
  });
  std::ranges::transform(domain()->loop(), std::back_inserter(loop), [&](IterDomain* id) {
      if (auto it = map_cloned_ids.find(id); it != map_cloned_ids.end()) {
        // reuse cloned_ids
        return it->second;
      }
      return IrBuilder::createInContainer<IterDomain>(container(), id);
  });
  producer = IrBuilder::createInContainer<TensorView>(
      container(),
      IrBuilder::createInContainer<TensorDomain>(container(), logical, loop, TensorDomain::getContiguityFilledWith(logical, true), /*skip_loop_validation=*/true),
      getDataType().value());
  // TODO:  we are not replaying the loop domain from consumer to producer, is that the right thing to do?!
} else {
  // Create Producer Domain
  // We only need root for full self replay.
  std::vector<IterDomain*> root;
  std::ranges::transform(domain()->hasRoot()?domain()->root():domain()->logical(), std::back_inserter(root), [&](IterDomain* id) {
      return IrBuilder::createInContainer<IterDomain>(container(), id);
  });

  producer = IrBuilder::createInContainer<TensorView>(
      container(),
      IrBuilder::createInContainer<TensorDomain>(container(), root, root, root, TensorDomain::getContiguityFilledWith(root, true)),
      getDataType().value());
  // replay from `root`->`loop` on producer
  producer->setDomain(TransformReplay::fullSelfReplay(producer->domain(), domain()));
}

// clean up consumer domain to wipe out root and all reduction IDs
std::vector<IterDomain*> logical_dom;
std::vector<IterDomain*> alloc_dom;
std::vector<IterDomain*> loop_dom;
std::vector<std::optional<bool>> contiguity;

// NOTE: I need to clear definition otherwise BestEffortReplay will not work with dangling sources
// create an issue for this, use the example from ./bin/test_nvfuser --gtest_filter="PointwiseTest.VectorizeWithBroadcastAndReshape1"
// copy non-reduction IDs onto logical and loop
std::ranges::copy_if(
    domain()->logical() | std::views::transform([](IterDomain* id) { id->setDefinition(nullptr); return id->resetRFactorProduct(); }),
    std::back_inserter(logical_dom),
    [](IterDomain* id) {return !id->isReduction();});
if (definition()->isA<ScatterOp>()) {
  // NOTE: this doesn't feel right. we would still want to replay the loop domain
  // we are basically dropping transformations on loop domain for scatter op during cacheBefore
  loop_dom = logical_dom;
} else {
std::ranges::copy_if(
    domain()->loop() | std::views::transform([](IterDomain* id) { return id->resetRFactorProduct(); }),
    std::back_inserter(loop_dom),
    [](IterDomain* id) {return !id->isReduction();});
}
for (auto&& [id, c] : zip(domain()->hasAllocation() ? domain()->allocation() : domain()->logical(), domain()->contiguity())) {
  if (id->isReduction()) {
    continue;
  }
  id->resetRFactorProduct();
  if (domain()->hasAllocation()) {
    alloc_dom.push_back(id);
  }
  contiguity.push_back(c);
}
// TODO: We also need to clear all rfactor across IDs between logical->loop and logical->allocation.

// Set domain of consumer
TensorView* consumer = this;
consumer->setDomain(IrBuilder::createInContainer<TensorDomain>(
    container(),
    std::vector<IterDomain*>{},
    logical_dom,
    alloc_dom,
    loop_dom,
    contiguity));

// TODO: figure out scatter special handling.
// if (!producer->definition()->isA<ScatterOp>()) {
// } else {
// }
Logic Gap

The canUsePresetAllocationDomain function has a complex set of conditions but doesn't clearly handle the case where a tensor has no allocation domain but is used in a context where padding might be needed. The relationship between the ignore_empty_alloc parameter and the overall logic could lead to unexpected behavior.

bool canUsePresetAllocationDomain(const TensorView* tv, bool ignore_empty_alloc) {
  if (ignore_empty_alloc && !tv->hasAllocation()) {
    return false;
  }
  // Honor the allocation domain if the tensor is global or Hopper MMA's
  // output
  if (tv->getMemoryType() == MemoryType::Global ||
      (tv->definition()->isA<MmaOp>() &&
       isHopper(tv->definition()->as<MmaOp>()->macro()))) {
    return true;
  }
  // If it's a shared memory tensor, the set domain is likely
  // valid if Swizzle or Bulk is used. Also, if the allocation
  // domain is just a permutation of the loop domain, use the
  // set allocation domain. This seems to happen only with
  // AllocationDomainTest.TransposedIntermediate.
  if (tv->getMemoryType() == MemoryType::Shared) {
    if (std::any_of(
            tv->getAllocationDomain().begin(),
            tv->getAllocationDomain().end(),
            [](IterDomain* allocation_domain) {
              return dynamic_cast<Swizzle*>(
                         allocation_domain->definition()) != nullptr ||
                  allocation_domain->getParallelType() == ParallelType::Bulk;
            }) ||
        std::is_permutation(
            tv->getLoopDomain().begin(),
            tv->getLoopDomain().end(),
            tv->getAllocationDomain().begin(),
            tv->getAllocationDomain().end())) {
      return true;
    }

    // Honor the set allocation domain if the tensor is used by a
    // TMA store or MmaOp
    if (std::ranges::any_of(tv->uses(), [](Expr* expr) {
          return ir_utils::isCpAsyncBulkStore(expr) || expr->isA<MmaOp>();
        })) {
      return true;
    }

    // If a shared memory output produced by scatter has an
    // allocation domain explicitly set, it's likely to be the
    // valid allocation domain.
    if (auto def = tv->definition();
        def != nullptr && def->isA<ScatterOp>()) {
      return true;
    }
  }
  return false;
Performance Impact

The vectorization validation has been modified to return early when canUsePresetAllocationDomain returns false, but this might skip important validation checks that could affect performance, especially in cases where padding is used to enable vectorization.

if (!ir_utils::canUsePresetAllocationDomain(tv)) {
  return;
}

out->split(1, 16);
out->setAllocationDomain(out->getLoopDomain(), true);
// restore loop domain
out->merge(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't restore. Is this necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Touche. It unsplit the loop domain so that it has the same size as logical domain.
You are right that the extent is no longer the same, so it's not a restoration.

Schedulers expects un-scheduled fusion. Without this merge, I'm hitting the assert here:

NVF_ERROR(broadcast_bit_multiples.size() == ref_loop.size());

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure that's good enough WAR, though this is just a test.

I thought the schedulers can work with some scheduled loop domains (for DID parallelization), not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// We always cacheBefore output at the beginning of the scheduling. And after
// cacheBefore, the reference tensor will have all reduction IDs removed.
ref_loop = TensorDomain::noDevices(TensorDomain::noReductions(ref_loop));

DID related IDs are just ignored by scheduler. So that's just too specific for multi-device.

I'm not a fan of this neither. Let me see if I can skip messing with loop and play transformation on allocation directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you can just modify the allocation domain with AbstractTensor. I remember there are some tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also directly using IterDomain::split for that.

Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698

#0  nvfuser::nvfCheckFail (func=0xaaaaac218080 "validateDomainEquivalence",
    file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
    msg=" INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. \nExpected !compare_result.dom0_has_u"...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:267
#1  0x0000aaaaab1bbe68 in nvfuser::nvfErrorFail (func=0xaaaaac218080 "validateDomainEquivalence",
    file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
    condMsg=0xaaaaac217fd8 " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. ",
    userMsg="Expected !compare_result.dom0_has_unreachable_ids . dom0 has unreachable IDs. dom0: iS10{i0}, iS11{i2}. dom1: iS10{i0}") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:277
#2  0x0000aaaaab60a3e8 in nvfuser::ir_utils::validateDomainEquivalence (
    dom0=std::vector of length 2, capacity 2 = {...}, dom1=std::vector of length 1, capacity 3 = {...},
    additional_ids=std::vector of length 0, capacity 0) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162
#3  0x0000aaaaab4aac30 in nvfuser::TensorDomain::setAllocationDomain (this=0xaaaab20918b0,
    new_allocation_domain=std::vector of length 1, capacity 3 = {...},
    new_contiguity=std::vector of length 1, capacity 3 = {...})
    at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:4055
#4  0x0000aaaaabc7b368 in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
    producer=0xaaaab2091200, producer_pos=2, logical_map=..., opt=...)
    at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:917
#5  0x0000aaaaabc7b7fc in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
    producer=0xaaaab2091200, compute_at_axis=-1, opt=...)
    at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:945
#6  0x0000aaaaabc44ccc in nvfuser::TensorView::cacheBefore (this=0xaaaab2088c00,
    op_type=nvfuser::LoadStoreOpType::Set) at /opt/pytorch/nvfuser/csrc/tensor_view.cpp:1160
#7  0x0000aaaaabbdb250 in nvfuser::scheduler_utils::cacheAndForkOutputs (fusion=0xaaaab2084910,
    unroll=true) at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:1357
#8  0x0000aaaaabb067dc in nvfuser::schedulePointwise (fusion=0xaaaab2084910, pparams=0xaaaab207f880)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:822
#9  0x0000aaaaabb0898c in nvfuser::PointWiseScheduler::schedule (this=0xaaaab2083460,
    fusion=0xaaaab2084910, params=0xaaaab207f880)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:1304

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, what did you decide to do? Nothing seems to have changed?

I can also directly using IterDomain::split for that.

Of course, but you'd need to maintain the proper ordering of the ID vector yourself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also directly using IterDomain::split for that.

Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698

Yes rfactor replay for allocation will also complain similarly if allocation transforms are disjoint from root-to-loop.
replayPasC also uses the loop domain as the target so if you intend to use IterDomain::split, we will have to update that, among other things.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. switched to selfReplay instead of replayCasP for TensorView::cacheBefore

}
};

TEST_F(LayoutOpTest, LogicalAndAllocationSizes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is being tested here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the relaxation in vectorization analysis, this test will trigger an assert.

So the test just verifies that we do allow allocation domain split now.
In the follow up PR, we added more validation to this test to check the produce tensor matches the logical sizes.

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good for the multidevice support part. I am not familiar enough with the requirements for LayoutOp, so I will defer to Naoya to approve the PR.
Is there an existing issue or doc detailing the LayoutOp design?

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

Is there an existing issue or doc detailing the LayoutOp design?

Sorry I don't have anything on that yet. I'll try to write up one when I have the end-2-end example working at least in a prototype. Mostly trying to wing it at this moment.

@jjsjann123
Copy link
Collaborator Author

!test


// Replay loop.
if (self_loop != self->logical()) {
ReplaySelf replay(self_loop, axis_map);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI: #4585 reversed this. I expect some tests to break.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton. Let me sweep through failing tests and see if there's anything easy to patch. 🧑‍💼

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

fusion.addOutput(out);
// padding output to multiple of 16 on allocation domain
auto&& [io, ii] = IterDomain::split(
out->axis(1), IrBuilder::create<Val>(16L, DataType::Index), true);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tagging @naoyam changed the test to only apply split on logical -> allocation.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

errr. seeing wrong results coming from

[  FAILED  ] 6 tests, listed below:
[  FAILED  ] DistributedTransformerTest.MLP_Backward/__half, where GetParam() = __half
[  FAILED  ] DistributedTransformerTest.MLP_Backward/__bfloat, where GetParam() = __bfloat
[  FAILED  ] DistributedTransformerTest.MHA_Backward/__half, where GetParam() = __half
[  FAILED  ] DistributedTransformerTest.MHA_Backward/__bfloat, where GetParam() = __bfloat
[  FAILED  ] DistributedTransformerTest.Backward/__half, where GetParam() = __half
[  FAILED  ] DistributedTransformerTest.Backward/__bfloat, where GetParam() = __bfloat

That's pretty scary. Keep digging.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

I need to double check the tensor layout produced in the tests.

TransposeTest.FusionTransposeSelfMapping
I'm guessing the transpose test failures are coming from ExprEval not respecting allocation domain set in the fusion.

TransposeTest.FusionReshapeSmallTransposeDimensionSchedule
I'm surprised that expr_eval can take that segment. double check the result as well.


// Parallelize type could include device from split.
ido->parallelize(s->outer()->getParallelType());
idi->parallelize(s->inner()->getParallelType());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wujingyue tagging you to try this guy out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks -- this will definitely help my #5229!

At this moment, I can't really take this two-line change because at head cacheBefore still uses TransformReplay::replayCasP not fullSelfReplay. However, not pressure! I'll come back to #5229 after you make more progress on this.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

For my own record.

HopperMatmulTest seems to be affected by this as well.

00:11:19 [  FAILED  ] 37 tests, listed below:
00:11:19 [  FAILED  ] HopperMatmulTest.HSH_NN_UseScheduler
00:11:19 [  FAILED  ] HopperMatmulTest.ScheduleWithTranslation
00:11:19 [  FAILED  ] HopperMatmulTest.EpilogueBiasPersistentBroadcastInputs
00:11:19 [  FAILED  ] HopperMatmulTest.HSS_NT_SplitKTMAStore
00:11:19 [  FAILED  ] MatmulSchedulerTest.PreBroadcastMmaBiasNeg
00:11:19 [  FAILED  ] MatmulNodeTranslationTest.AutomaticSchedulerMatmulNode/4dA_4dB, where GetParam() = (4, 4, true, false, false, matmul)
00:11:19 [  FAILED  ] LinearNodeTranslationTest.AutomaticSchedulerLinearNode/2dA_2dB, where GetParam() = (2, 2, -1, true, false, false)
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdGEMM/persistent_non_warpspec, where GetParam() = 2-byte object <00-01>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdGEMM_BroadcastInputs/dataparallel_warpspec, where GetParam() = 2-byte object <01-00>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdEpilogueBiasFusion/dataparallel_non_warpspec, where GetParam() = 2-byte object <00-00>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdEpilogueBiasFusion/persistent_warpspec, where GetParam() = 2-byte object <01-01>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdEpilogueSiluFusion/persistent_non_warpspec, where GetParam() = 2-byte object <00-01>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdHorizontalFusion/dataparallel_non_warpspec, where GetParam() = 2-byte object <00-00>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdHorizontalFusion/persistent_warpspec, where GetParam() = 2-byte object <01-01>
00:11:19 [  FAILED  ] MLPBenchmarkTest.FwdHorizontalFusion_BroadcastInputs/persistent_non_warpspec, where GetParam() = 2-byte object <00-01>
00:11:19 [  FAILED  ] MLPBenchmarkTest.BatchGEMM/dataparallel_warpspec, where GetParam() = 2-byte object <01-00>
00:11:19 [  FAILED  ] HopperMatmulTest/MLPGemmPersistentBroadcastInputs.NumWarpGroups/1, where GetParam() = 1
00:11:19 [  FAILED  ] MatmulSchedulerTest/AllocationDomainTest.BasicMatmul/2, where GetParam() = (true, false)
00:11:19 [  FAILED  ] General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MN_512_256_128_MmaMacro_m64_n128_k16_tma_store, where GetParam() = (true, false, false, 512, 256, 128, 8-byte object <10-00 80-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KN_512_256_128_MmaMacro_m64_n128_k16_tma_store_splitk_2, where GetParam() = (true, true, false, 512, 256, 128, 8-byte object <10-00 80-00 40-00 04-00>, 2)
00:11:19 [  FAILED  ] General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_128_MmaMacro_m64_n128_k16_tma_store, where GetParam() = (true, true, true, 512, 256, 128, 8-byte object <10-00 80-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MN_512_256_128_MmaMacro_m64_n128_k16_tma_store_splitk_2, where GetParam() = (true, false, false, 512, 256, 128, 8-byte object <10-00 80-00 40-00 04-00>, 2)
00:11:19 [  FAILED  ] General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MK_512_256_128_MmaMacro_m64_n128_k16_tma_store, where GetParam() = (true, false, true, 512, 256, 128, 8-byte object <10-00 80-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/KK_512_256_128_MmaMacro_m64_n128_k16_tma_store_splitk_2, where GetParam() = (true, true, true, 512, 256, 128, 8-byte object <10-00 80-00 40-00 04-00>, 2)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MN_512_256_64_MmaMacro_m64_n256_k16_tma_store_128BSwizzle, where GetParam() = (true, false, false, 512, 256, 64, 8-byte object <10-00 00-01 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MN_512_256_64_MmaMacro_m64_n32_k16_tma_store_64BSwizzle, where GetParam() = (true, false, false, 512, 256, 64, 8-byte object <10-00 20-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MK_512_256_64_MmaMacro_m64_n64_k16_tma_store_128BSwizzle, where GetParam() = (true, false, true, 512, 256, 64, 8-byte object <10-00 40-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KN_512_256_64_MmaMacro_m64_n128_k16_tma_store_128BSwizzle, where GetParam() = (true, true, false, 512, 256, 64, 8-byte object <10-00 80-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KN_512_256_64_MmaMacro_m64_n16_k16_tma_store_32BSwizzle, where GetParam() = (true, true, false, 512, 256, 64, 8-byte object <10-00 10-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_64_MmaMacro_m64_n256_k16_tma_store_128BSwizzle, where GetParam() = (true, true, true, 512, 256, 64, 8-byte object <10-00 00-01 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_64_MmaMacro_m64_n32_k16_tma_store_64BSwizzle, where GetParam() = (true, true, true, 512, 256, 64, 8-byte object <10-00 20-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MN_512_256_64_MmaMacro_m64_n64_k16_tma_store_128BSwizzle, where GetParam() = (true, false, false, 512, 256, 64, 8-byte object <10-00 40-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MK_512_256_64_MmaMacro_m64_n128_k16_tma_store_128BSwizzle, where GetParam() = (true, false, true, 512, 256, 64, 8-byte object <10-00 80-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MK_512_256_64_MmaMacro_m64_n16_k16_tma_store_32BSwizzle, where GetParam() = (true, false, true, 512, 256, 64, 8-byte object <10-00 10-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/KN_512_256_64_MmaMacro_m64_n256_k16_tma_store_128BSwizzle, where GetParam() = (true, true, false, 512, 256, 64, 8-byte object <10-00 00-01 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/KN_512_256_64_MmaMacro_m64_n32_k16_tma_store_64BSwizzle, where GetParam() = (true, true, false, 512, 256, 64, 8-byte object <10-00 20-00 40-00 04-00>, 1)
00:11:19 [  FAILED  ] Swizzle/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/KK_512_256_64_MmaMacro_m64_n64_k16_tma_store_128BSwizzle, where GetParam() = (true, true, true, 512, 256, 64, 8-byte object <10-00 40-00 40-00 04-00>, 1)
00:11:19 

@jjsjann123
Copy link
Collaborator Author

I'm shelving this PR for now. We have changed the approach taking here on how cacheBefore needs to be handled, which is causing too many test breakage that's tricky to debug at this time.

It's not blocking my grouped_mm layout support, so I'll jump back to this when the stack of PRs on that has been cleaned.
I'll come back to this with incremental changes. 🤞

bool only_valid_device_split = true;
for (Expr* expr : exprs | std::views::reverse) {
validateDeviceSplit(expr);
if (!isValidDeviceSplit(expr)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@protonu You might need this relaxed. (things coming from vectorize_helper.cpp and multidevice/...

I'll start a PR on the side for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - For now, I modified it here:
https://github.com/NVIDIA/Fuser/pull/5322/files

jjsjann123 added a commit that referenced this pull request Oct 23, 2025
Cherry-picked from #5184 
non-divisible split between logical->allocation domain could be used to
represent padding.
@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

Ha I'm seeing GB200 with

[1,2]<stdout>:[ RUN      ] DistributedTransformerTest.LoopSplitMHAFwd/__half
[1,0]<stdout>:unknown file: Failure
[1,0]<stdout>:C++ exception with description "
[1,0]<stdout>:Error from segmentation group 5: shape '[2, 128, 768]' is invalid for input of size 49152
[1,0]<stdout>:Exception raised from infer_size_impl at /opt/pytorch/pytorch/aten/src/ATen/InferSize.h:100 (most recent call first):
[1,0]<stdout>:frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xb4 (0x497ba4 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
[1,0]<stdout>:frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x68 (0x42e76c in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
[1,0]<stdout>:frame #2: <unknown function> + 0xfc8528 (0xffffe0ea8528 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #3: <unknown function> + 0x175d398 (0xffffe163d398 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #4: at::native::view(at::Tensor const&, c10::ArrayRef<long>) + 0x14 (0xffffe163d7d8 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #5: <unknown function> + 0x2fe2b50 (0xffffc4842b50 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
[1,0]<stdout>:frame #6: at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) + 0x94 (0xffffe1e82b24 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #7: <unknown function> + 0x4fc40b4 (0xffffe4ea40b4 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #8: <unknown function> + 0x4fc44c8 (0xffffe4ea44c8 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #9: at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) + 0x94 (0xffffe1e82b24 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #10: <unknown function> + 0x4a00668 (0xffffe48e0668 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #11: <unknown function> + 0x4a00b78 (0xffffe48e0b78 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #12: at::_ops::view::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>) + 0x184 (0xffffe1edcce4 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #13: at::native::reshape_symint(at::Tensor const&, c10::ArrayRef<c10::SymInt>) + 0x304 (0xffffe1649c44 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #14: <unknown function> + 0x24fed70 (0xffffe23ded70 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #15: at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>) + 0x184 (0xffffe1edc624 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
[1,0]<stdout>:frame #16: <unknown function> + 0x864cf0 (0xaaaaab304cf0 in bin/test_multidevice)
[1,0]<stdout>:frame #17: <unknown function> + 0x7ea3a4 (0xaaaaab28a3a4 in bin/test_multidevice)
[1,0]<stdout>:frame #18: <unknown function> + 0x683cf8 (0xaaaaab123cf8 in bin/test_multidevice)
[1,0]<stdout>:frame #19: <unknown function> + 0x7ea1f4 (0xaaaaab28a1f4 in bin/test_multidevice)
[1,0]<stdout>:frame #20: <unknown function> + 0x683cf8 (0xaaaaab123cf8 in bin/test_multidevice)
[1,0]<stdout>:frame #21: <unknown function> + 0x684070 (0xaaaaab124070 in bin/test_multidevice)
[1,0]<stdout>:frame #22: <unknown function> + 0xa9fbf4 (0xaaaaab53fbf4 in bin/test_multidevice)
[1,0]<stdout>:frame #23: <unknown function> + 0xaa0ab0 (0xaaaaab540ab0 in bin/test_multidevice)
[1,0]<stdout>:frame #24: <unknown function> + 0xacdd98 (0xaaaaab56dd98 in bin/test_multidevice)
[1,0]<stdout>:frame #25: <unknown function> + 0xad0e04 (0xaaaaab570e04 in bin/test_multidevice)
[1,0]<stdout>:frame #26: <unknown function> + 0xad7a94 (0xaaaaab577a94 in bin/test_multidevice)
[1,0]<stdout>:frame #27: <unknown function> + 0xb0ac74 (0xaaaaab5aac74 in bin/test_multidevice)
[1,0]<stdout>:frame #28: <unknown function> + 0xb0d6ec (0xaaaaab5ad6ec in bin/test_multidevice)
[1,0]<stdout>:frame #29: <unknown function> + 0xb0dfac (0xaaaaab5adfac in bin/test_multidevice)
[1,0]<stdout>:frame #30: <unknown function> + 0xb034f4 (0xaaaaab5a34f4 in bin/test_multidevice)
[1,0]<stdout>:frame #31: <unknown function> + 0xeb350c ([1,0]<stdout>:0xaaaaab95350c in bin/test_multidevice)
[1,0]<stdout>:frame #32: <unknown function> + 0xf0e5f0 (0xaaaaab9ae5f0 in bin/test_multidevice)
[1,0]<stdout>:frame #33: <unknown function> + 0xef5fc4 (0xaaaaab995fc4 in bin/test_multidevice)
[1,0]<stdout>:frame #34: <unknown function> + 0xef64b8 (0xaaaaab9964b8 in bin/test_multidevice)
[1,0]<stdout>:frame #35: <unknown function> + 0xef6ab4 (0xaaaaab996ab4 in bin/test_multidevice)
[1,0]<stdout>:frame #36: <unknown function> + 0xf03d60 (0xaaaaab9a3d60 in bin/test_multidevice)
[1,0]<stdout>:frame #37: <unknown function> + 0xef6c90 (0xaaaaab996c90 in bin/test_multidevice)
[1,0]<stdout>:frame #38: <unknown function> + 0x361194 (0xaaaaaae01194 in bin/test_multidevice)
[1,0]<stdout>:frame #39: <unknown function> + 0x284c4 (0xffffc15284c4 in /usr/lib/aarch64-linux-gnu/libc.so.6)
[1,0]<stdout>:frame #40: __libc_start_main + 0x98 (0xffffc1528598 in /usr/lib/aarch64-linux-gnu/libc.so.6)
[1,0]<stdout>:frame #41: <unknown function> + 0x382e30 (0xaaaaaae22e30 in bin/test_multidevice)

Which is a lot more promising than the wrong result to poke at. 🦅
There's also still another internal assert

[ RUN      ] MatmulSchedulerTest/AllocationDomainTest.BasicMatmul/0
unknown file: Failure
C++ exception with description " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/val_graph.cpp:82, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
Expected disjoint_set_it != disjoint_vals_.disjointSetMap().end() . 
Val group could not be found in graph associated with: rS14{256}

Exception raised from toGroup at /opt/pytorch/nvfuser/csrc/val_graph.cpp:82 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, long, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfb (0x55c4c8ce7553 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, long, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x69 (0x55c4c90d7b69 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #2: <unknown function> + 0xde03b3 (0x55c4c97673b3 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #3: <unknown function> + 0xc1880b (0x55c4c959f80b in /opt/pytorch/nvfuser/bin/test_matmul)
frame #4: <unknown function> + 0xc19197 (0x55c4c95a0197 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #5: <unknown function> + 0xc48281 (0x55c4c95cf281 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #6: <unknown function> + 0xc18314 (0x55c4c959f314 in /opt/pytorch/nvfuser/bin/test_matmul)
To reproduce: NVFUSER_TEST_RANDOM_SEED=frame #7: <unknown function> + 0xf8e473 (0x55c4c9915473 in /opt/pytorch/nvfuser/bin/test_matmul)
1761334588 NVFUSER_TEST_ATEN_RANDOM_SEED=0 test_nvfuser --gtest_filter='MatmulSchedulerTest/AllocationDomainTest.BasicMatmul/3'
frame #8: <unknown function> + 0x1040ae1 (0x55c4c99c7ae1 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #9: <unknown function> + 0x102811a (0x55c4c99af11a in /opt/pytorch/nvfuser/bin/test_matmul)
frame #10: <unknown function> + 0x10286d2 (0x55c4c99af6d2 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #11: <unknown function> + 0x1028d11 (0x55c4c99afd11 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #12: <unknown function> + 0x10368ba (0x55c4c99bd8ba in /opt/pytorch/nvfuser/bin/test_matmul)
frame #13: <unknown function> + 0x1028f0a (0x55c4c99aff0a in /opt/pytorch/nvfuser/bin/test_matmul)
frame #14: <unknown function> + 0x42d6d4 (0x55c4c8db46d4 in /opt/pytorch/nvfuser/bin/test_matmul)
frame #15: <unknown function> + 0x2a1ca (0x7f49c1eb11ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #16: __libc_start_main + 0x8b (0x7f49c1eb128b in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #17: <unknown function> + 0x447665 (0x55c4c8dce665 in /opt/pytorch/nvfuser/bin/test_matmul)
" thrown in the test body.

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Oct 25, 2025

DistributedTransformerTest.LoopSplitMHAFwd/__half reproed with only multiple gpu mpirun -n 2 ./bin/test_multidevice --gtest_filter="DistributedTransformerTest.LoopSplitMHAFwd/__half"

😢
I think I see where the issue is coming from. T35 is the cache tensor, which we decided that it's not going to keep the allocation domain. But in this example, the TV is not getting the right sharding on the IR. I'm surprised the error is coming from aten::reshape asserting on sizes via evaluator.

T13_g___half[iS215{( ceilDiv(48, blockDim.x) )}, iS220{( ceilDiv(256, blockDim.y) )}, iS221{blockDim.y}, iS219{1}, iS214{8}, iS216{blockDim.x}, ideviceIdx.x175{2}] (DeviceMesh{0 1})
 logical domain : (iS46{2}, iS47{16}, iS48{128}, iS49{48})
 allocation domain : (iS46{2}, iS48{128}, ideviceIdx.x127{2}, iS128{8}, iS49{48})
 contiguity: t t t t t
  Outer split: iS47{16} by factor 2 -> ideviceIdx.x127{2}, iS128{8}
  Merge: iS47{16} and iS49{48} -> iS174{768}
  Merge: iS46{2} and iS48{128} -> iS217{256}
  Split: iS217{256} by factor 1 -> iS218{256}, iS219{1}
  Outer split: iS174{768} by factor 2 -> ideviceIdx.x175{2}, iS176{384}
  Split: iS218{256} by factor blockDim.y -> iS220{( ceilDiv(256, blockDim.y) )}, iS221{blockDim.y}
  Split: iS176{384} by factor 8 -> iS213{48}, iS214{8}
  Split: iS213{48} by factor blockDim.x -> iS215{( ceilDiv(48, blockDim.x) )}, iS216{blockDim.x}
 loop domain : (iS215{( ceilDiv(48, blockDim.x) )}, iS220{( ceilDiv(256, blockDim.y) )}, iS221{blockDim.y}, iS219{1}, iS214{8}, iS216{blockDim.x}, ideviceIdx.x175{2})
T34_l___half[iblockIdx.x206{( ceilDiv(48, blockDim.x) )}, iblockIdx.y211{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y212{blockDim.y}, iUS210{1}, iV205{8}, ithreadIdx.x207{blockDim.x}, ideviceIdx.x172{2}] ca_pos( 4 ) (DeviceMesh{0 1})
 logical domain : (iS155{2}, iS156{16}, iS157{128}, iS158{48})
 allocation domain : (iS155{2}, iS157{128}, ideviceIdx.x159{2}, iS160{8}, iS158{48})
 contiguity: t t t t t
  Outer split: iS156{16} by factor 2 -> ideviceIdx.x159{2}, iS160{8}
  Merge: iS156{16} and iS158{48} -> iS171{768}
  Merge: iS155{2} and iS157{128} -> iS208{256}
  Split: iS208{256} by factor 1 -> iS209{256}, iUS210{1}
  Outer split: iS171{768} by factor 2 -> ideviceIdx.x172{2}, iS173{384}
  Split: iS209{256} by factor blockDim.y -> iblockIdx.y211{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y212{blockDim.y}
  Split: iS173{384} by factor 8 -> iS204{48}, iV205{8}
  Split: iS204{48} by factor blockDim.x -> iblockIdx.x206{( ceilDiv(48, blockDim.x) )}, ithreadIdx.x207{blockDim.x}
 loop domain : (iblockIdx.x206{( ceilDiv(48, blockDim.x) )}, iblockIdx.y211{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y212{blockDim.y}, iUS210{1}, iV205{8}, ithreadIdx.x207{blockDim.x}, ideviceIdx.x172{2})
T17_l___half[iblockIdx.x197{( ceilDiv(48, blockDim.x) )}, iblockIdx.y202{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y203{blockDim.y}, iUS201{1}, iS196{8}, ithreadIdx.x198{blockDim.x}, ideviceIdx.x169{2}] ca_pos( 7 ) produce_pos( 4 ) (DeviceMesh{0 1})
 root domain : (iS54{2}, iS55{16}, iS56{128}, iS57{48})
 logical domain : (iS54{2}, iS56{128}, iS55{16}, iS57{48})
 allocation domain : (iS54{2}, ideviceIdx.x131{2}, iS132{8}, iS56{128}, iS57{48})
 contiguity: t t t t t
  Outer split: iS55{16} by factor 2 -> ideviceIdx.x131{2}, iS132{8}
  Merge: iS55{16} and iS57{48} -> iS168{768}
  Merge: iS54{2} and iS56{128} -> iS199{256}
  Split: iS199{256} by factor 1 -> iS200{256}, iUS201{1}
  Outer split: iS168{768} by factor 2 -> ideviceIdx.x169{2}, iS170{384}
  Split: iS200{256} by factor blockDim.y -> iblockIdx.y202{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y203{blockDim.y}
  Split: iS170{384} by factor 8 -> iS195{48}, iS196{8}
  Split: iS195{48} by factor blockDim.x -> iblockIdx.x197{( ceilDiv(48, blockDim.x) )}, ithreadIdx.x198{blockDim.x}
 loop domain : (iblockIdx.x197{( ceilDiv(48, blockDim.x) )}, iblockIdx.y202{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y203{blockDim.y}, iUS201{1}, iS196{8}, ithreadIdx.x198{blockDim.x}, ideviceIdx.x169{2})
T35_l___half[iblockIdx.x188{( ceilDiv(48, blockDim.x) )}, iblockIdx.y193{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y194{blockDim.y}, iUS192{1}, iS187{8}, ithreadIdx.x189{blockDim.x}, ideviceIdx.x166{2}] ca_pos( 4 ) produce_pos( 7 ) (DeviceMesh{0 1})
 root domain : (iS161{2}, iS162{128}, iS163{16}rf, iS164{48}rf)
  Merge: iS163{16}rf and iS164{48}rf -> iS165{768}rf
 logical domain : (iS161{2}, iS162{128}, iS165{768}rf)
 contiguity: t t t
  Merge: iS161{2} and iS162{128} -> iS190{256}
  Split: iS190{256} by factor 1 -> iS191{256}, iUS192{1}
  Merge: iS163{16}rf and iS164{48}rf -> iS165{768}rf
  Outer split: iS165{768}rf by factor 2 -> ideviceIdx.x166{2}, iS167{384}
  Split: iS191{256} by factor blockDim.y -> iblockIdx.y193{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y194{blockDim.y}
  Split: iS167{384} by factor 8 -> iS186{48}, iS187{8}
  Split: iS186{48} by factor blockDim.x -> iblockIdx.x188{( ceilDiv(48, blockDim.x) )}, ithreadIdx.x189{blockDim.x}
 loop domain : (iblockIdx.x188{( ceilDiv(48, blockDim.x) )}, iblockIdx.y193{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y194{blockDim.y}, iUS192{1}, iS187{8}, ithreadIdx.x189{blockDim.x}, ideviceIdx.x166{2})
T18_g___half[iblockIdx.x180{( ceilDiv(48, blockDim.x) )}, iblockIdx.y184{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y185{blockDim.y}, iUS183{1}, iV179{8}, ithreadIdx.x181{blockDim.x}, ideviceIdx.x133{2}] ca_pos( 4 ) produce_pos( 4 ) (DeviceMesh{0 1})
 logical domain : (iS58{2}, iS59{128}, iS64{768})
 allocation domain : (iS58{2}, iS59{128}, ideviceIdx.x133{2}, iS134{384})
 contiguity: t t t t
  Outer split: iS64{768} by factor 2 -> ideviceIdx.x133{2}, iS134{384}
  Split: iS134{384} by factor 8 -> iS178{48}, iV179{8}
  Merge: iS58{2} and iS59{128} -> iS177{256}
  Split: iS178{48} by factor blockDim.x -> iblockIdx.x180{( ceilDiv(48, blockDim.x) )}, ithreadIdx.x181{blockDim.x}
  Split: iS177{256} by factor 1 -> iS182{256}, iUS183{1}
  Split: iS182{256} by factor blockDim.y -> iblockIdx.y184{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y185{blockDim.y}
 loop domain : (iblockIdx.x180{( ceilDiv(48, blockDim.x) )}, iblockIdx.y184{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y185{blockDim.y}, iUS183{1}, iV179{8}, ithreadIdx.x181{blockDim.x}, ideviceIdx.x133{2})
Inputs:
  T13_g___half[iS215{( ceilDiv(48, blockDim.x) )}, iS220{( ceilDiv(256, blockDim.y) )}, iS221{blockDim.y}, iS219{1}, iS214{8}, iS216{blockDim.x}, ideviceIdx.x175{2}] (DeviceMesh{0 1})
Outputs:
  T18_g___half[iblockIdx.x180{( ceilDiv(48, blockDim.x) )}, iblockIdx.y184{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y185{blockDim.y}, iUS183{1}, iV179{8}, ithreadIdx.x181{blockDim.x}, ideviceIdx.x133{2}] ca_pos( 4 ) produce_pos( 4 ) (DeviceMesh{0 1})

%kernel_math {
T34_l___half[iblockIdx.x206{( ceilDiv(48, blockDim.x) )}, iblockIdx.y211{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y212{blockDim.y}, iUS210{1}, iV205{8}, ithreadIdx.x207{blockDim.x}, ideviceIdx.x172{2}] ca_pos( 4 ) (DeviceMesh{0 1})
   = Set( T13_g___half[iS215{( ceilDiv(48, blockDim.x) )}, iS220{( ceilDiv(256, blockDim.y) )}, iS221{blockDim.y}, iS219{1}, iS214{8}, iS216{blockDim.x}, ideviceIdx.x175{2}] (DeviceMesh{0 1}), cache_op=Streaming )
T17_l___half[iblockIdx.x197{( ceilDiv(48, blockDim.x) )}, iblockIdx.y202{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y203{blockDim.y}, iUS201{1}, iS196{8}, ithreadIdx.x198{blockDim.x}, ideviceIdx.x169{2}] ca_pos( 7 ) produce_pos( 4 ) (DeviceMesh{0 1})
   = Set.Permute( T34_l___half[iblockIdx.x206{( ceilDiv(48, blockDim.x) )}, iblockIdx.y211{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y212{blockDim.y}, iUS210{1}, iV205{8}, ithreadIdx.x207{blockDim.x}, ideviceIdx.x172{2}] ca_pos( 4 ) (DeviceMesh{0 1}), cache_op=Streaming )
T35_l___half[iblockIdx.x188{( ceilDiv(48, blockDim.x) )}, iblockIdx.y193{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y194{blockDim.y}, iUS192{1}, iS187{8}, ithreadIdx.x189{blockDim.x}, ideviceIdx.x166{2}] ca_pos( 4 ) produce_pos( 7 ) (DeviceMesh{0 1}) = view( T17_l___half[iblockIdx.x197{( ceilDiv(48, blockDim.x) )}, iblockIdx.y202{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y203{blockDim.y}, iUS201{1}, iS196{8}, ithreadIdx.x198{blockDim.x}, ideviceIdx.x169{2}] ca_pos( 7 ) produce_pos( 4 ) (DeviceMesh{0 1}) )
T18_g___half[iblockIdx.x180{( ceilDiv(48, blockDim.x) )}, iblockIdx.y184{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y185{blockDim.y}, iUS183{1}, iV179{8}, ithreadIdx.x181{blockDim.x}, ideviceIdx.x133{2}] ca_pos( 4 ) produce_pos( 4 ) (DeviceMesh{0 1})
   = Set( T35_l___half[iblockIdx.x188{( ceilDiv(48, blockDim.x) )}, iblockIdx.y193{( ceilDiv(256, blockDim.y) )}, ithreadIdx.y194{blockDim.y}, iUS192{1}, iS187{8}, ithreadIdx.x189{blockDim.x}, ideviceIdx.x166{2}] ca_pos( 4 ) produce_pos( 7 ) (DeviceMesh{0 1}), cache_op=Streaming )
} // %kernel_math

This issue showed up during shape inference to figure out output buffer size for allocation. vvv

(gdb) bt
#0  0x0000000000499254 in c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so
#1  0x000000000043544c in c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so
#2  0x0000f21c8a037068 in void at::infer_size_impl<c10::ArrayRef<long>, long, c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#3  0x0000f21c8a7fa654 in at::native::view_impl(at::Tensor const&, c10::ArrayRef<long>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#4  0x0000f21c8a0394c4 in at::native::view(at::Tensor const&, c10::ArrayRef<long>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#5  0x0000f21c6b4453b0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>), &at::(anonymous namespace)::(anonym
ous namespace)::wrapper_CUDA__view>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<c10::SymInt> > >, at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>)>::call(c10::OperatorKernel*, c10::DispatchKeySet
, at::Tensor const&, c10::ArrayRef<c10::SymInt>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so
#6  0x0000f21c8a05283c in at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#7  0x0000f21c8a078844 in torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#8  0x0000f21c8a078c64 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>), &torch::ADInpl
aceOrView::(anonymous namespace)::view>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt> > >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>)>::
call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#9  0x0000f21c8a05283c in at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#10 0x0000f21c8a072ea8 in torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#11 0x0000f21c8a0733c4 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>), &torch::autogr
ad::VariableType::(anonymous namespace)::view>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt> > >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymI
nt>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#12 0x0000f21c8a0567ac in at::_ops::view::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#13 0x0000f21c8a03d694 in at::native::reshape_symint(at::Tensor const&, c10::ArrayRef<c10::SymInt>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#14 0x0000f21c8a05e1a0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<c10::SymInt> > >, at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#15 0x0000f21c8a05643c in at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#16 0x0000bff60a8be910 in at::Tensor::reshape (this=0xbff6411017a0, shape=...) at /usr/local/lib/python3.12/dist-packages/torch/include/ATen/core/TensorBody.h:3270
#17 0x0000bff60a8948c8 in nvfuser::ReshapeOp::evaluate (this=0xf218340b0910, ee=..., inputs=std::vector of length 1, capacity 1 = {...}) at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:2439
#18 0x0000bff60a828544 in nvfuser::Expr::evaluate (this=0xf218340b0910, ee=..., known_values=std::unordered_map with 2 elements = {...})
    at /opt/pytorch/nvfuser/csrc/ir/base_nodes.cpp:408
#19 0x0000bff60a5ad000 in nvfuser::ExpressionEvaluator::evaluate (this=0xffffe663b3b8, value=0xf21834091c20, known_values=std::unordered_map with 2 elements = {...})
    at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:275
#20 0x0000bff60a828450 in nvfuser::Expr::evaluate (this=0xf218340b06b0, ee=..., known_values=std::unordered_map with 2 elements = {...})
    at /opt/pytorch/nvfuser/csrc/ir/base_nodes.cpp:402
#21 0x0000bff60a5ad000 in nvfuser::ExpressionEvaluator::evaluate (this=0xffffe663b3b8, value=0xf2183405af70, known_values=std::unordered_map with 2 elements = {...})
    at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:275
#22 0x0000bff60a5ace30 in nvfuser::ExpressionEvaluator::evaluate (this=0xffffe663b3b8, value=0xf2183405af70) at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:258
#23 0x0000bff60ac9b730 in nvfuser::inferTensorShapes (tv=0xf2183405af70, expr_eval=...) at /opt/pytorch/nvfuser/csrc/runtime/allocations.cpp:885
#24 0x0000bff60ac985f0 in nvfuser::getBufferInfos (expr_eval=..., index_dtype=..., tvs=std::vector of length 1, capacity 1 = {...})
    at /opt/pytorch/nvfuser/csrc/runtime/allocations.cpp:355
#25 0x0000bff60acf0f74 in nvfuser::KernelExecutor::initializeExecutorEntry (this=0xf2183404ea60, executor_entry=..., args=..., launch_constraints=..., compile_params=...,
    output_args=..., index_type=...) at /opt/pytorch/nvfuser/csrc/runtime/executor.cpp:763
#26 0x0000bff60acf34b0 in nvfuser::KernelExecutor::run (this=0xf2183404ea60, args=..., output_args=..., launch_constraints=..., compile_params=...)
    at /opt/pytorch/nvfuser/csrc/runtime/executor.cpp:1065

@jjsjann123
Copy link
Collaborator Author

!test

// NOTE: this doesn't feel right, we have to mark contiguity on axis(0) as
// `false` to avoid accidntal indexing collapsing, this should be figured out
// by indexing from the ceilDiv.
out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding this issue correctly?
a) The tensor actually is contiguous with respect to this allocation domain, which has size M, ceilDiv(K, 16), 16.
b) The tensor winds up not being contiguous with respect to its logical domain which is of size M, K, because the nondivisible split adds some padding to K.
b) By "indexing collapsing" you mean it does contiguous indexing so that stride is not part of the index? Is that wrong? It seems like indexing as contiguous allocation is what we want here.

My question is what specifically goes wrong when allocation is set to contiguous?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are absolutely correct about a) and b).

Indexing collapsing is wrong here, because we are mapping from logical to allocation, which is not accessing contiguous memory (because of non-divisible split).

This is the before and after of the indexing.

with false contiguity flag

root@812ada01cb39:/opt/pytorch/nvfuser# NVFUSER_DUMP=cuda_kernel ./bin/test_layout_op --gtest_filter="*LogicalAndAllocationSizes"
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *LogicalAndAllocationSizes
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from LayoutOpTest
[ RUN      ] LayoutOpTest.LogicalAndAllocationSizes

======= Codegen output for kernel: nvfuser_pointwise_f0_c1_r0_g0 =======

// Codegen generated code
__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 2, 3> T1) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i1;
  i1 = i0 % T0.logical_size[1LL];
  nvfuser_index_t i2;
  i2 = i0 / T0.logical_size[1LL];
  if ((i0 < (T0.logical_size[0LL] * T0.logical_size[1LL]))) {
    Array<float, 1LL, 1> T2;
    T2[0LL] = 0LL;
    T2[0LL]
       = T0[((T0.alloc_stride[0LL] * i2) + (T0.alloc_stride[1LL] * i1))];
    Array<float, 1LL, 1> T3;
    T3[0LL]
       = T2[0LL];
    T1[(i1 + (T1.alloc_stride[0LL] * i2))]
       = T3[0LL];
  }
}

======================================

[       OK ] LayoutOpTest.LogicalAndAllocationSizes (966 ms)
[----------] 1 test from LayoutOpTest (966 ms total)

with true contiguity flag

root@558d9dfeefb8:/opt/pytorch/nvfuser# NVFUSER_DUMP=cuda_kernel ./bin/test_layout_op --gtest_filter="*LogicalAndAllocationSizes"
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *LogicalAndAllocationSizes
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from LayoutOpTest
[ RUN      ] LayoutOpTest.LogicalAndAllocationSizes

======= Codegen output for kernel: nvfuser_pointwise_f0_c1_r0_g0 =======

// Codegen generated code
__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 2, 3> T1) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
  if ((i0 < (T0.logical_size[0LL] * T0.logical_size[1LL]))) {
    Array<float, 1, 1> T2;
    T2[0] = 0;
    T2[0]
       = T0[((T0.alloc_stride[0LL] * (i0 / T0.logical_size[1LL])) + (T0.alloc_stride[1LL] * (i0 % T0.logical_size[1LL])))];
    Array<float, 1, 1> T3;
    T3[0]
       = T2[0];
    T1[i0]
       = T3[0];
  }
}

======================================

/opt/pytorch/nvfuser/tests/cpp/test_layout_op.cpp:128: Failure
Value of: t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k))
  Actual: false
Expected: true

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants