Skip to content

Commit

Permalink
[TensorRT EP] Add supportsModelV2 (#22081)
Browse files Browse the repository at this point in the history
`supportsModel` is deprecated in TRT 10.1.
Add `supportsModelV2 `but still keep `supportsModel` as we still need to
support TRT 8.6 where `supportsModelV2 ` is not
supported.
  • Loading branch information
chilo-ms authored Sep 17, 2024
1 parent 9786909 commit 6dcdc70
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2255,23 +2255,54 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
#endif
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));

auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#endif

#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10
auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_);

// Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior.
auto num_subgraphs = trt_parser->getNbSubgraphs();
parser_nodes_list.reserve(num_subgraphs);

for (int64_t i = 0; i < num_subgraphs; ++i) {
int64_t subgraph_len = 0;
int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len);
parser_nodes_list.emplace_back();
parser_nodes_list.back().first.reserve(subgraph_len);
for (int64_t j = 0; j < subgraph_len; ++j) {
parser_nodes_list.back().first.push_back(nodes[j]);
}
parser_nodes_list.back().second = is_model_supported ? true : false;
}
#else
trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_);
#if defined(_MSC_VER)
#pragma warning(pop)
#endif

SubGraphCollection_t next_nodes_list;
const std::vector<NodeIndex>& subgraph_node_index = graph_viewer->GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination);
for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) {
for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) {
/*
* Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT TRT.
*
* TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model proto and to string buffer) to onnx-tensorrt parser.
* The node index in the list returning from onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a node index mapping table here.
*
* The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node order in the newly constructed graph (see Graph::AllocateNode() in graph.cc),
* however, once the graph is converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort.
*
* The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table:
* subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first
*
* In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4].
* With the change of using ORT's priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of 0, 1, ... n-1 for most of the cases,
* therefore subgraph_node_index as a mapping table is not needed anymore.
*
* TODO: Remove the subgraph_node_index
*/
next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]];
}
nodes_list_output.push_back(next_nodes_list[i]);
Expand Down

0 comments on commit 6dcdc70

Please sign in to comment.