From 86ff8b04623bfc4fcc6991f36afe45886ce11f6a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 19 Nov 2024 17:38:10 +0000 Subject: [PATCH] update --- .../tensorrt/tensorrt_provider_options.h | 6 +- .../tensorrt/tensorrt_execution_provider.cc | 40 +++++++++---- .../tensorrt/tensorrt_execution_provider.h | 1 + .../tensorrt_execution_provider_info.cc | 6 ++ .../tensorrt_execution_provider_info.h | 1 + .../tensorrt/tensorrt_provider_factory.cc | 1 + .../core/session/provider_bridge_ort.cc | 1 + .../python/onnxruntime_pybind_state.cc | 9 ++- .../providers/tensorrt/tensorrt_basic_test.cc | 60 +++++++++++++++++++ 9 files changed, 109 insertions(+), 16 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index ec9be80a63574..bf24fbaf8a45b 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -85,7 +85,7 @@ struct OrtTensorRTProviderOptionsV2 { // can be updated using: UpdateTensorRTProviderOptionsWithValue size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream" // can be updated using: UpdateTensorRTProviderOptionsWithValue - - const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix - int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true + const char* trt_op_types_to_exclude{nullptr}; // Exclude specific ops from running on TRT. }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a7330c9bcf13d..1fe2c8146a817 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1379,6 +1379,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_opt_shapes = info.profile_opt_shapes; cuda_graph_enable_ = info.cuda_graph_enable; engine_hw_compatible_ = info.engine_hw_compatible; + op_types_to_exclude_ = info.op_types_to_exclude; + } else { try { const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -1565,6 +1567,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); } + const std::string op_types_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOpTypesToExclude); + if (!op_types_to_exclude_env.empty()) { + op_types_to_exclude_ = op_types_to_exclude_env; + } + } catch (const std::invalid_argument& ex) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); } catch (const std::out_of_range& ex) { @@ -1766,7 +1773,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_ep_context_embed_mode: " << ep_context_embed_mode_ << ", trt_cache_prefix: " << cache_prefix_ << ", trt_engine_hw_compatible: " << engine_hw_compatible_ - << ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_; + << ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ + << ", trt_op_types_to_exclude: " << op_types_to_exclude_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -2434,6 +2442,18 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& return cycle_detected; } +std::set GetExcludedNodeSet(std::string node_list_to_exclude) { + std::set set; + if (!node_list_to_exclude.empty()) { + std::stringstream node_list(node_list_to_exclude); + std::string node; + while (std::getline(node_list, node, ',')) { + set.insert(node); + } + } + return set; +} + std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { @@ -2466,17 +2486,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - std::set exclude_ops_set; + std::set exclude_ops_set = GetExcludedNodeSet(op_types_to_exclude_); - /* - * There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. - * TRT EP automatically excludes DDS ops from running on TRT. - */ - if (trt_version_ >= 100000 && trt_version_ < 110000) { - exclude_ops_set.insert("NonMaxSuppression"); - exclude_ops_set.insert("NonZero"); - exclude_ops_set.insert("RoiAlign"); - LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable"; + // Print excluded nodes, if any. + std::set::iterator it; + for (it = exclude_ops_set.begin(); it != exclude_ops_set.end(); ++it) { + std::string op = *it; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude \"" << op << "\" from running on TRT, if any."; } SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; @@ -2485,7 +2501,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. - * 2. It's a DDS op. + * 2. It's in the exlucded ops set which specified by trt_op_types_to_exclude. */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 9e3a03417d917..9d8af02ba10e6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -57,6 +57,7 @@ static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL"; static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE"; static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE"; static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX"; +static const std::string kOpTypesToExclude = "ORT_TENSORRT_OP_TYPES_TO_EXCLUDE"; // Old env variable for backward compatibility static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; } // namespace tensorrt_env_vars diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 63b6d35072290..bc0d00ec6791f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -56,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; constexpr const char* kONNXBytestream = "trt_onnx_bytestream"; constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size"; +constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude"; } // namespace provider_option_names } // namespace tensorrt @@ -134,6 +135,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; @@ -188,6 +190,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, + {tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)}, }; return options; } @@ -206,6 +209,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); + const std::string kOpTypesToExclude_ = empty_if_null(info.trt_op_types_to_exclude); const ProviderOptions options{ {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, @@ -251,6 +255,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast(info.trt_onnx_bytestream))}, {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)}, + {tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_}, }; return options; } @@ -355,5 +360,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream; trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size; + trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index fa1bbd6d3d7e6..a6fda73b55529 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -60,6 +60,7 @@ struct TensorrtExecutionProviderInfo { int ep_context_embed_mode{0}; std::string engine_cache_prefix{""}; bool engine_hw_compatible{false}; + std::string op_types_to_exclude{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index e242788ff389a..e4521ddd18ade 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -118,6 +118,7 @@ struct Tensorrt_Provider : Provider { info.engine_hw_compatible = options.trt_engine_hw_compatible != 0; info.onnx_bytestream = options.trt_onnx_bytestream; info.onnx_bytestream_size = options.trt_onnx_bytestream_size; + info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude; return std::make_shared(info); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ecbdd31160c7a..bf1f874d89361 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2410,6 +2410,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor delete[] ptr->trt_profile_opt_shapes; delete[] ptr->trt_ep_context_file_path; delete[] ptr->trt_onnx_model_folder_path; + delete[] ptr->trt_op_types_to_exclude; } std::unique_ptr p(ptr); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 4d9583be0ef0f..0479059cd59af 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -526,7 +526,7 @@ std::unique_ptr CreateExecutionProviderInstance( // and TRT EP instance, so it won't be released.) std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path, - onnx_model_folder_path; + onnx_model_folder_path, trt_op_types_to_exclude; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -824,6 +824,13 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_op_types_to_exclude") { + if (!option.second.empty()) { + trt_op_types_to_exclude = option.second; + params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_op_types_to_exclude' should be a string.\n"); + } } else { ORT_THROW("Invalid TensorRT EP option: ", option.first); } diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 63327a028c6f4..b4199548ae515 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -612,6 +612,66 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { RunSession(session_object9, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); } +TEST(TensorrtExecutionProviderTest, ExcludeOpsTest) { + /* The mnist.onnx looks like this: + * Conv + * | + * Add + * . + * . + * | + * MaxPool + * | + * . + * . + * MaxPool + * | + * Reshape + * | + * MatMul + * . + * . + * + */ + PathString model_name = ORT_TSTR("testdata/mnist.onnx"); + SessionOptions so; + so.session_logid = "TensorrtExecutionProviderExcludeOpsTest"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + InferenceSession session_object{so, GetEnvironment()}; + auto cuda_provider = DefaultCudaExecutionProvider(); + auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1]; + std::vector dims_op_x = {1, 1, 28, 28}; + std::vector values_op_x(784, 1.0f); // 784=1*1*28*28 + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_x); + NameMLValMap feeds; + feeds.insert(std::make_pair("Input3", ml_value_x)); + + // prepare outputs + std::vector output_names; + output_names.push_back("Plus214_Output_0"); + std::vector fetches; + + RemoveCachesByType("./", ".engine"); + OrtTensorRTProviderOptionsV2 params; + params.trt_engine_cache_enable = 1; + params.trt_op_types_to_exclude = "MaxPool"; + std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + auto status = session_object.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()); + status = session_object.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + std::vector engine_files; + engine_files = GetCachesByType("./", ".engine"); + // The whole graph should be partitioned into 3 TRT subgraphs and 2 cpu nodes + ASSERT_EQ(engine_files.size(), 3); +} + TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { PathString model_name = ORT_TSTR("testdata/trt_plugin_custom_op_test.onnx"); SessionOptions so;