Skip to content

Commit 5537d33

Browse files
authored
[TRT RTX EP] Add support for RTX runtime caches (microsoft#25917)
### Description Runtime caches can accelerate the JIT time when deserializing an engine of TRT RTX. Here we introduce a per engine caching in a user specified folder. The cache file will be named after the fused node name - which will also be the node name of an ep context node. @chilo-ms we would like to pick this to 1.23
1 parent daa0306 commit 5537d33

File tree

7 files changed

+230
-17
lines changed

7 files changed

+230
-17
lines changed

include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes";
3434
constexpr const char* kCudaGraphEnable = "enable_cuda_graph";
3535
constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable";
3636
constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer";
37+
constexpr const char* kRuntimeCacheFile = "nv_runtime_cache_path";
3738

3839
} // namespace provider_option_names
3940
namespace run_option_names {

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,9 @@ void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fus
655655
}
656656
}
657657

658-
bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, std::unique_ptr<nvinfer1::IExecutionContext> context) {
658+
bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context) {
659659
if (!context) {
660-
context = std::make_unique<nvinfer1::IExecutionContext>();
660+
context = tensorrt_ptr::unique_pointer_exec_ctx();
661661
}
662662
trt_context_map_[fused_node] = std::move(context);
663663

@@ -758,11 +758,11 @@ bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string f
758758
nvinfer1::IExecutionContext& NvExecutionProvider::PerThreadContext::GetTensorRTContext(std::string fused_node) {
759759
auto it = trt_context_map_.find(fused_node);
760760
if (it != trt_context_map_.end()) {
761-
return *(it->second); // dereference shared pointer
761+
return *(it->second.get()); // dereference shared pointer
762762
}
763-
auto context = std::make_unique<nvinfer1::IExecutionContext>();
763+
auto context = tensorrt_ptr::unique_pointer_exec_ctx();
764764
trt_context_map_[fused_node] = std::move(context);
765-
return *(trt_context_map_[fused_node]); // dereference shared pointer
765+
return *(trt_context_map_[fused_node].get()); // dereference shared pointer
766766
}
767767

768768
void NvExecutionProvider::ReleasePerThreadContext() const {
@@ -871,6 +871,20 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
871871
max_shared_mem_size_ = info.max_shared_mem_size;
872872
dump_subgraphs_ = info.dump_subgraphs;
873873
weight_stripped_engine_enable_ = info.weight_stripped_engine_enable;
874+
// make runtime cache path absolute and create directory if it doesn't exist
875+
if (!info.runtime_cache_path.empty()) {
876+
std::filesystem::path p(info.runtime_cache_path);
877+
std::filesystem::path abs_path = std::filesystem::absolute(p);
878+
const auto& env = GetDefaultEnv();
879+
auto status = env.CreateFolder(abs_path.string());
880+
if (!status.IsOK()) {
881+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The runtime cache directory could not be created at: " << abs_path
882+
<< ". Runtime cache is disabled.";
883+
} else {
884+
runtime_cache_ = abs_path;
885+
}
886+
}
887+
874888
onnx_model_folder_path_ = info.onnx_model_folder_path;
875889
onnx_model_bytestream_ = info.onnx_bytestream;
876890
onnx_model_bytestream_size_ = info.onnx_bytestream_size;
@@ -1054,7 +1068,8 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
10541068
<< ", nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
10551069
<< ", nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_
10561070
<< ", nv_use_external_data_initializer_: " << use_external_data_initializer_
1057-
<< ", nv_op_types_to_exclude: " << op_types_to_exclude_;
1071+
<< ", nv_op_types_to_exclude: " << op_types_to_exclude_
1072+
<< ", nv_runtime_cache_path: " << runtime_cache_;
10581073
}
10591074

10601075
Status NvExecutionProvider::Sync() const {
@@ -2637,8 +2652,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
26372652
//
26382653
// Otherwise engine will be handled at inference time.
26392654
std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
2640-
std::unique_ptr<nvinfer1::IExecutionContext> trt_context;
2655+
tensorrt_ptr::unique_pointer_exec_ctx trt_context;
2656+
std::unique_ptr<nvinfer1::IRuntimeCache> trt_runtime_cache;
26412657
std::unique_ptr<nvinfer1::IRuntimeConfig> trt_runtime_config;
2658+
std::string runtime_cache_file = "";
26422659

26432660
// Generate file name for dumping ep context model
26442661
if (dump_ep_context_model_ && ctx_model_path_.empty()) {
@@ -2667,6 +2684,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
26672684
trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER);
26682685
}
26692686
trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED);
2687+
if (!runtime_cache_.empty()) {
2688+
runtime_cache_file = (runtime_cache_ / fused_node.Name()).string();
2689+
trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache());
2690+
auto cache_data = file_utils::ReadFile(runtime_cache_file);
2691+
if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) {
2692+
trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache());
2693+
LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl;
2694+
}
2695+
if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) {
2696+
LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl;
2697+
}
2698+
}
26702699

