Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
fe90fb5
PR0: Relax assert on non-device split on allocation domain
jjsjann123 Sep 18, 2025
a0df5e9
relaxing the check
jjsjann123 Sep 18, 2025
5097533
Adding test validating vectorization
jjsjann123 Sep 18, 2025
d4b7c8b
renaming
jjsjann123 Sep 19, 2025
4b07e79
clangformat
jjsjann123 Sep 19, 2025
051fc9e
I think it's working now!
jjsjann123 Sep 19, 2025
bf85c0b
clangformat
jjsjann123 Sep 19, 2025
6ff1050
quick patch
jjsjann123 Sep 22, 2025
f2f43be
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 22, 2025
a32f54b
fix clearing allocation domain on cache for cacheBefore
jjsjann123 Sep 22, 2025
cf6e609
revert changes
jjsjann123 Sep 23, 2025
b303923
updating tests
jjsjann123 Sep 23, 2025
17dbf23
i was dumb as always
jjsjann123 Sep 23, 2025
c2a3aeb
why is it so hard for me
jjsjann123 Sep 23, 2025
1a156be
Apply suggestions from code review
jjsjann123 Sep 23, 2025
2081d0c
clangformat
jjsjann123 Sep 23, 2025
6f674aa
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Sep 23, 2025
4f8ecfc
Merge remote-tracking branch 'origin/main' into jj/allocation_PR_0
jjsjann123 Sep 26, 2025
ee37038
reverting selfReplay & cacheBefore changes per reviewer's comments
jjsjann123 Sep 26, 2025
f87e99d
wip
jjsjann123 Sep 26, 2025
c5155ff
wip
jjsjann123 Sep 26, 2025
ded16ec
wip
jjsjann123 Sep 26, 2025
f02440c
wip
jjsjann123 Sep 26, 2025
aa084bc
errr zip
jjsjann123 Sep 26, 2025
f82ad1f
wip
jjsjann123 Sep 26, 2025
d9a33d8
err, forgot to push something last night
jjsjann123 Sep 26, 2025
6dda5e2
typo
jjsjann123 Sep 26, 2025
173a7e9
skipping checks
jjsjann123 Sep 26, 2025
98654a0
wip
jjsjann123 Sep 26, 2025
a870f9d
relaxing checks in tests
jjsjann123 Sep 26, 2025
bca1734
wip
jjsjann123 Sep 26, 2025
d91ac03
clean up IDs for cacheBefore
jjsjann123 Sep 26, 2025
eff3069
clear up definition of output TV for cacheBefore
jjsjann123 Sep 26, 2025
2105e1e
fixing one alias test!
jjsjann123 Sep 27, 2025
bdaaccb
wip
jjsjann123 Sep 27, 2025
fdf9dba
fixing definition
jjsjann123 Sep 27, 2025
65022bd
wip
jjsjann123 Sep 27, 2025
5431648
not set allocation domain when original output doesn't have it
jjsjann123 Sep 27, 2025
75b06b5
update output itertype
jjsjann123 Sep 27, 2025
c2bf4cf
wip
jjsjann123 Sep 27, 2025
c5d66b6
wip
jjsjann123 Sep 27, 2025
b1836f5
wip
jjsjann123 Sep 27, 2025
7ee9317
fixing contiguity in fullselfreplay
jjsjann123 Sep 27, 2025
f7bbab2
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 27, 2025
afccea0
fixing transpose tests
jjsjann123 Sep 27, 2025
482afc8
set parallelization type after fullseflreplay
jjsjann123 Sep 27, 2025
255055d
fix mark alias
jjsjann123 Oct 1, 2025
599d809
fixing alias analysis
jjsjann123 Oct 1, 2025
eadf148
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 24, 2025
fe9c1f6
quick patch on nvfuser::schedule_matmul::Common::cacheBefore
jjsjann123 Oct 24, 2025
f927809
quick patch on nvfuser::schedule_matmul::Common::updateIdModel
jjsjann123 Oct 25, 2025
493434a
agent you can do better!
jjsjann123 Oct 25, 2025
92fb6f9
err
jjsjann123 Oct 25, 2025
642d9a8
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Oct 28, 2025
10eb4e0
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Oct 30, 2025
6ff66f2
try self replay so allocation domain is preserved for multi device
jjsjann123 Oct 30, 2025
2950054
err revert something that's not working
jjsjann123 Oct 30, 2025
a30432c
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 12, 2025
b1e7352
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,21 +776,16 @@ std::unordered_set<TensorView*> getTvsWithDifferentSharding(
return ret;
}

