Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Nov 15, 2024
1 parent b6919fb commit e9042ef
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph

// Find inputs and outputs of the subgraph
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_set<const NodeArg*> erased;
int input_order = 0;
int output_order = 0;
Expand Down Expand Up @@ -2052,12 +2052,25 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end());
fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end());

// Sort inputs and outputs by the order they were added
std::multimap<int, const NodeArg*> inputs, outputs;

// Get the input order of the original graph
int order = 0;
for (const auto* input : graph.GetInputs()) {
original_inputs[input] = order++;
}

// input order needs to be consistent with original graph's input order
for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) {
inputs.insert(std::pair<int, const NodeArg*>(it->second, it->first));
const auto& iter = original_inputs.find(it->first);
if (iter != original_inputs.end()) {
inputs.insert(std::pair<int, const NodeArg*>(iter->second, iter->first));
} else {
inputs.insert(std::pair<int, const NodeArg*>(it->second, it->first));
}
}

// Sort outputs by the order they were added
for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) {
outputs.insert(std::pair<int, const NodeArg*>(it->second, it->first));
}
Expand Down

0 comments on commit e9042ef

Please sign in to comment.