diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index bab8b70c538..79bc106fe7b 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -69,13 +69,15 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // TODO: put this into ir_cloner instead - for (const auto& [output, alias_info] : from->io_alias_) { - Val* copied_output = ir_cloner.clone(output); - Val* copied_input = ir_cloner.clone(alias_info.aliased_io); - to->io_alias_[copied_output] = { - .type = alias_info.type, - .aliased_io = copied_input, - .hide_output = alias_info.hide_output}; + for (Val* out : from->outputs_) { + const AliasInfo& alias = from->io_alias_.get(out); + if (alias.type == AllocationType::New) { + continue; + } + + Val* copied_out = ir_cloner.clone(out); + Val* copied_in = ir_cloner.clone(alias.aliased_io); + to->io_alias_.add(copied_out, copied_in, alias.type, alias.visibility); } to->all_tv_uses_valid_ = from->all_tv_uses_valid_; @@ -270,17 +272,18 @@ void Fusion::addOutputInternal(Val* output) { } void Fusion::addOutput(Val* output) { - // special handling for returning aliased output. We just need to remove its + // Special handling for returning aliased output. We just need to remove its // existing entry in the outputs_ used for inplace update - if (io_alias_.count(output) != 0) { + if (io_alias_.get(output).type != AllocationType::New) { + AliasInfo& alias = io_alias_.mutable_at(output); // if previous output is only added for aliasing purpose, we should remove // the previous entry and add a new one. Otherwise, it may be positioned // wrong in the output list. - if (io_alias_[output].hide_output) { + if (alias.visibility == OutputVisibility::kHidden) { removeOutput(output); } // output shouldn't be hidden any more - io_alias_[output].hide_output = false; + alias.visibility = OutputVisibility::kVisible; } addOutputInternal(output); @@ -333,10 +336,10 @@ void Fusion::replaceOutput(Val* output, Val* replacement) { // Temporary WAR for issue #1112 // (https://github.com/csarofeen/pytorch/issues/1112) - if (io_alias_.count(output) != 0) { - auto input = io_alias_[output]; + AliasInfo alias = io_alias_.get(output); + if (alias.type != AllocationType::New) { io_alias_.erase(output); - io_alias_[replacement] = input; + io_alias_.add(replacement, alias.aliased_io, alias.type, alias.visibility); } } @@ -759,6 +762,42 @@ std::vector Fusion::getTerminatingOutputs() const { return terminating_outputs; } +std::ostream& operator<<(std::ostream& os, OutputVisibility visibility) { + switch (visibility) { + case OutputVisibility::kVisible: + return os << "Visible"; + case OutputVisibility::kHidden: + return os << "Hidden"; + } + std::unreachable(); +} + +void AliasInfoMap::add( + Val* out, + Val* in, + AllocationType type, + OutputVisibility visibility) { + auto [_, inserted] = aliases_.try_emplace( + out, AliasInfo{.type = type, .aliased_io = in, .visibility = visibility}); + NVF_ERROR( + inserted, + "The map already has an AliasInfo for ", + out, + ": ", + aliases_.at(out).toString()); +} + +const AliasInfo& AliasInfoMap::get(const Val* v) const { + static AliasInfo no_alias_info{ + .type = AllocationType::New, + .aliased_io = nullptr, + .visibility = OutputVisibility::kVisible}; + if (auto search = aliases_.find(v); search != aliases_.end()) { + return search->second; + } + return no_alias_info; +} + void Fusion::aliasOutputToInput( Val* output, Val* input, @@ -772,8 +811,7 @@ void Fusion::aliasOutputToInput( NVF_CHECK( output->isFusionOutput(), "Only fusion outputs can be expression evaluated."); - io_alias_[output] = - AliasInfo{.type = type, .aliased_io = input, .hide_output = false}; + io_alias_.add(output, input, type, OutputVisibility::kVisible); return; } @@ -803,10 +841,12 @@ void Fusion::aliasOutputToInput( // Let integration hide any output that wasn't a fusion output when // `aliasOutputToInput` was called. For example, running mean and var for // batch norm. - io_alias_[output] = AliasInfo{ - .type = type, - .aliased_io = input, - .hide_output = !output->isFusionOutput()}; + io_alias_.add( + output, + input, + type, + !output->isFusionOutput() ? OutputVisibility::kHidden + : OutputVisibility::kVisible); // only add output when it's not in outputs_ if (!output->isFusionOutput()) { @@ -815,12 +855,7 @@ void Fusion::aliasOutputToInput( } const AliasInfo& Fusion::getOutputAlias(const Val* output) const { - static AliasInfo no_alias_info{ - .type = AllocationType::New, .aliased_io = nullptr, .hide_output = false}; - if (auto search = io_alias_.find(output); search != io_alias_.end()) { - return search->second; - } - return no_alias_info; + return io_alias_.get(output); } bool Fusion::hasDynamicTransform() { diff --git a/csrc/fusion.h b/csrc/fusion.h index 234b27f433a..f560741e69f 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -83,16 +83,23 @@ enum class AllocationType : int { Evaluate, }; +enum class OutputVisibility : int { + kHidden, + kVisible, +}; + +std::ostream& operator<<(std::ostream& os, OutputVisibility visibility); + struct AliasInfo { AllocationType type; Val* aliased_io; // Whether integration should hide the output from users. This is currently // only used for ReuseBuffer. - bool hide_output; + OutputVisibility visibility; bool operator==(const AliasInfo& other) const { return type == other.type && aliased_io == other.aliased_io && - hide_output == other.hide_output; + visibility == other.visibility; } bool operator!=(const AliasInfo& other) const { @@ -116,12 +123,34 @@ struct AliasInfo { } ss << ",\n aliased_io = " << (aliased_io == nullptr ? "nullptr" : aliased_io->toString()) << ",\n"; - ss << " hide_output = " << (hide_output ? "true" : "false") << "\n"; + ss << " visibility = " << visibility << "\n"; ss << "}\n"; return ss.str(); } }; +class AliasInfoMap { + public: + void add(Val* out, Val* in, AllocationType type, OutputVisibility visibility); + + const AliasInfo& get(const Val* v) const; + + AliasInfo& mutable_at(const Val* v) { + return aliases_.at(v); + } + + void erase(const Val* v) { + aliases_.erase(v); + } + + void clear() { + aliases_.clear(); + } + + private: + std::unordered_map aliases_; +}; + //! Fusion is mutable but unique. Nodes cannot be copied in any way from one //! Fusion to another. If anything like that is desired, it would require //! duplicating all associated values and exprs. Fusion is considered to be SSA, @@ -498,8 +527,8 @@ class NVF_API Fusion : public IrContainer { std::vector inputs_; std::vector outputs_; - // io alias pointing from output to input - std::unordered_map io_alias_; + // Aliases between fusion inputs and outputs. + AliasInfoMap io_alias_; // Records if the current use data in the IR nodes are valid // the states are either all valid or all invalid diff --git a/csrc/runtime/fusion_executor_cache.cpp b/csrc/runtime/fusion_executor_cache.cpp index d748a9c0326..f4e397f8e4c 100644 --- a/csrc/runtime/fusion_executor_cache.cpp +++ b/csrc/runtime/fusion_executor_cache.cpp @@ -91,7 +91,7 @@ KernelArgumentHolder FusionExecutorCache::runFusionWithInputs( KernelArgumentHolder unaliased_outputs; for (auto out_index : arange(outputs.size())) { Val* out = fusion->outputs()[out_index]; - if (!fusion->getOutputAlias(out).hide_output) { + if (fusion->getOutputAlias(out).visibility == OutputVisibility::kVisible) { unaliased_outputs.push(outputs[out_index]); } } diff --git a/csrc/validator_utils.cpp b/csrc/validator_utils.cpp index b73233602ce..fd3b43c5a7c 100644 --- a/csrc/validator_utils.cpp +++ b/csrc/validator_utils.cpp @@ -333,7 +333,8 @@ void testValidate( // Returns true when `out` is **not** an aliased output that's hidden // from integration. Hidden outputs won't show up in `fusion_outputs` // for us to compare, so we skip them. - return !fusion->getOutputAlias(out).hide_output; + return fusion->getOutputAlias(out).visibility == + OutputVisibility::kVisible; }); auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams); diff --git a/python/python_common/distributed_tensor.cpp b/python/python_common/distributed_tensor.cpp index 734f82c380a..1dcfa16a615 100644 --- a/python/python_common/distributed_tensor.cpp +++ b/python/python_common/distributed_tensor.cpp @@ -46,7 +46,8 @@ std::vector getOutputShardings(Fusion* fusion) { output_shardings.reserve(fusion->outputs().size()); for (Val* out_val : fusion->outputs()) { if (auto* out_tv = dynamic_cast(out_val)) { - if (fusion->getOutputAlias(out_tv).hide_output) { + if (fusion->getOutputAlias(out_tv).visibility == + OutputVisibility::kHidden) { continue; } const DeviceMesh& mesh = out_tv->getDeviceMesh(); diff --git a/python/python_direct/python_translate.cpp b/python/python_direct/python_translate.cpp index efc03da34ad..f8cc55478b8 100644 --- a/python/python_direct/python_translate.cpp +++ b/python/python_direct/python_translate.cpp @@ -645,7 +645,7 @@ class PythonTranslator : public OptInConstDispatch { } // If not hide_output, then the aliased output is returned as a // fusion output. - if (!alias_info.hide_output) { + if (alias_info.visibility == OutputVisibility::kVisible) { handleOutput(v->as()); } break; diff --git a/python/python_frontend/translation.cpp b/python/python_frontend/translation.cpp index 32874c75197..e9c4cd461fc 100644 --- a/python/python_frontend/translation.cpp +++ b/python/python_frontend/translation.cpp @@ -282,8 +282,9 @@ class FusionTranslator : public OptInConstDispatch { handleOutput(v->as(), alias_info); } // An alias output can also be returned as a fusion output - // if it is already aliased or if not hide_output - if (num_visited > 0 || !alias_info.hide_output) { + // if it is already aliased or if the output is visible. + if (num_visited > 0 || + alias_info.visibility == OutputVisibility::kVisible) { handleOutput(v->as()); } break;