Skip to content

Commit

Permalink
serialize refitted engine
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed May 23, 2024
1 parent cd5eba2 commit 89c8b0f
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ bool IsRelativePathToParentPath(const std::string& path_string) {
#endif
}

// Get the refitted engine cache path

Check warning on line 247 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:247: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
std::string GetRefittedEnginePath(std::string engine_cache_path) {
std::filesystem::path full_engine_cache_path(engine_cache_path);
// The weight-stripped engine has the naming of xxx.stripped.engine
std::string refitted_engine_cache_path = full_engine_cache_path.stem().stem().string() + ".engine";
return refitted_engine_cache_path;
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
Expand All @@ -266,17 +274,6 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
// Get engine from cache file.
std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s();

// If the serialized refitted engine is preset, directly use it without needed to refit it again
if (weight_stripped_engine_refit_) {
// The weight-stripped engine has the naming of xxx.stripped.engine
std::filesystem::path stripped_cache_path(cache_path);
std::string weight_stripped_cache_path = stripped_cache_path.stem().stem().string() + ".engine";
if (std::filesystem::exists(weight_stripped_cache_path)) {
cache_path = weight_stripped_cache_path;
weight_stripped_engine_refit_ = false;
}
}

// For security purpose, in the case of running context model, TRT EP won't allow
// engine cache path to be the relative path like "../file_path" or the absolute path.
// It only allows the engine cache to be in the same directory or sub directory of the context model.
Expand All @@ -291,6 +288,15 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
auto engine_cache_path = ctx_model_dir.append(cache_path);

// If the serialized refitted engine is present, use it directly without refitting the engine again
if (weight_stripped_engine_refit_) {
std::string refitted_engine_cache_path = GetRefittedEnginePath(engine_cache_path.string());
if (std::filesystem::exists(refitted_engine_cache_path)) {
engine_cache_path = refitted_engine_cache_path;
weight_stripped_engine_refit_ = false;
}
}

if (!std::filesystem::exists(engine_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP can't find engine cache: " + engine_cache_path.string() +
Expand Down Expand Up @@ -343,6 +349,15 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());

Check warning on line 350 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:350: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

// serialize the refitted engine to disk
if (embed_mode == 0) {
std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s();
std::string refitted_cache_path = GetRefittedEnginePath(cache_path);
nvinfer1::IHostMemory* serialized_engine = (*trt_engine_)->serialize();
std::ofstream engine_file(refitted_cache_path, std::ios::binary | std::ios::out);
engine_file.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());
}
#else
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards.");
#endif
Expand Down

0 comments on commit 89c8b0f

Please sign in to comment.