Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
0b5b4d7
update
chilo-ms Dec 8, 2025
80561db
update
chilo-ms Dec 9, 2025
6bd316f
add API summary
chilo-ms Dec 9, 2025
ad0a023
update
chilo-ms Dec 9, 2025
5e398d4
address reviewer's comments and add GetNumCustomOpDomains
chilo-ms Dec 10, 2025
aeb2386
update example ep to run Custom_Mul op
chilo-ms Dec 10, 2025
3849cd3
address reviewr's comment
chilo-ms Dec 10, 2025
9c987be
lintrunner -a
chilo-ms Dec 10, 2025
fbe2434
update example ep GetCapability()
chilo-ms Dec 10, 2025
40fa8fe
update Example EP
chilo-ms Dec 11, 2025
c7a0491
add more comments in API summary
chilo-ms Dec 12, 2025
4787c3f
address reviewer's comments
chilo-ms Dec 18, 2025
632ce31
lintrunner -a
chilo-ms Dec 18, 2025
5905434
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 7, 2026
6017c00
Use CreateKernelV2 and ComputeKernelV2
chilo-ms Jan 8, 2026
47bb4dc
address reviewer's comments
chilo-ms Jan 8, 2026
6721a98
lintrunner -a
chilo-ms Jan 8, 2026
1ab246d
update
chilo-ms Jan 8, 2026
3478732
Remove accidentally added file
chilo-ms Jan 8, 2026
a1d36af
address reviewer's comments
chilo-ms Jan 9, 2026
3065e9d
address reviewer's comment
chilo-ms Jan 9, 2026
d340de5
address reveiwer's comment
chilo-ms Jan 9, 2026
ee8851b
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 9, 2026
15f5baf
update
chilo-ms Jan 9, 2026
6b01e7f
lintrunner -a
chilo-ms Jan 9, 2026
adf565e
fix bug when merging main
chilo-ms Jan 9, 2026
062280e
Make auto ep selection be able to register custom op
chilo-ms Jan 13, 2026
ff58721
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 13, 2026
002fcdc
add comments
chilo-ms Jan 13, 2026
bb7e082
Make code be able to get model_metadata from model during auto ep sel…
chilo-ms Jan 13, 2026
953dbd3
Use Model::Load
chilo-ms Jan 14, 2026
cf5948a
revert unnecessary change
chilo-ms Jan 14, 2026
cc31408
update API comment
chilo-ms Jan 14, 2026
e2604b9
fix build issue for minimal build
chilo-ms Jan 14, 2026
27cb17a
address reviewer's comments
chilo-ms Jan 14, 2026
f77e5f7
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 14, 2026
2e91855
lintrunner -a
chilo-ms Jan 14, 2026
1f448a6
fix compile warning for minimal build
chilo-ms Jan 14, 2026
debdc79
address reviewer's comment
chilo-ms Jan 14, 2026
452bb26
Add AddEpCustomDomainsToSessionOptions()
chilo-ms Jan 15, 2026
96d42fb
clean up code
chilo-ms Jan 15, 2026
04b75e8
clean up code and fix compile error
chilo-ms Jan 15, 2026
84cff1f
revert auto ep selection
chilo-ms Jan 15, 2026
4d9849f
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 15, 2026
32e2e57
add back accidentaly removed code
chilo-ms Jan 15, 2026
b80b451
address reviewer's comments
chilo-ms Jan 16, 2026
0b7302e
update
chilo-ms Jan 16, 2026
be94f18
fix compile error for onnxruntime_pybind_state.cc
chilo-ms Jan 16, 2026
1841117
address reveiwer's comment
chilo-ms Jan 16, 2026
6a571ef
address reviewer's comments
chilo-ms Jan 16, 2026
3b2e5a5
address Copilot comment
chilo-ms Jan 17, 2026
5d0b15b
address Copilot comment
chilo-ms Jan 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,49 @@ struct OrtEpFactory {
* \since Version 1.24.
*/
ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options);

