-
Notifications
You must be signed in to change notification settings - Fork 70
Support Split between logical domain to allocation domain to represent padding
#5184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
fe90fb5
a0df5e9
5097533
d4b7c8b
4b07e79
051fc9e
bf85c0b
6ff1050
f2f43be
a32f54b
cf6e609
b303923
17dbf23
c2a3aeb
1a156be
2081d0c
6f674aa
4f8ecfc
ee37038
f87e99d
c5155ff
ded16ec
f02440c
aa084bc
f82ad1f
d9a33d8
6dda5e2
173a7e9
98654a0
a870f9d
bca1734
d91ac03
eff3069
2105e1e
bdaaccb
fdf9dba
65022bd
5431648
75b06b5
c2bf4cf
c5d66b6
b1836f5
7ee9317
f7bbab2
afccea0
482afc8
255055d
599d809
eadf148
fe9c1f6
f927809
493434a
92fb6f9
642d9a8
10eb4e0
6ff66f2
2950054
a30432c
b1e7352
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -807,13 +807,22 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize( | |
| {alloc_iid}); | ||
| IterDomain* logical_id = alloc_iid; | ||
| Val* num_devices = of_tv->container()->oneVal(); | ||
| bool only_valid_device_split = true; | ||
| for (Expr* expr : exprs | std::views::reverse) { | ||
| validateDeviceSplit(expr); | ||
| if (!isValidateDeviceSplit(expr)) { | ||
| only_valid_device_split = false; | ||
| break; | ||
| } | ||
| auto* split = expr->as<Split>(); | ||
| logical_id = split->in(); | ||
| num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor()); | ||
| } | ||
|
|
||
| // Non device split could lead to padding, which prevents vectorization | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have non-device split if not for padding? Should we throw an error here so we do not have random transforms?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might have missed your suggestion. If I throw on seeing a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry about the confusion. To clarify, I was wondering if we can assert that a non-device split is infact padding, and not a random transform. For example, if it is a divisible split, we can include it in vectorization, correct?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got'ya. I think a split on allocation domain is meant for padding. A more complicated case is, if we use split and permute on allocation domain to represent blocking, i.e. in some sense it's indeed used to facilitate vectorization/TMA. But conceptually I think they are still different. Going back to the topic between |
||
| if (!only_valid_device_split) { | ||
| break; | ||
| } | ||
|
|
||
| // Mapping order isn't correct, cannot expand vectorization dimension. | ||
| if (projected_dims[--projected_dims_i] != logical_id) { | ||
| break; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -70,6 +70,70 @@ class LayoutOpTest : public NVFuserTest { | |||||||||
| } | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| TEST_F(LayoutOpTest, LogicalAndAllocationSizes) { | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is being tested here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the relaxation in vectorization analysis, this test will trigger an assert. So the test just verifies that we do allow allocation domain split now. |
||||||||||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||||||||||
| Fusion& fusion = *fusion_ptr.get(); | ||||||||||
| FusionGuard fg(&fusion); | ||||||||||
|
|
||||||||||
| auto inp = makeSymbolicTensor(2); | ||||||||||
| fusion.addInput(inp); | ||||||||||
| auto out = set(inp); | ||||||||||
| fusion.addOutput(out); | ||||||||||
| // padding output to multiple of 16 | ||||||||||
| out->split(1, 16); | ||||||||||
| out->setAllocationDomain(out->getLoopDomain(), true); | ||||||||||
| // restore loop domain | ||||||||||
| out->merge(1); | ||||||||||
|
||||||||||
| NVF_ERROR(broadcast_bit_multiples.size() == ref_loop.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure that's good enough WAR, though this is just a test.
I thought the schedulers can work with some scheduled loop domains (for DID parallelization), not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fuser/csrc/scheduler/pointwise.cpp
Lines 231 to 233 in 12121b9
| // We always cacheBefore output at the beginning of the scheduling. And after | |
| // cacheBefore, the reference tensor will have all reduction IDs removed. | |
| ref_loop = TensorDomain::noDevices(TensorDomain::noReductions(ref_loop)); |
DID related IDs are just ignored by scheduler. So that's just too specific for multi-device.
I'm not a fan of this neither. Let me see if I can skip messing with loop and play transformation on allocation directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose you can just modify the allocation domain with AbstractTensor. I remember there are some tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also directly using IterDomain::split for that.
Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698
#0 nvfuser::nvfCheckFail (func=0xaaaaac218080 "validateDomainEquivalence",
file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
msg=" INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. \nExpected !compare_result.dom0_has_u"...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:267
#1 0x0000aaaaab1bbe68 in nvfuser::nvfErrorFail (func=0xaaaaac218080 "validateDomainEquivalence",
file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
condMsg=0xaaaaac217fd8 " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. ",
userMsg="Expected !compare_result.dom0_has_unreachable_ids . dom0 has unreachable IDs. dom0: iS10{i0}, iS11{i2}. dom1: iS10{i0}") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:277
#2 0x0000aaaaab60a3e8 in nvfuser::ir_utils::validateDomainEquivalence (
dom0=std::vector of length 2, capacity 2 = {...}, dom1=std::vector of length 1, capacity 3 = {...},
additional_ids=std::vector of length 0, capacity 0) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162
#3 0x0000aaaaab4aac30 in nvfuser::TensorDomain::setAllocationDomain (this=0xaaaab20918b0,
new_allocation_domain=std::vector of length 1, capacity 3 = {...},
new_contiguity=std::vector of length 1, capacity 3 = {...})
at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:4055
#4 0x0000aaaaabc7b368 in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
producer=0xaaaab2091200, producer_pos=2, logical_map=..., opt=...)
at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:917
#5 0x0000aaaaabc7b7fc in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
producer=0xaaaab2091200, compute_at_axis=-1, opt=...)
at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:945
#6 0x0000aaaaabc44ccc in nvfuser::TensorView::cacheBefore (this=0xaaaab2088c00,
op_type=nvfuser::LoadStoreOpType::Set) at /opt/pytorch/nvfuser/csrc/tensor_view.cpp:1160
#7 0x0000aaaaabbdb250 in nvfuser::scheduler_utils::cacheAndForkOutputs (fusion=0xaaaab2084910,
unroll=true) at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:1357
#8 0x0000aaaaabb067dc in nvfuser::schedulePointwise (fusion=0xaaaab2084910, pparams=0xaaaab207f880)
at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:822
#9 0x0000aaaaabb0898c in nvfuser::PointWiseScheduler::schedule (this=0xaaaab2083460,
fusion=0xaaaab2084910, params=0xaaaab207f880)
at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:1304
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, what did you decide to do? Nothing seems to have changed?
I can also directly using IterDomain::split for that.
Of course, but you'd need to maintain the proper ordering of the ID vector yourself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also directly using
IterDomain::splitfor that.Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698
Yes rfactor replay for allocation will also complain similarly if allocation transforms are disjoint from root-to-loop.
replayPasC also uses the loop domain as the target so if you intend to use IterDomain::split, we will have to update that, among other things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. switched to selfReplay instead of replayCasP for TensorView::cacheBefore
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: revert this!
Uh oh!
There was an error while loading. Please reload this page.