@@ -655,9 +655,9 @@ void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fus
655
655
}
656
656
}
657
657
658
- bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext (std::string fused_node, std::unique_ptr<nvinfer1::IExecutionContext> context) {
658
+ bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext (std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context) {
659
659
if (!context) {
660
- context = std::make_unique<nvinfer1::IExecutionContext> ();
660
+ context = tensorrt_ptr::unique_pointer_exec_ctx ();
661
661
}
662
662
trt_context_map_[fused_node] = std::move (context);
663
663
@@ -758,11 +758,11 @@ bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string f
758
758
nvinfer1::IExecutionContext& NvExecutionProvider::PerThreadContext::GetTensorRTContext (std::string fused_node) {
759
759
auto it = trt_context_map_.find (fused_node);
760
760
if (it != trt_context_map_.end ()) {
761
- return *(it->second ); // dereference shared pointer
761
+ return *(it->second . get () ); // dereference shared pointer
762
762
}
763
- auto context = std::make_unique<nvinfer1::IExecutionContext> ();
763
+ auto context = tensorrt_ptr::unique_pointer_exec_ctx ();
764
764
trt_context_map_[fused_node] = std::move (context);
765
- return *(trt_context_map_[fused_node]); // dereference shared pointer
765
+ return *(trt_context_map_[fused_node]. get () ); // dereference shared pointer
766
766
}
767
767
768
768
void NvExecutionProvider::ReleasePerThreadContext () const {
@@ -871,6 +871,20 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
871
871
max_shared_mem_size_ = info.max_shared_mem_size ;
872
872
dump_subgraphs_ = info.dump_subgraphs ;
873
873
weight_stripped_engine_enable_ = info.weight_stripped_engine_enable ;
874
+ // make runtime cache path absolute and create directory if it doesn't exist
875
+ if (!info.runtime_cache_path .empty ()) {
876
+ std::filesystem::path p (info.runtime_cache_path );
877
+ std::filesystem::path abs_path = std::filesystem::absolute (p);
878
+ const auto & env = GetDefaultEnv ();
879
+ auto status = env.CreateFolder (abs_path.string ());
880
+ if (!status.IsOK ()) {
881
+ LOGS_DEFAULT (WARNING) << " [NvTensorRTRTX EP] The runtime cache directory could not be created at: " << abs_path
882
+ << " . Runtime cache is disabled." ;
883
+ } else {
884
+ runtime_cache_ = abs_path;
885
+ }
886
+ }
887
+
874
888
onnx_model_folder_path_ = info.onnx_model_folder_path ;
875
889
onnx_model_bytestream_ = info.onnx_bytestream ;
876
890
onnx_model_bytestream_size_ = info.onnx_bytestream_size ;
@@ -1054,7 +1068,8 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
1054
1068
<< " , nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
1055
1069
<< " , nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_
1056
1070
<< " , nv_use_external_data_initializer_: " << use_external_data_initializer_
1057
- << " , nv_op_types_to_exclude: " << op_types_to_exclude_;
1071
+ << " , nv_op_types_to_exclude: " << op_types_to_exclude_
1072
+ << " , nv_runtime_cache_path: " << runtime_cache_;
1058
1073
}
1059
1074
1060
1075
Status NvExecutionProvider::Sync () const {
@@ -2637,8 +2652,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
2637
2652
//
2638
2653
// Otherwise engine will be handled at inference time.
2639
2654
std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
2640
- std::unique_ptr<nvinfer1::IExecutionContext> trt_context;
2655
+ tensorrt_ptr::unique_pointer_exec_ctx trt_context;
2656
+ std::unique_ptr<nvinfer1::IRuntimeCache> trt_runtime_cache;
2641
2657
std::unique_ptr<nvinfer1::IRuntimeConfig> trt_runtime_config;
2658
+ std::string runtime_cache_file = " " ;
2642
2659
2643
2660
// Generate file name for dumping ep context model
2644
2661
if (dump_ep_context_model_ && ctx_model_path_.empty ()) {
@@ -2667,6 +2684,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
2667
2684
trt_runtime_config->setDynamicShapesKernelSpecializationStrategy (nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER );
2668
2685
}
2669
2686
trt_runtime_config->setExecutionContextAllocationStrategy (nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED );
2687
+ if (!runtime_cache_.empty ()) {
2688
+ runtime_cache_file = (runtime_cache_ / fused_node.Name ()).string ();
2689
+ trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache ());
2690
+ auto cache_data = file_utils::ReadFile (runtime_cache_file);
2691
+ if (!trt_runtime_cache->deserialize (cache_data.data (), cache_data.size ())) {
2692
+ trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache ());
2693
+ LOGS_DEFAULT (INFO) << " TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl;
2694
+ }
2695
+ if (!trt_runtime_config->setRuntimeCache (*trt_runtime_cache)) {
2696
+ LOGS_DEFAULT (INFO) << " TensorRT RTX failed to set the runtime cache" << std::endl;
2697
+ }
2698
+ }
2670
2699
2671
2700
if (detailed_build_log_) {
2672
2701
auto engine_build_stop = std::chrono::steady_clock::now ();
@@ -2727,7 +2756,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
2727
2756
// Build context
2728
2757
// Note: Creating an execution context from an engine is thread safe per TRT doc
2729
2758
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
2730
- trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext (trt_runtime_config.get ()));
2759
+ trt_context = tensorrt_ptr::unique_pointer_exec_ctx (
2760
+ trt_engine->createExecutionContext (trt_runtime_config.get ()),
2761
+ tensorrt_ptr::IExecutionContextDeleter (runtime_cache_file, std::move (trt_runtime_cache)));
2731
2762
if (!trt_context) {
2732
2763
return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2733
2764
" NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name ());
@@ -3008,7 +3039,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
3008
3039
std::unordered_map<std::string, size_t >& output_map,
3009
3040
std::vector<NodeComputeInfo>& node_compute_funcs) {
3010
3041
std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
3011
- std::unique_ptr<nvinfer1::IExecutionContext> trt_context;
3042
+ tensorrt_ptr::unique_pointer_exec_ctx trt_context;
3012
3043
std::unordered_map<std::string, size_t > input_indexes; // TRT engine input name -> ORT kernel context input index
3013
3044
std::unordered_map<std::string, size_t > output_indexes; // TRT engine output name -> ORT kernel context output index
3014
3045
std::unordered_map<std::string, size_t > output_types; // TRT engine output name -> ORT output tensor type
@@ -3030,11 +3061,33 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
3030
3061
return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL, status.ErrorMessage ());
3031
3062
}
3032
3063
3064
+ std::unique_ptr<nvinfer1::IRuntimeCache> trt_runtime_cache;
3065
+ auto trt_runtime_config = std::unique_ptr<nvinfer1::IRuntimeConfig>(trt_engine->createRuntimeConfig ());
3066
+ if (trt_runtime_config && cuda_graph_enable_) {
3067
+ trt_runtime_config->setDynamicShapesKernelSpecializationStrategy (nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER );
3068
+ }
3069
+ trt_runtime_config->setExecutionContextAllocationStrategy (nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED );
3070
+ std::string runtime_cache_file = " " ;
3071
+ if (!runtime_cache_.empty ()) {
3072
+ runtime_cache_file = (runtime_cache_ / graph_body_viewer.GetNode (node_idx)->Name ()).string ();
3073
+ trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache ());
3074
+ auto cache_data = file_utils::ReadFile (runtime_cache_file);
3075
+ if (!trt_runtime_cache->deserialize (cache_data.data (), cache_data.size ())) {
3076
+ trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache ());
3077
+ LOGS_DEFAULT (INFO) << " TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl;
3078
+ }
3079
+ if (!trt_runtime_config->setRuntimeCache (*trt_runtime_cache)) {
3080
+ LOGS_DEFAULT (INFO) << " TensorRT RTX failed to set the runtime cache" << std::endl;
3081
+ }
3082
+ }
3083
+
3033
3084
// Build context
3034
3085
//
3035
3086
// Note: Creating an execution context from an engine is thread safe per TRT doc
3036
3087
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
3037
- trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext (nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED ));
3088
+ trt_context = tensorrt_ptr::unique_pointer_exec_ctx (
3089
+ trt_engine->createExecutionContext (trt_runtime_config.get ()),
3090
+ tensorrt_ptr::IExecutionContextDeleter (runtime_cache_file, std::move (trt_runtime_cache)));
3038
3091
if (!trt_context) {
3039
3092
return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
3040
3093
" NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name ());
0 commit comments