-
Notifications
You must be signed in to change notification settings - Fork 69
TransformReplay::selfReplay replays contiguity #5316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a41c05e
3c09559
ddbcbbd
7003c37
30a4c19
cb4a0c6
2ed59d4
a0711a8
8479b97
cbef21b
c9e9615
2edd504
f65e0b5
df825a6
974748a
4232e21
80e8dff
23ee554
3cb70ee
f569b20
129e9bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -250,77 +250,89 @@ void TransformReplay::selfReplay( | |
| TensorDomain* new_self) { | ||
| FUSER_PERF_SCOPE("TransformReplay::selfReplay"); | ||
|
|
||
| std::vector<IterDomain*> logical = self->logical(); | ||
| std::vector<IterDomain*> new_logical = new_self->logical(); | ||
|
|
||
| // For convenience, automatically remove extra reduction dimensions. | ||
| bool ignore_reductions = logical.size() != new_logical.size(); | ||
| if (logical.size() > new_logical.size()) { | ||
| logical = TensorDomain::noReductions(logical); | ||
| } else if (logical.size() < new_logical.size()) { | ||
| new_logical = TensorDomain::noReductions(new_logical); | ||
| } | ||
| NVF_ERROR_EQ(logical.size(), new_logical.size()); | ||
| auto replay = [&]() -> IterDomainMap { | ||
| std::vector<IterDomain*> logical = self->logical(); | ||
| std::vector<IterDomain*> new_logical = new_self->logical(); | ||
|
|
||
| // For convenience, automatically remove extra reduction dimensions. | ||
| if (logical.size() > new_logical.size()) { | ||
| logical = TensorDomain::noReductions(logical); | ||
| } else if (logical.size() < new_logical.size()) { | ||
| new_logical = TensorDomain::noReductions(new_logical); | ||
| } | ||
| NVF_ERROR_EQ(logical.size(), new_logical.size()); | ||
|
|
||
| IterDomainMap axis_map; | ||
| for (auto&& [id, new_id] : zip(logical, new_logical)) { | ||
| // We don't check for equal `isRFactorProduct`, since we could replay | ||
| // Allocation of the output of a reduction to a later consumer tensor, | ||
| // which would not have the rfactor flag on. | ||
| // | ||
| // This function can be used prior to concretization, where we might have | ||
| // a concrete ID map a symbolic ID. Otherwise, the IterTypes must be the | ||
| // same. | ||
| auto iter_types_match = [](IterType lhs, IterType rhs) -> bool { | ||
| if (lhs == rhs) { | ||
| return true; | ||
| } | ||
| return lhs == IterType::Symbolic || rhs == IterType::Symbolic; | ||
| }; | ||
| NVF_ERROR( | ||
| iter_types_match(id->getIterType(), new_id->getIterType()), | ||
| "Axes ", | ||
| id, | ||
| " and ", | ||
| new_id, | ||
| " do not match for self replay."); | ||
| axis_map[id] = new_id; | ||
| } | ||
|
|
||
| IterDomainMap axis_map; | ||
| for (auto&& [id, new_id] : zip(logical, new_logical)) { | ||
| // We don't check for equal `isRFactorProduct`, since we could replay | ||
| // Allocation of the output of a reduction to a later consumer tensor, which | ||
| // would not have the rfactor flag on. | ||
| // We create one ReplaySelf instance to replay loop and allocation. This | ||
| // way, loop and allocation share the same transforms if they are split the | ||
| // same way. | ||
| // | ||
| // This function can be used prior to concretization, where we might have a | ||
| // concrete ID map a symbolic ID. Otherwise, the IterTypes must be the | ||
| // same. | ||
| auto iter_types_match = [](IterType lhs, IterType rhs) -> bool { | ||
| if (lhs == rhs) { | ||
| return true; | ||
| } | ||
| return lhs == IterType::Symbolic || rhs == IterType::Symbolic; | ||
| }; | ||
| NVF_ERROR( | ||
| iter_types_match(id->getIterType(), new_id->getIterType()), | ||
| "Axes ", | ||
| id, | ||
| " and ", | ||
| new_id, | ||
| " do not match for self replay."); | ||
| axis_map[id] = new_id; | ||
| } | ||
| // We use `loop` as the target domain because loop post-dominates | ||
| // allocation. | ||
| ReplaySelf replay(self->loop(), axis_map); | ||
| return replay.getReplay(); | ||
| }(); | ||
|
|
||
| // We create one ReplaySelf instance to replay loop and allocation. This way, | ||
| // loop and allocation share the same transforms if they are split the same | ||
| // way. | ||
| // | ||
| // We use `loop` as the target domain because loop post-dominates | ||
| // allocation. | ||
| const std::vector<IterDomain*>& loop = self->loop(); | ||
| ReplaySelf replay(loop, axis_map); | ||
| auto mapped_new_ids = [&]() { | ||
| auto values = replay | std::views::values; | ||
| return std::unordered_set<IterDomain*>(values.begin(), values.end()); | ||
| }(); | ||
|
|
||
| // Replay loop. | ||
| if (loop != self->logical()) { | ||
| if (self->loop() != self->logical()) { | ||
| std::vector<IterDomain*> new_loop; | ||
| if (ignore_reductions) { | ||
| for (auto* id : new_self->logical()) { | ||
| if (id->isReduction()) { | ||
| new_loop.push_back(id); | ||
| } | ||
| for (auto* new_id : new_self->logical()) { | ||
| if (mapped_new_ids.count(new_id) == 0) { | ||
| NVF_ERROR( | ||
| new_id->isReduction(), | ||
| new_id->toString(), | ||
| " should be a reduction."); | ||
| new_loop.push_back(new_id); | ||
| } | ||
| } | ||
|
|
||
| for (IterDomain* loop_id : loop) { | ||
| if (ignore_reductions && loop_id->isReduction()) { | ||
| for (IterDomain* loop_id : self->loop()) { | ||
| IterDomain* new_loop_id = getOrDefault(replay, loop_id); | ||
| if (new_loop_id == nullptr) { | ||
Priya2698 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| continue; | ||
| } | ||
| auto it = replay.getReplay().find(loop_id); | ||
| NVF_ERROR( | ||
| it != replay.getReplay().end(), | ||
| "failed to replay IterDomain: ", | ||
| loop_id); | ||
| it->second->parallelize(loop_id->getParallelType()); | ||
| new_loop.push_back(it->second); | ||
| new_loop_id->parallelize(loop_id->getParallelType()); | ||
| new_loop.push_back(new_loop_id); | ||
| } | ||
|
|
||
| new_self->setLoopDomain(new_loop); | ||
| } else { | ||
| NVF_ERROR_EQ( | ||
| new_self->loop(), | ||
| new_self->logical(), | ||
| "It is unclear what the correct contract should be when replaying an " | ||
| "empty transform sequence on a non-empty loop domain. We could keep " | ||
| "the loop domain as is or set the loop domain to be the same as " | ||
| "logical. Fortunately, we do not have a use case for this scenario."); | ||
| } | ||
|
|
||
| // Replay allocation. | ||
|
|
@@ -331,40 +343,86 @@ void TransformReplay::selfReplay( | |
|
|
||
| std::vector<IterDomain*> new_allocation; | ||
| std::vector<std::optional<bool>> new_contiguities; | ||
| new_allocation.reserve(allocation.size()); | ||
| new_contiguities.reserve(contiguities.size()); | ||
|
|
||
| // Push back the reduction IDs that are not mapped | ||
| if (ignore_reductions) { | ||
| for (auto* id : new_self->logical()) { | ||
| if (id->isReduction()) { | ||
| new_allocation.push_back(id); | ||
| new_contiguities.push_back(std::nullopt); | ||
| } | ||
|
|
||
| for (auto* new_id : new_self->logical()) { | ||
| if (mapped_new_ids.count(new_id) == 0) { | ||
| NVF_ERROR( | ||
| new_id->isReduction(), | ||
| new_id->toString(), | ||
| " should be a reduction."); | ||
| new_allocation.push_back(new_id); | ||
| new_contiguities.push_back(std::nullopt); | ||
| } | ||
| } | ||
|
|
||
| // Pushing the mapped IDs and corresponding contiguity flags | ||
| for (auto&& [alloc_id, contiguity] : zip(allocation, contiguities)) { | ||
| if (ignore_reductions && alloc_id->isReduction()) { | ||
| for (const auto& [alloc_id, contiguity] : zip(allocation, contiguities)) { | ||
| IterDomain* new_alloc_id = getOrDefault(replay, alloc_id); | ||
| if (new_alloc_id == nullptr) { | ||
| continue; | ||
| } | ||
| auto it = replay.getReplay().find(alloc_id); | ||
| NVF_ERROR( | ||
| it != replay.getReplay().end(), | ||
| "failed to replay IterDomain: ", | ||
| alloc_id); | ||
| NVF_ERROR_EQ( | ||
| (it->second->isBroadcast() || it->second->isReduction()), | ||
| !contiguity.has_value(), | ||
| "Contiguity should be nullopt iff broadcast or reduction, true/false " | ||
| "otherwise."); | ||
| new_contiguities.push_back(contiguity); | ||
| it->second->parallelize(alloc_id->getParallelType()); | ||
| new_allocation.push_back(it->second); | ||
| new_alloc_id->parallelize(alloc_id->getParallelType()); | ||
| new_allocation.push_back(new_alloc_id); | ||
|
|
||
| std::optional<bool> new_contiguity = contiguity; | ||
| if (new_alloc_id->isBroadcast() || new_alloc_id->isReduction()) { | ||
| new_contiguity = std::nullopt; | ||
| } | ||
| new_contiguities.push_back(new_contiguity); | ||
| } | ||
|
|
||
| new_self->setAllocationDomain(new_allocation, new_contiguities); | ||
| } else { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels a little unexpected to me because, even though there's nothing to replay for the allocation, the contiguity could be modified.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would happen if
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Added a check and a comment. |
||
| NVF_ERROR( | ||
| !new_self->hasAllocation(), | ||
| "It is unclear what the correct contract should be when replaying an " | ||
| "empty transform sequence on a non-empty allocation domain. " | ||
| "Fortunately, we do not have a use case for this scenario."); | ||
| const std::vector<IterDomain*>& new_logical = new_self->logical(); | ||
| const auto new_rank = std::ssize(new_logical); | ||
| std::vector<std::optional<bool>> new_contiguities(new_rank, std::nullopt); | ||
|
|
||
| int new_pos = 0; | ||
| for (auto [id, contiguity] : zip(self->logical(), self->contiguity())) { | ||
| IterDomain* new_id = getOrDefault(replay, id); | ||
| if (new_id == nullptr) { | ||
| continue; | ||
| } | ||
|
|
||
| // Find the corresponding contiguity in new_logical. Mapped IterDomains | ||
| // in self->logical() and new_logical follow the same order. So it's safe | ||
| // to only increment `new_pos`. | ||
| while (new_pos < new_rank && new_logical.at(new_pos) != new_id) { | ||
| new_pos++; | ||
| } | ||
| NVF_ERROR_LT( | ||
| new_pos, | ||
| new_rank, | ||
| "Failed to find ", | ||
| new_id->toString(), | ||
| " in ", | ||
| new_logical); | ||
| std::optional<bool>& new_contiguity = new_contiguities.at(new_pos); | ||
|
|
||
| new_contiguity = contiguity; | ||
| // When used during or before concretization, TransformReplay::selfReplay | ||
| // can be applied to replay transformations from symbolic dimensions to | ||
| // concrete dimensions, or in the reverse direction. Therefore, | ||
| // `new_contiguity` is not always identical to `contiguity`. | ||
| if (new_id->isBroadcast()) { | ||
| new_contiguity = std::nullopt; | ||
| } else if (new_id->isSymbolic()) { | ||
| // See AliasTest.AccumulateSlices for an example. aliasOutputToInput is | ||
| // called before concretization and tries to replay contiguity from a | ||
| // broadcast IterDomain to a symbolic IterDomain. However, a symbolic | ||
| // IterDomain can't have contiguity null. | ||
| if (!new_contiguity.has_value()) { | ||
| new_contiguity = true; | ||
Priya2698 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| new_self->setContiguity(new_contiguities); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -377,7 +435,7 @@ namespace { | |
| std::unordered_set<IterDomain*> getMaybeUnmappedIDs( | ||
| const TensorView* tv, | ||
| bool is_producer, | ||
| const std::unordered_map<IterDomain*, IterDomain*>& root_id_map) { | ||
| const IterDomainMap& root_id_map) { | ||
| std::unordered_set<Val*> unmapped_root_ids; | ||
|
|
||
| const auto& root_domain = | ||
|
|
@@ -1121,7 +1179,7 @@ bool TransformReplay::fullSelfMatching( | |
| auto replay_dom = replay->getLoopDomain(); | ||
| auto target_root = target->getMaybeRootDomain(); | ||
| auto target_dom = target->getLoopDomain(); | ||
| std::unordered_map<IterDomain*, IterDomain*> target2replay_map; | ||
| IterDomainMap target2replay_map; | ||
| if (replay_root.size() != target_root.size()) { | ||
| return false; | ||
| } | ||
|
|
@@ -1372,7 +1430,7 @@ namespace { | |
| TensorDomain* fullReplay( | ||
| const TensorDomain* old_domain, | ||
| const std::vector<IterDomain*>& new_root) { | ||
| std::unordered_map<IterDomain*, IterDomain*> old_root_to_new; | ||
| IterDomainMap old_root_to_new; | ||
| NVF_CHECK( | ||
| old_domain->maybeRoot().size() == new_root.size(), | ||
| "Unable to replay transformations on a root domain of different size: ", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.