diff --git a/include/onnxruntime/core/framework/kernel_def_builder.h b/include/onnxruntime/core/framework/kernel_def_builder.h index de5a3a52f5be7..baccbe1929ac4 100644 --- a/include/onnxruntime/core/framework/kernel_def_builder.h +++ b/include/onnxruntime/core/framework/kernel_def_builder.h @@ -253,6 +253,16 @@ class KernelDefBuilder { /** Specify that this kernel's output buffers are passed from external, i.e. not created or managed by ORT's memory allocator. + + The OrtValue set as external outputs, must be safe to release as long as the OrtValue's reference + count reaches zero in ORT's allocation/deallocation plan. We usually create such an OrtValue + following flows: torch tensors --> to dlpack tensors (destructor will release a view of original torch tensor, + instead of releasing original torch tensor) --> to OrtValue. + + When the OrtValue is not needed in the graph, then it will be released after calling the attached + destructor. The destructor will release the view of the original torch tensor, instead of releasing the original + torch tensor. This is to make sure the original torch tensor can still be okay to use externally, + even after OrtValue is released in the graph. (Recalled this OrtValue is also not reused by ORT). */ KernelDefBuilder& ExternalOutputs() { kernel_def_->external_outputs_ = true; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index fec4e6e87edc3..7747058f0d0aa 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -294,8 +294,12 @@ class PlannerImpl { #endif // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. - bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, - bool* is_strided_tensor) { + bool FindReusableInput(const GraphViewer& graph, const onnxruntime::Node& node, int output_arg_num, + OrtValueIndex* reusable_input, bool* is_strided_tensor) { +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + ORT_UNUSED_PARAMETER(graph); +#endif + *is_strided_tensor = false; #ifdef ENABLE_TRAINING // Inputs of Yields are essentially the outputs for FW partial subgraph @@ -326,6 +330,16 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, + // we cannot. + const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; + } +#endif *reusable_input = Index(p_input_arg->Name()); return true; } @@ -342,6 +356,16 @@ class PlannerImpl { if (alias_input_index >= 0 && static_cast(alias_input_index) < input_args.size()) { auto p_input_arg = input_args[alias_input_index]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, + // we cannot. + const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; + } +#endif *reusable_input = Index(p_input_arg->Name()); return true; } @@ -357,10 +381,25 @@ class PlannerImpl { auto input_arg_index = Index(p_input_arg->Name()); auto original = Buffer(input_arg_index); if (1 == UseCount(original)) { - if (SameSize(*p_input_arg, *p_output_arg)) { - // we can reuse this input since it is its last use and permitted for in-place update - *reusable_input = input_arg_index; // or original; both should be okay - return true; + bool need_skip = false; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // If the producer node does not have external output, then we can reuse the input buffer; Otherwise, + // we cannot. + const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); + need_skip = producer_node && HasExternalOutputs(*producer_node); +#endif + + if (!need_skip) { + if (SameSize(*p_input_arg, *p_output_arg)) { + // we can reuse this input since it is its last use and permitted for in-place update + *reusable_input = input_arg_index; // or original; both should be okay + return true; + } + } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; +#endif } } } @@ -395,10 +434,24 @@ class PlannerImpl { break; } } + if (can_strided) { - *reusable_input = Index(input_args[pair.first]->Name()); - *is_strided_tensor = true; - return true; + bool need_skip = false; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + const Node* producer_node = graph.GetProducerNode(input_args[pair.first]->Name()); + need_skip = producer_node && HasExternalOutputs(*producer_node); +#endif + + if (!need_skip) { + *reusable_input = Index(input_args[pair.first]->Name()); + *is_strided_tensor = true; + return true; + } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " + << producer_node->Name() << " as it has external outputs."; +#endif + } } } } @@ -613,13 +666,12 @@ class PlannerImpl { auto outputs = pnode->OutputDefs(); auto num_outputs = outputs.size(); - bool has_external_outputs = HasExternalOutputs(*pnode); for (size_t i = 0; i < num_outputs; ++i) { auto* node_output = outputs[i]; if (!node_output->Exists()) continue; OrtValueIndex index = Index(node_output->Name()); - // Ensures external outputs will not be reused. - UseCount(index) += (has_external_outputs ? 2 : 1); + + UseCount(index) += 1; } } } @@ -1088,7 +1140,8 @@ class PlannerImpl { int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); auto origin = AllocPlan(value_idx).reused_buffer; - if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate || + AllocPlan(origin).alloc_kind == AllocKind::kAllocatedExternally) { // add current node as consumer for origin buffer value_consumer_map[origin].insert(node_index); } @@ -1140,6 +1193,17 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // If the producer node does not has external outputs, we can reuse the input buffer; + // Otherwise, we cannot reuse the buffer. + const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; + } +#endif + OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() /*&& allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate*/ @@ -1172,6 +1236,17 @@ class PlannerImpl { auto p_input_arg = input_args[alias_input_index]; if (p_input_arg->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // If the producer node does not has external outputs, we can reuse the input buffer; + // Otherwise, we cannot reuse the buffer. + const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); + if (producer_node && HasExternalOutputs(*producer_node)) { + LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; + } +#endif + OrtValueIndex reusable_input{}; if (value_map.GetIdx(p_input_arg->Name(), reusable_input).IsOK() && allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { @@ -1193,16 +1268,31 @@ class PlannerImpl { if ((0 <= pair.first) && (static_cast(pair.first) < input_args.size())) { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { - OrtValueIndex input_arg_index{}; - if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && - allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { - allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; - allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); - reused.insert(input_arg_index); + bool need_skip = false; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // If the producer node does not has external outputs, we can reuse the input buffer; + // Otherwise, we cannot reuse the buffer. + const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); + need_skip = producer_node && HasExternalOutputs(*producer_node); +#endif + + if (!need_skip) { + OrtValueIndex input_arg_index{}; + if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && + allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; + allocation_plan[output_idx_global].reused_buffer = input_arg_index; + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); + reused.insert(input_arg_index); + } } + } else { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; +#endif } } } @@ -1448,7 +1538,8 @@ class PlannerImpl { } } } else if (!context_->IsParallelExecutionEnabled() && - FindReusableInput(*pnode, static_cast(output_arg_def_index), &reused, &is_strided_tensor)) { + FindReusableInput(graph_viewer_, *pnode, static_cast(output_arg_def_index), + &reused, &is_strided_tensor)) { // Re-using inputs is applicable for tensors, sequence tensors, // and optional types if the kernel has marked certain inputs as // possible candidates for re-use @@ -1483,6 +1574,14 @@ class PlannerImpl { // determine if inputs of *pnode can be freed: for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + const Node* producer_node = graph_viewer_.GetProducerNode(node_input->Name()); + // Skip if the producer node has external outputs. + if (producer_node != nullptr && HasExternalOutputs(*producer_node)) { + continue; + } +#endif + auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. @@ -1495,6 +1594,14 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + const Node* producer_node = graph_viewer_.GetProducerNode(node_input->Name()); + // Skip if the producer node has external outputs. + if (producer_node != nullptr && HasExternalOutputs(*producer_node)) { + continue; + } +#endif + auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. @@ -1505,15 +1612,17 @@ class PlannerImpl { } } - // determine if any outputs of *pnode are unused and can be freed: - for (auto node_output : pnode->OutputDefs()) { - if (node_output->Exists()) { - auto& sym = node_output->Name(); - auto original = Buffer(Index(sym)); - // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. - // See comments in the OrtValueInfo definition. - if (0 == DecrementUseCount(original)) { - freelist_.push_front(FreeBufferInfo(original, program_counter)); + if (!HasExternalOutputs(*pnode)) { + // determine if any outputs of *pnode are unused and can be freed: + for (auto node_output : pnode->OutputDefs()) { + if (node_output->Exists()) { + auto& sym = node_output->Name(); + auto original = Buffer(Index(sym)); + // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. + // See comments in the OrtValueInfo definition. + if (0 == DecrementUseCount(original)) { + freelist_.push_front(FreeBufferInfo(original, program_counter)); + } } } } @@ -1537,7 +1646,8 @@ class PlannerImpl { if (!node_output->Exists()) continue; // OrtValue index of the considered output NodeArg. const auto current = Index(node_output->Name()); - if (AllocPlan(current).alloc_kind == AllocKind::kAllocate) { + if (AllocPlan(current).alloc_kind == AllocKind::kAllocate || + AllocPlan(current).alloc_kind == AllocKind::kAllocatedExternally) { AllocPlan(current).program_counter.AddStart(program_counter); } } @@ -1585,8 +1695,7 @@ class PlannerImpl { // OrtValue index of the considered output NodeArg. const auto current = Index(node_output->Name()); AllocPlan(current).life_interval.first = program_counter; - if (AllocPlan(current).alloc_kind == AllocKind::kAllocatedExternally || - AllocPlan(current).alloc_kind == AllocKind::kAllocateOutput) { + if (AllocPlan(current).alloc_kind == AllocKind::kAllocateOutput) { AllocPlan(current).life_interval.second = execution_plan.size(); } // determine if inputs of *pnode can be freed: @@ -1694,9 +1803,9 @@ class PlannerImpl { // Convert information in execution plan and memory reuse plan into release plan Status GenerateDeallocationPlan() { // 1. build the consumer list for each value - std::vector> value_consumers; + std::vector> ortvalue_to_consumers_map; int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumers.resize(num_ml_values); + ortvalue_to_consumers_map.resize(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1710,9 +1819,10 @@ class PlannerImpl { int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); auto origin = AllocPlan(value_idx).reused_buffer; - if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate || + AllocPlan(origin).alloc_kind == AllocKind::kAllocatedExternally) { // add current node as consumer for origin buffer - value_consumers[origin].push_back(node_index); + ortvalue_to_consumers_map[origin].push_back(node_index); } } return Status::OK(); @@ -1728,8 +1838,8 @@ class PlannerImpl { plan_.node_release_list[node_index].push_back(release_action_idx); }; plan_.node_release_list.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); - for (size_t i = 0; i < value_consumers.size(); ++i) { - if (!value_consumers[i].empty()) { + for (size_t i = 0; i < ortvalue_to_consumers_map.size(); ++i) { + if (!ortvalue_to_consumers_map[i].empty()) { plan_.release_actions.push_back(SequentialExecutionPlan::ReleaseAction{i, 0}); auto release_action_idx = plan_.release_actions.size() - 1; // check whether we can static determine where to release. @@ -1737,19 +1847,19 @@ class PlannerImpl { // we actually can do better if all the consumers depends on the last consumer. // will optimize it later bool is_all_consumer_same_stream = true; - auto stream_idx = plan_.node_stream_map_[value_consumers[i][0]]; - for (size_t j = 1; j < value_consumers[i].size(); ++j) { - if (plan_.node_stream_map_[value_consumers[i][j]] != stream_idx) { + auto stream_idx = plan_.node_stream_map_[ortvalue_to_consumers_map[i][0]]; + for (size_t j = 1; j < ortvalue_to_consumers_map[i].size(); ++j) { + if (plan_.node_stream_map_[ortvalue_to_consumers_map[i][j]] != stream_idx) { is_all_consumer_same_stream = false; break; } } if (is_all_consumer_same_stream) { // all the consumers are on the same stream, so the first element is the last consumer int the stream. - process_consumer(release_action_idx, value_consumers[i][0]); + process_consumer(release_action_idx, ortvalue_to_consumers_map[i][0]); } else { // can't static determin, add all the consumers, we will use ref count in release action - for (auto node_index : value_consumers[i]) { + for (auto node_index : ortvalue_to_consumers_map[i]) { process_consumer(release_action_idx, node_index); } } diff --git a/onnxruntime/core/framework/partial_graph_execution_state.cc b/onnxruntime/core/framework/partial_graph_execution_state.cc index 8a749008945c0..a053634adbe35 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.cc +++ b/onnxruntime/core/framework/partial_graph_execution_state.cc @@ -59,8 +59,10 @@ DeviceStreamCollection* PartialGraphExecutionState::GetDeviceStreamCollection(co return device_stream_collection_.get(); } -StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span& feed_mlvalue_idxs, gsl::span& feeds, - gsl::span& fetch_mlvalue_idxs, std::vector& fetches, +StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span& feed_mlvalue_idxs, + std::vector& feeds, + gsl::span& fetch_mlvalue_idxs, + std::vector& fetches, const std::unordered_map& fetch_allocators, const SessionState& session_state, const logging::Logger& sess_logger, @@ -81,7 +83,7 @@ StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::spa execution_plan->num_barriers, device_streams, feed_mlvalue_idxs, - feeds, + std::move(feeds), fetch_mlvalue_idxs, fetches, fetch_allocators, @@ -89,7 +91,7 @@ StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::spa // partial executor in training can only be run with single thread true); } else { - execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, feeds); + execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, std::move(feeds)); execution_context_->GetExecutionFrame().UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors()); execution_context_->SetLogger(sess_logger); } diff --git a/onnxruntime/core/framework/partial_graph_execution_state.h b/onnxruntime/core/framework/partial_graph_execution_state.h index 80f7e52a9cf68..adfb5f8f5fca7 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.h +++ b/onnxruntime/core/framework/partial_graph_execution_state.h @@ -30,7 +30,8 @@ struct PartialGraphExecutionState { ProgramRegion& GetProgramRegions(const SessionState& session_state); - StreamExecutionContext& GetExecutionContext(gsl::span& feed_mlvalue_idxs, gsl::span& feeds, + // Be noted: feeds will be std::move to ctx, so it will be empty after this function. + StreamExecutionContext& GetExecutionContext(gsl::span& feed_mlvalue_idxs, std::vector& feeds, gsl::span& fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, const SessionState& session_state, diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 0cc7294a46495..a374e381a2b0e 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -615,7 +615,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #ifdef ENABLE_TRAINING onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span feed_mlvalue_idxs, - gsl::span feeds, gsl::span fetch_mlvalue_idxs, + std::vector& feeds, + gsl::span fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, @@ -626,8 +627,10 @@ onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl PartialGraphExecutionState& state, const OrtValueCachePtr& cache, int32_t partial_graph_index) { + // Be noted: feeds will be std::move to ctx, so it will be empty after this function. auto& ctx = state.GetExecutionContext(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches, fetch_allocators, session_state, logger, device_streams); + auto* plan = session_state.GetExecutionPlan(); ctx.SetCurrentRange(&state.GetProgramRegions(session_state)); diff --git a/onnxruntime/core/framework/sequential_executor.h b/onnxruntime/core/framework/sequential_executor.h index 9bef4be34fd80..c22ccb041afdf 100644 --- a/onnxruntime/core/framework/sequential_executor.h +++ b/onnxruntime/core/framework/sequential_executor.h @@ -50,7 +50,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #ifdef ENABLE_TRAINING onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span feed_mlvalue_idxs, - gsl::span feeds, gsl::span fetch_mlvalue_idxs, + std::vector& feeds, gsl::span fetch_mlvalue_idxs, std::vector& fetches, const std::unordered_map& fetch_allocators, const logging::Logger& logger, diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 0c4d498fae9e0..9c282210d2169 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -784,7 +784,7 @@ common::Status ExecuteGraph(const SessionState& session_state, #ifdef ENABLE_TRAINING common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, - gsl::span feeds, std::vector& fetches, + std::vector& feeds, std::vector& fetches, const logging::Logger& logger, PartialGraphExecutionState& state, const OrtValueCachePtr& cache, const bool& terminate_flag, DeviceStreamCollection* device_stream_collection, @@ -882,7 +882,7 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF } common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, - gsl::span feeds, std::vector& fetches, + std::vector& feeds, std::vector& fetches, const logging::Logger& logger, PartialGraphExecutionState& state, const OrtValueCachePtr& cache, const bool& terminate_flag, int32_t partial_graph_index, diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index fd14eeeb33d27..17cf9671b70eb 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -100,7 +100,7 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, - gsl::span feeds, std::vector& fetches, + std::vector& feeds, std::vector& fetches, const logging::Logger& logger, PartialGraphExecutionState& state, const OrtValueCachePtr& cache, const bool& terminate_flag, diff --git a/onnxruntime/core/optimizer/gemm_sum_fusion.cc b/onnxruntime/core/optimizer/gemm_sum_fusion.cc index 3f2c1ac046105..be3c90a822fe2 100644 --- a/onnxruntime/core/optimizer/gemm_sum_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_sum_fusion.cc @@ -36,7 +36,7 @@ Status GemmSumFusion::Apply(Graph& graph, Node& gemm_node, RewriteRuleEffect& mo std::vector new_gemm_output_defs = sum_node.MutableOutputDefs(); ORT_ENFORCE(new_gemm_output_defs.size() == 1); - Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_sum_transformed"), + Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmSumFusion/"), gemm_node.OpType(), "Fused Gemm with Sum", new_gemm_input_defs, diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index ad77f4d143d31..2a4916ccb324a 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -106,7 +106,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - Node& gemm_node = graph.AddNode(graph.GenerateNodeName("gemm"), + Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), "Gemm", "fused Matmul and Add " + add_node.OpType(), gemm_input_defs, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index d5f72df4e07d3..e8d33bc154b0c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2349,7 +2349,7 @@ common::Status InferenceSession::ValidateOutputs(gsl::span ou #ifdef ENABLE_TRAINING Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, - const std::vector& feeds, + std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 48f6d73b077cb..204b24974ff50 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -377,8 +377,8 @@ class InferenceSession { /** * Partially run a pre-loaded and pre-intialized model. * @param run_options run options. - * @param feeds inputs owned by client code and should not be changed during - * execution of this function. + * @param mutable_feeds inputs owned by client code and will be released as long as the feeds be set in session states. + * Then the feeds will purely managed in the session states. * @param fetches outputs produced after the executin of this function. * @param state State of the graph needed to resume partial graph run. * @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device @@ -388,7 +388,7 @@ class InferenceSession { * @param partial_graph_index Index of the partial graph to run. */ common::Status PartialRun(onnxruntime::RunOptions& run_options, - const std::vector& feeds, + std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 9cbf80f16ee33..72eee5aca2638 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -370,8 +370,9 @@ class PlannerTest : public ::testing::Test { EXPECT_EQ(plan_->execution_plan.size(), 1U); int list_size = static_cast(plan_->node_release_list.size()); EXPECT_GT(list_size, step_number); - for (auto freed : plan_->node_release_list[step_number]) { - plan_result.insert(static_cast(freed)); + for (auto action_idx : plan_->node_release_list[step_number]) { + const size_t ortvalue_id = plan_->release_actions[action_idx].value_index; + plan_result.insert(static_cast(ortvalue_id)); } EXPECT_EQ(plan_result, expected) << "Freed items incorrect for step " << step_number; } @@ -433,7 +434,7 @@ TEST_F(PlannerTest, ChainTest) { CreatePlan(); // Expected plan: - // W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); X is returned output + // W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); Z is returned output CheckAllocKind(W, AllocKind::kAllocateStatically); CheckAllocKind(X, AllocKind::kAllocate); CheckAllocKind(B, AllocKind::kAllocate); @@ -442,8 +443,8 @@ TEST_F(PlannerTest, ChainTest) { CheckFreed(0, {}); CheckFreed(1, {}); - CheckFreed(2, {"X"}); - CheckFreed(3, {"W"}); + CheckFreed(2, {B}); + CheckFreed(3, {X}); } /* InputOutputTest: Test that: @@ -510,7 +511,7 @@ TEST_F(PlannerTest, InPlaceTest) { // check each ml-value is freed at appropriate step CheckFreed(0, {}); CheckFreed(1, {}); - CheckFreed(2, {X1}); + CheckFreed(2, {X2}); } TEST_F(PlannerTest, ExternalOutputsTest) { @@ -536,10 +537,42 @@ TEST_F(PlannerTest, ExternalOutputsTest) { CheckAllocKind(X4, AllocKind::kAllocateOutput); // check each ml-value is freed at appropriate step - // X2 will not be reused and will not be freed. X3 will be allocated and will be freed. + // X2 will not be reused but will be freed (to release the current reference). X3 will be allocated and will be freed. CheckFreed(0, {}); - CheckFreed(1, {}); - CheckFreed(2, {X1}); + CheckFreed(1, {X2}); + CheckFreed(2, {X3}); +} + +TEST_F(PlannerTest, ExternalOutputsNoReuseTest) { + // tensor variables: + std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"), X5("X5"); + + // graph structure: + AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary + AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary + AddNormalNode(X3, X4); // normal operator; X4: temporary + AddNormalNode(X4, X5); // normal operator; X5: output + + // simulate shape-inference results: + Shape shape1{"M", "N"}; + auto shape = &shape1.value; + SetShape({{X1, shape}, {X2, shape}, {X3, shape}, {X4, shape}, {X5, shape}}); + + CreatePlan(); + + // check allocation kind: + CheckAllocKind(X1, AllocKind::kPreExisting); + CheckAllocKind(X2, AllocKind::kAllocatedExternally); + CheckAllocKind(X3, AllocKind::kAllocate); // Should not be Reused. + CheckAllocKind(X4, AllocKind::kAllocate); + CheckAllocKind(X5, AllocKind::kAllocateOutput); + + // check each ml-value is freed at appropriate step + // X2 will not be reused. X3 will be allocated and will be freed. + CheckFreed(0, {}); + CheckFreed(1, {X2}); + CheckFreed(2, {X3}); + CheckFreed(3, {X4}); } #ifdef ENABLE_STRIDED_TENSORS @@ -564,9 +597,9 @@ TEST_F(PlannerTest, MayStridedTest1) { CheckAllocKind(X3, AllocKind::kAllocateOutput); // check each ml-value is freed at appropriate step - // X2 will not be reused and will not be freed. X3 will be allocated and will be freed. + // X2 will not be reused because X3 is a graph output. X3 will be allocated and will be freed. CheckFreed(0, {}); - CheckFreed(1, {X1}); + CheckFreed(1, {X2}); } TEST_F(PlannerTest, MayStridedTest2) { @@ -574,9 +607,9 @@ TEST_F(PlannerTest, MayStridedTest2) { std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"); // graph structure: - AddMayStridedOutputNode(X1, X2); - AddMayStridedInputNode(X2, X3); - AddMayStridedInputNode(X2, X4); + AddMayStridedOutputNode(X1, X2); // X2 can reuse X1, and is a strided output. + AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse. + AddMayStridedInputNode(X2, X4); // X4 is a graph output, cannot reuse. // simulate shape-inference results: Shape shape1{"M", "N"}; @@ -603,8 +636,9 @@ TEST_F(PlannerTest, MayStridedTest3) { std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"); // graph structure: - AddMayStridedOutputNode(X1, X2); - AddMayStridedInputNode(X2, X3); + AddMayStridedOutputNode(X1, X2); // X2 cannot strided reuse X1 because, + // one of X2's consumers is a node not supporting strided input. So X2 is a allocate. + AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse. AddNormalNode(X2, X4); // simulate shape-inference results: @@ -620,11 +654,27 @@ TEST_F(PlannerTest, MayStridedTest3) { CheckAllocKind(X3, AllocKind::kAllocateOutput); CheckAllocKind(X4, AllocKind::kAllocateOutput); + // Be noted: the last two nodes added can run in two different orders, we need figure out the exact order + // we planned then we know how to check the free order. + const GraphViewer& graph_viewer = GetState().GetGraphViewer(); + // Normal node index is 2. + bool does_normal_node_run_at_last = graph_viewer.GetNodesInTopologicalOrder()[2] == 2; + // check each ml-value is freed at appropriate step - // X2 will not be reused and will not be freed. X3 will be allocated and will be freed. + CheckFreed(0, {}); - CheckFreed(1, {X1}); - CheckFreed(2, {}); + + if (does_normal_node_run_at_last) { + // Normal node has node index to be 2, but it is possible that the normal node is executed after the strided node. + // Then X2 will released once the normal node is executed. + CheckFreed(1, {}); + CheckFreed(2, {X2}); + } else { + // Normal node has node index to be 2, and is executed before the strided node. + // So X2 will be released after the strided node (node index to be 1) is executed. + CheckFreed(2, {}); + CheckFreed(1, {X2}); + } } #endif @@ -637,7 +687,7 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) { // graph structure: AddNormalNode(X1, X2); // no in-place operator; X1: input; X2: temporary AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary - AddNormalNode(X3, X4); // no in-place operator; X4: temporary + AddNormalNode(X3, X4); // no in-place operator; X4: temporary (reuse X2) AddInplaceNode(X4, X5); // may-in-place operator; X5 output // simulate shape-inference results: @@ -659,8 +709,8 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) { // check each ml-value is freed at appropriate step CheckFreed(0, {}); CheckFreed(1, {}); - CheckFreed(2, {X2}); - CheckFreed(3, {X1}); + CheckFreed(2, {X3}); + CheckFreed(3, {X2}); } // Test operator<< to output details of an allocation & execution plan. diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index cc8d341dbc084..cc3f6d1ff82ee 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -60,7 +60,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session, TrainingAgent::~TrainingAgent() = default; -common::Status TrainingAgent::RunForward(const std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunForward(std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, const OrtValueCachePtr& cache) { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) inference_session_.GetMemoryProfiler().GetMemoryInfo().SetIteration(profile_step_); @@ -74,7 +74,7 @@ common::Status TrainingAgent::RunForward(const std::vector& feeds, std return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache, partial_graph_index); } -common::Status TrainingAgent::RunBackward(const std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunBackward(std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state) { state.SetProgramCounterStart(fw_program_counter_end_); state.SetProgramCounterEnd(bw_program_counter_end_); @@ -82,7 +82,7 @@ common::Status TrainingAgent::RunBackward(const std::vector& feeds, st return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr, partial_graph_index); } -common::Status TrainingAgent::RunCore(const std::vector& feeds, std::vector& fetches, +common::Status TrainingAgent::RunCore(std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, const OrtValueCachePtr& cache, int32_t partial_graph_index) { auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size(); diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h index 8d88a6df39352..2b74e56ae60e2 100644 --- a/orttraining/orttraining/core/agent/training_agent.h +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -32,14 +32,14 @@ class TrainingAgent { int local_rank = 0); ~TrainingAgent(); // For ORTModule.forward() - [[nodiscard]] common::Status RunForward(const std::vector& feeds, std::vector& fetches, + [[nodiscard]] common::Status RunForward(std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state, const OrtValueCachePtr& cache); // For ORTModule.backward() - [[nodiscard]] common::Status RunBackward(const std::vector& feeds, std::vector& fetches, + [[nodiscard]] common::Status RunBackward(std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state); - [[nodiscard]] common::Status RunCore(const std::vector& feeds, std::vector& fetches, + [[nodiscard]] common::Status RunCore(std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager, const OrtValueCachePtr& cache, int32_t partial_graph_index); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 5ea60102f3ef8..e2616e1b441f7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -423,14 +423,18 @@ void addObjectMethodsForTraining(py::module& m) { return std::make_unique(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info, bw_fetches_names, bw_outputs_device_info, local_rank); })) - .def("run_forward", [](TrainingAgent* agent, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void { - Status status = agent->RunForward(feeds, fetches, *state, cache); + .def("run_forward", [](TrainingAgent* agent, std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void { + // Feed is passed in mutable way, to allow the internal logic to release the feeds as long as it is not needed. + // Otherwise, the feeds will be released after the forward pass, which hold some unnecessary memory. + Status status = agent->RunForward(mutable_feeds, fetches, *state, cache); if (!status.IsOK()) { throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage()); } }) - .def("run_backward", [](TrainingAgent* agent, const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState* state) -> void { - Status status = agent->RunBackward(feeds, fetches, *state); + .def("run_backward", [](TrainingAgent* agent, std::vector& mutable_feeds, std::vector& fetches, PartialGraphExecutionState* state) -> void { + // Feed is passed in mutable way, to allow the internal logic to release the feeds as long as it is not needed. + // Otherwise, the feeds will be released after the forward pass, which hold some unnecessary memory. + Status status = agent->RunBackward(mutable_feeds, fetches, *state); if (!status.IsOK()) { throw std::runtime_error("Error in backward pass execution: " + status.ErrorMessage()); }