From e9042ef6dbb559ef8d426cb60476134da6ab8f7e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 21:06:35 +0000 Subject: [PATCH] update --- .../tensorrt/tensorrt_execution_provider.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 1a5cf6ababdfc..9ed554345ca23 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1960,7 +1960,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph // Find inputs and outputs of the subgraph std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; int output_order = 0; @@ -2052,12 +2052,25 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); - // Sort inputs and outputs by the order they were added std::multimap inputs, outputs; + + // Get the input order of the original graph + int order = 0; + for (const auto* input : graph.GetInputs()) { + original_inputs[input] = order++; + } + + // input order needs to be consistent with original graph's input order for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); + const auto& iter = original_inputs.find(it->first); + if (iter != original_inputs.end()) { + inputs.insert(std::pair(iter->second, iter->first)); + } else { + inputs.insert(std::pair(it->second, it->first)); + } } + // Sort outputs by the order they were added for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { outputs.insert(std::pair(it->second, it->first)); }