-
Notifications
You must be signed in to change notification settings - Fork 69
Support Split between logical domain to allocation domain to represent padding
#5184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
fe90fb5
a0df5e9
5097533
d4b7c8b
4b07e79
051fc9e
bf85c0b
6ff1050
f2f43be
a32f54b
cf6e609
b303923
17dbf23
c2a3aeb
1a156be
2081d0c
6f674aa
4f8ecfc
ee37038
f87e99d
c5155ff
ded16ec
f02440c
aa084bc
f82ad1f
d9a33d8
6dda5e2
173a7e9
98654a0
a870f9d
bca1734
d91ac03
eff3069
2105e1e
bdaaccb
fdf9dba
65022bd
5431648
75b06b5
c2bf4cf
c5d66b6
b1836f5
7ee9317
f7bbab2
afccea0
482afc8
255055d
599d809
eadf148
fe9c1f6
f927809
493434a
92fb6f9
642d9a8
10eb4e0
6ff66f2
2950054
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -807,13 +807,22 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize( | |
| {alloc_iid}); | ||
| IterDomain* logical_id = alloc_iid; | ||
| Val* num_devices = of_tv->container()->oneVal(); | ||
| bool only_valid_device_split = true; | ||
| for (Expr* expr : exprs | std::views::reverse) { | ||
| validateDeviceSplit(expr); | ||
| if (!isValidDeviceSplit(expr)) { | ||
| only_valid_device_split = false; | ||
| break; | ||
| } | ||
| auto* split = expr->as<Split>(); | ||
| logical_id = split->in(); | ||
| num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor()); | ||
| } | ||
|
|
||
| // Non device split could lead to padding, which prevents vectorization | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have non-device split if not for padding? Should we throw an error here so we do not have random transforms?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might have missed your suggestion. If I throw on seeing a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry about the confusion. To clarify, I was wondering if we can assert that a non-device split is infact padding, and not a random transform. For example, if it is a divisible split, we can include it in vectorization, correct?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got'ya. I think a split on allocation domain is meant for padding. A more complicated case is, if we use split and permute on allocation domain to represent blocking, i.e. in some sense it's indeed used to facilitate vectorization/TMA. But conceptually I think they are still different. Going back to the topic between |
||
| if (!only_valid_device_split) { | ||
| break; | ||
| } | ||
|
|
||
| // Mapping order isn't correct, cannot expand vectorization dimension. | ||
| if (projected_dims[--projected_dims_i] != logical_id) { | ||
| break; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -288,10 +288,10 @@ void TransformReplay::selfReplay( | |
| // We use `self_loop` as the target domain because loop post-dominates | ||
| // allocation. | ||
| const std::vector<IterDomain*>& self_loop = self->loop(); | ||
| ReplaySelf replay(self_loop, axis_map); | ||
|
|
||
| // Replay loop. | ||
| if (self_loop != self->logical()) { | ||
| ReplaySelf replay(self_loop, axis_map); | ||
|
||
| std::vector<IterDomain*> new_loop; | ||
| if (ignore_reductions) { | ||
| for (auto* id : new_self->logical()) { | ||
|
|
@@ -321,6 +321,7 @@ void TransformReplay::selfReplay( | |
| // Replay allocation. | ||
| if (self->hasAllocation()) { | ||
| const std::vector<IterDomain*>& self_allocation = self->allocation(); | ||
| ReplaySelf replay(self_allocation, axis_map); | ||
| const std::vector<std::optional<bool>>& self_contiguity = | ||
| self->contiguity(); | ||
| NVF_ERROR_EQ(self_allocation.size(), self_contiguity.size()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,79 @@ class LayoutOpTest : public NVFuserTest { | |
| } | ||
| }; | ||
|
|
||
| TEST_F(LayoutOpTest, LogicalAndAllocationSizes) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is being tested here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the relaxation in vectorization analysis, this test will trigger an assert. So the test just verifies that we do allow allocation domain split now. |
||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| Fusion& fusion = *fusion_ptr.get(); | ||
| FusionGuard fg(&fusion); | ||
|
|
||
| auto inp = makeSymbolicTensor(2); | ||
| fusion.addInput(inp); | ||
| auto out = set(inp); | ||
| fusion.addOutput(out); | ||
| // padding output to multiple of 16 on allocation domain | ||
| auto&& [io, ii] = IterDomain::split( | ||
| out->axis(1), IrBuilder::create<Val>(16L, DataType::Index), true); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tagging @naoyam changed the test to only apply split on logical -> allocation. |
||
| // NOTE: this doesn't feel right, we have to mark contiguity on axis(0) as | ||
| // `false` to avoid accidntal indexing collapsing, this should be figured out | ||
| // by indexing from the ceilDiv. | ||
| out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true}); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I understanding this issue correctly? My question is what specifically goes wrong when allocation is set to contiguous?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you are absolutely correct about a) and b). Indexing collapsing is wrong here, because we are mapping from logical to allocation, which is not accessing contiguous memory (because of non-divisible split). This is the before and after of the indexing. with false contiguity flag with true contiguity flag |
||
|
|
||
| // Tow issues with split and merge approach: | ||
jjsjann123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // 1. This causes predication to expand to the padded region. | ||
| // 2. Indexing with allocation domain set as `true` is wrong. | ||
| // out->split(1, 16); // padding output to multiple of 16 | ||
| // out->setAllocationDomain(out->getLoopDomain(), true); | ||
| // out->merge(1); // restore loop domain | ||
|
|
||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| int m = 512; | ||
| int k = 9; // note: padded column size would be 16 | ||
| auto t0 = at::randn({m, k}, options); | ||
|
|
||
| FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||
| auto cg_outputs = executor_cache.runFusionWithInputs({t0}); | ||
| // padding on the inner dimension is represented as stride on the outer | ||
| // dimension | ||
| EXPECT_EQ( | ||
| cg_outputs[0].as<at::Tensor>().strides(), std::vector<int64_t>({16, 1})); | ||
| // We need to slice because output buffer shape is not right | ||
| EXPECT_TRUE(t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k))); | ||
| // TODO: enable this when output buffer shape is fixed. | ||
| // output should remain the correct logical size | ||
| // EXPECT_EQ( | ||
| // cg_outputs[0].as<at::Tensor>().sizes(), std::vector<int64_t>({512, | ||
| // 9})); | ||
| } | ||
|
|
||
| TEST_F(LayoutOpTest, AllocationDomainSplitVectorizationFactor) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| Fusion& fusion = *fusion_ptr.get(); | ||
| FusionGuard fg(&fusion); | ||
|
|
||
| auto inp = makeSymbolicTensor(3); | ||
| fusion.addInput(inp); | ||
| auto out = set(inp); | ||
| fusion.addOutput(out); | ||
| // split would prevent vectorization | ||
| out->split(1, 16); | ||
| out->setAllocationDomain(out->getLoopDomain(), true); | ||
| // restore loop domain | ||
| out->merge(1); | ||
|
||
|
|
||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| // because of the split on the middle dimension, we only have the fastest | ||
| // dimension participating in vectorization. | ||
| auto t0 = at::randn({512, 128, 2}, options); | ||
|
|
||
| // NOTE force pointwise scheduler here just for testing purpose | ||
| auto cg_results = | ||
| scheduleAndRun(fusion_ptr.get(), SchedulerType::PointWise, {t0}); | ||
| auto pparams = cg_results.heuristic_params->as<PointwiseParams>(); | ||
| EXPECT_EQ(pparams->vectorization_factor, 2); | ||
|
|
||
| testValidate(fusion_ptr.get(), cg_results.outputs, {t0}, __LINE__, __FILE__); | ||
| } | ||
|
|
||
| TEST_F(LayoutOpTest, CppApi) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| Fusion& fusion = *fusion_ptr.get(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@protonu You might need this relaxed. (things coming from vectorize_helper.cpp and multidevice/...
I'll start a PR on the side for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - For now, I modified it here:
https://github.com/NVIDIA/Fuser/pull/5322/files