/** \brief Returns the number of OrtCustomOpDomains that this factory creates.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains);

/** \brief Creates the EP-specific OrtCustomOpDomains.
*
* This function is used when running inference on a model that contains EP-specific custom operations.
* For compile-based EPs, the EP does not need to provide a concrete kernel implementation for each custom op.
* Instead, it may provide only placeholder custom ops with the correct names so they can be recognized
* during model loading.
*
* Workflow:
* 1. The EP implements this function to supply a list of OrtCustomOpDomain instances.
* 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing
* the plugin EP's factory.
* 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the
* session options.
*
* As a result, any session created from these session options will have these custom op domains registered
* in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded.
*
* Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens
* automatically if using ORT C++ api.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with
OrtCustomOpDomain created by the EP.
* \param[in] num_domains The size of the `domains` array pre-allocated by ORT.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr,
_Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains);
};

#ifdef __cplusplus
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3332,7 +3332,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS
ep_devices_span,
ep_option_keys_span,
ep_option_vals_span,
session_options->value));
*session_options));

session_options->provider_factories.push_back(std::move(provider_factory));

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class EpFactoryInternalImpl {
return nullptr;
}

virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** out,
_Out_ size_t* num_domains) const noexcept {
*out = nullptr;
*num_domains = 0;
return nullptr;
}

// Function ORT calls to release an EP instance.
void ReleaseEp(OrtEp* ep);

Expand Down
41 changes: 40 additions & 1 deletion onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@
Status AddEpOptionsToSessionOptions(gsl::span<const OrtEpDevice* const> ep_devices,
gsl::span<const char* const> ep_option_keys,
gsl::span<const char* const> ep_option_vals,
SessionOptions& session_options) {
OrtSessionOptions& ort_session_options) {
SessionOptions& session_options = ort_session_options.value;
const size_t num_ep_options = ep_option_keys.size();
if (ep_option_vals.size() != num_ep_options) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -505,6 +506,44 @@

ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j]));
}

// Add custom op domain provided by EP to the session options if any.
// OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::CreateCustomOpDomains
// were added in ORT 1.24.
OrtEpFactory* ep_factory = ep_device->ep_factory;
if (ep_factory &&
ep_factory->ort_version_supported >= 24 &&
ep_factory->CreateCustomOpDomains != nullptr) {
auto is_already_in_domains = [&](const std::string& domain_name, const std::vector<OrtCustomOpDomain*>& domains) {

Check warning on line 517 in onnxruntime/core/session/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/utils.cc:517: Add #include <vector> for vector<> [build/include_what_you_use] [4]

Check warning on line 517 in onnxruntime/core/session/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/utils.cc:517: Add #include <string> for string [build/include_what_you_use] [4]
for (auto ptr : domains) {
if (domain_name == ptr->domain_) {
return true;
}
}
return false;
};

size_t num_domains = 0;
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains)));

InlinedVector<OrtCustomOpDomain*> domains;
domains.resize(num_domains);

ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory,
domains.data(),
num_domains)));

const auto domains_span = gsl::span<OrtCustomOpDomain*>(domains.data(), domains.size());
for (auto domain : domains_span) {
if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) &&
domain->custom_ops_.size() > 0) {
ort_session_options.custom_op_domains_.push_back(domain);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name "
<< domain->domain_ << " is already in the session option. Skip it.";
}
}
}
}

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env,
Status AddEpOptionsToSessionOptions(gsl::span<const OrtEpDevice* const> ep_devices,
gsl::span<const char* const> ep_options_keys,
gsl::span<const char* const> ep_options_vals,
SessionOptions& session_options);
OrtSessionOptions& session_options);

} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options,
ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(ep_devices,
ep_option_keys,
ep_option_vals,
py_sess_options.value));
py_sess_options));

py_sess_options.provider_factories.push_back(std::move(provider_factory));
return Status::OK();
Expand Down
207 changes: 98 additions & 109 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,117 +15,97 @@
#include "ep_factory.h"
#include "ep_stream_support.h"

/// <summary>
/// Example implementation of ONNX Mul. Does not handle many things like broadcasting.
/// </summary>
struct MulKernel {
MulKernel(const OrtApi& ort_api, const OrtLogger& logger,
const std::unordered_map<std::string, FloatInitializer>& float_initializers,
std::string input0_name, std::string input1_name)
: ort_api(ort_api),
logger(logger),
float_initializers(float_initializers),
input0_name(input0_name),
input1_name(input1_name) {}

const FloatInitializer* TryGetSavedInitializer(const std::string& name) const {
auto iter = float_initializers.find(name);
return iter != float_initializers.end() ? &iter->second : nullptr;
}
const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const {
auto iter = float_initializers.find(name);
return iter != float_initializers.end() ? &iter->second : nullptr;
}

void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index,
/*out*/ gsl::span<const float>& data,
/*out*/ std::vector<int64_t>& shape) const {
Ort::ConstValue input = kernel_context.GetInput(index);
auto type_shape = input.GetTensorTypeAndShapeInfo();
void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index,
/*out*/ gsl::span<const float>& data,
/*out*/ std::vector<int64_t>& shape) const {
Ort::ConstValue input = kernel_context.GetInput(index);
auto type_shape = input.GetTensorTypeAndShapeInfo();

ONNXTensorElementDataType elem_type = type_shape.GetElementType();
if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL);
ONNXTensorElementDataType elem_type = type_shape.GetElementType();
if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL);

const float* float_data = input.GetTensorData<float>();
size_t num_elems = type_shape.GetElementCount();
data = gsl::span<const float>(float_data, num_elems);
shape = type_shape.GetShape();
}

