Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1885,7 +1885,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) {

fusion->removeExpr(rop);
IrBuilder::createInContainer<GroupedReductionOp>(
fusion, op_types, init_vals, outputs, inputs, is_allreduce);
fusion->container(), op_types, init_vals, outputs, inputs, is_allreduce);
} else if (tv->definition()->isA<WelfordOp>()) {
// Convert WelfordOp to GroupedWelfordOp
auto wop = def->as<WelfordOp>();
Expand All @@ -1911,7 +1911,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) {
{{wop->initAvg(), wop->initVar(), wop->initN()}});
fusion->removeExpr(wop);
IrBuilder::createInContainer<GroupedWelfordOp>(
fusion, output_vals, input_vals, init_vals, is_allreduce);
fusion->container(), output_vals, input_vals, init_vals, is_allreduce);
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ namespace nvfuser {

DynamicTransformInitialInfo DynamicTransformInitialInfo::clone(
IrCloner& ir_cloner) const {
DynamicTransformInitialInfo cloned_info(
static_cast<Fusion*>(ir_cloner.container()));
DynamicTransformInitialInfo cloned_info(ir_cloner.container()->fusion());
cloned_info.dynamic_reshaped_tvs_.reserve(dynamic_reshaped_tvs_.size());
for (const auto tv : dynamic_reshaped_tvs_) {
cloned_info.dynamic_reshaped_tvs_.push_back(ir_cloner.clone(tv));
Expand Down
2 changes: 1 addition & 1 deletion csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ void PrecomputedValues::invalidate() {
}

PrecomputedValues PrecomputedValues::clone(IrCloner& ir_cloner) const {
PrecomputedValues pv(static_cast<Fusion*>(ir_cloner.container()));
PrecomputedValues pv(ir_cloner.container()->fusion());

// this is a map to unique pointers to vectors, so we need to copy the
// vectors and create new unique pointers
Expand Down
92 changes: 72 additions & 20 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,31 @@ void swap(Fusion& a, Fusion& b) noexcept {

using std::swap;

swap(static_cast<IrContainer&>(a), static_cast<IrContainer&>(b));
swap(a.container_, b.container_);

// Update back-references after swapping containers
if (a.container_) {
a.container_->setOwningFusion(&a);
// Update all statements to point to the swapped container
a.container_->updateAllStatementContainerPointers();
}
if (b.container_) {
b.container_->setOwningFusion(&b);
// Update all statements to point to the swapped container
b.container_->updateAllStatementContainerPointers();
}

swap(a.inputs_, b.inputs_);
swap(a.outputs_, b.outputs_);

swap(a.io_alias_, b.io_alias_);

swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_);
swap(a.is_during_update_uses_, b.is_during_update_uses_);
swap(a.managed_data_, b.managed_data_);
swap(a.managed_named_data_, b.managed_named_data_);
swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_);
swap(a.all_tvs_ptr_, b.all_tvs_ptr_);
}

std::unique_ptr<SegmentedFusion> Fusion::segment(
Expand All @@ -122,11 +141,16 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(

IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();
auto ir_cloner = IrContainer::copy(from, to);

for (auto val : from->vals_) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
auto ir_cloner = IrContainer::copy(from->container(), to->container());

for (auto val : from->vals()) {
Val* cloned_val = ir_cloner.clone(val);
// Only set definition if not already set during Fusion::registerExpr()
// This avoids overwriting definitions that were properly set during cloning
if (cloned_val->definition() == nullptr && val->definition_ != nullptr) {
cloned_val->setDefinition(ir_cloner.clone(val->definition_));
}
cloned_val->setUses(ir_cloner.clone(val->uses_));
}

to->inputs_ = ir_cloner.clone(from->inputs_);
Expand Down Expand Up @@ -183,17 +207,21 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
return ir_cloner;
}

// Clang tidy complains when using default constructor for IrContainer instead
// of copy constructor. Fusion::copy has a call to IrContainer::copy, so it's
// redundant to use the IrContainer copy constructor, but it is harmless since
// Fusion::copy starts by calling clear().
Fusion::Fusion(const Fusion& other) : IrContainer(other) {
Fusion::Fusion() : container_(std::make_unique<IrContainer>()) {
container_->setOwningFusion(this);
}

Fusion::Fusion(const Fusion& other)
: container_(std::make_unique<IrContainer>()) {
FUSER_PERF_SCOPE("Fusion copy");
container_->setOwningFusion(this);
Fusion::copy(&other, this);
}

Fusion::Fusion(Fusion&& other) noexcept {
Fusion::Fusion(Fusion&& other) noexcept
: container_(std::make_unique<IrContainer>()) {
FUSER_PERF_SCOPE("Fusion move");
container_->setOwningFusion(this);
swap(*this, other);
}

Expand Down Expand Up @@ -222,7 +250,9 @@ void Fusion::clear() noexcept {
// constructor of Trace, which could throw an exception.
// FUSER_PERF_SCOPE("Fusion clear");

IrContainer::clear();
if (container_) {
container_->clear();
}

inputs_.clear();
outputs_.clear();
Expand Down Expand Up @@ -259,7 +289,7 @@ void Fusion::removeExpr(Expr* expr) {
}
}

IrContainer::removeExpr(expr);
container_->removeExpr(expr);
}

void Fusion::removeVal(Val* val) {
Expand All @@ -285,7 +315,7 @@ void Fusion::removeVal(Val* val) {
// caused a segfault when the fusion was cloned since that will clone not only
// live objects but also these dangerous dangling dead ones.
std::vector<Expr*> exprs_to_remove;
for (Expr* e : exprs_) {
for (Expr* e : unordered_exprs()) {
if (!inContainer(e)) {
continue;
}
Expand All @@ -298,7 +328,7 @@ void Fusion::removeVal(Val* val) {
for (auto e : exprs_to_remove) {
removeExpr(e);
}
IrContainer::removeVal(val);
container_->removeVal(val);

invalidateTvsAndUses();
}
Expand Down Expand Up @@ -652,6 +682,14 @@ void Fusion::printTransforms() {
t_exprs.handle(this);
}

void Fusion::registerStmt(Statement* stmt) {
if (stmt->isVal()) {
registerVal(stmt->asVal());
} else {
registerExpr(stmt->asExpr());
}
}

void Fusion::registerVal(Val* val) {
if (inContainer(val)) {
return;
Expand All @@ -662,7 +700,7 @@ void Fusion::registerVal(Val* val) {
val->fusion() == this, val, " was not found in the active fusion.");
}

IrContainer::registerVal(val);
container_->registerVal(IrBuilderPasskey(container()), val);
}

void Fusion::registerExpr(Expr* expr) {
Expand All @@ -675,7 +713,7 @@ void Fusion::registerExpr(Expr* expr) {
expr->fusion() == this, expr, " was not found in the active fusion.");
}

IrContainer::registerExpr(expr);
container_->registerExpr(IrBuilderPasskey(container()), expr);

for (Val* input : expr->inputs()) {
assertInContainer(input, "Input to expr is invalid, ");
Expand All @@ -696,7 +734,14 @@ void Fusion::registerExpr(Expr* expr) {
for (Val* output : expr->outputs()) {
assertInContainer(output, "Output to expr is invalid, ");
if (output->definition() != nullptr && is_ssa) {
removeExpr(output->definition());
// Only remove old definition if it belongs to THIS container
// During cloning, old definition might be from source container
if (inContainer(output->definition())) {
removeExpr(output->definition());
} else {
// Old definition is from a different container - clear it
output->definition_ = nullptr;
}
}
if (is_ssa || output->definition() == nullptr) {
output->setDefinition(expr);
Expand All @@ -707,6 +752,13 @@ void Fusion::registerExpr(Expr* expr) {
// vector after setDefinition.
invalidateTvsAndUses();
}
} else {
// DEBUG: This branch means definition was not set!
// This happens in non-SSA contexts (Kernel/HostIrContainer) when
// output already has a definition and we don't overwrite it
if (output->isA<TensorView>() && !is_ssa) {
// Expected for non-SSA - multiple definitions allowed
}
}
}
}
Expand All @@ -718,7 +770,7 @@ void Fusion::resetTvUses() {
// getExprs only uses definition, so even if we've modified uses already to
// remove dead exprs, this could reinsert them. getExprs is also boundeds by
// inputs as registered inputs will return nullptr as their definition.
const auto all_tvs = ir_utils::filterByType<TensorView>(vals_);
const auto all_tvs = ir_utils::filterByType<TensorView>(vals());
const auto used_exprs = StmtSort::getExprs(this);

for (auto tv : all_tvs) {
Expand Down
Loading
Loading