Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3210,13 +3210,10 @@ namespace {
void validateContiguity(
const std::vector<IterDomain*>& allocation_domain,
const std::vector<std::optional<bool>>& contiguity) {
NVF_CHECK(
contiguity.size() == allocation_domain.size(),
"Invalid contiguity information provided, incorrect size. Received "
"vector of size ",
NVF_CHECK_EQ(
contiguity.size(),
" but needed one of size ",
allocation_domain.size());
allocation_domain.size(),
"Invalid contiguity information provided, incorrect size.");
for (auto i : arange(contiguity.size())) {
bool expect_null =
(allocation_domain.at(i)->isBroadcast() ||
Expand Down
228 changes: 143 additions & 85 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,77 +250,89 @@ void TransformReplay::selfReplay(
TensorDomain* new_self) {
FUSER_PERF_SCOPE("TransformReplay::selfReplay");

std::vector<IterDomain*> logical = self->logical();
std::vector<IterDomain*> new_logical = new_self->logical();

// For convenience, automatically remove extra reduction dimensions.
bool ignore_reductions = logical.size() != new_logical.size();
if (logical.size() > new_logical.size()) {
logical = TensorDomain::noReductions(logical);
} else if (logical.size() < new_logical.size()) {
new_logical = TensorDomain::noReductions(new_logical);
}
NVF_ERROR_EQ(logical.size(), new_logical.size());
auto replay = [&]() -> IterDomainMap {
std::vector<IterDomain*> logical = self->logical();
std::vector<IterDomain*> new_logical = new_self->logical();

// For convenience, automatically remove extra reduction dimensions.
if (logical.size() > new_logical.size()) {
logical = TensorDomain::noReductions(logical);
} else if (logical.size() < new_logical.size()) {
new_logical = TensorDomain::noReductions(new_logical);
}
NVF_ERROR_EQ(logical.size(), new_logical.size());

IterDomainMap axis_map;
for (auto&& [id, new_id] : zip(logical, new_logical)) {
// We don't check for equal `isRFactorProduct`, since we could replay
// Allocation of the output of a reduction to a later consumer tensor,
// which would not have the rfactor flag on.
//
// This function can be used prior to concretization, where we might have
// a concrete ID map a symbolic ID. Otherwise, the IterTypes must be the
// same.
auto iter_types_match = [](IterType lhs, IterType rhs) -> bool {
if (lhs == rhs) {
return true;
}
return lhs == IterType::Symbolic || rhs == IterType::Symbolic;
};
NVF_ERROR(
iter_types_match(id->getIterType(), new_id->getIterType()),
"Axes ",
id,
" and ",
new_id,
" do not match for self replay.");
axis_map[id] = new_id;
}

IterDomainMap axis_map;
for (auto&& [id, new_id] : zip(logical, new_logical)) {
// We don't check for equal `isRFactorProduct`, since we could replay
// Allocation of the output of a reduction to a later consumer tensor, which
// would not have the rfactor flag on.
// We create one ReplaySelf instance to replay loop and allocation. This
// way, loop and allocation share the same transforms if they are split the
// same way.
//
// This function can be used prior to concretization, where we might have a
// concrete ID map a symbolic ID. Otherwise, the IterTypes must be the
// same.
auto iter_types_match = [](IterType lhs, IterType rhs) -> bool {
if (lhs == rhs) {
return true;
}
return lhs == IterType::Symbolic || rhs == IterType::Symbolic;
};
NVF_ERROR(
iter_types_match(id->getIterType(), new_id->getIterType()),
"Axes ",
id,
" and ",
new_id,
" do not match for self replay.");
axis_map[id] = new_id;
}
// We use `loop` as the target domain because loop post-dominates
// allocation.
ReplaySelf replay(self->loop(), axis_map);
return replay.getReplay();
}();

// We create one ReplaySelf instance to replay loop and allocation. This way,
// loop and allocation share the same transforms if they are split the same
// way.
//
// We use `loop` as the target domain because loop post-dominates
// allocation.
const std::vector<IterDomain*>& loop = self->loop();
ReplaySelf replay(loop, axis_map);
auto mapped_new_ids = [&]() {
auto values = replay | std::views::values;
return std::unordered_set<IterDomain*>(values.begin(), values.end());
}();

// Replay loop.
if (loop != self->logical()) {
if (self->loop() != self->logical()) {
std::vector<IterDomain*> new_loop;
if (ignore_reductions) {
for (auto* id : new_self->logical()) {
if (id->isReduction()) {
new_loop.push_back(id);
}
for (auto* new_id : new_self->logical()) {
if (mapped_new_ids.count(new_id) == 0) {
NVF_ERROR(
new_id->isReduction(),
new_id->toString(),
" should be a reduction.");
new_loop.push_back(new_id);
}
}

for (IterDomain* loop_id : loop) {
if (ignore_reductions && loop_id->isReduction()) {
for (IterDomain* loop_id : self->loop()) {
IterDomain* new_loop_id = getOrDefault(replay, loop_id);
if (new_loop_id == nullptr) {
continue;
}
auto it = replay.getReplay().find(loop_id);
NVF_ERROR(
it != replay.getReplay().end(),
"failed to replay IterDomain: ",
loop_id);
it->second->parallelize(loop_id->getParallelType());
new_loop.push_back(it->second);
new_loop_id->parallelize(loop_id->getParallelType());
new_loop.push_back(new_loop_id);
}

new_self->setLoopDomain(new_loop);
} else {
NVF_ERROR_EQ(
new_self->loop(),
new_self->logical(),
"It is unclear what the correct contract should be when replaying an "
"empty transform sequence on a non-empty loop domain. We could keep "
"the loop domain as is or set the loop domain to be the same as "
"logical. Fortunately, we do not have a use case for this scenario.");
}

// Replay allocation.
Expand All @@ -331,40 +343,86 @@ void TransformReplay::selfReplay(

std::vector<IterDomain*> new_allocation;
std::vector<std::optional<bool>> new_contiguities;
new_allocation.reserve(allocation.size());
new_contiguities.reserve(contiguities.size());

// Push back the reduction IDs that are not mapped
if (ignore_reductions) {
for (auto* id : new_self->logical()) {
if (id->isReduction()) {
new_allocation.push_back(id);
new_contiguities.push_back(std::nullopt);
}

for (auto* new_id : new_self->logical()) {
if (mapped_new_ids.count(new_id) == 0) {
NVF_ERROR(
new_id->isReduction(),
new_id->toString(),
" should be a reduction.");
new_allocation.push_back(new_id);
new_contiguities.push_back(std::nullopt);
}
}

// Pushing the mapped IDs and corresponding contiguity flags
for (auto&& [alloc_id, contiguity] : zip(allocation, contiguities)) {
if (ignore_reductions && alloc_id->isReduction()) {
for (const auto& [alloc_id, contiguity] : zip(allocation, contiguities)) {
IterDomain* new_alloc_id = getOrDefault(replay, alloc_id);
if (new_alloc_id == nullptr) {
continue;
}
auto it = replay.getReplay().find(alloc_id);
NVF_ERROR(
it != replay.getReplay().end(),
"failed to replay IterDomain: ",
alloc_id);
NVF_ERROR_EQ(
(it->second->isBroadcast() || it->second->isReduction()),
!contiguity.has_value(),
"Contiguity should be nullopt iff broadcast or reduction, true/false "
"otherwise.");
new_contiguities.push_back(contiguity);
it->second->parallelize(alloc_id->getParallelType());
new_allocation.push_back(it->second);
new_alloc_id->parallelize(alloc_id->getParallelType());
new_allocation.push_back(new_alloc_id);

std::optional<bool> new_contiguity = contiguity;
if (new_alloc_id->isBroadcast() || new_alloc_id->isReduction()) {
new_contiguity = std::nullopt;
}
new_contiguities.push_back(new_contiguity);
}

new_self->setAllocationDomain(new_allocation, new_contiguities);
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a little unexpected to me because, even though there's nothing to replay for the allocation, the contiguity could be modified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if self doesn't have an allocation domain but new_self does? Would it work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Added a check and a comment.

NVF_ERROR(
!new_self->hasAllocation(),
"It is unclear what the correct contract should be when replaying an "
"empty transform sequence on a non-empty allocation domain. "
"Fortunately, we do not have a use case for this scenario.");
const std::vector<IterDomain*>& new_logical = new_self->logical();
const auto new_rank = std::ssize(new_logical);
std::vector<std::optional<bool>> new_contiguities(new_rank, std::nullopt);

int new_pos = 0;
for (auto [id, contiguity] : zip(self->logical(), self->contiguity())) {
IterDomain* new_id = getOrDefault(replay, id);
if (new_id == nullptr) {
continue;
}

// Find the corresponding contiguity in new_logical. Mapped IterDomains
// in self->logical() and new_logical follow the same order. So it's safe
// to only increment `new_pos`.
while (new_pos < new_rank && new_logical.at(new_pos) != new_id) {
new_pos++;
}
NVF_ERROR_LT(
new_pos,
new_rank,
"Failed to find ",
new_id->toString(),
" in ",
new_logical);
std::optional<bool>& new_contiguity = new_contiguities.at(new_pos);

new_contiguity = contiguity;
// When used during or before concretization, TransformReplay::selfReplay
// can be applied to replay transformations from symbolic dimensions to
// concrete dimensions, or in the reverse direction. Therefore,
// `new_contiguity` is not always identical to `contiguity`.
if (new_id->isBroadcast()) {
new_contiguity = std::nullopt;
} else if (new_id->isSymbolic()) {
// See AliasTest.AccumulateSlices for an example. aliasOutputToInput is
// called before concretization and tries to replay contiguity from a
// broadcast IterDomain to a symbolic IterDomain. However, a symbolic
// IterDomain can't have contiguity null.
if (!new_contiguity.has_value()) {
new_contiguity = true;
}
}
}

new_self->setContiguity(new_contiguities);
}
}

Expand All @@ -377,7 +435,7 @@ namespace {
std::unordered_set<IterDomain*> getMaybeUnmappedIDs(
const TensorView* tv,
bool is_producer,
const std::unordered_map<IterDomain*, IterDomain*>& root_id_map) {
const IterDomainMap& root_id_map) {
std::unordered_set<Val*> unmapped_root_ids;

const auto& root_domain =
Expand Down Expand Up @@ -1121,7 +1179,7 @@ bool TransformReplay::fullSelfMatching(
auto replay_dom = replay->getLoopDomain();
auto target_root = target->getMaybeRootDomain();
auto target_dom = target->getLoopDomain();
std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
IterDomainMap target2replay_map;
if (replay_root.size() != target_root.size()) {
return false;
}
Expand Down Expand Up @@ -1372,7 +1430,7 @@ namespace {
TensorDomain* fullReplay(
const TensorDomain* old_domain,
const std::vector<IterDomain*>& new_root) {
std::unordered_map<IterDomain*, IterDomain*> old_root_to_new;
IterDomainMap old_root_to_new;
NVF_CHECK(
old_domain->maybeRoot().size() == new_root.size(),
"Unable to replay transformations on a root domain of different size: ",
Expand Down
Loading
Loading