Skip to content

Commit

Permalink
Add OrtGraphApis::OrtNode_GetAttributeStrWithSize to handle case wher…
Browse files Browse the repository at this point in the history
…e attribute might contain null character (#22769)

When running EP Context model, EP might call
`OrtGraphApis::OrtNode_GetAttributeStr` to get the string-based content
of the attribute.
However, the API returns the c_str() of the string, and it's possible
that the cache context contains null character, so the string might be
cut off and caller ends up getting the wrong string.

Add a new OrtGraphApis::OrtNode_GetAttributeStrWithSize to return const
char* pointer and string size.
  • Loading branch information
chilo-ms authored Nov 13, 2024
1 parent 2b1cfdf commit e337d8f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
21 changes: 21 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,17 @@ ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* k
*/
ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out);

/** \brief Gets the i-th string in the attribute with the given key.
*
* \param[in] node The node to query
* \param[in] key The attribute key
* \param[in] i The index of the string
* \param[out] out The i-th string in the attribute
* \param[out] size The length of the string
*
*/
ORT_API2_STATUS(OrtNode_GetAttributeIthStrWithSize, const OrtNode* node, const char* key, int i, _Outptr_ const char** out, _Outptr_ size_t* size);

/** \brief Gets the string value of the attribute with the given key.
*
* \param[in] node The node to query
Expand All @@ -565,6 +576,16 @@ ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key
*/
ORT_API2_STATUS(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out);

/** \brief Gets the string value of the attribute with the given key.
*
* \param[in] node The node to query
* \param[in] key The attribute key
* \param[out] out The string value of the attribute
* \param[out] size The length of the string
*
*/
ORT_API2_STATUS(OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size);

/** \brief Gets the int value of the attribute with the given key.
*
* \param[in] node The node to query
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,26 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthStr, const OrtNode* nod
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeIthStrWithSize, const OrtNode* node, const char* key, int i, _Outptr_ const char** out, _Outptr_ size_t* size) {
const ::onnxruntime::Node* n = reinterpret_cast<const ::onnxruntime::Node*>(node);
*size = n->GetAttributes().at(key).strings(i).size();
*out = n->GetAttributes().at(key).strings(i).c_str();
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out) {
const ::onnxruntime::Node* n = reinterpret_cast<const ::onnxruntime::Node*>(node);
*out = n->GetAttributes().at(key).s().c_str();
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size) {
const ::onnxruntime::Node* n = reinterpret_cast<const ::onnxruntime::Node*>(node);
*size = n->GetAttributes().at(key).s().size();
*out = n->GetAttributes().at(key).s().c_str();
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out) {
const ::onnxruntime::Node* n = reinterpret_cast<const ::onnxruntime::Node*>(node);
*out = n->GetAttributes().at(key).i();
Expand Down Expand Up @@ -841,7 +855,9 @@ static constexpr OrtGraphApi ort_graph_api = {
&OrtGraphApis::OrtNode_GetAttributeIthInt,
&OrtGraphApis::OrtNode_GetAttributeIthFloat,
&OrtGraphApis::OrtNode_GetAttributeIthStr,
&OrtGraphApis::OrtNode_GetAttributeIthStrWithSize,
&OrtGraphApis::OrtNode_GetAttributeStr,
&OrtGraphApis::OrtNode_GetAttributeStrWithSize,
&OrtGraphApis::OrtNode_GetAttributeInt,
&OrtGraphApis::OrtNode_GetAttributeFloat,
&OrtGraphApis::OrtNode_GetSubgraphs,
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/ort_apis_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,12 @@ ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthFloat, const OrtNode* node, const cha

ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Outptr_ const char** out);

ORT_API_STATUS_IMPL(OrtNode_GetAttributeIthStrWithSize, const OrtNode* node, const char* key, int i, _Outptr_ const char** out, _Outptr_ size_t* size);

ORT_API_STATUS_IMPL(OrtNode_GetAttributeStr, const OrtNode* node, const char* key, _Outptr_ const char** out);

ORT_API_STATUS_IMPL(OrtNode_GetAttributeStrWithSize, const OrtNode* node, const char* key, _Outptr_ const char** out, _Outptr_ size_t* size);

ORT_API_STATUS_IMPL(OrtNode_GetAttributeInt, const OrtNode* node, const char* key, _Out_ int64_t* out);

ORT_API_STATUS_IMPL(OrtNode_GetAttributeFloat, const OrtNode* node, const char* key, _Out_ float* out);
Expand Down
5 changes: 3 additions & 2 deletions samples/tensorRTEp/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphView
if (embed_mode) {
// Get engine from byte stream.
const char* context_binary_cstr = nullptr;
graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr);
std::string context_binary(context_binary_cstr);
size_t size;
graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &szie);
std::string context_binary(context_binary_cstr, size);
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
static_cast<size_t>(context_binary.length())));
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
Expand Down

0 comments on commit e337d8f

Please sign in to comment.