Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ std::string LaunchKernel::toString(int indent_size) const {
return ss.str();
}

std::string LaunchKernel::toInlineString(int indent_size) const {
NVF_CHECK(false, "Can not be printed inline");
}

Deallocate::Deallocate(IrBuilderPasskey passkey, TensorView* tv)
: Expr(passkey) {
addAttribute(tv);
Expand Down Expand Up @@ -221,7 +217,7 @@ bool Stream::sameAs(const Statement* other) const {
}

SetCurrentStream::SetCurrentStream(IrBuilderPasskey passkey, Stream* stream)
: Expr(passkey, {stream}, {}, {stream}) {
: Expr(passkey, {stream}, {}, {}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
}
Expand All @@ -230,34 +226,23 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(SetCurrentStream)

std::string SetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "SetCurrentStream to " << stream()->toString()
<< std::endl;
indent(ss, indent_size) << "SetCurrentStream(" << stream()->toInlineString()
<< ")" << std::endl;
return ss.str();
}

// TODO: implement better ?
std::string SetCurrentStream::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

// TODO: implement
bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey, Stream* stream)
: Expr(passkey, {}, {stream}, {}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
indent(ss, indent_size) << stream()->toInlineString()
<< " = GetCurrentStream()" << std::endl;
return ss.str();
}

Expand Down
10 changes: 3 additions & 7 deletions csrc/host_ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class LaunchKernel : public Expr {
NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::LaunchKernel";
}
Expand Down Expand Up @@ -226,22 +225,19 @@ class SetCurrentStream : public Expr {
NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::SetCurrentStream";
}

bool sameAs(const Statement* other) const override;

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
return inputs().at(0)->as<Stream>();
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);
GetCurrentStream(IrBuilderPasskey passkey, Stream* stream);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
Expand All @@ -256,7 +252,7 @@ class GetCurrentStream : public Expr {
}

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
return outputs().at(0)->as<Stream>();
}
};

Expand Down
5 changes: 3 additions & 2 deletions csrc/host_ir/pass/stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,9 @@ std::list<Expr*> addStreamManagement(std::list<Expr*> top_level_exprs) {
auto* for_loop = top_level_expr->as<kir::ForLoop>();

// Get the current stream before entering the loop
auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
hir::Stream* original_stream = get_current_stream->stream();
hir::Stream* original_stream = IrBuilder::create<hir::Stream>();
auto* get_current_stream =
IrBuilder::create<hir::GetCurrentStream>(original_stream);
new_top_level_exprs.push_back(get_current_stream);

// Create a new for-loop for getting the current stream
Expand Down
4 changes: 1 addition & 3 deletions tests/cpp/test_multidevice_lower_communication_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,7 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(
2 * 1024 * 1024LL, // 2 MB
8 * 1024 * 1024LL, // 8 MB
32 * 1024 * 1024LL, // 32 MB
128 * 1024 * 1024LL, // 128 MB
256 * 1024 * 1024LL // 256 MB
32 * 1024 * 1024LL // 32 MB
),
testing::Values(
CommunicationProtocol::kMemcpy,
Expand Down