Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,49 @@ struct OrtEpFactory {
* \since Version 1.24.
*/
ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options);

/** \brief Returns the number of OrtCustomOpDomains that this factory creates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we renamed it to GetCustomOpDomains, maybe replace "creates" with "provides" or "supplies".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with "provides"

*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains);

/** \brief Creates the EP-specific OrtCustomOpDomains.
*
* This function is used when running inference on a model that contains EP-specific custom operations.
* For compile-based EPs, the EP does not need to provide a concrete kernel implementation for each custom op.
* Instead, it may provide only placeholder custom ops with the correct names so they can be recognized
* during model loading.
*
* Workflow:
* 1. The EP implements this function to supply a list of OrtCustomOpDomain instances.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should clarify that the custom op domains are provided by the EP factory, not the EP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, modified to EP factory

* 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing
* the plugin EP's factory.
* 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the
* session options.
*
* As a result, any session created from these session options will have these custom op domains registered
* in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded.
*
* Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be the EP factory and not the EP that has this responsibility, right? and the instances must be valid while any session is using them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

* automatically if using ORT C++ api.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, with a name like "CreateCustomOpDomains", there's some expectation of ownership transfer. E.g., how OrtApi::CreateCustomOpDomain() creates a new domain that the user is responsible for releasing.

since this function does not transfer ownership to the caller, maybe a name like GetCustomOpDomains() would be better?

Copy link
Contributor Author

@chilo-ms chilo-ms Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me to have a name GetCustomOpDomains(). Changed.

*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with
OrtCustomOpDomain created by the EP.
* \param[in] num_domains The size of the `domains` array pre-allocated by ORT.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr,
_Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrtCustomOpDomain!
This structure is not ABI safe or stable across boundaries.

Copy link
Contributor Author

@chilo-ms chilo-ms Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OrtCustomOpDomain is an opaque struct to EP, and EP can create it by calling C API CreateCustomOpDomain() to get a pointer of an OrtCustomOpDomain instance and then add OrtCustomOp instances to it.
I think it's ABI stable as it's just a C pointer.

Please see the implementation in Example EP in the unit test as a reference.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expectation for OrtCustomOpDomain** domains this will contain a deep copy or shallow copy of the OrtCustomOp that ep provides?

Copy link
Contributor Author

@chilo-ms chilo-ms Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OrtCustomOpDomain contains a shallow copy (a pointer) of the OrtCustomOp.

The reason is:

EP should call OrtApis::CustomOpDomain_Add() to add OrtCustomOp to OrtCustomOpDomain.
As per api implementation, it simply stores the pointer of the OrtCustomOp in the vector, not constructing a new OrtCustomOp

ORT_API_STATUS_IMPL(OrtApis::CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op) {
  API_IMPL_BEGIN
  custom_op_domain->custom_ops_.emplace_back(op);
  return nullptr;
  API_IMPL_END
}

Then, inside ORT, it simply stores the pointer of OrtCustomOpDomain returned from EP into session options.

  Status AddEpOptionsToSessionOptions() {
  ...
      ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory,
                                                                               domains.data(),
                                                                               domains.size())));

      const auto domains_span = gsl::span<OrtCustomOpDomain*>(domains.data(), domains.size());
      for (auto domain : domains_span) {
        if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) &&
            domain->custom_ops_.size() > 0) {
          ort_session_options.custom_op_domains_.push_back(domain);
        } else {
          LOGS_DEFAULT(WARNING) << "The custom op domain name "
                                << domain->domain_ << " is already in the session option. Skip it.";
        }
      }
    ...
  }  

};

#ifdef __cplusplus
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3332,7 +3332,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS
ep_devices_span,
ep_option_keys_span,
ep_option_vals_span,
session_options->value));
*session_options));

session_options->provider_factories.push_back(std::move(provider_factory));

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class EpFactoryInternalImpl {
return nullptr;
}

virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need GetNumCustomOpDomains in here?

_Out_ size_t* num_domains) const noexcept {
*out = nullptr;
*num_domains = 0;
return nullptr;
}

// Function ORT calls to release an EP instance.
void ReleaseEp(OrtEp* ep);

Expand Down
41 changes: 40 additions & 1 deletion onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@
Status AddEpOptionsToSessionOptions(gsl::span<const OrtEpDevice* const> ep_devices,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not called during auto ep selection path, when the session_options.set_provider_selection_policy(ort.OrtExecutionProviderDevicePolicy.PREFER_GPU) is set?
For auto ep path when it creates plugin ep it goes for inference_session.cc - RegisterExecutionProvider() where it calls GetCustomOpDomainList? Is this understanding correct?

Copy link
Contributor Author

@chilo-ms chilo-ms Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of EP using GetCustomOpDomains to register custom op, only this EP can run the model contains that custom op.
IMO, the application should explicitly call SessionOptionsAppendExecutionProvider_V2() and specify that ep device, rather than using auto ep selection as other devices might not be able to run that custom op.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering custom ops should be independent of EP , so it should be available in auto ep selectin path as well

gsl::span<const char* const> ep_option_keys,
gsl::span<const char* const> ep_option_vals,
SessionOptions& session_options) {
OrtSessionOptions& ort_session_options) {
SessionOptions& session_options = ort_session_options.value;
const size_t num_ep_options = ep_option_keys.size();
if (ep_option_vals.size() != num_ep_options) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -505,6 +506,44 @@

ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j]));
}

// Add custom op domain provided by EP to the session options if any.
// OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::CreateCustomOpDomains
// were added in ORT 1.24.
OrtEpFactory* ep_factory = ep_device->ep_factory;
if (ep_factory &&
ep_factory->ort_version_supported >= 24 &&
ep_factory->CreateCustomOpDomains != nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to check ep_factory->GetNumCustomOpDomains just in case?

auto is_already_in_domains = [&](const std::string& domain_name, const std::vector<OrtCustomOpDomain*>& domains) {

Check warning on line 517 in onnxruntime/core/session/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/utils.cc:517: Add #include <vector> for vector<> [build/include_what_you_use] [4]

Check warning on line 517 in onnxruntime/core/session/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/utils.cc:517: Add #include <string> for string [build/include_what_you_use] [4]
for (auto ptr : domains) {
if (domain_name == ptr->domain_) {
return true;
}
}
return false;
};

size_t num_domains = 0;
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains)));

InlinedVector<OrtCustomOpDomain*> domains;
domains.resize(num_domains);

ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory,
domains.data(),
num_domains)));

const auto domains_span = gsl::span<OrtCustomOpDomain*>(domains.data(), domains.size());
for (auto domain : domains_span) {
if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) &&
domain->custom_ops_.size() > 0) {
ort_session_options.custom_op_domains_.push_back(domain);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name "
<< domain->domain_ << " is already in the session option. Skip it.";
}
}
}
}

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env,
Status AddEpOptionsToSessionOptions(gsl::span<const OrtEpDevice* const> ep_devices,
gsl::span<const char* const> ep_options_keys,
gsl::span<const char* const> ep_options_vals,
SessionOptions& session_options);
OrtSessionOptions& session_options);

} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options,
ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(ep_devices,
ep_option_keys,
ep_option_vals,
py_sess_options.value));
py_sess_options));

py_sess_options.provider_factories.push_back(std::move(provider_factory));
return Status::OK();
Expand Down
207 changes: 98 additions & 109 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,117 +15,97 @@
#include "ep_factory.h"
#include "ep_stream_support.h"

/// <summary>
/// Example implementation of ONNX Mul. Does not handle many things like broadcasting.
/// </summary>
struct MulKernel {
MulKernel(const OrtApi& ort_api, const OrtLogger& logger,
const std::unordered_map<std::string, FloatInitializer>& float_initializers,
std::string input0_name, std::string input1_name)
: ort_api(ort_api),
logger(logger),
float_initializers(float_initializers),
input0_name(input0_name),
input1_name(input1_name) {}

const FloatInitializer* TryGetSavedInitializer(const std::string& name) const {
auto iter = float_initializers.find(name);
return iter != float_initializers.end() ? &iter->second : nullptr;
}
const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const {
auto iter = float_initializers.find(name);
return iter != float_initializers.end() ? &iter->second : nullptr;
}

void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index,
/*out*/ gsl::span<const float>& data,
/*out*/ std::vector<int64_t>& shape) const {
Ort::ConstValue input = kernel_context.GetInput(index);
auto type_shape = input.GetTensorTypeAndShapeInfo();
void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index,
/*out*/ gsl::span<const float>& data,
/*out*/ std::vector<int64_t>& shape) const {
Ort::ConstValue input = kernel_context.GetInput(index);
auto type_shape = input.GetTensorTypeAndShapeInfo();

ONNXTensorElementDataType elem_type = type_shape.GetElementType();
if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL);
ONNXTensorElementDataType elem_type = type_shape.GetElementType();
if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL);

