diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 5b46b7f0abb..2f943d4c9bc 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -158,10 +158,6 @@ std::string LaunchKernel::toString(int indent_size) const { return ss.str(); } -std::string LaunchKernel::toInlineString(int indent_size) const { - NVF_CHECK(false, "Can not be printed inline"); -} - Deallocate::Deallocate(IrBuilderPasskey passkey, TensorView* tv) : Expr(passkey) { addAttribute(tv); @@ -197,7 +193,10 @@ std::string Stream::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "Stream "; if (index() == nullptr) { - ss << name(); + // HostIrEvaluator looks up streams by address when index is null. + // Print address as identifier. We used to print `name()` but that's often + // an integer which would be confusing/ambiguous with the index. + ss << static_cast(this); } else { ss << index()->toInlineString(); } @@ -221,7 +220,7 @@ bool Stream::sameAs(const Statement* other) const { } SetCurrentStream::SetCurrentStream(IrBuilderPasskey passkey, Stream* stream) - : Expr(passkey, {stream}, {}, {stream}) { + : Expr(passkey, {stream}, {}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR(passkey.ir_container_->isA()); } @@ -230,34 +229,23 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(SetCurrentStream) std::string SetCurrentStream::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << "SetCurrentStream to " << stream()->toString() + indent(ss, indent_size) << "SetCurrentStream(" << stream()->toString() << ")" << std::endl; return ss.str(); } -// TODO: implement better ? -std::string SetCurrentStream::toInlineString(int indent_size) const { - NVF_CHECK(false, "Cannot be printed inline"); -} - -// TODO: implement -bool SetCurrentStream::sameAs(const Statement* other) const { - return false; -} - -GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { +GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey, Stream* stream) + : Expr(passkey, {}, {stream}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR(passkey.ir_container_->isA()); - auto stream = IrBuilder::createInContainer(passkey.ir_container_); - addAttribute(stream); } NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream) std::string GetCurrentStream::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString() - << std::endl; + indent(ss, indent_size) << stream()->toInlineString() + << " = GetCurrentStream()" << std::endl; return ss.str(); } @@ -277,15 +265,14 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr) NVFUSER_DEFINE_CLONE_AND_CREATE(Wait) std::string Wait::toString(int indent_size) const { - std::stringstream ss; - indent(ss, indent_size) << "Wait Communication " << communication()->name() - << std::endl; - return ss.str(); + return toInlineString(indent_size) + "\n"; } -// TODO: implement better ? std::string Wait::toInlineString(int indent_size) const { - NVF_CHECK(false, "Cannot be printed inline"); + std::stringstream ss; + indent(ss, indent_size) << "Wait(Communication " << communication()->name() + << ")"; + return ss.str(); } // TODO: implement @@ -294,7 +281,7 @@ bool Wait::sameAs(const Statement* other) const { } Synchronize::Synchronize(IrBuilderPasskey passkey, Stream* stream) - : Expr(passkey, {}, {}, {stream}) { + : Expr(passkey, {stream}, {}, {}) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( passkey.ir_container_->isA(), @@ -305,13 +292,13 @@ Synchronize::Synchronize(IrBuilderPasskey passkey, Stream* stream) NVFUSER_DEFINE_CLONE_AND_CREATE(Synchronize) std::string Synchronize::toString(int indent_size) const { - std::stringstream ss; - indent(ss, indent_size) << "Synchronize " << stream() << std::endl; - return ss.str(); + return toInlineString(indent_size) + "\n"; } std::string Synchronize::toInlineString(int indent_size) const { - NVF_CHECK(false, "Cannot be printed inline"); + std::stringstream ss; + indent(ss, indent_size) << "Synchronize(" << stream() << ")"; + return ss.str(); } // TODO: implement @@ -451,15 +438,11 @@ std::string ShardByStream::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << out()->toString() << " = ShardByStream(" << in()->toString() - << ", stream_index = " << stream_index()->toString() + << ", stream_index=" << stream_index()->toString() << ")" << std::endl; return ss.str(); } -std::string ShardByStream::toInlineString(int indent_size) const { - NVF_CHECK(false, "Cannot be printed inline"); -} - SymmetricContiguousView::SymmetricContiguousView( IrBuilderPasskey passkey, TensorView* out, diff --git a/csrc/host_ir/ir.h b/csrc/host_ir/ir.h index a3b36284e36..22681a75324 100644 --- a/csrc/host_ir/ir.h +++ b/csrc/host_ir/ir.h @@ -138,7 +138,6 @@ class LaunchKernel : public Expr { NVFUSER_DECLARE_CLONE_AND_CREATE std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; const char* getOpString() const override { return "hir::LaunchKernel"; } @@ -226,22 +225,19 @@ class SetCurrentStream : public Expr { NVFUSER_DECLARE_CLONE_AND_CREATE std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; const char* getOpString() const override { return "hir::SetCurrentStream"; } - bool sameAs(const Statement* other) const override; - Stream* stream() const { - return attributes_.at(0)->as(); + return inputs().at(0)->as(); } }; class GetCurrentStream : public Expr { public: using Expr::Expr; - GetCurrentStream(IrBuilderPasskey passkey); + GetCurrentStream(IrBuilderPasskey passkey, Stream* stream); GetCurrentStream(const GetCurrentStream& other) = delete; GetCurrentStream& operator=(const GetCurrentStream& other) = delete; @@ -256,7 +252,7 @@ class GetCurrentStream : public Expr { } Stream* stream() const { - return attributes_.at(0)->as(); + return outputs().at(0)->as(); } }; @@ -308,7 +304,7 @@ class Synchronize : public Expr { bool sameAs(const Statement* other) const override; Stream* stream() const { - return attributes_.at(0)->as(); + return inputs().at(0)->as(); } }; @@ -478,7 +474,6 @@ class ShardByStream : public Expr { NVFUSER_DECLARE_CLONE_AND_CREATE std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; const char* getOpString() const override { return "hir::ShardByStream"; } diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 8e82975bfc2..44d3712d08c 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -600,8 +600,9 @@ std::list addStreamManagement(std::list top_level_exprs) { auto* for_loop = top_level_expr->as(); // Get the current stream before entering the loop - auto* get_current_stream = IrBuilder::create(); - hir::Stream* original_stream = get_current_stream->stream(); + hir::Stream* original_stream = IrBuilder::create(); + auto* get_current_stream = + IrBuilder::create(original_stream); new_top_level_exprs.push_back(get_current_stream); // Create a new for-loop for getting the current stream diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index f1d3dc21cc3..4793e21c189 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -567,8 +567,8 @@ TEST_F(StreamTest, HostIrDefaultStream) { TEST_F(StreamTest, HostIrGetCurrentStream) { auto hic = std::make_unique(); FusionGuard fg(hic.get()); - auto get_stream = IrBuilder::create(); - auto current_stream = get_stream->stream(); + hir::Stream* current_stream = IrBuilder::create(); + auto* get_stream = IrBuilder::create(current_stream); auto other_stream = IrBuilder::create(); hic->pushBackTopLevelExprs(get_stream); hic->pushBackTopLevelExprs(IrBuilder::create(other_stream));