Skip to content

Commit 9fc41c3

Browse files
authored
[EP ABI] Add Node_GetEpType API (microsoft#25350)
Add a new API `Node_GetEpType` to get the EP that the node is assigned to run on. This API is needed when porting the plugin TRT EP in `GetCapability` where ep needs to know whether the subgraph(s) of the control flow node is assigned to the ep and then to add this control flow op to the support list.
1 parent fb0f6c6 commit 9fc41c3

File tree

6 files changed

+44
-0
lines changed

6 files changed

+44
-0
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6026,6 +6026,18 @@ struct OrtApi {
60266026
*/
60276027
ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph);
60286028

6029+
/** \brief Returns the execution provider type (name) that this node is assigned to run on.
6030+
* Returns NULL if the node has not been assigned to any execution provider yet.
6031+
*
6032+
* \param[in] node The OrtNode instance.
6033+
* \param[out] out Output execution provider type and can be NULL if node has not been assigned.
6034+
*
6035+
* \snippet{doc} snippets.dox OrtStatus Return Value
6036+
*
6037+
* \since Version 1.23.
6038+
*/
6039+
ORT_API2_STATUS(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out);
6040+
60296041
/// @}
60306042

60316043
/// \name OrtRunOptions

onnxruntime/core/graph/ep_api_types.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const {
276276
}
277277
}
278278

279+
const std::string& EpNode::GetEpType() const {
280+
return node_.GetExecutionProviderType();
281+
}
282+
279283
//
280284
// EpValueInfo
281285
//

onnxruntime/core/graph/ep_api_types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ struct EpNode : public OrtNode {
208208
// Helper that gets the node's attributes by name.
209209
const OrtOpAttr* GetAttribute(const std::string& name) const;
210210

211+
// Helper that gets the execution provider that this node is assigned to run on.
212+
const std::string& GetEpType() const;
213+
211214
private:
212215
// Back pointer to containing graph. Useful when traversing through nested subgraphs.
213216
// Will be nullptr if the EpNode was created without an owning graph.

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,6 +3052,23 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node,
30523052
API_IMPL_END
30533053
}
30543054

3055+
ORT_API_STATUS_IMPL(OrtApis::Node_GetEpType, _In_ const OrtNode* node,
3056+
_Outptr_result_maybenull_ const char** out) {
3057+
API_IMPL_BEGIN
3058+
if (out == nullptr) {
3059+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL");
3060+
}
3061+
3062+
const EpNode* ep_node = EpNode::ToInternal(node);
3063+
if (ep_node == nullptr) {
3064+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpType.");
3065+
}
3066+
3067+
*out = ep_node->GetEpType().c_str();
3068+
return nullptr;
3069+
API_IMPL_END
3070+
}
3071+
30553072
ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) {
30563073
#ifdef ENABLE_TRAINING_APIS
30573074
if (version >= 13 && version <= ORT_API_VERSION)
@@ -3734,6 +3751,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
37343751
&OrtApis::Node_GetNumSubgraphs,
37353752
&OrtApis::Node_GetSubgraphs,
37363753
&OrtApis::Node_GetGraph,
3754+
&OrtApis::Node_GetEpType,
37373755

37383756
&OrtApis::GetRunConfigEntry,
37393757

onnxruntime/core/session/ort_apis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node,
680680
_Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs,
681681
_Out_writes_opt_(num_subgraphs) const char** attribute_names);
682682
ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph);
683+
ORT_API_STATUS_IMPL(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out);
683684

684685
ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options,
685686
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value);

onnxruntime/test/autoep/library/ep.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
328328
RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0]));
329329
RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1]));
330330

331+
const char* ep_type = nullptr;
332+
RETURN_IF_ERROR(ort_api.Node_GetEpType(fused_nodes[0], &ep_type));
333+
if (std::strncmp(ep_type, "example_ep", 11) != 0) {
334+
return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on");
335+
}
336+
331337
// Associate the name of the fused node with our MulKernel.
332338
const char* fused_node_name = nullptr;
333339
RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name));

0 commit comments

Comments
 (0)