26712700
if (detailed_build_log_) {
26722701
auto engine_build_stop = std::chrono::steady_clock::now();
@@ -2727,7 +2756,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
27272756
// Build context
27282757
// Note: Creating an execution context from an engine is thread safe per TRT doc
27292758
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
2730-
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext(trt_runtime_config.get()));
2759+
trt_context = tensorrt_ptr::unique_pointer_exec_ctx(
2760+
trt_engine->createExecutionContext(trt_runtime_config.get()),
2761+
tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache)));
27312762
if (!trt_context) {
27322763
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
27332764
"NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name());
@@ -3008,7 +3039,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
30083039
std::unordered_map<std::string, size_t>& output_map,
30093040
std::vector<NodeComputeInfo>& node_compute_funcs) {
30103041
std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
3011-
std::unique_ptr<nvinfer1::IExecutionContext> trt_context;
3042+
tensorrt_ptr::unique_pointer_exec_ctx trt_context;
30123043
std::unordered_map<std::string, size_t> input_indexes; // TRT engine input name -> ORT kernel context input index
30133044
std::unordered_map<std::string, size_t> output_indexes; // TRT engine output name -> ORT kernel context output index
30143045
std::unordered_map<std::string, size_t> output_types; // TRT engine output name -> ORT output tensor type
@@ -3030,11 +3061,33 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
30303061
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
30313062
}
30323063

