Skip to content

Commit

Permalink
Make EP plugin be able to create and update EP Context graph (#22740)
Browse files Browse the repository at this point in the history
This PR support several features:

- Add new graph API to create and update EP Context graph, and dump EP
Context model.

1.  OrtGraph_CreateOrUpdateEpCtxGraph
2. OrtGraph_DumpOnnxModel
3. OrtGraph_ReleaseGraph 

- Add new graph API to dump onnx model
- The APIs provided by this PR can dump EP Context model when the whole
model can be run by one EP, the APIs also aim to support the case where
the whole model is partitioned into multiple EP's subgraphs. (Note: i
haven't fully tested the partitioning case, please help review it)
- Modify TRT EP plugin to use those APIs.
  • Loading branch information
chilo-ms authored Nov 13, 2024
1 parent e337d8f commit afe92e1
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 55 deletions.
61 changes: 57 additions & 4 deletions include/onnxruntime/core/session/onnxruntime_c_api_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,48 @@ ORT_API2_STATUS(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info);
*
*/
ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); // TODO(leca): review and discuss

/** \brief Serialize the graph(model) to disk.
*
* \param[in] graph The graph to be serialized
* \param[in] onnx_model_path The file path to save to
*
*/
ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path);

/** \brief Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise:
* 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph
* 2. if the node being found with the givne node name, update the node attributes only
*
* Please see https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for more details about EP Context design
*
* \param[in] graph The graph to create or add
* \param[in] node_name The node to be added or updated
* \param[in] main_context The attribute of EP Context op
* \param[in] embed_mode The attribute of EP Context op
* \param[in] cache_path The cache or binary file path. It's for setting the ep_cache_context attribute if embed_mode is 0
* \param[in] cache_data The cache or binary data. It's for setting the ep_cache_context attribute if embed_mode is 1
* \param[in] size The size of cache data.
* \param[in] extra_attr_keys The other attribute names
* \param[in] extra_attr_values The other attribute value in string
* \param[in] extra_attr_num Number of other attributes
* \param[out] ep_context_graph The constructed or updated ep context graph
*
* \remarks The caller is responsible for releasing the ep_context_graph using OrtGraph_ReleaseGraph.
*
*/
ORT_API2_STATUS(OrtGraph_CreateOrUpdateEpCtxGraph,
const OrtGraphViewer* graph,
const char* node_name,
const int64_t main_context,
const int64_t embed_mode,
const char* cache_path,
char* cache_data,
size_t size,
const char* const* extra_attr_keys,
const char* const* extra_attr_values,
size_t extra_attr_num,
_Outptr_ OrtGraph** ep_context_graph);

/** \brief Construct a subgraph from the Graph with the given node indices.
*
Expand All @@ -345,17 +387,28 @@ ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ vo
*
*/
ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss

/** \brief Release the graph instance.
*
* NOTE!!: Invoke this function after the use of OrtGraph_CreateOrUpdateEpCtxGraph. As OrtGraph_CreateOrUpdateEpCtxGraph allocates model instead of
* graph, this API releases graph's owning_model explicitly which in turn will release the graph
* (because graph is hosted in an unique_ptr in Model class)
*
* \param[in] graph The graph to release
*
*/
ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph);

/** \brief Release the graph.
/** \brief Release the graph viewer instance.
*
* NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocate model instead of
* graph, this API release graph's owning_model explicitly which in turn will release the graph
* NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocates model instead of
* graph, this API releases graph's owning_model explicitly which in turn will release the graph
* (because graph is hosted in an unique_ptr in Model class)
*
* \param[in] graph The graph to release
*
*/
ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph);
ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph);

/** \brief Gets the name of the node
*
Expand Down
194 changes: 193 additions & 1 deletion onnxruntime/core/session/onnxruntime_c_api_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include "core/framework/tensorprotoutils.h"
#include "core/session/ort_apis.h"

#include <fstream>
#include <iostream>

using namespace onnxruntime;

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetName, const OrtGraphViewer* graph, _Out_ const char** out) {
Expand Down Expand Up @@ -477,6 +480,184 @@ static void SetAllGraphInputs(Graph& graph, std::unordered_map<std::string, std:
graph.SetInputs(graph_inputs_including_initializers);
}

/*
* Given a graph, get the corresponding model and serialize it to disk.
*/
ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_DumpOnnxModel,
const OrtGraph* graph,
const char* onnx_model_path) {
const ::onnxruntime::Graph* internal_graph = reinterpret_cast<const ::onnxruntime::Graph*>(graph);
auto model = &(internal_graph->GetModel());

// Two options to generate model proto:
// 1. directly call model->ToProto()
// 2. new model ---> model->ToProto ---> update graph proto in model proto with GraphViewerToProto()
//
// TODO: (Chi) Need more thinking on which to choose

// option 1
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto());

