diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index b3ec5e39c3a..282c683f8a2 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -488,67 +488,6 @@ std::vector SdpaFwdOp::evaluate( return {output, log_sumexp, philox_seed, philox_offset}; } -std::string Scope::toString(int indent_size) const { - std::stringstream ss; - for (auto expr : exprs()) { - ss << expr->toString(indent_size); - } - return ss.str(); -} - -Scope::Iterator Scope::insert(Iterator pos, Expr* expr) { - return exprs_.insert(pos, expr); -} - -Scope::Iterator Scope::insert_before(Expr* ref, Expr* expr) { - const auto it = std::find(exprs_.begin(), exprs_.end(), ref); - NVF_ERROR( - it != exprs_.end(), - "Tried to insert ", - expr, - " before the reference: ", - ref, - " @ ", - (size_t)ref, - " however the reference was not found in this scope."); - return insert(it, expr); -} - -Scope::Iterator Scope::insert_after(Expr* ref, Expr* expr) { - const auto it = std::find(exprs_.begin(), exprs_.end(), ref); - NVF_ERROR( - it != exprs_.end(), - "Tried to insert ", - expr, - " after the reference: ", - ref, - " however the reference was not found in this scope."); - auto insert_pos = std::next(it); - return insert(insert_pos, expr); -} - -void Scope::erase(Iterator pos) { - // Remove the scope of the expr if this is the scope - [[maybe_unused]] auto expr = *pos; - exprs_.erase(pos); -} - -void Scope::erase(Expr* ref) { - const auto it = std::find(exprs_.begin(), exprs_.end(), ref); - if (it != exprs_.end()) { - erase(it); - } -} - -bool Scope::contains(Expr* expr) const { - const auto it = std::find(exprs_.begin(), exprs_.end(), expr); - return it != exprs_.end(); -} - -void Scope::clear() { - exprs_.clear(); -} - SdpaBwdOp::SdpaBwdOp( IrBuilderPasskey passkey, TensorView* grad_query, diff --git a/csrc/ir/composite_nodes.h b/csrc/ir/composite_nodes.h index 13bf54de515..5a40ac4b9b3 100644 --- a/csrc/ir/composite_nodes.h +++ b/csrc/ir/composite_nodes.h @@ -7,8 +7,6 @@ // clang-format on #pragma once -#include - #include #include #include @@ -223,73 +221,6 @@ class SdpaFwdOp : public Expr { const std::vector& inputs) const override; }; -class Scope { - public: - using ExprList = std::list; - using Iterator = ExprList::const_iterator; - - explicit Scope(Expr* owner) : owner_(owner) {} - - std::string toString(int indent_size = 0) const; - - const ExprList& exprs() const { - return exprs_; - } - - // Used only by MultiDeviceExecutor. Should generally be avoided in favor of - // other modifying methods. - ExprList& mutableExprs() { - return exprs_; - } - - Expr* front() const { - NVF_ERROR( - !exprs_.empty(), "Attempting to access the front of an empty Scope"); - return exprs_.front(); - } - - Expr* back() const { - NVF_ERROR( - !exprs_.empty(), "Attempting to access the back of an empty Scope"); - return exprs_.back(); - } - - bool empty() const { - return exprs_.empty(); - } - - int64_t size() const { - return std::ssize(exprs_); - } - - Iterator insert(Iterator pos, Expr* expr); - - Iterator pushBack(Expr* e) { - return insert(exprs_.end(), e); - } - - void clear(); - - Expr* owner() const { - return owner_; - } - - // The following methods perform linear searches over exprs_. Use them only - // when necessary, as they do not scale well with large scopes. - Iterator insert_before(Expr* ref, Expr* expr); - Iterator insert_after(Expr* ref, Expr* expr); - void erase(Expr* ref); - bool contains(Expr* expr) const; - - private: - void erase(Iterator pos); - - ExprList exprs_; - - //! Owner exprssion of this scope, e.g., IfThenElse - Expr* owner_ = nullptr; -}; - // SDPA bwd node with same functionality // at::_scaled_dot_product_flash_attention_backward // grad_query = [N, H, L, E] diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 35208c26f06..9eb51804116 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -38,6 +38,69 @@ namespace nvfuser { +std::string Scope::toString(int indent_size) const { + std::stringstream ss; + for (auto expr : exprs()) { + ss << expr->toString(indent_size); + } + return ss.str(); +} + +Scope::Iterator Scope::insert(Iterator pos, Expr* expr) { + return exprs_.insert(pos, expr); +} + +Scope::Iterator Scope::insert_before(Expr* ref, Expr* expr) { + const auto it = std::find(exprs_.begin(), exprs_.end(), ref); + NVF_ERROR( + it != exprs_.end(), + "Tried to insert ", + expr, + " before the reference: ", + ref, + " @ ", + (size_t)ref, + " however the reference was not found in this scope."); + return insert(it, expr); +} + +Scope::Iterator Scope::insert_after(Expr* ref, Expr* expr) { + const auto it = std::find(exprs_.begin(), exprs_.end(), ref); + NVF_ERROR( + it != exprs_.end(), + "Tried to insert ", + expr, + " after the reference: ", + ref, + " @ ", + (size_t)ref, + " however the reference was not found in this scope."); + auto insert_pos = std::next(it); + return insert(insert_pos, expr); +} + +void Scope::erase(Iterator pos) { + // Remove the scope of the expr if this is the scope + [[maybe_unused]] auto expr = *pos; + exprs_.erase(pos); +} + +void Scope::erase(Expr* ref) { + const auto it = std::find(exprs_.begin(), exprs_.end(), ref); + if (it != exprs_.end()) { + erase(it); + } +} + +bool Scope::contains(Expr* expr) const { + const auto it = std::find(exprs_.begin(), exprs_.end(), expr); + return it != exprs_.end(); +} + +void Scope::clear() { + exprs_.clear(); +} + FullOp::FullOp(IrBuilderPasskey passkey, Val* out, Val* fill_value) : Expr(passkey) { if (out->isA()) { diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 9e11e78fcc0..0856638e3a8 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -35,10 +35,76 @@ namespace nvfuser { class ViewTransform; -class Scope; class IrCloner; struct AnalyzeViewResult; +class Scope { + public: + using ExprList = std::list; + using Iterator = ExprList::const_iterator; + + explicit Scope(Expr* owner) : owner_(owner) {} + + std::string toString(int indent_size = 0) const; + + const ExprList& exprs() const { + return exprs_; + } + + // Used only by MultiDeviceExecutor. Should generally be avoided in favor of + // other modifying methods. + ExprList& mutableExprs() { + return exprs_; + } + + Expr* front() const { + NVF_ERROR( + !exprs_.empty(), "Attempting to access the front of an empty Scope"); + return exprs_.front(); + } + + Expr* back() const { + NVF_ERROR( + !exprs_.empty(), "Attempting to access the back of an empty Scope"); + return exprs_.back(); + } + + bool empty() const { + return exprs_.empty(); + } + + int64_t size() const { + return std::ssize(exprs_); + } + + Iterator insert(Iterator pos, Expr* expr); + + Iterator pushBack(Expr* e) { + return insert(exprs_.end(), e); + } + + void clear(); + + Expr* owner() const { + return owner_; + } + + // The following methods perform linear searches over exprs_. Use them only + // when necessary, as they do not scale well with large scopes. + Iterator insert_before(Expr* ref, Expr* expr); + Iterator insert_after(Expr* ref, Expr* expr); + void erase(Expr* ref); + bool contains(Expr* expr) const; + + private: + void erase(Iterator pos); + + ExprList exprs_; + + //! Owner exprssion of this scope, e.g., IfThenElse + Expr* owner_ = nullptr; +}; + class NVF_API FullOp : public Expr { public: using Expr::Expr;