3064+
std::unique_ptr<nvinfer1::IRuntimeCache> trt_runtime_cache;
3065+
auto trt_runtime_config = std::unique_ptr<nvinfer1::IRuntimeConfig>(trt_engine->createRuntimeConfig());
3066+
if (trt_runtime_config && cuda_graph_enable_) {
3067+
trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER);
3068+
}
3069+
trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED);
3070+
std::string runtime_cache_file = "";
3071+
if (!runtime_cache_.empty()) {
3072+
runtime_cache_file = (runtime_cache_ / graph_body_viewer.GetNode(node_idx)->Name()).string();
3073+
trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache());
3074+
auto cache_data = file_utils::ReadFile(runtime_cache_file);
3075+
if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) {
3076+
trt_runtime_cache = std::unique_ptr<nvinfer1::IRuntimeCache>(trt_runtime_config->createRuntimeCache());
3077+
LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl;
3078+
}
3079+
if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) {
3080+
LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl;
3081+
}
3082+
}
3083+
30333084
// Build context
30343085
//
30353086
// Note: Creating an execution context from an engine is thread safe per TRT doc
30363087
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
3037-
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
3088+
trt_context = tensorrt_ptr::unique_pointer_exec_ctx(
3089+
trt_engine->createExecutionContext(trt_runtime_config.get()),
3090+
tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache)));
30383091
if (!trt_context) {
30393092
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
30403093
"NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name());

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ typedef void* cudnnStatus_t;
1616
#include <mutex>
1717
#include "core/providers/cuda/cuda_graph.h"
1818
#include "nv_execution_provider_info.h"
19+
#include "core/providers/nv_tensorrt_rtx/nv_file_utils.h"
1920

2021
namespace onnxruntime {
2122

@@ -58,6 +59,26 @@ class TensorrtLogger : public nvinfer1::ILogger {
5859
};
5960

6061
namespace tensorrt_ptr {
62+
/*
63+
* custom deleter that will dump the optimized runtime cache when the execution context is destructed
64+
*/
65+
struct IExecutionContextDeleter {
66+
IExecutionContextDeleter() = default;
67+
IExecutionContextDeleter(const std::string& runtime_cache_path, std::unique_ptr<nvinfer1::IRuntimeCache>&& runtime_cache) : runtime_cache_path_(runtime_cache_path), runtime_cache_(std::move(runtime_cache)) {};
68+
void operator()(nvinfer1::IExecutionContext* context) {
69+
if (context != nullptr) {
70+
if (!runtime_cache_path_.empty()) {
71+
auto serialized_cache_data = std::unique_ptr<nvinfer1::IHostMemory>(runtime_cache_->serialize());
72+
file_utils::WriteFile(runtime_cache_path_, serialized_cache_data->data(), serialized_cache_data->size());
73+
}
74+
delete context;
75+
}
76+
}
77+
78+
private:
79+
std::string runtime_cache_path_;
80+
std::unique_ptr<nvinfer1::IRuntimeCache> runtime_cache_;
81+
};
6182

6283
struct TensorrtInferDeleter {
6384
template <typename T>
@@ -70,6 +91,7 @@ struct TensorrtInferDeleter {
7091

7192
template <typename T>
7293
using unique_pointer = std::unique_ptr<T, TensorrtInferDeleter>;
94+
using unique_pointer_exec_ctx = std::unique_ptr<nvinfer1::IExecutionContext, IExecutionContextDeleter>;
7395
}; // namespace tensorrt_ptr
7496

7597
//
@@ -196,7 +218,7 @@ struct TensorrtFuncState {
196218
std::string fused_node_name;
197219
nvinfer1::IBuilder* builder;
198220
std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
199-
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
221+
tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr;
200222
std::unique_ptr<nvinfer1::INetworkDefinition>* network = nullptr;
201223
std::vector<std::unordered_map<std::string, size_t>> input_info;
202224
std::vector<std::unordered_map<std::string, size_t>> output_info;
@@ -233,7 +255,7 @@ struct TensorrtShortFuncState {
233255
AllocatorHandle allocator = nullptr;
234256
std::string fused_node_name;
235257
std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
236-
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
258+
tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr;
237259
std::vector<std::unordered_map<std::string, size_t>> input_info;
238260
std::vector<std::unordered_map<std::string, size_t>> output_info;
239261
std::mutex* tensorrt_mu_ptr = nullptr;
@@ -357,6 +379,7 @@ class NvExecutionProvider : public IExecutionProvider {
357379
bool detailed_build_log_ = false;
358380
bool cuda_graph_enable_ = false;
359381
bool multi_profile_enable_ = false;
382+
std::filesystem::path runtime_cache_;
360383
std::string cache_prefix_;
361384
std::string op_types_to_exclude_;
362385
int nv_profile_index_ = 0;
@@ -387,7 +410,7 @@ class NvExecutionProvider : public IExecutionProvider {
387410
// But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
388411
// For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization.
389412
std::unordered_map<std::string, std::unique_ptr<nvinfer1::ICudaEngine>> engines_;
390-
std::unordered_map<std::string, std::unique_ptr<nvinfer1::IExecutionContext>> contexts_;
413+
std::unordered_map<std::string, tensorrt_ptr::unique_pointer_exec_ctx> contexts_;
391414
std::unordered_map<std::string, std::unique_ptr<nvinfer1::IBuilder>> builders_;
392415
std::unordered_map<std::string, std::unique_ptr<nvinfer1::INetworkDefinition>> networks_;
393416
std::unordered_map<std::string, std::vector<std::unordered_map<std::string, size_t>>> input_info_;
@@ -425,7 +448,7 @@ class NvExecutionProvider : public IExecutionProvider {
425448

426449
bool IsTensorRTContextInMap(std::string fused_node);
427450
nvinfer1::IExecutionContext& GetTensorRTContext(std::string fused_node);
428-
bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr<nvinfer1::IExecutionContext> context);
451+
bool UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context);
429452
void ResetTensorRTContext(std::string fused_node);
430453

431454
// CUDA Graph management
@@ -455,7 +478,7 @@ class NvExecutionProvider : public IExecutionProvider {
455478
// See more details here:
456479
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
457480
// https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a63cd95430852038ce864e17c670e0b36
458-
std::unordered_map<std::string, std::unique_ptr<nvinfer1::IExecutionContext>> trt_context_map_;
481+
std::unordered_map<std::string, tensorrt_ptr::unique_pointer_exec_ctx> trt_context_map_;
459482

460483
// The profile shape ranges for the engine that the execution context maintained by the PerThreadContext is built with.
461484
// TRT EP needs this info to determine whether to rebuild the execution context.

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
5151
.AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable)
5252
.AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer)
5353
.AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable)
54+
.AddAssignmentToReference(nv::provider_option_names::kRuntimeCacheFile, info.runtime_cache_path)
5455
.Parse(options)); // add new provider option here.
5556

5657
info.user_compute_stream = user_compute_stream;
@@ -105,7 +106,8 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv
105106
{nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)},
106107
{nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)},
107108
{nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
108-
{nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}};
109+
{nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)},
110+
{nv::provider_option_names::kRuntimeCacheFile, MakeStringWithClassicLocale(info.runtime_cache_path)}};
109111
return options;
110112
}
111113
} // namespace onnxruntime

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct NvExecutionProviderInfo {
3737
bool engine_decryption_enable{false};
3838
std::string engine_decryption_lib_path{""};
3939
bool force_sequential_engine_build{false};
40-
std::string timing_cache_path{""};
40+
std::string runtime_cache_path{""};
4141
bool detailed_build_log{false};
4242
bool sparsity_enable{false};
4343
int auxiliary_streams{-1};
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
#include <string>
3+
#include <fstream>
4+
#include <filesystem>
5+
#include <stdexcept>
6+
#include <vector>
7+
#include "core/providers/shared_library/provider_api.h"
8+
9+
namespace onnxruntime {
10+
namespace file_utils {
11+
12+
inline std::vector<char> ReadFile(const std::string& path) {
13+
if (!std::filesystem::exists(path)) {
14+
LOGS_DEFAULT(INFO) << "TensorRT RTX could not find the file and will create a new one " << path << std::endl;
15+
return {};
16+
}
17+
std::ifstream file(path, std::ios::in | std::ios::binary);
18+
if (!file) {
19+
ORT_THROW("Failed to open file: " + path);
20+
}
21+
file.seekg(0, std::ios::end);
22+
std::streamsize size = file.tellg();
23+
file.seekg(0, std::ios::beg);
24+
std::vector<char> buffer(size);
25+
if (size > 0 && !file.read(buffer.data(), size)) {
26+
ORT_THROW("Failed to read file: " + path);
27+
}
28+
return buffer;
29+
}
30+
31+
inline void WriteFile(const std::string& path, const void* data, size_t size) {
32+
if (std::filesystem::exists(path)) {
33+
std::ofstream file(path, std::ios::out | std::ios::binary | std::ios::trunc);
34+
if (!file) {
35+
ORT_THROW("Failed to open file for writing: " + path);
36+
}
37+
file.write(static_cast<const char*>(data), size);
38+
} else {
39+
LOGS_DEFAULT(INFO) << "TensorRT RTX a new file cache was written to " << path << std::endl;
40+
// Create new file
41+
std::ofstream file(path, std::ios::out | std::ios::binary);
42+
if (!file) {
43+
ORT_THROW("Failed to create file: " + path);
44+
}
45+
file.write(static_cast<const char*>(data), size);
46+
}
47+
}
48+
49+
inline void WriteFile(const std::string& path, const std::vector<char>& data) { WriteFile(path, data.data(), data.size()); }
50+
51+
} // namespace file_utils
52+
} // namespace onnxruntime

0 commit comments

Comments
 (0)