Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 61 additions & 26 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

// TODO: put this into ir_cloner instead
for (const auto& [output, alias_info] : from->io_alias_) {
Val* copied_output = ir_cloner.clone(output);
Val* copied_input = ir_cloner.clone(alias_info.aliased_io);
to->io_alias_[copied_output] = {
.type = alias_info.type,
.aliased_io = copied_input,
.hide_output = alias_info.hide_output};
for (Val* out : from->outputs_) {
const AliasInfo& alias = from->io_alias_.get(out);
if (alias.type == AllocationType::New) {
continue;
}

Val* copied_out = ir_cloner.clone(out);
Val* copied_in = ir_cloner.clone(alias.aliased_io);
to->io_alias_.add(copied_out, copied_in, alias.type, alias.visibility);
}

to->all_tv_uses_valid_ = from->all_tv_uses_valid_;
Expand Down Expand Up @@ -270,17 +272,18 @@ void Fusion::addOutputInternal(Val* output) {
}

void Fusion::addOutput(Val* output) {
// special handling for returning aliased output. We just need to remove its
// Special handling for returning aliased output. We just need to remove its
// existing entry in the outputs_ used for inplace update
if (io_alias_.count(output) != 0) {
if (io_alias_.get(output).type != AllocationType::New) {
AliasInfo& alias = io_alias_.mutable_at(output);
// if previous output is only added for aliasing purpose, we should remove
// the previous entry and add a new one. Otherwise, it may be positioned
// wrong in the output list.
if (io_alias_[output].hide_output) {
if (alias.visibility == OutputVisibility::kHidden) {
removeOutput(output);
}
// output shouldn't be hidden any more
io_alias_[output].hide_output = false;
alias.visibility = OutputVisibility::kVisible;
}

addOutputInternal(output);
Expand Down Expand Up @@ -333,10 +336,10 @@ void Fusion::replaceOutput(Val* output, Val* replacement) {

// Temporary WAR for issue #1112
// (https://github.com/csarofeen/pytorch/issues/1112)
if (io_alias_.count(output) != 0) {
auto input = io_alias_[output];
AliasInfo alias = io_alias_.get(output);
if (alias.type != AllocationType::New) {
io_alias_.erase(output);
io_alias_[replacement] = input;
io_alias_.add(replacement, alias.aliased_io, alias.type, alias.visibility);
}
}

Expand Down Expand Up @@ -759,6 +762,42 @@ std::vector<Val*> Fusion::getTerminatingOutputs() const {
return terminating_outputs;
}

std::ostream& operator<<(std::ostream& os, OutputVisibility visibility) {
switch (visibility) {
case OutputVisibility::kVisible:
return os << "Visible";
case OutputVisibility::kHidden:
return os << "Hidden";
}
std::unreachable();
}

void AliasInfoMap::add(
Val* out,
Val* in,
AllocationType type,
OutputVisibility visibility) {
auto [_, inserted] = aliases_.try_emplace(
out, AliasInfo{.type = type, .aliased_io = in, .visibility = visibility});
NVF_ERROR(
inserted,
"The map already has an AliasInfo for ",
out,
": ",
aliases_.at(out).toString());
}

const AliasInfo& AliasInfoMap::get(const Val* v) const {
static AliasInfo no_alias_info{
.type = AllocationType::New,
.aliased_io = nullptr,
.visibility = OutputVisibility::kVisible};
if (auto search = aliases_.find(v); search != aliases_.end()) {
return search->second;
}
return no_alias_info;
}

void Fusion::aliasOutputToInput(
Val* output,
Val* input,
Expand All @@ -772,8 +811,7 @@ void Fusion::aliasOutputToInput(
NVF_CHECK(
output->isFusionOutput(),
"Only fusion outputs can be expression evaluated.");
io_alias_[output] =
AliasInfo{.type = type, .aliased_io = input, .hide_output = false};
io_alias_.add(output, input, type, OutputVisibility::kVisible);
return;
}

Expand Down Expand Up @@ -803,10 +841,12 @@ void Fusion::aliasOutputToInput(
// Let integration hide any output that wasn't a fusion output when
// `aliasOutputToInput` was called. For example, running mean and var for
// batch norm.
io_alias_[output] = AliasInfo{
.type = type,
.aliased_io = input,
.hide_output = !output->isFusionOutput()};
io_alias_.add(
output,
input,
type,
!output->isFusionOutput() ? OutputVisibility::kHidden
: OutputVisibility::kVisible);

// only add output when it's not in outputs_
if (!output->isFusionOutput()) {
Expand All @@ -815,12 +855,7 @@ void Fusion::aliasOutputToInput(
}

const AliasInfo& Fusion::getOutputAlias(const Val* output) const {
static AliasInfo no_alias_info{
.type = AllocationType::New, .aliased_io = nullptr, .hide_output = false};
if (auto search = io_alias_.find(output); search != io_alias_.end()) {
return search->second;
}
return no_alias_info;
return io_alias_.get(output);
}

bool Fusion::hasDynamicTransform() {
Expand Down
39 changes: 34 additions & 5 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,23 @@ enum class AllocationType : int {
Evaluate,
};

enum class OutputVisibility : int {
kHidden,
kVisible,
};

std::ostream& operator<<(std::ostream& os, OutputVisibility visibility);

struct AliasInfo {
AllocationType type;
Val* aliased_io;
// Whether integration should hide the output from users. This is currently
// only used for ReuseBuffer.
bool hide_output;
OutputVisibility visibility;

bool operator==(const AliasInfo& other) const {
return type == other.type && aliased_io == other.aliased_io &&
hide_output == other.hide_output;
visibility == other.visibility;
}

bool operator!=(const AliasInfo& other) const {
Expand All @@ -116,12 +123,34 @@ struct AliasInfo {
}
ss << ",\n aliased_io = "
<< (aliased_io == nullptr ? "nullptr" : aliased_io->toString()) << ",\n";
ss << " hide_output = " << (hide_output ? "true" : "false") << "\n";
ss << " visibility = " << visibility << "\n";
ss << "}\n";
return ss.str();
}
};

class AliasInfoMap {
public:
void add(Val* out, Val* in, AllocationType type, OutputVisibility visibility);

const AliasInfo& get(const Val* v) const;

AliasInfo& mutable_at(const Val* v) {
return aliases_.at(v);
}

void erase(const Val* v) {
aliases_.erase(v);
}

void clear() {
aliases_.clear();
}

private:
std::unordered_map<const Val*, AliasInfo> aliases_;
};

//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
//! Fusion to another. If anything like that is desired, it would require
//! duplicating all associated values and exprs. Fusion is considered to be SSA,
Expand Down Expand Up @@ -498,8 +527,8 @@ class NVF_API Fusion : public IrContainer {
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;

// io alias pointing from output to input
std::unordered_map<const Val*, AliasInfo> io_alias_;
// Aliases between fusion inputs and outputs.
AliasInfoMap io_alias_;

// Records if the current use data in the IR nodes are valid
// the states are either all valid or all invalid
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/fusion_executor_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ KernelArgumentHolder FusionExecutorCache::runFusionWithInputs(
KernelArgumentHolder unaliased_outputs;
for (auto out_index : arange(outputs.size())) {
Val* out = fusion->outputs()[out_index];
if (!fusion->getOutputAlias(out).hide_output) {
if (fusion->getOutputAlias(out).visibility == OutputVisibility::kVisible) {
unaliased_outputs.push(outputs[out_index]);
}
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/validator_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ void testValidate(
// Returns true when `out` is **not** an aliased output that's hidden
// from integration. Hidden outputs won't show up in `fusion_outputs`
// for us to compare, so we skip them.
return !fusion->getOutputAlias(out).hide_output;
return fusion->getOutputAlias(out).visibility ==
OutputVisibility::kVisible;
});

auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams);
Expand Down
3 changes: 2 additions & 1 deletion python/python_common/distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ std::vector<Sharding> getOutputShardings(Fusion* fusion) {
output_shardings.reserve(fusion->outputs().size());
for (Val* out_val : fusion->outputs()) {
if (auto* out_tv = dynamic_cast<TensorView*>(out_val)) {
if (fusion->getOutputAlias(out_tv).hide_output) {
if (fusion->getOutputAlias(out_tv).visibility ==
OutputVisibility::kHidden) {
continue;
}
const DeviceMesh& mesh = out_tv->getDeviceMesh();
Expand Down
2 changes: 1 addition & 1 deletion python/python_direct/python_translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ class PythonTranslator : public OptInConstDispatch {
}
// If not hide_output, then the aliased output is returned as a
// fusion output.
if (!alias_info.hide_output) {
if (alias_info.visibility == OutputVisibility::kVisible) {
handleOutput(v->as<TensorView>());
}
break;
Expand Down
5 changes: 3 additions & 2 deletions python/python_frontend/translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ class FusionTranslator : public OptInConstDispatch {
handleOutput(v->as<TensorView>(), alias_info);
}
// An alias output can also be returned as a fusion output
// if it is already aliased or if not hide_output
if (num_visited > 0 || !alias_info.hide_output) {
// if it is already aliased or if the output is visible.
if (num_visited > 0 ||
alias_info.visibility == OutputVisibility::kVisible) {
handleOutput(v->as<TensorView>());
}
break;
Expand Down