void validateDeviceSplit(Expr* expr) {
NVF_ERROR(expr != nullptr, "Expected a valid expression.");
auto* split = dynamic_cast<Split*>(expr);
NVF_ERROR(
split != nullptr,
"Only split expressions are supported for producing device ids: ",
expr->toString());
NVF_ERROR(
split->outer()->isDeviceDim(),
"Expected the outer dimension to be a device dimension: ",
expr->toString());
NVF_ERROR(
!split->innerSplit(),
"Inner split by device dimension is not supported: ",
expr->toString());
bool isValidateDeviceSplit(Expr* expr) {
if (expr == nullptr || !expr->isA<Split>()) {
return false;
}
auto* split = expr->as<Split>();
if (split == nullptr || !split->outer()->isDeviceDim() ||
split->innerSplit()) {
return false;
}
return true;
}

IterDomain* projectShardedAllocationToLogical(
Expand All @@ -806,7 +801,10 @@ IterDomain* projectShardedAllocationToLogical(

IterDomain* logical_id = allocation_id;
for (Expr* expr : exprs | std::views::reverse) {
validateDeviceSplit(expr);
NVF_ERROR(
isValidateDeviceSplit(expr),
"invalid device split: ",
expr->toString());
logical_id = expr->as<Split>()->in();
}
return logical_id;
Expand All @@ -825,7 +823,10 @@ IterDomain* projectLogicalToShardedAllocation(
tv->getMaybeAllocationDomain().end()});
IterDomain* allocation_id = logical_id;
for (auto expr : exprs) {
validateDeviceSplit(expr);
NVF_ERROR(
isValidateDeviceSplit(expr),
"invalid device split: ",
expr->toString());
allocation_id = expr->as<Split>()->inner();
}
return allocation_id;
Expand Down
2 changes: 1 addition & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ std::vector<int64_t> unshardedSizes(

// Validate the expression is a valid DID split: expr is an outer split with
// device dim as the outer dimension.
void validateDeviceSplit(Expr* expr);
bool isValidateDeviceSplit(Expr* expr);

// Find the producing logical id of the given allocation id traversing
// through device splits. For unsharded allocation_id, logical_id is the same as
Expand Down
11 changes: 10 additions & 1 deletion csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,13 +807,22 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize(
{alloc_iid});
IterDomain* logical_id = alloc_iid;
Val* num_devices = of_tv->container()->oneVal();
bool only_valid_device_split = true;
for (Expr* expr : exprs | std::views::reverse) {
validateDeviceSplit(expr);
if (!isValidateDeviceSplit(expr)) {
only_valid_device_split = false;
break;
}
auto* split = expr->as<Split>();
logical_id = split->in();
num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor());
}

// Non device split could lead to padding, which prevents vectorization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have non-device split if not for padding? Should we throw an error here so we do not have random transforms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed your suggestion.

If I throw on seeing a split here, we wouldn't be able to support padding via transformation on allocation domains then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the confusion. To clarify, I was wondering if we can assert that a non-device split is infact padding, and not a random transform. For example, if it is a divisible split, we can include it in vectorization, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got'ya.

I think a split on allocation domain is meant for padding.
vectorization should be handled in loop domain.

A more complicated case is, if we use split and permute on allocation domain to represent blocking, i.e. in some sense it's indeed used to facilitate vectorization/TMA. But conceptually I think they are still different.

Going back to the topic between divisible vs non-divisible split. We would need to specialize it during concretization, if we are to distinguish them (assuming dynamic shape). I tend to think that's something more like an optimization.
Without that assert, we are leaving margin for max vectorization factor, which I think isn't too big a deal. 🤞

if (!only_valid_device_split) {
break;
}

// Mapping order isn't correct, cannot expand vectorization dimension.
if (projected_dims[--projected_dims_i] != logical_id) {
break;
Expand Down
64 changes: 64 additions & 0 deletions tests/cpp/test_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,70 @@ class LayoutOpTest : public NVFuserTest {
}
};

TEST_F(LayoutOpTest, LogicalAndAllocationSizes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is being tested here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the relaxation in vectorization analysis, this test will trigger an assert.

So the test just verifies that we do allow allocation domain split now.
In the follow up PR, we added more validation to this test to check the produce tensor matches the logical sizes.

auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

auto inp = makeSymbolicTensor(2);
fusion.addInput(inp);
auto out = set(inp);
fusion.addOutput(out);
// padding output to multiple of 16
out->split(1, 16);
out->setAllocationDomain(out->getLoopDomain(), true);
// restore loop domain
out->merge(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't restore. Is this necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Touche. It unsplit the loop domain so that it has the same size as logical domain.
You are right that the extent is no longer the same, so it's not a restoration.

Schedulers expects un-scheduled fusion. Without this merge, I'm hitting the assert here:

NVF_ERROR(broadcast_bit_multiples.size() == ref_loop.size());

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure that's good enough WAR, though this is just a test.

I thought the schedulers can work with some scheduled loop domains (for DID parallelization), not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// We always cacheBefore output at the beginning of the scheduling. And after
// cacheBefore, the reference tensor will have all reduction IDs removed.
ref_loop = TensorDomain::noDevices(TensorDomain::noReductions(ref_loop));

DID related IDs are just ignored by scheduler. So that's just too specific for multi-device.

I'm not a fan of this neither. Let me see if I can skip messing with loop and play transformation on allocation directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you can just modify the allocation domain with AbstractTensor. I remember there are some tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also directly using IterDomain::split for that.

Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698

#0  nvfuser::nvfCheckFail (func=0xaaaaac218080 "validateDomainEquivalence",
    file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
    msg=" INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. \nExpected !compare_result.dom0_has_u"...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:267
#1  0x0000aaaaab1bbe68 in nvfuser::nvfErrorFail (func=0xaaaaac218080 "validateDomainEquivalence",
    file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
    condMsg=0xaaaaac217fd8 " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. ",
    userMsg="Expected !compare_result.dom0_has_unreachable_ids . dom0 has unreachable IDs. dom0: iS10{i0}, iS11{i2}. dom1: iS10{i0}") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:277
#2  0x0000aaaaab60a3e8 in nvfuser::ir_utils::validateDomainEquivalence (
    dom0=std::vector of length 2, capacity 2 = {...}, dom1=std::vector of length 1, capacity 3 = {...},
    additional_ids=std::vector of length 0, capacity 0) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162
#3  0x0000aaaaab4aac30 in nvfuser::TensorDomain::setAllocationDomain (this=0xaaaab20918b0,
    new_allocation_domain=std::vector of length 1, capacity 3 = {...},
    new_contiguity=std::vector of length 1, capacity 3 = {...})
    at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:4055
#4  0x0000aaaaabc7b368 in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
    producer=0xaaaab2091200, producer_pos=2, logical_map=..., opt=...)
    at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:917