const float* float_data = input.GetTensorData<float>();
size_t num_elems = type_shape.GetElementCount();
data = gsl::span<const float>(float_data, num_elems);
shape = type_shape.GetShape();
}

OrtStatus* Compute(OrtKernelContext* kernel_ctx) {
RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
"MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__));
Ort::KernelContext kernel_context(kernel_ctx);
try {
gsl::span<const float> input0;
gsl::span<const float> input1;
std::vector<int64_t> shape0;
std::vector<int64_t> shape1;

size_t num_inputs = kernel_context.GetInputCount();

if (num_inputs == 2) {
// Both inputs are non-constant. Get them from ORT's KernelContext.
GetInputDataAndShape(kernel_context, 0, input0, shape0);
GetInputDataAndShape(kernel_context, 1, input1, shape1);
} else if (num_inputs == 1) {
// ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs.
// Get the constant input from the initializers saved by the EP.
// Refer to "NodeFusionOptions_DropConstantInitializers()".

if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input1, shape1);
input0 = gsl::span<const float>(const_input0->data);
shape0 = const_input0->shape;
} else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input0, shape0);
input1 = gsl::span<const float>(const_input1->data);
shape1 = const_input1->shape;
}
} else {
// Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding)
// are disabled.
const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name);
const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name);
RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api,
"Expected 2 initializer inputs to be saved by EP");
const float* float_data = input.GetTensorData<float>();
size_t num_elems = type_shape.GetElementCount();
data = gsl::span<const float>(float_data, num_elems);
shape = type_shape.GetShape();
}

OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) {
RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
"MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__));
Ort::KernelContext kernel_context(kernel_ctx);
try {
gsl::span<const float> input0;
gsl::span<const float> input1;
std::vector<int64_t> shape0;
std::vector<int64_t> shape1;

size_t num_inputs = kernel_context.GetInputCount();

if (num_inputs == 2) {
// Both inputs are non-constant. Get them from ORT's KernelContext.
GetInputDataAndShape(kernel_context, 0, input0, shape0);
GetInputDataAndShape(kernel_context, 1, input1, shape1);
} else if (num_inputs == 1) {
// ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs.
// Get the constant input from the initializers saved by the EP.
// Refer to "NodeFusionOptions_DropConstantInitializers()".

if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input1, shape1);
input0 = gsl::span<const float>(const_input0->data);
input1 = gsl::span<const float>(const_input1->data);
shape0 = const_input0->shape;
} else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) {
GetInputDataAndShape(kernel_context, 0, input0, shape0);
input1 = gsl::span<const float>(const_input1->data);
shape1 = const_input1->shape;
}
} else {
// Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding)
// are disabled.
const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name);
const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name);
RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api,
"Expected 2 initializer inputs to be saved by EP");

input0 = gsl::span<const float>(const_input0->data);
input1 = gsl::span<const float>(const_input1->data);
shape0 = const_input0->shape;
shape1 = const_input1->shape;
}

if (shape0 != shape1) {
throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT);
}
if (shape0 != shape1) {
throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT);
}

size_t num_outputs = kernel_context.GetOutputCount();
if (num_outputs != 1) {
throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT);
}
size_t num_outputs = kernel_context.GetOutputCount();
if (num_outputs != 1) {
throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT);
}

auto output = kernel_context.GetOutput(0, shape0);
float* output_data = output.GetTensorMutableData<float>();
auto output = kernel_context.GetOutput(0, shape0);
float* output_data = output.GetTensorMutableData<float>();

for (size_t i = 0; i < input0.size(); ++i) {
output_data[i] = input0[i] * input1[i];
}
} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
} catch (const std::exception& ex) {
Ort::Status status(ex.what(), ORT_EP_FAIL);
return status.release();
for (size_t i = 0; i < input0.size(); ++i) {
output_data[i] = input0[i] * input1[i];
}

return nullptr;
} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
} catch (const std::exception& ex) {
Ort::Status status(ex.what(), ORT_EP_FAIL);
return status.release();
}

const OrtApi& ort_api;
const OrtLogger& logger;
const std::unordered_map<std::string, FloatInitializer>& float_initializers;
std::string input0_name;
std::string input1_name;
};
return nullptr;
}

/// <summary>
/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
Expand Down Expand Up @@ -262,26 +242,35 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG

supported_nodes.push_back(node); // Only support a single Mul for now.
break;
} else if (op_type == "Custom_Mul") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also check the domain?

Copy link
Contributor Author

@chilo-ms chilo-ms Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to. As when the model is loading, ORT already checks the domain.
Also, we can't really get the domain info given only an op, can we? i might be wrong

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the op domain, like node.GetDomain(). domain + op type identifies the op. here we are matching an op from a custom domain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, okay, i added the domain check.

supported_nodes.push_back(node);
}
}

if (supported_nodes.empty()) {
return nullptr;
}

// Create (optional) fusion options for the supported nodes to fuse.
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

// Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers
// as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers.
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
if (supported_nodes[0].GetOperatorType() == "Mul") {
// Create (optional) fusion options for the supported nodes to fuse.
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

// Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers
// as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers.
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
} else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") {
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
// as CustomMul has the concrete kernel implementation.
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0]));
}

} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
Expand Down
Loading
Loading