diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index f786f6847ac..e784768b2c9 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -1885,7 +1885,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { fusion->removeExpr(rop); IrBuilder::createInContainer( - fusion, op_types, init_vals, outputs, inputs, is_allreduce); + fusion->container(), op_types, init_vals, outputs, inputs, is_allreduce); } else if (tv->definition()->isA()) { // Convert WelfordOp to GroupedWelfordOp auto wop = def->as(); @@ -1911,7 +1911,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { {{wop->initAvg(), wop->initVar(), wop->initN()}}); fusion->removeExpr(wop); IrBuilder::createInContainer( - fusion, output_vals, input_vals, init_vals, is_allreduce); + fusion->container(), output_vals, input_vals, init_vals, is_allreduce); } } } diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 9951339fd90..eaa07221e2b 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -28,8 +28,7 @@ namespace nvfuser { DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( IrCloner& ir_cloner) const { - DynamicTransformInitialInfo cloned_info( - static_cast(ir_cloner.container())); + DynamicTransformInitialInfo cloned_info(ir_cloner.container()->fusion()); cloned_info.dynamic_reshaped_tvs_.reserve(dynamic_reshaped_tvs_.size()); for (const auto tv : dynamic_reshaped_tvs_) { cloned_info.dynamic_reshaped_tvs_.push_back(ir_cloner.clone(tv)); diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5e983777b04..83d7bc18345 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -294,7 +294,7 @@ void PrecomputedValues::invalidate() { } PrecomputedValues PrecomputedValues::clone(IrCloner& ir_cloner) const { - PrecomputedValues pv(static_cast(ir_cloner.container())); + PrecomputedValues pv(ir_cloner.container()->fusion()); // this is a map to unique pointers to vectors, so we need to copy the // vectors and create new unique pointers diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 5c086599579..af014bf8ce6 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -106,12 +106,31 @@ void swap(Fusion& a, Fusion& b) noexcept { using std::swap; - swap(static_cast(a), static_cast(b)); + swap(a.container_, b.container_); + + // Update back-references after swapping containers + if (a.container_) { + a.container_->setOwningFusion(&a); + // Update all statements to point to the swapped container + a.container_->updateAllStatementContainerPointers(); + } + if (b.container_) { + b.container_->setOwningFusion(&b); + // Update all statements to point to the swapped container + b.container_->updateAllStatementContainerPointers(); + } swap(a.inputs_, b.inputs_); swap(a.outputs_, b.outputs_); swap(a.io_alias_, b.io_alias_); + + swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_); + swap(a.is_during_update_uses_, b.is_during_update_uses_); + swap(a.managed_data_, b.managed_data_); + swap(a.managed_named_data_, b.managed_named_data_); + swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_); + swap(a.all_tvs_ptr_, b.all_tvs_ptr_); } std::unique_ptr Fusion::segment( @@ -122,11 +141,16 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = IrContainer::copy(from, to); - - for (auto val : from->vals_) { - ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); - ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); + auto ir_cloner = IrContainer::copy(from->container(), to->container()); + + for (auto val : from->vals()) { + Val* cloned_val = ir_cloner.clone(val); + // Only set definition if not already set during Fusion::registerExpr() + // This avoids overwriting definitions that were properly set during cloning + if (cloned_val->definition() == nullptr && val->definition_ != nullptr) { + cloned_val->setDefinition(ir_cloner.clone(val->definition_)); + } + cloned_val->setUses(ir_cloner.clone(val->uses_)); } to->inputs_ = ir_cloner.clone(from->inputs_); @@ -183,17 +207,21 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { return ir_cloner; } -// Clang tidy complains when using default constructor for IrContainer instead -// of copy constructor. Fusion::copy has a call to IrContainer::copy, so it's -// redundant to use the IrContainer copy constructor, but it is harmless since -// Fusion::copy starts by calling clear(). -Fusion::Fusion(const Fusion& other) : IrContainer(other) { +Fusion::Fusion() : container_(std::make_unique()) { + container_->setOwningFusion(this); +} + +Fusion::Fusion(const Fusion& other) + : container_(std::make_unique()) { FUSER_PERF_SCOPE("Fusion copy"); + container_->setOwningFusion(this); Fusion::copy(&other, this); } -Fusion::Fusion(Fusion&& other) noexcept { +Fusion::Fusion(Fusion&& other) noexcept + : container_(std::make_unique()) { FUSER_PERF_SCOPE("Fusion move"); + container_->setOwningFusion(this); swap(*this, other); } @@ -222,7 +250,9 @@ void Fusion::clear() noexcept { // constructor of Trace, which could throw an exception. // FUSER_PERF_SCOPE("Fusion clear"); - IrContainer::clear(); + if (container_) { + container_->clear(); + } inputs_.clear(); outputs_.clear(); @@ -259,7 +289,7 @@ void Fusion::removeExpr(Expr* expr) { } } - IrContainer::removeExpr(expr); + container_->removeExpr(expr); } void Fusion::removeVal(Val* val) { @@ -285,7 +315,7 @@ void Fusion::removeVal(Val* val) { // caused a segfault when the fusion was cloned since that will clone not only // live objects but also these dangerous dangling dead ones. std::vector exprs_to_remove; - for (Expr* e : exprs_) { + for (Expr* e : unordered_exprs()) { if (!inContainer(e)) { continue; } @@ -298,7 +328,7 @@ void Fusion::removeVal(Val* val) { for (auto e : exprs_to_remove) { removeExpr(e); } - IrContainer::removeVal(val); + container_->removeVal(val); invalidateTvsAndUses(); } @@ -652,6 +682,14 @@ void Fusion::printTransforms() { t_exprs.handle(this); } +void Fusion::registerStmt(Statement* stmt) { + if (stmt->isVal()) { + registerVal(stmt->asVal()); + } else { + registerExpr(stmt->asExpr()); + } +} + void Fusion::registerVal(Val* val) { if (inContainer(val)) { return; @@ -662,7 +700,7 @@ void Fusion::registerVal(Val* val) { val->fusion() == this, val, " was not found in the active fusion."); } - IrContainer::registerVal(val); + container_->registerVal(IrBuilderPasskey(container()), val); } void Fusion::registerExpr(Expr* expr) { @@ -675,7 +713,7 @@ void Fusion::registerExpr(Expr* expr) { expr->fusion() == this, expr, " was not found in the active fusion."); } - IrContainer::registerExpr(expr); + container_->registerExpr(IrBuilderPasskey(container()), expr); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); @@ -696,7 +734,14 @@ void Fusion::registerExpr(Expr* expr) { for (Val* output : expr->outputs()) { assertInContainer(output, "Output to expr is invalid, "); if (output->definition() != nullptr && is_ssa) { - removeExpr(output->definition()); + // Only remove old definition if it belongs to THIS container + // During cloning, old definition might be from source container + if (inContainer(output->definition())) { + removeExpr(output->definition()); + } else { + // Old definition is from a different container - clear it + output->definition_ = nullptr; + } } if (is_ssa || output->definition() == nullptr) { output->setDefinition(expr); @@ -707,6 +752,13 @@ void Fusion::registerExpr(Expr* expr) { // vector after setDefinition. invalidateTvsAndUses(); } + } else { + // DEBUG: This branch means definition was not set! + // This happens in non-SSA contexts (Kernel/HostIrContainer) when + // output already has a definition and we don't overwrite it + if (output->isA() && !is_ssa) { + // Expected for non-SSA - multiple definitions allowed + } } } } @@ -718,7 +770,7 @@ void Fusion::resetTvUses() { // getExprs only uses definition, so even if we've modified uses already to // remove dead exprs, this could reinsert them. getExprs is also boundeds by // inputs as registered inputs will return nullptr as their definition. - const auto all_tvs = ir_utils::filterByType(vals_); + const auto all_tvs = ir_utils::filterByType(vals()); const auto used_exprs = StmtSort::getExprs(this); for (auto tv : all_tvs) { diff --git a/csrc/fusion.h b/csrc/fusion.h index b42ac8623dd..0ef0696ac2a 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -142,11 +142,11 @@ class AliasInfoMap { //! The Fusion owns the whole IR graph (Vals and Exprs) //! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class NVF_API Fusion : public IrContainer { +class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; public: - Fusion() = default; + Fusion(); Fusion(const Fusion& other); Fusion(Fusion&& other) noexcept; @@ -154,12 +154,105 @@ class NVF_API Fusion : public IrContainer { Fusion& operator=(const Fusion& other); Fusion& operator=(Fusion&& other) noexcept; - ~Fusion() override; + virtual ~Fusion(); friend void swap(Fusion& a, Fusion& b) noexcept; void clear() noexcept; + // Accessor for the underlying container + IrContainer* container() { + return container_.get(); + } + const IrContainer* container() const { + return container_.get(); + } + + // Forward IrContainer methods + bool inContainer(const Statement* stmt) const { + return container_->inContainer(stmt); + } + + void assertInContainer(const Statement* stmt, const std::string& msg) const { + container_->assertInContainer(stmt, msg); + } + + const std::deque deterministic_vals() const noexcept { + return container_->deterministic_vals(); + } + + const std::deque deterministic_exprs() const noexcept { + return container_->deterministic_exprs(); + } + + const std::unordered_map deterministic_vals_map() const noexcept { + return container_->deterministic_vals_map(); + } + + const std::unordered_map deterministic_exprs_map() const noexcept { + return container_->deterministic_exprs_map(); + } + + const std::unordered_set& unordered_exprs() const noexcept { + return container_->unordered_exprs(); + } + + const std::unordered_set& vals() const noexcept { + return container_->vals(); + } + + int64_t numExprs() const noexcept { + return container_->numExprs(); + } + + int64_t numVals(bool include_shortcuts) const noexcept { + return container_->numVals(include_shortcuts); + } + + Val* zeroVal() { + return container_->zeroVal(); + } + + Val* oneVal() { + return container_->oneVal(); + } + + Val* falseVal() { + return container_->falseVal(); + } + + Val* trueVal() { + return container_->trueVal(); + } + + NamedScalar* magicZeroVal() { + return container_->magicZeroVal(); + } + + Val* zeroVal(DataType dtype) { + return container_->zeroVal(dtype); + } + + Val* oneVal(DataType dtype) { + return container_->oneVal(dtype); + } + + Val* metadataOf(Val* val) { + return container_->metadataOf(val); + } + + const std::vector& axioms() { + return container_->axioms(); + } + + void assumePositive(Val* val) { + container_->assumePositive(val); + } + + void assumeNonNegative(Val* val) { + container_->assumeNonNegative(val); + } + // Hash the fusion. This is used to identify the fusion in the cache. size_t hash() const; @@ -168,11 +261,11 @@ class NVF_API Fusion : public IrContainer { //! Break dependency chains associated with Expr, remove references to expr //! delete expr - void removeExpr(Expr* expr) override; + virtual void removeExpr(Expr* expr); //! Completely remove val from the fusion, break all dependencies associated //! with it - void removeVal(Val* val) override; + virtual void removeVal(Val* val); //! Register input as an input of the fusion void addInput(Val* input); @@ -484,12 +577,13 @@ class NVF_API Fusion : public IrContainer { friend SegmentedFusion; friend class TranslateApplicableWelford; friend Val; + friend class IrBuilder; - using IrContainer::registerExpr; - using IrContainer::registerVal; + //! Register a statement (Val or Expr) with this fusion + virtual void registerStmt(Statement* stmt); //! Register the Val with this fusion - void registerVal(Val* val) override; + virtual void registerVal(Val* val); //! Register expr with this fusion. //! When we register an expression, we want to update the dependency tracking @@ -497,7 +591,7 @@ class NVF_API Fusion : public IrContainer { //! definitions of outputs and register this Expr as the definition. Otherwise //! will update definition if not previously set, but will not remove old //! definitions. - void registerExpr(Expr* expr) override; + virtual void registerExpr(Expr* expr); //! Clear Expr's from TV uses that are not required to produce outputs from //! inputs. Only other place this is used (other than Fusion) is in @@ -512,6 +606,9 @@ class NVF_API Fusion : public IrContainer { } private: + // Container that owns all IR nodes + std::unique_ptr container_; + // Fusion inputs and outputs std::vector inputs_; std::vector outputs_; diff --git a/csrc/grouped_reduction.cpp b/csrc/grouped_reduction.cpp index 699ca71ed91..e8c0067e713 100644 --- a/csrc/grouped_reduction.cpp +++ b/csrc/grouped_reduction.cpp @@ -56,7 +56,7 @@ bool validateReductionGrouping( NVF_ERROR(inputs.size() == outputs.size()); NVF_ERROR(!inputs.empty()); - auto fusion = dynamic_cast(outputs[0]->container()); + auto fusion = dynamic_cast(outputs[0]->container()->fusion()); NVF_ERROR( fusion != nullptr, "Grouping of reductions must be done within a Fusion"); diff --git a/csrc/host_ir/container.cpp b/csrc/host_ir/container.cpp index 71888e04200..89619815460 100644 --- a/csrc/host_ir/container.cpp +++ b/csrc/host_ir/container.cpp @@ -23,7 +23,7 @@ namespace hir { Stream* HostIrContainer::getDefaultStream() { if (default_stream_ == nullptr) { - default_stream_ = IrBuilder::createInContainer(this); + default_stream_ = IrBuilder::createInContainer(this->container()); } return default_stream_; } diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 5b46b7f0abb..cd1818f6725 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -26,7 +26,7 @@ namespace nvfuser::hir { HostUnit::HostUnit(IrBuilderPasskey passkey, std::unique_ptr fusion) : Expr(passkey), fusion_(std::make_unique(*fusion)) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_ERROR(passkey.ir_container_->isA()); + NVF_ERROR(passkey.ir_container_->fusion()->isA()); } HostUnit::HostUnit(const HostUnit* src, IrCloner* ir_cloner) @@ -72,7 +72,7 @@ PostOnStream::PostOnStream( : Expr(passkey, std::move(inputs), std::move(outputs), {host_op}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); NVF_ERROR( @@ -223,7 +223,7 @@ bool Stream::sameAs(const Statement* other) const { SetCurrentStream::SetCurrentStream(IrBuilderPasskey passkey, Stream* stream) : Expr(passkey, {stream}, {}, {stream}) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_ERROR(passkey.ir_container_->isA()); + NVF_ERROR(passkey.ir_container_->fusion()->isA()); } NVFUSER_DEFINE_CLONE_AND_CREATE(SetCurrentStream) @@ -247,7 +247,7 @@ bool SetCurrentStream::sameAs(const Statement* other) const { GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_ERROR(passkey.ir_container_->isA()); + NVF_ERROR(passkey.ir_container_->fusion()->isA()); auto stream = IrBuilder::createInContainer(passkey.ir_container_); addAttribute(stream); } @@ -265,7 +265,7 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); NVF_ERROR( @@ -297,7 +297,7 @@ Synchronize::Synchronize(IrBuilderPasskey passkey, Stream* stream) : Expr(passkey, {}, {}, {stream}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); } @@ -322,7 +322,7 @@ bool Synchronize::sameAs(const Statement* other) const { StartCoalescing::StartCoalescing(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); } @@ -342,7 +342,7 @@ std::string StartCoalescing::toInlineString(int indent_size) const { EndCoalescing::EndCoalescing(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); } @@ -365,7 +365,7 @@ ShareMemHandles::ShareMemHandles( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); addDataAttribute(std::move(communications)); @@ -396,7 +396,7 @@ HirAliasSelect::HirAliasSelect( : Expr(passkey, {in, index}, {}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), this, "must be registered in a HostIrContainer"); NVF_ERROR( @@ -439,7 +439,7 @@ ShardByStream::ShardByStream( Val* stream_index) : Expr(passkey, {in, stream_index}, {out}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_ERROR(passkey.ir_container_->isA()); + NVF_ERROR(passkey.ir_container_->fusion()->isA()); NVF_ERROR_EQ( TensorDomain::noReductions(in->getLogicalDomain()).size(), out->getLogicalDomain().size()); @@ -466,7 +466,7 @@ SymmetricContiguousView::SymmetricContiguousView( TensorView* in) : Expr(passkey, {in}, {out}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_ERROR(passkey.ir_container_->isA()); + NVF_ERROR(passkey.ir_container_->fusion()->isA()); NVF_ERROR( in->getMemoryType() == MemoryType::Symmetric, "Input tensor must have symmetric memory type, got: ", @@ -489,7 +489,7 @@ std::string SymmetricContiguousView::toInlineString(int indent_size) const { ForLoop::ForLoop(IrBuilderPasskey passkey, Val* index, Val* start, Val* stop) : Expr(passkey, {index, start, stop}, {}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_ERROR(passkey.ir_container_->isA()); + NVF_ERROR(passkey.ir_container_->fusion()->isA()); addDataAttribute(Scope(this)); } diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 6c29c472d89..d5c69bf5b50 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -123,7 +123,7 @@ std::unique_ptr HostIrLower::lower( // the segmented fusion will be translated to a HostIR auto hic = std::make_unique(); FusionGuard fg(hic.get()); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); auto clone = [&ir_cloner](const std::vector& vals) -> std::vector { std::vector cloned_vals(vals.size()); diff --git a/csrc/id_model/schedule.cpp b/csrc/id_model/schedule.cpp index a5867598259..5745ea06309 100644 --- a/csrc/id_model/schedule.cpp +++ b/csrc/id_model/schedule.cpp @@ -131,7 +131,7 @@ std::pair split( graph, g, IrBuilder::createInContainer( - g->front()->fusion(), factor, DataType::Index), + g->front()->fusion()->container(), factor, DataType::Index), inner_split); } diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 974f2f109a9..9464e04a54f 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -71,16 +71,17 @@ std::string Statement::toInlineString(int indent_size) const { } Fusion* Statement::fusion() const { - NVF_ERROR( - ir_container_->isA(), "Statement does not belong to a fusion."); - return ir_container_->as(); + Fusion* fusion = ir_container_->fusion(); + NVF_ERROR(fusion != nullptr, "Statement does not belong to a fusion."); + return fusion; } kir::Kernel* Statement::kernel() const { + Fusion* fusion = ir_container_->fusion(); NVF_ERROR( - ir_container_->isA(), + fusion != nullptr && fusion->isA(), "Statement does not belong to a kernel."); - return ir_container_->as(); + return fusion->as(); } NVFUSER_DEFINE_CLONE(Val) @@ -364,7 +365,7 @@ Expr::Expr( Expr* Expr::shallowCopy() const { auto result = newObjectFunc()(ir_container_, inputs(), outputs(), attributes()); - if (container()->isA()) { + if (container()->fusion() != nullptr && container()->fusion()->isA()) { result->predicate_ = predicate_; result->write_predicate_ = write_predicate_; } @@ -483,14 +484,16 @@ bool Expr::sameAs(const Statement* other) const { kir::Predicate* Expr::predicate() const { NVF_ERROR( - (container()->isOneOf()), + container()->fusion() != nullptr && + (container()->fusion()->isOneOf()), "Function invalid for fusion."); return predicate_; } void Expr::setPredicate(kir::Predicate* predicate) { NVF_ERROR( - (container()->isOneOf()), + container()->fusion() != nullptr && + (container()->fusion()->isOneOf()), "Function invalid for fusion."); predicate_ = predicate; } @@ -502,12 +505,12 @@ Expr* Expr::withPredicate(kir::Predicate* predicate) { } kir::Predicate* Expr::writePredicate() const { - NVF_ERROR(container()->isA(), "Function invalid for fusion."); + NVF_ERROR(container()->fusion() != nullptr && container()->fusion()->isA(), "Function invalid for fusion."); return write_predicate_; } void Expr::setWritePredicate(kir::Predicate* write_predicate) { - NVF_ERROR(container()->isA(), "Function invalid for fusion."); + NVF_ERROR(container()->fusion() != nullptr && container()->fusion()->isA(), "Function invalid for fusion."); write_predicate_ = write_predicate; } diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 232d7f4fade..9fc076b2c01 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -95,6 +95,7 @@ class ExprPasskey { class NVF_API Statement : public NonCopyable, public PolymorphicBase { friend void swap(Fusion&, Fusion&) noexcept; friend void swap(IrContainer& a, IrContainer& b) noexcept; + friend class IrContainer; public: Statement() = delete; diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 45ab743b22a..e4ce600353b 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -20,6 +20,23 @@ namespace nvfuser { +IrContainer* IrBuilder::getActiveContainer() { + Fusion* fusion = FusionGuard::getCurFusion(); + NVF_ERROR(fusion != nullptr, "Need an active fusion to build IR."); + return fusion->container(); +} + +void IrBuilder::registerWithContainer(IrContainer* container, Statement* stmt) { + // If the container belongs to a Fusion, register through the Fusion + // to ensure proper definition tracking and other Fusion-level bookkeeping + if (container->fusion() != nullptr) { + container->fusion()->registerStmt(stmt); + } else { + // For standalone containers, register directly + container->registerStmt(IrBuilderPasskey(container), stmt); + } +} + Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { NVF_CHECK( lhs != nullptr && rhs != nullptr, diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 6de18ce6c78..79baee3bc9f 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -29,15 +29,21 @@ class Val; //! IR builder interface class IrBuilder { + private: + // Helper to get container from current fusion + static IrContainer* getActiveContainer(); + public: //! Allocate a new IR node, forwarding the arguments to the appropriate //! constructor and registering with the container template static T* create(Args&&... args) { - Fusion* fusion = FusionGuard::getCurFusion(); - return createInContainer(fusion, std::forward(args)...); + return createInContainer(getActiveContainer(), std::forward(args)...); } + // Helper to register a statement with the appropriate container/fusion + static void registerWithContainer(IrContainer* container, Statement* stmt); + //! Allocate a new IR node, forwarding the arguments to the appropriate //! constructor and registering with the container template @@ -45,7 +51,7 @@ class IrBuilder { NVF_ERROR(container != nullptr, "Need an active container to build IR."); T* node = new T(IrBuilderPasskey(container), std::forward(args)...); - container->registerStmt(IrBuilderPasskey(container), node); + registerWithContainer(container, node); return node; } diff --git a/csrc/ir/builder_passkey.h b/csrc/ir/builder_passkey.h index 9feb26fb010..daed35484f4 100644 --- a/csrc/ir/builder_passkey.h +++ b/csrc/ir/builder_passkey.h @@ -15,6 +15,7 @@ class IrContainer; // functions in IrContainer class IrBuilderPasskey { friend class IrBuilder; + friend class Fusion; public: // TODO: Collapse ir_container and Kernel once Kernel inherits from diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 1cdf98fc08d..f1f960c2257 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -90,7 +90,7 @@ TensorView* RecomputeTv::recompute( return cloned_val->as(); } -RecomputeTv::RecomputeTv(Fusion* fusion) : IrCloner(fusion) { +RecomputeTv::RecomputeTv(Fusion* fusion) : IrCloner(fusion->container()) { // Add inputs to the clones map to prevent cloning them. for (const auto inp : fusion->inputs()) { clones_map_[inp] = inp; diff --git a/csrc/ir/cloner.h b/csrc/ir/cloner.h index 14f15323ee0..553843147fb 100644 --- a/csrc/ir/cloner.h +++ b/csrc/ir/cloner.h @@ -195,7 +195,9 @@ T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { auto dest_container = ir_cloner->container(); auto src_container = src_stmt->container(); - dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + // Use the same registration helper as createInContainer to ensure + // proper registration through Fusion if one exists + registerWithContainer(dest_container, dest_stmt); if (src_container != dest_container) { dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index e3b33b26479..cc5fcd389eb 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -5,12 +5,15 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include +#include #include #include #include #include #include #include +#include namespace nvfuser { @@ -57,14 +60,16 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { // that are not registered in the container. for (auto val : from->deterministic_vals()) { if (from->vals().count(val) > 0) { - to->vals_.insert(ir_cloner.clone(val)); + // Clone - registration happens inside clone() via registerWithContainer() + ir_cloner.clone(val); } } // Copy expressions in deterministic order for (auto expr : from->deterministic_exprs()) { if (from->unordered_exprs().count(expr) > 0) { - to->exprs_.insert(ir_cloner.clone(expr)); + // Clone - registration happens inside clone() via registerWithContainer() + ir_cloner.clone(expr); } } @@ -241,6 +246,20 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } +void IrContainer::updateAllStatementContainerPointers() { + // Update all Val pointers + for (auto* val : vals_) { + // Access the protected ir_container_ member through Statement base class + const_cast(static_cast(val))->ir_container_ = + this; + } + // Update all Expr pointers + for (auto* expr : exprs_) { + const_cast(static_cast(expr))->ir_container_ = + this; + } +} + // Shortcuts for frequently used vals Val* IrContainer::zeroVal() { if (!zero_val_) { diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 02ad0572c9a..f8645958020 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -32,7 +32,7 @@ class IrContainerPasskey { explicit IrContainerPasskey() = default; }; -class IrContainer : public PolymorphicBase { +class IrContainer { public: NVF_API IrContainer(); @@ -42,7 +42,7 @@ class IrContainer : public PolymorphicBase { IrContainer& operator=(const IrContainer& other); IrContainer& operator=(IrContainer&& other) noexcept; - ~IrContainer() override; + ~IrContainer(); bool inContainer(const Statement* stmt) const; @@ -51,6 +51,10 @@ class IrContainer : public PolymorphicBase { inContainer(stmt), msg, " it was not found in the active container."); } + //! Update all statements in this container to point to this container + //! This is needed after swapping containers between Fusions + void updateAllStatementContainerPointers(); + //! Return values in insertion order const std::deque deterministic_vals() const noexcept { std::deque vals_deque; @@ -104,13 +108,13 @@ class IrContainer : public PolymorphicBase { } //! Register the Statement with this container - NVF_API virtual void registerStmt(IrBuilderPasskey, Statement* stmt); + NVF_API void registerStmt(IrBuilderPasskey, Statement* stmt); //! Register the Val with this container - NVF_API virtual void registerVal(IrBuilderPasskey, Val* val); + NVF_API void registerVal(IrBuilderPasskey, Val* val); //! Register expr with this container. - NVF_API virtual void registerExpr(IrBuilderPasskey, Expr* expr); + NVF_API void registerExpr(IrBuilderPasskey, Expr* expr); //! Return the set of Exprs registered with this fusion. Warning: This will //! return exprs outside inputs/outputs, so can be unsafe for use with @@ -152,6 +156,17 @@ class IrContainer : public PolymorphicBase { void assumePositive(Val* val); void assumeNonNegative(Val* val); + // Get the owning Fusion (if this container is owned by a Fusion) + // Returns nullptr for standalone containers + Fusion* fusion() const { + return owning_fusion_; + } + + // Set the owning fusion - should only be called by Fusion + void setOwningFusion(Fusion* fusion) { + owning_fusion_ = fusion; + } + protected: static IrCloner copy(const IrContainer* from, IrContainer* to); @@ -160,17 +175,20 @@ class IrContainer : public PolymorphicBase { // Let mutator remove Exprs. friend OptOutMutator; - virtual void removeExpr(Expr* expr); + // Allow Fusion to access protected members since it now uses composition + friend class Fusion; + + void removeExpr(Expr* expr); //! Completely remove val from the fusion, break all dependencies associated //! with it - virtual void removeVal(Val* val); + void removeVal(Val* val); //! Register the Val with this container - virtual void registerVal(Val* val); + void registerVal(Val* val); //! Register expr with this container. - virtual void registerExpr(Expr* expr); + void registerExpr(Expr* expr); StmtNameType getValName(ValType vtype) { if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { @@ -234,6 +252,10 @@ class IrContainer : public PolymorphicBase { std::unique_ptr magic_zero_val_; std::unique_ptr> axioms_; std::unordered_map> metadata_; + + // Back-reference to owning Fusion (if any) for composition pattern + // This is set by Fusion when it creates/owns this container + Fusion* owning_fusion_ = nullptr; }; } // namespace nvfuser diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 7d7a318545c..9b0f551d886 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2518,7 +2518,7 @@ std::string LoadStoreOp::toString(int indent_size) const { indent(ss, indent_size + 1) << " = " << optype << modifier << "( " << in()->toString(); // Fusion IR does not have predicate - if (container()->isA() && predicate() != nullptr) { + if (container()->fusion()->isA() && predicate() != nullptr) { ss << ", " << std::endl; indent(ss, indent_size + 1) << std::string(optype.size() + 5, ' ') << predicate()->toInlineString(); @@ -3005,7 +3005,7 @@ CatOp::CatOp( passkey.ir_container_ != nullptr, "IrContainer must be provided to create a CatOp."); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "Should only be used for Kernel container."); addOutput(out); @@ -3037,7 +3037,7 @@ std::string CatOp::toInlineString(int indent_size) const { Val* CatOp::getConcatenatedDomainIndex() const { NVF_ERROR( - container()->isA(), + container()->fusion()->isA(), "Should only be used for Kernel container."); NVF_ERROR(!attributes().empty(), "No attribute found"); NVF_ERROR(attribute(1) != nullptr, "nulllptr attribute is invalid"); @@ -3047,7 +3047,7 @@ Val* CatOp::getConcatenatedDomainIndex() const { Val* CatOp::getPred(int input_idx) const { NVF_ERROR( - container()->isA(), + container()->fusion()->isA(), "Should only be used for Kernel container."); const auto num_input_tensors = static_cast(inputs().size()); NVF_ERROR(input_idx < num_input_tensors, "Invalid input index: ", input_idx); diff --git a/csrc/kernel.h b/csrc/kernel.h index c9de9084cc7..4a644e39441 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -280,9 +280,6 @@ class NVF_API Kernel final : public Fusion { } protected: - using IrContainer::registerExpr; - using IrContainer::registerVal; - //! Register the Val with this fusion void registerVal(Val* val) override; diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index b0aae31f75d..72132121e48 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -54,8 +54,8 @@ ForLoop::ForLoop( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA() || - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA() || + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel or Host container."); NVF_ERROR(isIntegralType(index->dtype())); addInput(index); @@ -312,7 +312,7 @@ namespace { class RuntimeReductionFinder : kir::ConstIrVisitor { public: static bool exists(const Expr* expr) { - NVF_CHECK(expr->container()->isA()); + NVF_CHECK(expr->container()->fusion()->isA()); RuntimeReductionFinder finder; finder.handle(std::vector{expr}); return finder.is_found_; @@ -356,7 +356,7 @@ Predicate::Predicate( thread_pred_(thread_pred) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR(ptype != PredicateType::Unswitch && ptype != PredicateType::Manual); } @@ -372,7 +372,7 @@ Predicate::Predicate( tma_1d_load_loops_(std::move(tma_1d_load_loops)) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR(ptype == PredicateType::OneDimTmaLoadExpectArrive); NVF_ERROR(!tma_1d_load_loops_.empty()); @@ -384,7 +384,7 @@ Predicate::Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop) unrolled_loop_(unrolled_loop) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR(unrolled_loop != nullptr); } @@ -395,7 +395,7 @@ Predicate::Predicate(IrBuilderPasskey passkey, Val* value) value_(value) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - (passkey.ir_container_->isOneOf()), + (passkey.ir_container_->fusion()->isOneOf()), "IR type only valid for Kernel or HostIr container."); NVF_ERROR(value != nullptr); } @@ -425,7 +425,7 @@ TensorIndex::TensorIndex( index_(index) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); auto uint16x2 = ArrayType{std::make_shared(DataType::UInt16), 2}; NVF_ERROR( @@ -478,7 +478,7 @@ Allocate::Allocate( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - (passkey.ir_container_->isOneOf()), + (passkey.ir_container_->fusion()->isOneOf()), "IR type only valid for Kernel or HostIr container."); if (!shape.empty()) { NVF_ERROR( @@ -854,7 +854,7 @@ AllocTMem::AllocTMem(IrBuilderPasskey passkey, Val* address, Val* num_columns) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR( ir_utils::getTv(address)->getMemoryType() == MemoryType::Shared, @@ -886,7 +886,7 @@ BlockSync::BlockSync( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addDataAttribute(war_sync); addDataAttribute(optional_compute_or_load_sync); @@ -911,7 +911,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(BlockSync) ClusterSync::ClusterSync(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -953,7 +953,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GridSync) FenceAsyncProxy::FenceAsyncProxy(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -970,7 +970,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(FenceAsyncProxy) WgMmaFence::WgMmaFence(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -991,7 +991,7 @@ SetMaxNReg::SetMaxNReg( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addInput(number_of_registers); addDataAttribute(increase_registers); @@ -1016,7 +1016,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(SetMaxNReg) Continue::Continue(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -1035,7 +1035,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(Continue) Return::Return(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -1247,7 +1247,7 @@ AsyncWait::AsyncWait( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addDataAttribute(async_op_type); addDataAttribute(keep_stages); @@ -1298,7 +1298,7 @@ AsyncCommit::AsyncCommit(IrBuilderPasskey passkey, AsyncOpType async_op_type) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addDataAttribute(async_op_type); } @@ -1343,7 +1343,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(AsyncCommit) InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -1362,7 +1362,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(InitMagicZero) UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); } @@ -1420,7 +1420,7 @@ GridReduction::GridReduction( : ReductionOp(passkey, reduction_op_type, init, out, in, is_allreduce) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR( attributes().size() == num_reduction_op_attr, @@ -1503,7 +1503,7 @@ GroupedGridReduction::GroupedGridReduction( is_allreduce) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR( attributes().size() == numGroupedReductionOpAttr(), @@ -1571,7 +1571,7 @@ GridBroadcast::GridBroadcast( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addAttribute(broadcast_op); addAttribute(broadcast_buffer); @@ -1609,7 +1609,7 @@ GridWelford::GridWelford( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addAttribute(welford_op); addAttribute(var_buffer); @@ -1708,7 +1708,7 @@ GroupedGridWelford::GroupedGridWelford( is_allreduce) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR( attributes().size() == numGroupedWelfordOpAttr(), @@ -1737,7 +1737,7 @@ int64_t GroupedGridWelford::getSmemBufferSize( int64_t bdimy, int64_t bdimz) const { auto out_tv = ir_utils::getTvOutput(this); - auto kernel = dynamic_cast(container()); + auto kernel = container()->fusion()->as(); NVF_ERROR(kernel != nullptr); // By default, the required size is the same as the normal Welford reduction @@ -1842,7 +1842,7 @@ VectorizedWelfordOp::VectorizedWelfordOp( : WelfordOp(passkey, output, input, init, false) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addAttribute(count); addAttribute(reciprocal_of_count); @@ -1857,7 +1857,7 @@ AllocateFusedReduction::AllocateFusedReduction( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addAttribute(grid_expr); } @@ -2039,7 +2039,7 @@ RNGOp::RNGOp( : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR(out->isA()); NVF_ERROR(rng_result->isA()); @@ -2105,7 +2105,7 @@ ClusterReductionOp::ClusterReductionOp( is_all_reduce) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); addInput(mbarrier); } @@ -2136,7 +2136,7 @@ GroupedLoadStoreOp::GroupedLoadStoreOp( int64_t group_size) : Expr(passkey) { NVF_ERROR( - passkey.ir_container_->isA(), + passkey.ir_container_->fusion()->isA(), "IR type only valid for Kernel container."); NVF_ERROR( in->isScalar(), "Expected to have a scalar input: ", in->toString()); diff --git a/csrc/multidevice/resharding.cpp b/csrc/multidevice/resharding.cpp index 44d290ffbbe..e4359743e0f 100644 --- a/csrc/multidevice/resharding.cpp +++ b/csrc/multidevice/resharding.cpp @@ -354,7 +354,7 @@ bool haveDifferentShardings( bool isResharding(const Expr* expr) { FUSER_PERF_SCOPE("isResharding"); - if (!ir_utils::isTvOp(expr)) { + if (expr == nullptr || !ir_utils::isTvOp(expr)) { return false; } diff --git a/csrc/preseg_passes/move_gather.cpp b/csrc/preseg_passes/move_gather.cpp index 4bf10b6ede1..a2b41ea74ab 100644 --- a/csrc/preseg_passes/move_gather.cpp +++ b/csrc/preseg_passes/move_gather.cpp @@ -101,7 +101,7 @@ TensorView* addPostGatherUnary( GatherOp* old_gather, GatherOp* new_gather, Expr* def) { - IrCloner ir_cloner(fusion); + IrCloner ir_cloner(fusion->container()); auto cloned_def = static_cast(def->clone(&ir_cloner)); diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 45e66fbffc9..57aed2c0fb4 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -60,7 +60,7 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { NVF_ERROR(input_ids_.size() == 1); NVF_ERROR(output_ids_.size() == 2); replayed_expr_ = IrBuilder::createInContainer( - split->fusion(), + split->fusion()->container(), output_ids_[0], output_ids_[1], input_ids_[0], @@ -72,14 +72,14 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { NVF_ERROR(input_ids_.size() == 2); NVF_ERROR(output_ids_.size() == 1); replayed_expr_ = IrBuilder::createInContainer( - merge->fusion(), output_ids_[0], input_ids_[0], input_ids_[1]); + merge->fusion()->container(), output_ids_[0], input_ids_[0], input_ids_[1]); } void handle(const Resize* resize) final { NVF_ERROR(input_ids_.size() == 1); NVF_ERROR(output_ids_.size() == 1); replayed_expr_ = IrBuilder::createInContainer( - resize->fusion(), + resize->fusion()->container(), output_ids_[0], input_ids_[0], resize->leftExpand(), diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 134d514b9c1..6d3e44754e5 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1235,6 +1235,7 @@ std::vector getReductionTvs(Fusion* fusion) { std::vector reduction_tvs; for (auto tv : all_tvs) { if (!tv->isFusionInput() && + tv->definition() != nullptr && std::any_of( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 717ed11d8e2..2a3b49bf0f0 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -23,7 +23,7 @@ StatementGuard::StatementGuard(Fusion* fusion) prev_num_vals_(fusion_->numVals(/*include_shortcuts=*/false)) {} StatementGuard::~StatementGuard() { - fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_); + fusion_->container()->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_); } } // namespace nvfuser diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 7e0d39c02bd..d60071fb8d6 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -165,7 +165,7 @@ void TensorView::inlineAt( bool best_effort, MaxPosCalculator* calc) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); std::unique_ptr calc_owner; @@ -272,7 +272,7 @@ TensorView* TensorView::computeAt( int64_t position, ComputeAtMode mode) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); // Make sure this and consumer are not the same tensor, that's illegal NVF_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)"); @@ -301,7 +301,7 @@ TensorView* TensorView::computeAt( void TensorView::computeWith(int64_t pos, bool best_effort) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); if (isFusionInput()) { @@ -403,7 +403,7 @@ int64_t TensorView::getComputePosition(const TensorView* consumer) const { } bool TensorView::resolveComputeWith(const std::vector& sorted_exprs) { - NVF_ERROR(container()->isA(), "Function invalid for fusion."); + NVF_ERROR(container()->fusion()->isA(), "Function invalid for fusion."); auto siblings = ir_utils::filterByType(definition()->outputs()); @@ -452,7 +452,7 @@ bool TensorView::resolveComputeWith(const std::vector& sorted_exprs) { void TensorView::clearComputeWith() { //! This should be only used while in a Fusion container. NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); compute_with_pos_ = getComputeAtPosition(); @@ -619,7 +619,7 @@ TensorView* TensorView::flatten(int64_t from, int64_t to) { TensorView* TensorView::reorder( const std::unordered_map& old2new_) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); NVF_ERROR( !(nDims() == 0 && !old2new_.empty()), @@ -803,7 +803,7 @@ TensorView* TensorView::swizzle( TensorView* TensorView::rFactor(const std::vector& axes) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); // TODO: I think we should do this but // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the @@ -954,7 +954,7 @@ TensorView* TensorView::multiOutputRFactorHelper( TensorView* tv, const std::vector& axes) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); // Hack: // Semantically we should always keep the outputs of multi reduction ops @@ -986,7 +986,7 @@ std::vector TensorView::rFactor( const std::vector& axes, const std::vector& tvs) { NVF_CHECK( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); NVF_CHECK(nDims() > 0, "Tried to rFactor a 0-dim TensorView"); FusionGuard fg(fusion()); @@ -1090,7 +1090,7 @@ std::vector TensorView::rFactor( TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); FusionGuard fg(fusion()); @@ -1186,7 +1186,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { TensorView* TensorView::cacheFork() { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); FusionGuard fg(fusion()); @@ -1246,7 +1246,7 @@ TensorView* TensorView::cacheAfter( bool propagate_allocation_domain, std::vector cached_uses) { NVF_ERROR( - !container()->isA(), + !container()->fusion()->isA(), "Function invalid for kernel container."); FusionGuard fg(fusion()); diff --git a/tests/cpp/test_host_ir_evaluator.cpp b/tests/cpp/test_host_ir_evaluator.cpp index f7441b6cf70..73b6ed77e0f 100644 --- a/tests/cpp/test_host_ir_evaluator.cpp +++ b/tests/cpp/test_host_ir_evaluator.cpp @@ -51,7 +51,7 @@ TEST_F(HostIrEvaluatorTest, LaunchKernel) { ke->compile(&fusion, {in_tensor}); hic->addKernelExecutor(std::move(ke)); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); Val* in = ir_cloner.clone(fusion.inputs().at(0)); Val* out = ir_cloner.clone(fusion.outputs().at(0)); @@ -176,7 +176,7 @@ TEST_F(HostIrEvaluatorTest, AddInLoop) { auto hic = std::make_unique(); { FusionGuard fg(hic.get()); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); auto* in = ir_cloner.clone(fusion.inputs().at(0))->as(); auto* out = ir_cloner.clone(fusion.outputs().at(0))->as(); hic->addInput(in); diff --git a/tests/cpp/test_host_ir_jit.cpp b/tests/cpp/test_host_ir_jit.cpp index c766394b888..98241b17db3 100644 --- a/tests/cpp/test_host_ir_jit.cpp +++ b/tests/cpp/test_host_ir_jit.cpp @@ -285,7 +285,7 @@ TEST_F(HostIrJitTest, LaunchKernel) { hic->addKernelExecutor(std::move(ke)); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); auto hic_in = ir_cloner.clone(in); auto hic_out = ir_cloner.clone(out); diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index 7811352d01b..47fbb74cc34 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -479,7 +479,7 @@ TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { auto host_unit = IrBuilder::create(get_fusion()); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); TensorView* input = ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0)) ->as(); diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index f1d3dc21cc3..9711d98d5e0 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -102,7 +102,7 @@ TEST_P(HostIrTest, SingleFusion) { // [Step 4)] Create TensorViews representing the Fusion's I/O at the Host // level - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); std::vector post_on_stream_inputs = { ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0))}; std::vector post_on_stream_outputs = { @@ -185,7 +185,7 @@ TEST_P(HostIrTest, TwoFusions) { // [Step 4)a.] Create TensorViews representing the first Fusions I/O at the // Host level - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); std::vector post_on_stream_inputs_0 = { ir_cloner.clone(host_unit_0->fusion_to_execute()->inputs().at(0))}; std::vector post_on_stream_outputs_0 = { @@ -290,7 +290,7 @@ TEST_P(HostIrTest, ThreeFusions) { // [Step 4)a.] Create TensorViews representing the first Fusions I/O at the // Host level - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); auto clone = [&](std::vector vals) { std::vector ret; for (auto val : vals) { @@ -422,7 +422,7 @@ TEST_P(HostIrTest, ForLoops) { auto buffer_input = makeContigConcreteTensor({1}, DataType::Int); auto buffer_ouput = makeContigConcreteTensor({1}, DataType::Int); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); std::vector post_on_stream_inputs = {index, buffer_input}; std::vector post_on_stream_outputs = {buffer_ouput}; auto* host_unit = IrBuilder::create(std::move(fusion)); @@ -476,7 +476,7 @@ TEST_P(HostIrTest, PreAllocatedOutputs) { auto host_unit = IrBuilder::create(get_fusion()); - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); std::vector post_on_stream_inputs = { ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0))}; std::vector post_on_stream_outputs = { @@ -673,7 +673,7 @@ TEST_P(StreamHostIrTest, SingleFusionMultipleStreams) { // [Step 4)] Create TensorViews representing the Fusion's inputs at the Host // level - IrCloner ir_cloner_input(hic.get()); + IrCloner ir_cloner_input(hic->container()); std::vector post_on_stream_inputs = { ir_cloner_input.clone(host_unit->fusion_to_execute()->inputs().at(0))}; hic->addInput(post_on_stream_inputs.at(0)); @@ -681,7 +681,7 @@ TEST_P(StreamHostIrTest, SingleFusionMultipleStreams) { for (int i = 0; i < n_iterations; i++) { // [Step 4)] Create TensorViews representing the Fusion's ouputs at the Host // level - IrCloner ir_cloner_output(hic.get()); + IrCloner ir_cloner_output(hic->container()); std::vector post_on_stream_outputs = {ir_cloner_output.clone( host_unit->fusion_to_execute()->outputs().at(0))}; diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 5898a0365b1..e91f93cd817 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -77,7 +77,7 @@ TEST_P(MultiDeviceHostIrTest, SingleFusionSingleComm) { auto hu = IrBuilder::create(std::move(fusion)); // [Step 4)] Create TensorViews at the Host level - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); auto tv0 = ir_cloner.clone( hu->fusion_to_execute()->inputs().at(0)->as()); auto tv1 = ir_cloner.clone( @@ -173,7 +173,7 @@ TEST_P(MultiDeviceHostIrTest, SingleCommTwoFusionAndWait) { auto hu = IrBuilder::create(std::move(fusion)); // [Step 4)] Create TensorViews at the Host level - IrCloner ir_cloner(hic.get()); + IrCloner ir_cloner(hic->container()); auto tv0 = ir_cloner.clone( hu->fusion_to_execute()->inputs().at(0)->as()); auto tv1 = ir_cloner.clone(