// option 2
//auto model_proto = model->ToProto();
//graph->ToProto(*model_proto->mutable_graph(), true, true);
//model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

std::fstream dump(onnx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(&dump);
//LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
return nullptr;
}

/* Construct an "EP Context" graph if the given ep_context_graph graph is empty, otherwise:
* 1. if the given node name can't be found in the graph, add an new "EP Context" node to the existing graph
* 2. if the node is already existed, update the node attributes only
*/
ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph,
const OrtGraphViewer* graph,
const char* node_name,
const int64_t main_context,
const int64_t embed_mode,
const char* cache_path,
char* cache_data,
size_t size,
const char* const* extra_attr_keys,
const char* const* extra_attr_values,
size_t extra_attr_num,
_Outptr_ OrtGraph** ep_context_graph) {

const std::string EPCONTEXT_OP = "EPContext";
const std::string MAIN_CONTEXT = "main_context";
const std::string EMBED_MODE = "embed_mode";
const std::string EP_CACHE_CONTEXT = "ep_cache_context";
const std::string ONNX_MODEL_FILENAME = "onnx_model_filename";
const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
const std::string EPCONTEXT_WARNING =
"It's suggested to set the ORT graph optimization level to 0 and \
make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\
for the best model loading time";

const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph);
::onnxruntime::Graph* graph_build;

if (!graph_viewer && !(*ep_context_graph)) return nullptr;

std::unordered_map<std::string, std::string> attr_keys_values;
for (size_t i = 0; i < extra_attr_num; i++) {
attr_keys_values[extra_attr_keys[i]] = extra_attr_values[i];
}

// Create a new graph or use the existing one
if (*ep_context_graph == nullptr) {
Model* model_build = new Model (graph_viewer->Name(), true, ModelMetaData(), PathString(),
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(),
#else
IOnnxRuntimeOpSchemaRegistryList(), graph_viewer->DomainToVersionMap(),
#endif // ORT_MINIMAL_BUILD
std::vector<ONNX_NAMESPACE::FunctionProto>(), graph_viewer->GetGraph().GetLogger());
graph_build = &(model_build->MainGraph());
*ep_context_graph = reinterpret_cast<OrtGraph*>(graph_build);
} else {
graph_build = reinterpret_cast<::onnxruntime::Graph*>(*ep_context_graph);
}

// Get graph inputs and outputs
std::vector<onnxruntime::NodeArg*> inputs, outputs;
if (graph_viewer) {
for (auto input : graph_viewer->GetInputs()) {
auto& n_input = graph_build->GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
inputs.push_back(&n_input);
}

for (auto output : graph_viewer->GetOutputs()) {
auto& n_output = graph_build->GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
outputs.push_back(&n_output);
}
}

// locate specific node if any
auto get_node_index = [&](Graph* graph, const char* node_name) -> size_t {
std::string name = node_name;
for (auto& node : graph->Nodes()) {
if (name == node.Name()) {
return node.Index();
}
}
// return impossible value to indicate the node is not existed
return std::numeric_limits<size_t>::max();
};
size_t node_idx = get_node_index(graph_build, node_name);
bool node_existed = node_idx != std::numeric_limits<size_t>::max() ? true : false;

// Create or get EP context node attributes
auto new_node_attributes = NodeAttributes(); // using NodeAttributes = std::unordered_map<std::string, ONNX_NAMESPACE::AttributeProto>
NodeAttributes* node_attributes;
if (node_existed) {
node_attributes = &graph_build->GetNode(node_idx)->GetMutableAttributes();
} else {
new_node_attributes.reserve(3 + extra_attr_num);
node_attributes = &new_node_attributes;
}
std::unique_ptr<ONNX_NAMESPACE::AttributeProto> attr_0 = std::make_unique<ONNX_NAMESPACE::AttributeProto>(); // main_context
std::unique_ptr<ONNX_NAMESPACE::AttributeProto> attr_1 = std::make_unique<ONNX_NAMESPACE::AttributeProto>(); // embed_mode
std::unique_ptr<ONNX_NAMESPACE::AttributeProto> attr_2 = std::make_unique<ONNX_NAMESPACE::AttributeProto>(); // ep_cache_context

