From dc5b336f1cac5ced861151cc4cd431fb3cea56e9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 18:13:32 +0000 Subject: [PATCH 1/3] update --- .../tensorrt/onnx_ctx_model_helper.cc | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index ef45d6c85d6a9..9f65b05dd0ba4 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -75,17 +75,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, auto model_build = graph_viewer.CreateModel(*logger); auto& graph_build = model_build->MainGraph(); - // Get graph inputs and outputs std::vector inputs, outputs; - for (auto input : graph_viewer.GetInputs()) { - auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); - inputs.push_back(&n_input); - } - - for (auto output : graph_viewer.GetOutputs()) { - auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); - outputs.push_back(&n_output); - } // Create EP context node attributes auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode @@ -124,6 +114,22 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, // Create EP context node graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); + + // Get graph inputs and outputs + for (auto input : graph_viewer.GetInputs()) { + auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + } + + for (auto output : graph_viewer.GetOutputs()) { + auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + } + + + // Set inputs outputs explicitly to make sure the order is same as the user model + graph_build.SetInputs(inputs); + graph_build.SetOutputs(outputs); ORT_ENFORCE(graph_build.Resolve().IsOK()); // Serialize modelproto to string From b6919fbfd9d257410ac7aaf069d3365f840cab56 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 20:28:00 +0000 Subject: [PATCH 2/3] revert --- .../tensorrt/onnx_ctx_model_helper.cc | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 9f65b05dd0ba4..ef45d6c85d6a9 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -75,7 +75,17 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, auto model_build = graph_viewer.CreateModel(*logger); auto& graph_build = model_build->MainGraph(); + // Get graph inputs and outputs std::vector inputs, outputs; + for (auto input : graph_viewer.GetInputs()) { + auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + } + + for (auto output : graph_viewer.GetOutputs()) { + auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + } // Create EP context node attributes auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode @@ -114,22 +124,6 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, // Create EP context node graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); - - // Get graph inputs and outputs - for (auto input : graph_viewer.GetInputs()) { - auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); - inputs.push_back(&n_input); - } - - for (auto output : graph_viewer.GetOutputs()) { - auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); - outputs.push_back(&n_output); - } - - - // Set inputs outputs explicitly to make sure the order is same as the user model - graph_build.SetInputs(inputs); - graph_build.SetOutputs(outputs); ORT_ENFORCE(graph_build.Resolve().IsOK()); // Serialize modelproto to string From e9042ef6dbb559ef8d426cb60476134da6ab8f7e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 21:06:35 +0000 Subject: [PATCH 3/3] update --- .../tensorrt/tensorrt_execution_provider.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 1a5cf6ababdfc..9ed554345ca23 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1960,7 +1960,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph // Find inputs and outputs of the subgraph std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; int output_order = 0; @@ -2052,12 +2052,25 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); - // Sort inputs and outputs by the order they were added std::multimap inputs, outputs; + + // Get the input order of the original graph + int order = 0; + for (const auto* input : graph.GetInputs()) { + original_inputs[input] = order++; + } + + // input order needs to be consistent with original graph's input order for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); + const auto& iter = original_inputs.find(it->first); + if (iter != original_inputs.end()) { + inputs.insert(std::pair(iter->second, iter->first)); + } else { + inputs.insert(std::pair(it->second, it->first)); + } } + // Sort outputs by the order they were added for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { outputs.insert(std::pair(it->second, it->first)); }