#5  0x0000aaaaabc7b7fc in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
    producer=0xaaaab2091200, compute_at_axis=-1, opt=...)
    at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:945
#6  0x0000aaaaabc44ccc in nvfuser::TensorView::cacheBefore (this=0xaaaab2088c00,
    op_type=nvfuser::LoadStoreOpType::Set) at /opt/pytorch/nvfuser/csrc/tensor_view.cpp:1160
#7  0x0000aaaaabbdb250 in nvfuser::scheduler_utils::cacheAndForkOutputs (fusion=0xaaaab2084910,
    unroll=true) at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:1357
#8  0x0000aaaaabb067dc in nvfuser::schedulePointwise (fusion=0xaaaab2084910, pparams=0xaaaab207f880)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:822
#9  0x0000aaaaabb0898c in nvfuser::PointWiseScheduler::schedule (this=0xaaaab2083460,
    fusion=0xaaaab2084910, params=0xaaaab207f880)
    at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:1304

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, what did you decide to do? Nothing seems to have changed?

I can also directly using IterDomain::split for that.

Of course, but you'd need to maintain the proper ordering of the ID vector yourself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also directly using IterDomain::split for that.

Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698

Yes rfactor replay for allocation will also complain similarly if allocation transforms are disjoint from root-to-loop.
replayPasC also uses the loop domain as the target so if you intend to use IterDomain::split, we will have to update that, among other things.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. switched to selfReplay instead of replayCasP for TensorView::cacheBefore


auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
int m = 512;
int k = 9; // note: padded column size would be 16
auto t0 = at::randn({m, k}, options);

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto cg_outputs = executor_cache.runFusionWithInputs({t0});
// padding on the inner dimension is represented as stride on the outer
// dimension
EXPECT_EQ(
cg_outputs[0].as<at::Tensor>().strides(), std::vector<int64_t>({16, 1}));
// We need to slice because output buffer shape is not right
EXPECT_TRUE(t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k)));
// TODO: enable this when output buffer shape is fixed.
// output should remain the correct logical size
// EXPECT_EQ(
// cg_outputs[0].as<at::Tensor>().sizes(), std::vector<int64_t>({512,
// 9}));
}

TEST_F(LayoutOpTest, AllocationDomainSplitVectorizationFactor) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

auto inp = makeSymbolicTensor(3);
fusion.addInput(inp);
auto out = set(inp);
fusion.addOutput(out);
// split would prevent vectorization
out->split(1, 16);
out->setAllocationDomain(out->getLoopDomain(), true);
// restore loop domain
out->merge(1);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: revert this!


auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
// because of the split on the middle dimension, we only have the fastest
// dimension participating in vectorization.
auto t0 = at::randn({512, 128, 2}, options);

// NOTE force pointwise scheduler here just for testing purpose
auto cg_results =
scheduleAndRun(fusion_ptr.get(), SchedulerType::PointWise, {t0});
auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
EXPECT_EQ(pparams->vectorization_factor, 2);

testValidate(fusion_ptr.get(), cg_results.outputs, {t0}, __LINE__, __FILE__);
}

TEST_F(LayoutOpTest, CppApi) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Expand Down