From 4fd660d8cfe481b0c6e7ac63a364a9278513f836 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 24 May 2024 01:45:48 +0000 Subject: [PATCH 1/7] release external output as long as it is not needed. --- .../core/framework/allocation_planner.cc | 105 +++++++++++++----- .../partial_graph_execution_state.cc | 6 +- .../framework/partial_graph_execution_state.h | 2 +- .../core/framework/sequential_executor.cc | 5 +- .../core/framework/sequential_executor.h | 2 +- onnxruntime/core/framework/utils.cc | 4 +- onnxruntime/core/framework/utils.h | 2 +- onnxruntime/core/optimizer/gemm_sum_fusion.cc | 2 +- .../core/optimizer/matmul_add_fusion.cc | 2 +- onnxruntime/core/session/inference_session.cc | 2 +- onnxruntime/core/session/inference_session.h | 6 +- .../orttraining/core/agent/training_agent.cc | 6 +- .../orttraining/core/agent/training_agent.h | 6 +- .../python/orttraining_pybind_state.cc | 12 +- 14 files changed, 109 insertions(+), 53 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 95e5380675df2..c7f9826a7ab8a 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -294,8 +294,8 @@ class PlannerImpl { #endif // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. - bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, - bool* is_strided_tensor) { + bool FindReusableInput(const GraphViewer& graph, const onnxruntime::Node& node, int output_arg_num, + OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; #ifdef ENABLE_TRAINING // Inputs of Yields are essentially the outputs for FW partial subgraph @@ -326,6 +326,12 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { + // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, + // we cannot. + const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); + ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), + "Node ", node.Name(), " cannot reuse input buffer for node ", producer_node->Name(), + " as it has external outputs, which cannot be reused."); *reusable_input = Index(p_input_arg->Name()); return true; } @@ -342,6 +348,12 @@ class PlannerImpl { if (alias_input_index >= 0 && static_cast(alias_input_index) < input_args.size()) { auto p_input_arg = input_args[alias_input_index]; if (p_input_arg->Exists()) { + // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, + // we cannot. + const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); + ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), + "Node ", node.Name(), " cannot reuse input buffer for node ", producer_node->Name(), + " as it has external outputs, which cannot be reused."); *reusable_input = Index(p_input_arg->Name()); return true; } @@ -357,10 +369,18 @@ class PlannerImpl { auto input_arg_index = Index(p_input_arg->Name()); auto original = Buffer(input_arg_index); if (1 == UseCount(original)) { - if (SameSize(*p_input_arg, *p_output_arg)) { - // we can reuse this input since it is its last use and permitted for in-place update - *reusable_input = input_arg_index; // or original; both should be okay - return true; + // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, + // we cannot. + const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); + if (producer_node == nullptr || !HasExternalOutputs(*producer_node)) { + if (SameSize(*p_input_arg, *p_output_arg)) { + // we can reuse this input since it is its last use and permitted for in-place update + *reusable_input = input_arg_index; // or original; both should be okay + return true; + } + } else { + LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; } } } @@ -395,10 +415,17 @@ class PlannerImpl { break; } } + if (can_strided) { - *reusable_input = Index(input_args[pair.first]->Name()); - *is_strided_tensor = true; - return true; + const Node* producer_node = graph.GetProducerNode(input_args[pair.first]->Name()); + if (producer_node == nullptr || !HasExternalOutputs(*producer_node)) { + *reusable_input = Index(input_args[pair.first]->Name()); + *is_strided_tensor = true; + return true; + } else { + LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " + << producer_node->Name() << " as it has external outputs."; + } } } } @@ -604,13 +631,12 @@ class PlannerImpl { auto outputs = pnode->OutputDefs(); auto num_outputs = outputs.size(); - bool has_external_outputs = HasExternalOutputs(*pnode); for (size_t i = 0; i < num_outputs; ++i) { auto* node_output = outputs[i]; if (!node_output->Exists()) continue; OrtValueIndex index = Index(node_output->Name()); - // Ensures external outputs will not be reused. - UseCount(index) += (has_external_outputs ? 2 : 1); + + UseCount(index) += 1; } } } @@ -1079,7 +1105,8 @@ class PlannerImpl { int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); auto origin = AllocPlan(value_idx).reused_buffer; - if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate || + AllocPlan(origin).alloc_kind == AllocKind::kAllocatedExternally) { // add current node as consumer for origin buffer value_consumer_map[origin].insert(node_index); } @@ -1131,6 +1158,11 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { + // If the producer node does not has external outputs, we can reuse the input buffer; + // Otherwise, we cannot reuse the buffer. + const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); + ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), + "Cannot alias reuse input buffer for ", p_output_arg->Name(), " as input ", p_input_arg->Name()); OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() /*&& allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate*/ @@ -1163,6 +1195,11 @@ class PlannerImpl { auto p_input_arg = input_args[alias_input_index]; if (p_input_arg->Exists()) { + // If the producer node does not has external outputs, we can reuse the input buffer; + // Otherwise, we cannot reuse the buffer. + const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); + ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), + "Cannot reuse input buffer for ", p_output_arg->Name(), " as input ", p_input_arg->Name()); OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() && allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { @@ -1172,8 +1209,8 @@ class PlannerImpl { value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; - } // if - } // if + } + } } } @@ -1184,16 +1221,24 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { - OrtValueIndex input_arg_index{}; - if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && - allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { - allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; - allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); - reused.insert(input_arg_index); + // If the producer node does not has external outputs, we can reuse the input buffer; + // Otherwise, we cannot reuse the buffer. + const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); + if (producer_node == nullptr || !HasExternalOutputs(*producer_node)) { + OrtValueIndex input_arg_index{}; + if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && + allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; + allocation_plan[output_idx_global].reused_buffer = input_arg_index; + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); + reused.insert(input_arg_index); + } } + } else { + LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; } } } @@ -1439,7 +1484,8 @@ class PlannerImpl { } } } else if (!context_->IsParallelExecutionEnabled() && - FindReusableInput(*pnode, static_cast(output_arg_def_index), &reused, &is_strided_tensor)) { + FindReusableInput(graph_viewer_, *pnode, static_cast(output_arg_def_index), + &reused, &is_strided_tensor)) { // Re-using inputs is applicable for tensors, sequence tensors, // and optional types if the kernel has marked certain inputs as // possible candidates for re-use @@ -1522,7 +1568,8 @@ class PlannerImpl { if (!node_output->Exists()) continue; // OrtValue index of the considered output NodeArg. const auto current = Index(node_output->Name()); - if (AllocPlan(current).alloc_kind == AllocKind::kAllocate) { + if (AllocPlan(current).alloc_kind == AllocKind::kAllocate || + AllocPlan(current).alloc_kind == AllocKind::kAllocatedExternally) { AllocPlan(current).program_counter.AddStart(program_counter); } } @@ -1570,8 +1617,7 @@ class PlannerImpl { // OrtValue index of the considered output NodeArg. const auto current = Index(node_output->Name()); AllocPlan(current).life_interval.first = program_counter; - if (AllocPlan(current).alloc_kind == AllocKind::kAllocatedExternally || - AllocPlan(current).alloc_kind == AllocKind::kAllocateOutput) { + if (AllocPlan(current).alloc_kind == AllocKind::kAllocateOutput) { AllocPlan(current).life_interval.second = execution_plan.size(); } // determine if inputs of *pnode can be freed: @@ -1695,7 +1741,8 @@ class PlannerImpl { int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); auto origin = AllocPlan(value_idx).reused_buffer; - if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate || + AllocPlan(origin).alloc_kind == AllocKind::kAllocatedExternally) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } diff --git a/onnxruntime/core/framework/partial_graph_execution_state.cc b/onnxruntime/core/framework/partial_graph_execution_state.cc index 8a749008945c0..db27e249f24f8 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.cc +++ b/onnxruntime/core/framework/partial_graph_execution_state.cc @@ -59,8 +59,10 @@ DeviceStreamCollection* PartialGraphExecutionState::GetDeviceStreamCollection(co return device_stream_collection_.get(); } -StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span& feed_mlvalue_idxs, gsl::span& feeds, - gsl::span& fetch_mlvalue_idxs, std::vector& fetches, +StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span& feed_mlvalue_idxs, + std::vector& feeds, + gsl::span& fetch_mlvalue_idxs, + std::vector& fetches, const std::unordered_map& fetch_allocators, const SessionState& session_state, const logging::Logger& sess_logger, diff --git a/onnxruntime/core/framework/partial_graph_execution_state.h b/onnxruntime/core/framework/partial_graph_execution_state.h index 80f7e52a9cf68..369fa4061b956 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.h +++ b/onnxruntime/core/framework/partial_graph_execution_state.h @@ -30,7 +30,7 @@ struct PartialGraphExecutionState { ProgramRegion& GetProgramRegions(const SessionState& session_state); - StreamExecutionContext& GetExecutionContext(gsl::span& feed_mlvalue_idxs, gsl::span& feeds, + StreamExecutionContext& GetExecutionContext(gsl::span& feed_mlvalue_idxs, std::vector& feeds, gsl::span& fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, const SessionState& session_state, diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 0cc7294a46495..8062f6baa18bd 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -615,7 +615,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #ifdef ENABLE_TRAINING onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span feed_mlvalue_idxs, - gsl::span feeds, gsl::span fetch_mlvalue_idxs, + std::vector& feeds, + gsl::span fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, @@ -628,6 +629,8 @@ onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl int32_t partial_graph_index) { auto& ctx = state.GetExecutionContext(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches, fetch_allocators, session_state, logger, device_streams); + feeds.clear(); // Release the feeds at the earliest convenience. + auto* plan = session_state.GetExecutionPlan(); ctx.SetCurrentRange(&state.GetProgramRegions(session_state)); diff --git a/onnxruntime/core/framework/sequential_executor.h b/onnxruntime/core/framework/sequential_executor.h index 9bef4be34fd80..c22ccb041afdf 100644 --- a/onnxruntime/core/framework/sequential_executor.h +++ b/onnxruntime/core/framework/sequential_executor.h @@ -50,7 +50,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #ifdef ENABLE_TRAINING onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span feed_mlvalue_idxs, - gsl::span feeds, gsl::span fetch_mlvalue_idxs, + std::vector& feeds, gsl::span fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, const logging::Logger& logger, diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 0c4d498fae9e0..9c282210d2169 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -784,7 +784,7 @@ common::Status ExecuteGraph(const SessionState& session_state, #ifdef ENABLE_TRAINING common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, - gsl::span feeds, std::vector& fetches, + std::vector& feeds, std::vector& fetches, const logging::Logger& logger, PartialGraphExecutionState& state, const OrtValueCachePtr& cache, const bool& terminate_flag, DeviceStreamCollection* device_stream_collection, @@ -882,7 +882,7 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF } common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, - gsl::span feeds, std::vector& fetches, + std::vector& feeds, std::vector& fetches, const logging::Logger& logger, PartialGraphExecutionState& state, const OrtValueCachePtr& cache, const bool& terminate_flag, int32_t partial_graph_index, diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index f0b1b9109d405..5b948e4887978 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -100,7 +100,7 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, - gsl::span feeds, std::vector& fetches, + std::vector& feeds, std::vector& fetches, const logging::Logger& logger, PartialGraphExecutionState& state, const OrtValueCachePtr& cache, const bool& terminate_flag, diff --git a/onnxruntime/core/optimizer/gemm_sum_fusion.cc b/onnxruntime/core/optimizer/gemm_sum_fusion.cc index 3f2c1ac046105..be3c90a822fe2 100644 --- a/onnxruntime/core/optimizer/gemm_sum_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_sum_fusion.cc @@ -36,7 +36,7 @@ Status GemmSumFusion::Apply(Graph& graph, Node& gemm_node, RewriteRuleEffect& mo std::vector new_gemm_output_defs = sum_node.MutableOutputDefs(); ORT_ENFORCE(new_gemm_output_defs.size() == 1); - Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_sum_transformed"), + Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmSumFusion/"), gemm_node.OpType(), "Fused Gemm with Sum", new_gemm_input_defs, diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index ad77f4d143d31..2a4916ccb324a 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -106,7 +106,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - Node& gemm_node = graph.AddNode(graph.GenerateNodeName("gemm"), + Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), "Gemm", "fused Matmul and Add " + add_node.OpType(), gemm_input_defs, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index d1add79f0cb00..04149e11eab6d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2302,7 +2302,7 @@ common::Status InferenceSession::ValidateOutputs(gsl::span ou #ifdef ENABLE_TRAINING Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, - const std::vector& feeds, + std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 48f6d73b077cb..204b24974ff50 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -377,8 +377,8 @@ class InferenceSession { /** * Partially run a pre-loaded and pre-intialized model. * @param run_options run options. - * @param feeds inputs owned by client code and should not be changed during - * execution of this function. + * @param mutable_feeds inputs owned by client code and will be released as long as the feeds be set in session states. + * Then the feeds will purely managed in the session states. * @param fetches outputs produced after the executin of this function. * @param state State of the graph needed to resume partial graph run. * @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device @@ -388,7 +388,7 @@ class InferenceSession { * @param partial_graph_index Index of the partial graph to run. */ common::Status PartialRun(onnxruntime::RunOptions& run_options, - const std::vector& feeds, + std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index cc8d341dbc084..314c5744bb6b9 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -60,7 +60,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session, TrainingAgent::~TrainingAgent() = default; -common::Status TrainingAgent::RunForward(const std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunForward( std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, const OrtValueCachePtr& cache) { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) inference_session_.GetMemoryProfiler().GetMemoryInfo().SetIteration(profile_step_); @@ -74,7 +74,7 @@ common::Status TrainingAgent::RunForward(const std::vector& feeds, std return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache, partial_graph_index); } -common::Status TrainingAgent::RunBackward(const std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunBackward( std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state) { state.SetProgramCounterStart(fw_program_counter_end_); state.SetProgramCounterEnd(bw_program_counter_end_); @@ -82,7 +82,7 @@ common::Status TrainingAgent::RunBackward(const std::vector& feeds, st return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr, partial_graph_index); } -common::Status TrainingAgent::RunCore(const std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunCore( std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, const OrtValueCachePtr& cache, int32_t partial_graph_index) { auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size(); diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h index 8d88a6df39352..2b74e56ae60e2 100644 --- a/orttraining/orttraining/core/agent/training_agent.h +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -32,14 +32,14 @@ class TrainingAgent { int local_rank = 0); ~TrainingAgent(); // For ORTModule.forward() - [[nodiscard]] common::Status RunForward(const std::vector& feeds, std::vector& fetches, + [[nodiscard]] common::Status RunForward(std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state, const OrtValueCachePtr& cache); // For ORTModule.backward() - [[nodiscard]] common::Status RunBackward(const std::vector& feeds, std::vector& fetches, + [[nodiscard]] common::Status RunBackward(std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state); - [[nodiscard]] common::Status RunCore(const std::vector& feeds, std::vector& fetches, + [[nodiscard]] common::Status RunCore(std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, const OrtValueCachePtr& cache, int32_t partial_graph_index); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 5ea60102f3ef8..e2616e1b441f7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -423,14 +423,18 @@ void addObjectMethodsForTraining(py::module& m) { return std::make_unique(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info, bw_fetches_names, bw_outputs_device_info, local_rank); })) - .def("run_forward", [](TrainingAgent* agent, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void { - Status status = agent->RunForward(feeds, fetches, *state, cache); + .def("run_forward", [](TrainingAgent* agent, std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void { + // Feed is passed in mutable way, to allow the internal logic to release the feeds as long as it is not needed. + // Otherwise, the feeds will be released after the forward pass, which hold some unnecessary memory. + Status status = agent->RunForward(mutable_feeds, fetches, *state, cache); if (!status.IsOK()) { throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage()); } }) - .def("run_backward", [](TrainingAgent* agent, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState* state) -> void { - Status status = agent->RunBackward(feeds, fetches, *state); + .def("run_backward", [](TrainingAgent* agent, std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState* state) -> void { + // Feed is passed in mutable way, to allow the internal logic to release the feeds as long as it is not needed. + // Otherwise, the feeds will be released after the forward pass, which hold some unnecessary memory. + Status status = agent->RunBackward(mutable_feeds, fetches, *state); if (!status.IsOK()) { throw std::runtime_error("Error in backward pass execution: " + status.ErrorMessage()); } From e684871838b7dda0aab1d2780c76948074952ada Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 24 May 2024 02:59:51 +0000 Subject: [PATCH 2/7] minor --- onnxruntime/core/framework/allocation_planner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index c7f9826a7ab8a..3edf7fdda908a 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1237,7 +1237,7 @@ class PlannerImpl { } } } else { - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " << producer_node->Name() << " as it has external outputs"; } } From 8e4d34a0f6be0def38d19e2db9d880fac786eb11 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 11 Jun 2024 12:38:27 +0000 Subject: [PATCH 3/7] fix uts --- .../core/framework/allocation_planner.cc | 52 ++++++++++++------- .../test/framework/allocation_planner_test.cc | 40 ++++++++++++-- .../orttraining/core/agent/training_agent.cc | 6 +-- 3 files changed, 72 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 3edf7fdda908a..5db867fbd3511 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1514,6 +1514,12 @@ class PlannerImpl { // determine if inputs of *pnode can be freed: for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { + const Node* producer_node = graph_viewer_.GetProducerNode(node_input->Name()); + // Skip if the producer node has external outputs. + if (producer_node != nullptr && HasExternalOutputs(*producer_node)) { + continue; + } + auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. @@ -1526,6 +1532,12 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { + const Node* producer_node = graph_viewer_.GetProducerNode(node_input->Name()); + // Skip if the producer node has external outputs. + if (producer_node != nullptr && HasExternalOutputs(*producer_node)) { + continue; + } + auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. @@ -1536,15 +1548,17 @@ class PlannerImpl { } } - // determine if any outputs of *pnode are unused and can be freed: - for (auto node_output : pnode->OutputDefs()) { - if (node_output->Exists()) { - auto& sym = node_output->Name(); - auto original = Buffer(Index(sym)); - // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. - // See comments in the OrtValueInfo definition. - if (0 == DecrementUseCount(original)) { - freelist_.push_front(FreeBufferInfo(original, program_counter)); + if (!HasExternalOutputs(*pnode)) { + // determine if any outputs of *pnode are unused and can be freed: + for (auto node_output : pnode->OutputDefs()) { + if (node_output->Exists()) { + auto& sym = node_output->Name(); + auto original = Buffer(Index(sym)); + // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. + // See comments in the OrtValueInfo definition. + if (0 == DecrementUseCount(original)) { + freelist_.push_front(FreeBufferInfo(original, program_counter)); + } } } } @@ -1725,9 +1739,9 @@ class PlannerImpl { // Convert information in execution plan and memory reuse plan into release plan Status GenerateDeallocationPlan() { // 1. build the consumer list for each value - std::vector> value_consumers; + std::vector> ortvalue_to_consumers_map; int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumers.resize(num_ml_values); + ortvalue_to_consumers_map.resize(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1744,7 +1758,7 @@ class PlannerImpl { if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate || AllocPlan(origin).alloc_kind == AllocKind::kAllocatedExternally) { // add current node as consumer for origin buffer - value_consumers[origin].push_back(node_index); + ortvalue_to_consumers_map[origin].push_back(node_index); } } return Status::OK(); @@ -1760,8 +1774,8 @@ class PlannerImpl { plan_.node_release_list[node_index].push_back(release_action_idx); }; plan_.node_release_list.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); - for (size_t i = 0; i < value_consumers.size(); ++i) { - if (!value_consumers[i].empty()) { + for (size_t i = 0; i < ortvalue_to_consumers_map.size(); ++i) { + if (!ortvalue_to_consumers_map[i].empty()) { plan_.release_actions.push_back(SequentialExecutionPlan::ReleaseAction{i, 0}); auto release_action_idx = plan_.release_actions.size() - 1; // check whether we can static determine where to release. @@ -1769,19 +1783,19 @@ class PlannerImpl { // we actually can do better if all the consumers depends on the last consumer. // will optimize it later bool is_all_consumer_same_stream = true; - auto stream_idx = plan_.node_stream_map_[value_consumers[i][0]]; - for (size_t j = 1; j < value_consumers[i].size(); ++j) { - if (plan_.node_stream_map_[value_consumers[i][j]] != stream_idx) { + auto stream_idx = plan_.node_stream_map_[ortvalue_to_consumers_map[i][0]]; + for (size_t j = 1; j < ortvalue_to_consumers_map[i].size(); ++j) { + if (plan_.node_stream_map_[ortvalue_to_consumers_map[i][j]] != stream_idx) { is_all_consumer_same_stream = false; break; } } if (is_all_consumer_same_stream) { // all the consumers are on the same stream, so the first element is the last consumer int the stream. - process_consumer(release_action_idx, value_consumers[i][0]); + process_consumer(release_action_idx, ortvalue_to_consumers_map[i][0]); } else { // can't static determin, add all the consumers, we will use ref count in release action - for (auto node_index : value_consumers[i]) { + for (auto node_index : ortvalue_to_consumers_map[i]) { process_consumer(release_action_idx, node_index); } } diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 3a01f2c8d95ad..a4d64a51594a7 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -433,7 +433,7 @@ TEST_F(PlannerTest, ChainTest) { CreatePlan(); // Expected plan: - // W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); X is returned output + // W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); Z is returned output CheckAllocKind(W, AllocKind::kAllocateStatically); CheckAllocKind(X, AllocKind::kAllocate); CheckAllocKind(B, AllocKind::kAllocate); @@ -536,10 +536,42 @@ TEST_F(PlannerTest, ExternalOutputsTest) { CheckAllocKind(X4, AllocKind::kAllocateOutput); // check each ml-value is freed at appropriate step - // X2 will not be reused and will not be freed. X3 will be allocated and will be freed. + // X2 will not be reused but will be freed (to release the current reference). X3 will be allocated and will be freed. CheckFreed(0, {}); - CheckFreed(1, {}); - CheckFreed(2, {X1}); + CheckFreed(1, {X2}); + CheckFreed(2, {X3}); +} + +TEST_F(PlannerTest, ExternalOutputsNoReuseTest) { + // tensor variables: + std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"), X5("X5"); + + // graph structure: + AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary + AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary + AddNormalNode(X3, X4); // normal operator; X4: temporary + AddNormalNode(X4, X5); // normal operator; X5: output + + // simulate shape-inference results: + Shape shape1{"M", "N"}; + auto shape = &shape1.value; + SetShape({{X1, shape}, {X2, shape}, {X3, shape}, {X4, shape}, {X5, shape}}); + + CreatePlan(); + + // check allocation kind: + CheckAllocKind(X1, AllocKind::kPreExisting); + CheckAllocKind(X2, AllocKind::kAllocatedExternally); + CheckAllocKind(X3, AllocKind::kAllocate); // Should not be Reused. + CheckAllocKind(X4, AllocKind::kAllocate); + CheckAllocKind(X5, AllocKind::kAllocateOutput); + + // check each ml-value is freed at appropriate step + // X2 will not be reused. X3 will be allocated and will be freed. + CheckFreed(0, {}); + CheckFreed(1, {X2}); + CheckFreed(2, {X3}); + CheckFreed(3, {X4}); } #ifdef ENABLE_STRIDED_TENSORS diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index 314c5744bb6b9..cc3f6d1ff82ee 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -60,7 +60,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session, TrainingAgent::~TrainingAgent() = default; -common::Status TrainingAgent::RunForward( std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunForward(std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, const OrtValueCachePtr& cache) { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) inference_session_.GetMemoryProfiler().GetMemoryInfo().SetIteration(profile_step_); @@ -74,7 +74,7 @@ common::Status TrainingAgent::RunForward( std::vector& feeds, std::vec return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache, partial_graph_index); } -common::Status TrainingAgent::RunBackward( std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunBackward(std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state) { state.SetProgramCounterStart(fw_program_counter_end_); state.SetProgramCounterEnd(bw_program_counter_end_); @@ -82,7 +82,7 @@ common::Status TrainingAgent::RunBackward( std::vector& feeds, std::ve return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr, partial_graph_index); } -common::Status TrainingAgent::RunCore( std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunCore(std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, const OrtValueCachePtr& cache, int32_t partial_graph_index) { auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size(); From 5b00bebc521204051e3222e89253b32787f6bdab Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Tue, 11 Jun 2024 13:48:17 +0000 Subject: [PATCH 4/7] fix uts and add comments --- .../core/framework/kernel_def_builder.h | 10 ++++ .../partial_graph_execution_state.cc | 4 +- .../framework/partial_graph_execution_state.h | 1 + .../core/framework/sequential_executor.cc | 2 +- .../test/framework/allocation_planner_test.cc | 54 ++++++++++++------- 5 files changed, 50 insertions(+), 21 deletions(-) diff --git a/include/onnxruntime/core/framework/kernel_def_builder.h b/include/onnxruntime/core/framework/kernel_def_builder.h index de5a3a52f5be7..baccbe1929ac4 100644 --- a/include/onnxruntime/core/framework/kernel_def_builder.h +++ b/include/onnxruntime/core/framework/kernel_def_builder.h @@ -253,6 +253,16 @@ class KernelDefBuilder { /** Specify that this kernel's output buffers are passed from external, i.e. not created or managed by ORT's memory allocator. + + The OrtValue set as external outputs, must be safe to release as long as the OrtValue's reference + count reaches zero in ORT's allocation/deallocation plan. We usually create such an OrtValue + following flows: torch tensors --> to dlpack tensors (destructor will release a view of original torch tensor, + instead of releasing original torch tensor) --> to OrtValue. + + When the OrtValue is not needed in the graph, then it will be released after calling the attached + destructor. The destructor will release the view of the original torch tensor, instead of releasing the original + torch tensor. This is to make sure the original torch tensor can still be okay to use externally, + even after OrtValue is released in the graph. (Recalled this OrtValue is also not reused by ORT). */ KernelDefBuilder& ExternalOutputs() { kernel_def_->external_outputs_ = true; diff --git a/onnxruntime/core/framework/partial_graph_execution_state.cc b/onnxruntime/core/framework/partial_graph_execution_state.cc index db27e249f24f8..a053634adbe35 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.cc +++ b/onnxruntime/core/framework/partial_graph_execution_state.cc @@ -83,7 +83,7 @@ StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::spa execution_plan->num_barriers, device_streams, feed_mlvalue_idxs, - feeds, + std::move(feeds), fetch_mlvalue_idxs, fetches, fetch_allocators, @@ -91,7 +91,7 @@ StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::spa // partial executor in training can only be run with single thread true); } else { - execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, feeds); + execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, std::move(feeds)); execution_context_->GetExecutionFrame().UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors()); execution_context_->SetLogger(sess_logger); } diff --git a/onnxruntime/core/framework/partial_graph_execution_state.h b/onnxruntime/core/framework/partial_graph_execution_state.h index 369fa4061b956..adfb5f8f5fca7 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.h +++ b/onnxruntime/core/framework/partial_graph_execution_state.h @@ -30,6 +30,7 @@ struct PartialGraphExecutionState { ProgramRegion& GetProgramRegions(const SessionState& session_state); + // Be noted: feeds will be std::move to ctx, so it will be empty after this function. StreamExecutionContext& GetExecutionContext(gsl::span& feed_mlvalue_idxs, std::vector& feeds, gsl::span& fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 8062f6baa18bd..a374e381a2b0e 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -627,9 +627,9 @@ onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl PartialGraphExecutionState& state, const OrtValueCachePtr& cache, int32_t partial_graph_index) { + // Be noted: feeds will be std::move to ctx, so it will be empty after this function. auto& ctx = state.GetExecutionContext(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches, fetch_allocators, session_state, logger, device_streams); - feeds.clear(); // Release the feeds at the earliest convenience. auto* plan = session_state.GetExecutionPlan(); diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index a4d64a51594a7..9f9455f5cf723 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -370,8 +370,9 @@ class PlannerTest : public ::testing::Test { EXPECT_EQ(plan_->execution_plan.size(), 1U); int list_size = static_cast(plan_->node_release_list.size()); EXPECT_GT(list_size, step_number); - for (auto freed : plan_->node_release_list[step_number]) { - plan_result.insert(static_cast(freed)); + for (auto action_idx : plan_->node_release_list[step_number]) { + const size_t ortvalue_id = plan_->release_actions[action_idx].value_index; + plan_result.insert(static_cast(ortvalue_id)); } EXPECT_EQ(plan_result, expected) << "Freed items incorrect for step " << step_number; } @@ -442,8 +443,8 @@ TEST_F(PlannerTest, ChainTest) { CheckFreed(0, {}); CheckFreed(1, {}); - CheckFreed(2, {"X"}); - CheckFreed(3, {"W"}); + CheckFreed(2, {B}); + CheckFreed(3, {X}); } /* InputOutputTest: Test that: @@ -510,7 +511,7 @@ TEST_F(PlannerTest, InPlaceTest) { // check each ml-value is freed at appropriate step CheckFreed(0, {}); CheckFreed(1, {}); - CheckFreed(2, {X1}); + CheckFreed(2, {X2}); } TEST_F(PlannerTest, ExternalOutputsTest) { @@ -596,9 +597,9 @@ TEST_F(PlannerTest, MayStridedTest1) { CheckAllocKind(X3, AllocKind::kAllocateOutput); // check each ml-value is freed at appropriate step - // X2 will not be reused and will not be freed. X3 will be allocated and will be freed. + // X2 will not be reused because X3 is a graph output. X3 will be allocated and will be freed. CheckFreed(0, {}); - CheckFreed(1, {X1}); + CheckFreed(1, {X2}); } TEST_F(PlannerTest, MayStridedTest2) { @@ -606,9 +607,9 @@ TEST_F(PlannerTest, MayStridedTest2) { std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"); // graph structure: - AddMayStridedOutputNode(X1, X2); - AddMayStridedInputNode(X2, X3); - AddMayStridedInputNode(X2, X4); + AddMayStridedOutputNode(X1, X2); // X2 can reuse X1, and is a strided output. + AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse. + AddMayStridedInputNode(X2, X4); // X4 is a graph output, cannot reuse. // simulate shape-inference results: Shape shape1{"M", "N"}; @@ -635,8 +636,9 @@ TEST_F(PlannerTest, MayStridedTest3) { std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"); // graph structure: - AddMayStridedOutputNode(X1, X2); - AddMayStridedInputNode(X2, X3); + AddMayStridedOutputNode(X1, X2); // X2 cannot strided reuse X1 because, + // one of X2's consumers is a node not supporting strided input. So X2 is a allocate. + AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse. AddNormalNode(X2, X4); // simulate shape-inference results: @@ -652,11 +654,27 @@ TEST_F(PlannerTest, MayStridedTest3) { CheckAllocKind(X3, AllocKind::kAllocateOutput); CheckAllocKind(X4, AllocKind::kAllocateOutput); + // Be noted: the last two nodes added can run in two different orders, we need figure out the exact order + // we planned then we know how to check the free order. + const GraphViewer& graph_viewer = GetState().GetGraphViewer(); + // Normal node index is 2. + bool does_normal_node_run_at_last = graph_viewer.GetNodesInTopologicalOrder()[2] == 2; + // check each ml-value is freed at appropriate step - // X2 will not be reused and will not be freed. X3 will be allocated and will be freed. + CheckFreed(0, {}); - CheckFreed(1, {X1}); - CheckFreed(2, {}); + + if (does_normal_node_run_at_last) { + // Normal node has node index to be 2, but it is possible that the normal node is executed after the strided node. + // Then X2 will released once the normal node is executed. + CheckFreed(1, {}); + CheckFreed(2, {X2}); + } else { + // Normal node has node index to be 2, and is executed before the strided node. + // So X2 will be released after the strided node (node index to be 1) is executed. + CheckFreed(2, {}); + CheckFreed(1, {X2}); + } } #endif @@ -669,7 +687,7 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) { // graph structure: AddNormalNode(X1, X2); // no in-place operator; X1: input; X2: temporary AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary - AddNormalNode(X3, X4); // no in-place operator; X4: temporary + AddNormalNode(X3, X4); // no in-place operator; X4: temporary (reuse X2) AddInplaceNode(X4, X5); // may-in-place operator; X5 output // simulate shape-inference results: @@ -691,8 +709,8 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) { // check each ml-value is freed at appropriate step CheckFreed(0, {}); CheckFreed(1, {}); - CheckFreed(2, {X2}); - CheckFreed(3, {X1}); + CheckFreed(2, {X3}); + CheckFreed(3, {X2}); } // Test operator<< to output details of an allocation & execution plan. From e82780a4d6e00fd9cc141f13877e44c2fb95a9b7 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 12 Jun 2024 05:06:27 +0000 Subject: [PATCH 5/7] fix tests --- .../core/framework/allocation_planner.cc | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5db867fbd3511..b2f8c24ded2b9 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -329,9 +329,11 @@ class PlannerImpl { // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); - ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), - "Node ", node.Name(), " cannot reuse input buffer for node ", producer_node->Name(), - " as it has external outputs, which cannot be reused."); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; + } *reusable_input = Index(p_input_arg->Name()); return true; } @@ -351,9 +353,11 @@ class PlannerImpl { // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); - ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), - "Node ", node.Name(), " cannot reuse input buffer for node ", producer_node->Name(), - " as it has external outputs, which cannot be reused."); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; + } *reusable_input = Index(p_input_arg->Name()); return true; } @@ -1161,8 +1165,12 @@ class PlannerImpl { // If the producer node does not has external outputs, we can reuse the input buffer; // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); - ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), - "Cannot alias reuse input buffer for ", p_output_arg->Name(), " as input ", p_input_arg->Name()); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; + } + OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() /*&& allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate*/ @@ -1198,8 +1206,12 @@ class PlannerImpl { // If the producer node does not has external outputs, we can reuse the input buffer; // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); - ORT_ENFORCE(producer_node == nullptr || !HasExternalOutputs(*producer_node), - "Cannot reuse input buffer for ", p_output_arg->Name(), " as input ", p_input_arg->Name()); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; + } + OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() && allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { From d2bc31998a742f8e594ac27cac841fac6dfd8813 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 12 Jun 2024 06:19:48 +0000 Subject: [PATCH 6/7] fix minimal build --- .../core/framework/allocation_planner.cc | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index b2f8c24ded2b9..39fc8af1bd15b 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -326,6 +326,7 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); @@ -334,6 +335,7 @@ class PlannerImpl { << producer_node->Name() << " which has external outputs. " << "Be cautious the reuse MUST be a read-only usage."; } +#endif *reusable_input = Index(p_input_arg->Name()); return true; } @@ -350,6 +352,7 @@ class PlannerImpl { if (alias_input_index >= 0 && static_cast(alias_input_index) < input_args.size()) { auto p_input_arg = input_args[alias_input_index]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); @@ -358,6 +361,7 @@ class PlannerImpl { << producer_node->Name() << " which has external outputs. " << "Be cautious the reuse MUST be a read-only usage."; } +#endif *reusable_input = Index(p_input_arg->Name()); return true; } @@ -373,10 +377,15 @@ class PlannerImpl { auto input_arg_index = Index(p_input_arg->Name()); auto original = Buffer(input_arg_index); if (1 == UseCount(original)) { + bool need_skip = false; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); - if (producer_node == nullptr || !HasExternalOutputs(*producer_node)) { + need_skip = producer_node && HasExternalOutputs(*producer_node); +#endif + + if (!need_skip) { if (SameSize(*p_input_arg, *p_output_arg)) { // we can reuse this input since it is its last use and permitted for in-place update *reusable_input = input_arg_index; // or original; both should be okay @@ -421,8 +430,13 @@ class PlannerImpl { } if (can_strided) { + bool need_skip = false; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const Node* producer_node = graph.GetProducerNode(input_args[pair.first]->Name()); - if (producer_node == nullptr || !HasExternalOutputs(*producer_node)) { + need_skip = producer_node && HasExternalOutputs(*producer_node); +#endif + + if (!need_skip) { *reusable_input = Index(input_args[pair.first]->Name()); *is_strided_tensor = true; return true; @@ -1162,6 +1176,7 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // If the producer node does not has external outputs, we can reuse the input buffer; // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); @@ -1170,6 +1185,7 @@ class PlannerImpl { << producer_node->Name() << " which has external outputs is reused. " << "Be cautious the reuse MUST be a read-only usage."; } +#endif OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() /*&& @@ -1203,6 +1219,7 @@ class PlannerImpl { auto p_input_arg = input_args[alias_input_index]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // If the producer node does not has external outputs, we can reuse the input buffer; // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); @@ -1211,6 +1228,7 @@ class PlannerImpl { << producer_node->Name() << " which has external outputs is reused. " << "Be cautious the reuse MUST be a read-only usage."; } +#endif OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() && @@ -1233,10 +1251,15 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { + bool need_skip = false; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // If the producer node does not has external outputs, we can reuse the input buffer; // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); - if (producer_node == nullptr || !HasExternalOutputs(*producer_node)) { + need_skip = producer_node && HasExternalOutputs(*producer_node); +#endif + + if (!need_skip) { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { @@ -1526,11 +1549,13 @@ class PlannerImpl { // determine if inputs of *pnode can be freed: for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const Node* producer_node = graph_viewer_.GetProducerNode(node_input->Name()); // Skip if the producer node has external outputs. if (producer_node != nullptr && HasExternalOutputs(*producer_node)) { continue; } +#endif auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); @@ -1544,11 +1569,13 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const Node* producer_node = graph_viewer_.GetProducerNode(node_input->Name()); // Skip if the producer node has external outputs. if (producer_node != nullptr && HasExternalOutputs(*producer_node)) { continue; } +#endif auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); From 229868ff5b1eedbb9e55cff2d0e5004d4c402f18 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 12 Jun 2024 07:17:21 +0000 Subject: [PATCH 7/7] fix minimal build --- onnxruntime/core/framework/allocation_planner.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 39fc8af1bd15b..a13dbf4228024 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -296,6 +296,10 @@ class PlannerImpl { // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const GraphViewer& graph, const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + ORT_UNUSED_PARAMETER(graph); +#endif + *is_strided_tensor = false; #ifdef ENABLE_TRAINING // Inputs of Yields are essentially the outputs for FW partial subgraph @@ -392,8 +396,10 @@ class PlannerImpl { return true; } } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " << producer_node->Name() << " as it has external outputs"; +#endif } } } @@ -441,8 +447,10 @@ class PlannerImpl { *is_strided_tensor = true; return true; } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " << producer_node->Name() << " as it has external outputs."; +#endif } } } @@ -1272,8 +1280,10 @@ class PlannerImpl { } } } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " << producer_node->Name() << " as it has external outputs"; +#endif } } }