Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
fe90fb5
PR0: Relax assert on non-device split on allocation domain
jjsjann123 Sep 18, 2025
a0df5e9
relaxing the check
jjsjann123 Sep 18, 2025
5097533
Adding test validating vectorization
jjsjann123 Sep 18, 2025
d4b7c8b
renaming
jjsjann123 Sep 19, 2025
4b07e79
clangformat
jjsjann123 Sep 19, 2025
051fc9e
I think it's working now!
jjsjann123 Sep 19, 2025
bf85c0b
clangformat
jjsjann123 Sep 19, 2025
6ff1050
quick patch
jjsjann123 Sep 22, 2025
f2f43be
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 22, 2025
a32f54b
fix clearing allocation domain on cache for cacheBefore
jjsjann123 Sep 22, 2025
cf6e609
revert changes
jjsjann123 Sep 23, 2025
b303923
updating tests
jjsjann123 Sep 23, 2025
17dbf23
i was dumb as always
jjsjann123 Sep 23, 2025
c2a3aeb
why is it so hard for me
jjsjann123 Sep 23, 2025
1a156be
Apply suggestions from code review
jjsjann123 Sep 23, 2025
2081d0c
clangformat
jjsjann123 Sep 23, 2025
6f674aa
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Sep 23, 2025
4f8ecfc
Merge remote-tracking branch 'origin/main' into jj/allocation_PR_0
jjsjann123 Sep 26, 2025
ee37038
reverting selfReplay & cacheBefore changes per reviewer's comments
jjsjann123 Sep 26, 2025
f87e99d
wip
jjsjann123 Sep 26, 2025
c5155ff
wip
jjsjann123 Sep 26, 2025
ded16ec
wip
jjsjann123 Sep 26, 2025
f02440c
wip
jjsjann123 Sep 26, 2025
aa084bc
errr zip
jjsjann123 Sep 26, 2025
f82ad1f
wip
jjsjann123 Sep 26, 2025
d9a33d8
err, forgot to push something last night
jjsjann123 Sep 26, 2025
6dda5e2
typo
jjsjann123 Sep 26, 2025
173a7e9
skipping checks
jjsjann123 Sep 26, 2025
98654a0
wip
jjsjann123 Sep 26, 2025
a870f9d
relaxing checks in tests
jjsjann123 Sep 26, 2025
bca1734
wip
jjsjann123 Sep 26, 2025
d91ac03
clean up IDs for cacheBefore
jjsjann123 Sep 26, 2025
eff3069
clear up definition of output TV for cacheBefore
jjsjann123 Sep 26, 2025
2105e1e
fixing one alias test!
jjsjann123 Sep 27, 2025
bdaaccb
wip
jjsjann123 Sep 27, 2025
fdf9dba
fixing definition
jjsjann123 Sep 27, 2025
65022bd
wip
jjsjann123 Sep 27, 2025
5431648
not set allocation domain when original output doesn't have it
jjsjann123 Sep 27, 2025
75b06b5
update output itertype
jjsjann123 Sep 27, 2025
c2bf4cf
wip
jjsjann123 Sep 27, 2025
c5d66b6
wip
jjsjann123 Sep 27, 2025
b1836f5
wip
jjsjann123 Sep 27, 2025
7ee9317
fixing contiguity in fullselfreplay
jjsjann123 Sep 27, 2025
f7bbab2
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 27, 2025
afccea0
fixing transpose tests
jjsjann123 Sep 27, 2025
482afc8
set parallelization type after fullseflreplay
jjsjann123 Sep 27, 2025
255055d
fix mark alias
jjsjann123 Oct 1, 2025
599d809
fixing alias analysis
jjsjann123 Oct 1, 2025
eadf148
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 24, 2025
fe9c1f6
quick patch on nvfuser::schedule_matmul::Common::cacheBefore
jjsjann123 Oct 24, 2025
f927809
quick patch on nvfuser::schedule_matmul::Common::updateIdModel
jjsjann123 Oct 25, 2025
493434a
agent you can do better!
jjsjann123 Oct 25, 2025
92fb6f9
err
jjsjann123 Oct 25, 2025
642d9a8
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Oct 28, 2025
10eb4e0
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Oct 30, 2025
6ff66f2
try self replay so allocation domain is preserved for multi device
jjsjann123 Oct 30, 2025
2950054
err revert something that's not working
jjsjann123 Oct 30, 2025
a30432c
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 12, 2025
b1e7352
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ bool okToRelayout(
const TensorView* tv,
const Layout& new_layout,
const EmptyAllocationAs empty_allocation_as) {
// we can merge this with the one below
// when using logical domain as the allocation domain, we can basically ignore the layout when it's not used by codegen
if (empty_allocation_as == EmptyAllocationAs::kLogical && !ir_utils::canUsePresetAllocationDomain(tv, false)) {
return true;
}

if (empty_allocation_as == EmptyAllocationAs::kUndetermined &&
!tv->hasAllocation()) {
return true;
Expand Down
3 changes: 3 additions & 0 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,9 @@ class VectorizeValidator : public OptInDispatch {
TensorView* tv,
std::string name,
int64_t vector_word_size_bit) {
if (!ir_utils::canUsePresetAllocationDomain(tv)) {
return;
}
// aten_element_size_bit is the minimum unit (one element) of tv's
// corresponding at::Tensor. It may or may not be the same as
// dataTypeSizeBit(tv->dtype()), because we support non-ATen data types as
Expand Down
5 changes: 5 additions & 0 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ class NVF_API IterDomain : public Val {
return getIterType() == IterType::Iteration;
}

IterDomain* resetRFactorProduct(bool is_rfactor_domain=false) {
is_rfactor_domain_ = is_rfactor_domain;
return this;
}

bool isRFactorProduct() const {
return is_rfactor_domain_;
}
Expand Down
52 changes: 52 additions & 0 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,4 +1747,56 @@ bool isParallelizedBy(const std::vector<IterDomain*>& ids, ParallelType pt) {
ids, [&](IterDomain* id) { return id->getParallelType() == pt; });
}

bool canUsePresetAllocationDomain(const TensorView* tv, bool ignore_empty_alloc) {
if (ignore_empty_alloc && !tv->hasAllocation()) {
return false;
}
// Honor the allocation domain if the tensor is global or Hopper MMA's
// output
if (tv->getMemoryType() == MemoryType::Global ||
(tv->definition()->isA<MmaOp>() &&
isHopper(tv->definition()->as<MmaOp>()->macro()))) {
return true;
}
// If it's a shared memory tensor, the set domain is likely
// valid if Swizzle or Bulk is used. Also, if the allocation
// domain is just a permutation of the loop domain, use the
// set allocation domain. This seems to happen only with
// AllocationDomainTest.TransposedIntermediate.
if (tv->getMemoryType() == MemoryType::Shared) {
if (std::any_of(
tv->getAllocationDomain().begin(),
tv->getAllocationDomain().end(),
[](IterDomain* allocation_domain) {
return dynamic_cast<Swizzle*>(
allocation_domain->definition()) != nullptr ||
allocation_domain->getParallelType() == ParallelType::Bulk;
}) ||
std::is_permutation(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
tv->getAllocationDomain().begin(),
tv->getAllocationDomain().end())) {
return true;
}

// Honor the set allocation domain if the tensor is used by a
// TMA store or MmaOp
if (std::ranges::any_of(tv->uses(), [](Expr* expr) {
return ir_utils::isCpAsyncBulkStore(expr) || expr->isA<MmaOp>();
})) {
return true;
}

// If a shared memory output produced by scatter has an
// allocation domain explicitly set, it's likely to be the
// valid allocation domain.
if (auto def = tv->definition();
def != nullptr && def->isA<ScatterOp>()) {
return true;
}
}
return false;
}

} // namespace nvfuser::ir_utils
2 changes: 2 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -861,4 +861,6 @@ std::vector<IterDomain*> propagateScatterAllocationDomain(

bool isParallelizedBy(const std::vector<IterDomain*>& ids, ParallelType pt);

bool canUsePresetAllocationDomain(const TensorView* tv, bool ignore_empty_alloc=true);

} // namespace nvfuser::ir_utils
23 changes: 17 additions & 6 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,21 @@ void Common::updateIdModel() {
// IdModel
std::unordered_map<ValGroup, MatmulDimRole> new_id_roles;
for (auto& [k, v] : id_roles_) {
const ValGroup& new_group = new_graph.toGroup(k->front());
new_id_roles.emplace(new_group, v);
// We need to traverse the ValGroup to find the remaining IDs that remains in the new id_model. This is because that cacheBefore could have eliminated the reduction ID.

// e.g.
// output [m4, n5, rk6] = mma(A [m0, k1], B [n2, k3])
// would become
// cache [m6, n7, rk8] = mma(A [m0, k1], B [n2, k3])
// output [m4, n5 ] = set(cache [m6, n7, rk8])
//
// So the old role rK6 wouldn't be mapped in new_graph.
auto old_vg = std::ranges::find_if(k->vector(), [&new_graph](Val* vg){return new_graph.hasGroup(vg);});
NVF_ERROR(
old_vg != k->vector().end(),
"Old ValGroup not found in new ValGraph"
);
new_id_roles.emplace(new_graph.toGroup(*old_vg), v);
}
id_roles_ = new_id_roles;
}
Expand Down Expand Up @@ -290,10 +303,8 @@ TensorView* Common::cacheBefore(TensorView* orig, LoadStoreOpType op_type) {
const std::vector<IterDomain*> cache_logical = c->getLogicalDomain();
NVF_ERROR(orig_logical.size() == cache_logical.size());
for (size_t i : arange(orig_logical.size())) {
// The domain of orig gets transferred to c and a new domain is applied to
// orig
ValGroup vg = graph_->toGroup(cache_logical[i]);
graph_->initializeVal(orig_logical[i], vg);
ValGroup vg = graph_->toGroup(orig_logical[i]);
graph_->initializeVal(cache_logical[i], vg);
}

return c;
Expand Down
108 changes: 99 additions & 9 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,18 +1128,105 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {
"before computeAt.");
}

// Create Producer Domain
// This domain will be the consumer which needs a new domain, so replace the
// producers domain with this domain.

auto* producer = IrBuilder::createInContainer<TensorView>(
container(),
IrBuilder::createInContainer<TensorDomain>(container(), domain()),
getDataType().value());
// TODO: 1. test reshape; 2. test reduction
// We want the producer domain to preserve `root` & `logical`
// meanwhile, we want consumer Tensor to preserve `logical` & `allocation` (while erasing all reductions).

TensorView* producer;

if (definition()->isA<ScatterOp>()) {
// TODO: is there any way to replay a scatter op?!
// scatter output's loop is not connected to its root.
NVF_ERROR(!domain()->hasRoot(), "scatter output's root is not replayed in cacheBefore");

std::vector<IterDomain*> logical;
std::vector<IterDomain*> loop;
std::unordered_map<IterDomain*, IterDomain*> map_cloned_ids;

std::ranges::transform(domain()->logical(), std::back_inserter(logical), [&](IterDomain* id) {
IterDomain* cloned_id = IrBuilder::createInContainer<IterDomain>(container(), id);
map_cloned_ids[id] = cloned_id;
return cloned_id;
});
std::ranges::transform(domain()->loop(), std::back_inserter(loop), [&](IterDomain* id) {
if (auto it = map_cloned_ids.find(id); it != map_cloned_ids.end()) {
// reuse cloned_ids
return it->second;
}
return IrBuilder::createInContainer<IterDomain>(container(), id);
});
producer = IrBuilder::createInContainer<TensorView>(
container(),
IrBuilder::createInContainer<TensorDomain>(container(), logical, loop, TensorDomain::getContiguityFilledWith(logical, true), /*skip_loop_validation=*/true),
getDataType().value());
// TODO: we are not replaying the loop domain from consumer to producer, is that the right thing to do?!
} else {
// Create Producer Domain
// We only need root for full self replay.
std::vector<IterDomain*> root;
std::ranges::transform(domain()->hasRoot()?domain()->root():domain()->logical(), std::back_inserter(root), [&](IterDomain* id) {
return IrBuilder::createInContainer<IterDomain>(container(), id);
});

producer = IrBuilder::createInContainer<TensorView>(
container(),
IrBuilder::createInContainer<TensorDomain>(container(), root, root, root, TensorDomain::getContiguityFilledWith(root, true)),
getDataType().value());
// replay from `root`->`loop` on producer
producer->setDomain(TransformReplay::fullSelfReplay(producer->domain(), domain()));
}

// clean up consumer domain to wipe out root and all reduction IDs
std::vector<IterDomain*> logical_dom;
std::vector<IterDomain*> alloc_dom;
std::vector<IterDomain*> loop_dom;
std::vector<std::optional<bool>> contiguity;

// NOTE: I need to clear definition otherwise BestEffortReplay will not work with dangling sources
// create an issue for this, use the example from ./bin/test_nvfuser --gtest_filter="PointwiseTest.VectorizeWithBroadcastAndReshape1"
// copy non-reduction IDs onto logical and loop
std::ranges::copy_if(
domain()->logical() | std::views::transform([](IterDomain* id) { id->setDefinition(nullptr); return id->resetRFactorProduct(); }),
std::back_inserter(logical_dom),
[](IterDomain* id) {return !id->isReduction();});
if (definition()->isA<ScatterOp>()) {
// NOTE: this doesn't feel right. we would still want to replay the loop domain
// we are basically dropping transformations on loop domain for scatter op during cacheBefore
loop_dom = logical_dom;
} else {
std::ranges::copy_if(
domain()->loop() | std::views::transform([](IterDomain* id) { return id->resetRFactorProduct(); }),
std::back_inserter(loop_dom),
[](IterDomain* id) {return !id->isReduction();});
}
for (auto&& [id, c] : zip(domain()->hasAllocation() ? domain()->allocation() : domain()->logical(), domain()->contiguity())) {
if (id->isReduction()) {
continue;
}
id->resetRFactorProduct();
if (domain()->hasAllocation()) {
alloc_dom.push_back(id);
}
contiguity.push_back(c);
}
// TODO: We also need to clear all rfactor across IDs between logical->loop and logical->allocation.

// Set domain of consumer
TensorView* consumer = this;

consumer->setDomain(IrBuilder::createInContainer<TensorDomain>(
container(),
std::vector<IterDomain*>{},
logical_dom,
alloc_dom,
loop_dom,
contiguity));

// TODO: figure out scatter special handling.
// if (!producer->definition()->isA<ScatterOp>()) {
// } else {
// }

/* FIXME
std::vector<IterDomain*> new_logical_domain;
new_logical_domain.reserve(getLogicalDomain().size());
for (IterDomain* dom : getLogicalDomain() | TensorDomain::kNoReductions) {
Expand All @@ -1152,6 +1239,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {
container(),
new_logical_domain,
TensorDomain::getContiguityFilledWith(new_logical_domain, true)));
*/

// Insert producer - Cache_Before (CB) - before this TV.
// Before: Prev TV -> [Definition Op] -> This TV
Expand All @@ -1170,6 +1258,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {
// definition_ is no longer valid
// setDefinition(nullptr);

/* FIXME
// We do not want to reproduce the loop domain if it's for
// scatter. Recall that the loop domain of the scatter op is derived
// from the logical domain of the scatter index tensor. Here, the
Expand All @@ -1186,6 +1275,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {
producer, consumer->getLogicalDomain()),
true);
}
*/

if (consumer->hasDeviceMesh()) {
producer->setDeviceMesh(consumer->getDeviceMesh());
Expand Down
14 changes: 11 additions & 3 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class ReplaySelf : public ReplayTransformations {
s->outer()->getIterType(),
s->inner()->getIterType());

// Parallelize type could include device from split.
ido->parallelize(s->outer()->getParallelType());
idi->parallelize(s->inner()->getParallelType());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wujingyue tagging you to try this guy out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks -- this will definitely help my #5229!

At this moment, I can't really take this two-line change because at head cacheBefore still uses TransformReplay::replayCasP not fullSelfReplay. However, not pressure! I'll come back to #5229 after you make more progress on this.


// Remove mapped id from loop IDs
loop_ids_.erase(mapped);

Expand Down Expand Up @@ -106,7 +110,7 @@ class ReplaySelf : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped, m->out()->isRFactorProduct());

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down Expand Up @@ -147,11 +151,15 @@ class ReplaySelf : public ReplayTransformations {
// output domain also an rfactor
const auto resize_out_logical = resize->out()->isRFactorProduct();

// Mark output IterType
const auto resize_out_iter_type = resize->out()->getIterType();

auto replayed_out = IterDomain::resize(
mapped,
resize->leftExpand(),
resize->rightExpand(),
resize_out_logical);
resize_out_logical,
resize_out_iter_type);

loop_ids_.erase(mapped);

Expand Down Expand Up @@ -234,7 +242,7 @@ TensorDomain* TransformReplay::fullSelfReplay(
new_self_root->root(),
new_logical_domain,
new_domain,
self->contiguity());
TensorDomain::getContiguityFilledWith(new_logical_domain, true));
}
}

Expand Down
2 changes: 0 additions & 2 deletions tests/cpp/test_allocation_domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,6 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC4d_cacheBefore) {

ASSERT_EQ(tv0->getAllocationDomain(), tv0_nhwc);
ASSERT_EQ(tv1->getAllocationDomain(), expected_new_allocation_domain);
ASSERT_EQ(tv2->getAllocationDomain(), tv1_nhwc);

for (auto tv : {tv1, tv2}) {
// [N, C, H, W]
Expand Down Expand Up @@ -708,7 +707,6 @@ TEST_F(AllocationDomainTest, NHWC2d_To_NHWC2d_cacheBefore) {

ASSERT_EQ(tv0->getAllocationDomain(), tv0_2d);
ASSERT_EQ(tv1->getAllocationDomain(), expected_new_allocation_domain);
ASSERT_EQ(tv2->getAllocationDomain(), tv1_2d);

for (auto tv : {tv1, tv2}) {
tv->split(0, 128);
Expand Down
Loading
Loading