std::string cache_data_str = "";
std::string cache_path_str = cache_path;

// main_context
attr_0->set_name(MAIN_CONTEXT);
attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
attr_0->set_i(main_context);

// embed_mode
attr_1->set_name(EMBED_MODE);
attr_1->set_type(onnx::AttributeProto_AttributeType_INT);
attr_1->set_i(embed_mode);

// ep_cache_context
attr_2->set_name(EP_CACHE_CONTEXT);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
if (embed_mode) {
if (size > 0) {
cache_data_str.assign(cache_data, size);
}
attr_2->set_s(cache_data_str);
//LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
} else {
attr_2->set_s(cache_path_str);
}

(*node_attributes)[MAIN_CONTEXT] = *attr_0;
(*node_attributes)[EMBED_MODE] = *attr_1;
(*node_attributes)[EP_CACHE_CONTEXT] = *attr_2;

// other attributes
std::unordered_map<std::string, std::string>::iterator it;
for (it = attr_keys_values.begin(); it != attr_keys_values.end(); ++it) {
std::string key = it->first;
std::string value = it->second;
if (key == ONNX_MODEL_FILENAME) value = std::filesystem::path(value).filename().string();

std::unique_ptr<ONNX_NAMESPACE::AttributeProto> attr = std::make_unique<ONNX_NAMESPACE::AttributeProto>();
attr->set_name(key);
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
attr->set_s(value);
(*node_attributes)[key] = *attr;
}

if (!node_existed && graph_viewer) {
std::string name = node_name;
graph_build->AddNode(name, EPCONTEXT_OP, "", inputs, outputs, node_attributes, EPCONTEXT_OP_DOMAIN);
}

common::Status status = graph_build->Resolve();
if (status != Status::OK()) return onnxruntime::ToOrtStatus(status);

return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph) {
const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph);
// Get parent graph output names
Expand Down Expand Up @@ -595,7 +776,15 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_GetSubGraph, const OrtGraphViewer* gr
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraphViewer* graph) {
ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraph* ort_graph) {
if (ort_graph) {
const ::onnxruntime::Graph* graph = reinterpret_cast<const ::onnxruntime::Graph*>(ort_graph);
delete &(graph->GetModel());
}
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) {
if (graph) {
const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph);
delete &(graph_viewer->GetGraph()).GetModel();
Expand Down Expand Up @@ -830,8 +1019,11 @@ static constexpr OrtGraphApi ort_graph_api = {
&OrtGraphApis::OrtGraph_GetValueInfo,
&OrtGraphApis::OrtGraph_ReleaseValueInfo,
&OrtGraphApis::OrtGraph_SerializeToArray,
&OrtGraphApis::OrtGraph_DumpOnnxModel,
&OrtGraphApis::OrtGraph_CreateOrUpdateEpCtxGraph,
&OrtGraphApis::OrtGraph_GetSubGraph,
&OrtGraphApis::OrtGraph_ReleaseGraph,
&OrtGraphApis::OrtGraph_ReleaseGraphViewer,
&OrtGraphApis::OrtNode_GetName,
&OrtGraphApis::OrtNode_GetDescription,
&OrtGraphApis::OrtNode_GetDomain,
Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/core/session/ort_apis_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,26 @@ ORT_API_STATUS_IMPL(OrtGraph_ReleaseValueInfo, OrtValueInfoRef* value_info);

ORT_API_STATUS_IMPL(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size);

ORT_API_STATUS_IMPL(OrtGraph_DumpOnnxModel, const OrtGraph* graph, const char* onnx_model_path);

ORT_API_STATUS_IMPL(OrtGraph_CreateOrUpdateEpCtxGraph,
const OrtGraphViewer* graph,
const char* node_name,
const int64_t main_context,
const int64_t embed_mode,
const char* cache_path,
char* cache_data,
size_t size,
const char* const* extra_attr_keys,
const char* const* extra_attr_values,
size_t extra_attr_num,
_Outptr_ OrtGraph** ep_context_graph);

ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph);

ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraphViewer* graph);
ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraph* graph);

ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph);

ORT_API_STATUS_IMPL(OrtNode_GetName, const OrtNode* node, _Outptr_ const char** out);

Expand Down
Loading

0 comments on commit afe92e1

Please sign in to comment.