OrtStatus* Compute(OrtKernelContext* kernel_ctx) {
RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
"MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__));
Ort::KernelContext kernel_context(kernel_ctx);
try {
gsl::span<const float> input0;
gsl::span<const float> input1;
std::vector<int64_t> shape0;
std::vector<int64_t> shape1;

size_t num_inputs = kernel_context.GetInputCount();

if (num_inputs == 2) {
// Both inputs are non-constant. Get them from ORT's KernelContext.
GetInputDataAndShape(kernel_context, 0, input0, shape0);
GetInputDataAndShape(kernel_context, 1, input1, shape1);
} else if (num_inputs == 1) {
// ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs.
// Get the constant input from the initializers saved by the EP.
// Refer to "NodeFusionOptions_DropConstantInitializers()".

if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input1, shape1);
input0 = gsl::span<const float>(const_input0->data);
shape0 = const_input0->shape;
} else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input0, shape0);
input1 = gsl::span<const float>(const_input1->data);
shape1 = const_input1->shape;
}
} else {
// Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding)
// are disabled.
const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name);
const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name);
RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api,
"Expected 2 initializer inputs to be saved by EP");
const float* float_data = input.GetTensorData<float>();
size_t num_elems = type_shape.GetElementCount();
data = gsl::span<const float>(float_data, num_elems);
shape = type_shape.GetShape();
}

OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) {
RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
"MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__));
Ort::KernelContext kernel_context(kernel_ctx);
try {
gsl::span<const float> input0;
gsl::span<const float> input1;
std::vector<int64_t> shape0;
std::vector<int64_t> shape1;

size_t num_inputs = kernel_context.GetInputCount();

if (num_inputs == 2) {
// Both inputs are non-constant. Get them from ORT's KernelContext.
GetInputDataAndShape(kernel_context, 0, input0, shape0);
GetInputDataAndShape(kernel_context, 1, input1, shape1);
} else if (num_inputs == 1) {
// ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs.
// Get the constant input from the initializers saved by the EP.
// Refer to "NodeFusionOptions_DropConstantInitializers()".

if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input1, shape1);
input0 = gsl::span<const float>(const_input0->data);
input1 = gsl::span<const float>(const_input1->data);
shape0 = const_input0->shape;
} else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input0, shape0);
input1 = gsl::span<const float>(const_input1->data);
shape1 = const_input1->shape;
}
} else {
// Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding)
// are disabled.
const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name);
const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name);
RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api,
"Expected 2 initializer inputs to be saved by EP");

input0 = gsl::span<const float>(const_input0->data);
input1 = gsl::span<const float>(const_input1->data);
shape0 = const_input0->shape;
shape1 = const_input1->shape;
}

if (shape0 != shape1) {
throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT);
}
if (shape0 != shape1) {
throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT);
}

size_t num_outputs = kernel_context.GetOutputCount();
if (num_outputs != 1) {
throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT);
}
size_t num_outputs = kernel_context.GetOutputCount();
if (num_outputs != 1) {
throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT);
}

auto output = kernel_context.GetOutput(0, shape0);
float* output_data = output.GetTensorMutableData<float>();
auto output = kernel_context.GetOutput(0, shape0);
float* output_data = output.GetTensorMutableData<float>();

for (size_t i = 0; i < input0.size(); ++i) {
output_data[i] = input0[i] * input1[i];
}
} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
} catch (const std::exception& ex) {
Ort::Status status(ex.what(), ORT_EP_FAIL);
return status.release();
for (size_t i = 0; i < input0.size(); ++i) {
output_data[i] = input0[i] * input1[i];
}

return nullptr;
} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
} catch (const std::exception& ex) {
Ort::Status status(ex.what(), ORT_EP_FAIL);
return status.release();
}

const OrtApi& ort_api;
const OrtLogger& logger;
const std::unordered_map<std::string, FloatInitializer>& float_initializers;
std::string input0_name;
std::string input1_name;
};
return nullptr;
}

/// <summary>
/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
Expand Down Expand Up @@ -262,26 +242,35 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG

supported_nodes.push_back(node); // Only support a single Mul for now.
break;
} else if (op_type == "Custom_Mul") {
supported_nodes.push_back(node);
}
}

if (supported_nodes.empty()) {
return nullptr;
}

// Create (optional) fusion options for the supported nodes to fuse.
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

// Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers
// as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers.
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
if (supported_nodes[0].GetOperatorType() == "Mul") {
// Create (optional) fusion options for the supported nodes to fuse.
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

// Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers
// as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers.
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
} else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") {
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
// as CustomMul has the concrete kernel implementation.
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0]));
}

} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
Expand Down
Loading
Loading