Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 21 additions & 38 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ std::string LaunchKernel::toString(int indent_size) const {
return ss.str();
}

std::string LaunchKernel::toInlineString(int indent_size) const {
NVF_CHECK(false, "Can not be printed inline");
}

Deallocate::Deallocate(IrBuilderPasskey passkey, TensorView* tv)
: Expr(passkey) {
addAttribute(tv);
Expand Down Expand Up @@ -197,7 +193,10 @@ std::string Stream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Stream ";
if (index() == nullptr) {
ss << name();
// HostIrEvaluator looks up streams by address when index is null.
// Print address as identifier. We used to print `name()` but that's often
// an integer which would be confusing/ambiguous with the index.
ss << static_cast<const void*>(this);
} else {
ss << index()->toInlineString();
}
Expand All @@ -221,7 +220,7 @@ bool Stream::sameAs(const Statement* other) const {
}

SetCurrentStream::SetCurrentStream(IrBuilderPasskey passkey, Stream* stream)
: Expr(passkey, {stream}, {}, {stream}) {
: Expr(passkey, {stream}, {}, {}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
}
Expand All @@ -230,34 +229,23 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(SetCurrentStream)

std::string SetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "SetCurrentStream to " << stream()->toString()
indent(ss, indent_size) << "SetCurrentStream(" << stream()->toString() << ")"
<< std::endl;
return ss.str();
}

// TODO: implement better ?
std::string SetCurrentStream::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

// TODO: implement
bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey, Stream* stream)
: Expr(passkey, {}, {stream}, {}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
indent(ss, indent_size) << stream()->toInlineString()
<< " = GetCurrentStream()" << std::endl;
return ss.str();
}

Expand All @@ -277,15 +265,14 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)

std::string Wait::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Wait Communication " << communication()->name()
<< std::endl;
return ss.str();
return toInlineString(indent_size) + "\n";
}

// TODO: implement better ?
std::string Wait::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
std::stringstream ss;
indent(ss, indent_size) << "Wait(Communication " << communication()->name()
<< ")";
return ss.str();
}

// TODO: implement
Expand All @@ -294,7 +281,7 @@ bool Wait::sameAs(const Statement* other) const {
}

Synchronize::Synchronize(IrBuilderPasskey passkey, Stream* stream)
: Expr(passkey, {}, {}, {stream}) {
: Expr(passkey, {stream}, {}, {}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<HostIrContainer>(),
Expand All @@ -305,13 +292,13 @@ Synchronize::Synchronize(IrBuilderPasskey passkey, Stream* stream)
NVFUSER_DEFINE_CLONE_AND_CREATE(Synchronize)

std::string Synchronize::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Synchronize " << stream() << std::endl;
return ss.str();
return toInlineString(indent_size) + "\n";
}

std::string Synchronize::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
std::stringstream ss;
indent(ss, indent_size) << "Synchronize(" << stream() << ")";
return ss.str();
}

// TODO: implement
Expand Down Expand Up @@ -451,15 +438,11 @@ std::string ShardByStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << " = ShardByStream("
<< in()->toString()
<< ", stream_index = " << stream_index()->toString()
<< ", stream_index=" << stream_index()->toString()
<< ")" << std::endl;
return ss.str();
}

std::string ShardByStream::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

SymmetricContiguousView::SymmetricContiguousView(
IrBuilderPasskey passkey,
TensorView* out,
Expand Down
13 changes: 4 additions & 9 deletions csrc/host_ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class LaunchKernel : public Expr {
NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::LaunchKernel";
}
Expand Down Expand Up @@ -226,22 +225,19 @@ class SetCurrentStream : public Expr {
NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::SetCurrentStream";
}

bool sameAs(const Statement* other) const override;

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
return inputs().at(0)->as<Stream>();
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);
GetCurrentStream(IrBuilderPasskey passkey, Stream* stream);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
Expand All @@ -256,7 +252,7 @@ class GetCurrentStream : public Expr {
}

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
return outputs().at(0)->as<Stream>();
}
};

Expand Down Expand Up @@ -308,7 +304,7 @@ class Synchronize : public Expr {
bool sameAs(const Statement* other) const override;

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
return inputs().at(0)->as<Stream>();
}
};

Expand Down Expand Up @@ -478,7 +474,6 @@ class ShardByStream : public Expr {
NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::ShardByStream";
}
Expand Down
5 changes: 3 additions & 2 deletions csrc/host_ir/pass/stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,9 @@ std::list<Expr*> addStreamManagement(std::list<Expr*> top_level_exprs) {
auto* for_loop = top_level_expr->as<kir::ForLoop>();

// Get the current stream before entering the loop
auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
hir::Stream* original_stream = get_current_stream->stream();
hir::Stream* original_stream = IrBuilder::create<hir::Stream>();
auto* get_current_stream =
IrBuilder::create<hir::GetCurrentStream>(original_stream);
new_top_level_exprs.push_back(get_current_stream);

// Create a new for-loop for getting the current stream
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_host_irs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@ TEST_F(StreamTest, HostIrDefaultStream) {
TEST_F(StreamTest, HostIrGetCurrentStream) {
auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());
auto get_stream = IrBuilder::create<GetCurrentStream>();
auto current_stream = get_stream->stream();
hir::Stream* current_stream = IrBuilder::create<hir::Stream>();
auto* get_stream = IrBuilder::create<hir::GetCurrentStream>(current_stream);
auto other_stream = IrBuilder::create<Stream>();
hic->pushBackTopLevelExprs(get_stream);
hic->pushBackTopLevelExprs(IrBuilder::create<SetCurrentStream>(other_stream));
Expand Down