-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[EP ABI] Add CreateCustomOpDomains() API for plugin EP to register custom ops #26759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
onnxruntime/core/session/utils.cc
Outdated
| return false; | ||
| }; | ||
|
|
||
| if (ep_factory->CreateCustomOpDomains == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ep_factory->CreateCustomOpDomains does not exist for ORT 1.23. Think we need a check here that skips a factory with version older than 24.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the check.
|
|
||
| /*static*/ | ||
| OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, | ||
| _Outptr_result_maybenull_ OrtCustomOpDomain** out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a OrtCustomOpDomain** parameter allows this function to return more than 1 domain, right?
Maybe it would be easier to have a separate API that returns the number of domains (e.g., OrtEpFactory::GetNumberOfCustomOpDomains()), then ORT allocates an array of the right size, ORT calls OrtEpFactory::CreateCustomOpDomains with the array it allocated, and then the EP factory just fills the array.
The EP is responsible for calling ReleaseCustomOpDomain (happens automatically if using C++ api). ORT is responsible for managing the memory for the array of domains (can just be a std::vector<OrtCustomOpDomain*>)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
I update the CreateCustomOpDomains() and added GetNumCustomOpDomains() so that it now can support more than one domain.
| compute_stream_(compute_stream) { | ||
| } | ||
|
|
||
| void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is this called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the unit test, so now the CreateKernel will be called by ORT.
| } | ||
|
|
||
| void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { | ||
| return new PluginEpCustomKernel(info, compute_stream_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this lead to a memory leak?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't leak.
ExampleEpCustomOp derives from CustomOpBase so the KernelDestroy will be called during clean up.
struct CustomOpBase : OrtCustomOp {
...
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
... There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, | ||
| _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OrtCustomOpDomain!
This structure is not ABI safe or stable across boundaries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The OrtCustomOpDomain is an opaque struct to EP, and EP can create it by calling C API CreateCustomOpDomain() to get a pointer of an OrtCustomOpDomain instance and then add OrtCustomOp instances to it.
I think it's ABI stable as it's just a C pointer.
Please see the implementation in Example EP in the unit test as a reference.
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, | ||
| _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expectation for OrtCustomOpDomain** domains this will contain a deep copy or shallow copy of the OrtCustomOp that ep provides?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The OrtCustomOpDomain contains a shallow copy (a pointer) of the OrtCustomOp.
The reason is:
EP should call OrtApis::CustomOpDomain_Add() to add OrtCustomOp to OrtCustomOpDomain.
As per api implementation, it simply stores the pointer of the OrtCustomOp in the vector, not constructing a new OrtCustomOp
ORT_API_STATUS_IMPL(OrtApis::CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op) {
API_IMPL_BEGIN
custom_op_domain->custom_ops_.emplace_back(op);
return nullptr;
API_IMPL_END
}Then, inside ORT, it simply stores the pointer of OrtCustomOpDomain returned from EP into session options.
Status AddEpOptionsToSessionOptions() {
...
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory,
domains.data(),
domains.size())));
const auto domains_span = gsl::span<OrtCustomOpDomain*>(domains.data(), domains.size());
for (auto domain : domains_span) {
if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) &&
domain->custom_ops_.size() > 0) {
ort_session_options.custom_op_domains_.push_back(domain);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name "
<< domain->domain_ << " is already in the session option. Skip it.";
}
}
...
}
adrianlizarraga
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor things
onnxruntime/core/session/utils.cc
Outdated
| OrtEpFactory* ep_factory = ep_device->ep_factory; | ||
| if (ep_factory && | ||
| ep_factory->ort_version_supported >= 24 && | ||
| ep_factory->CreateCustomOpDomains != nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to check ep_factory->GetNumCustomOpDomains just in case?
| return nullptr; | ||
| } | ||
|
|
||
| virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need GetNumCustomOpDomains in here?
| // | ||
| // Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. | ||
|
|
||
| struct CustomMulKernel : MulKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could these custom op classes be placed in a separate file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
| #include "onnxruntime_c_api.h" | ||
| #include "ep.h" | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens | ||
| * automatically if using ORT C++ api. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, with a name like "CreateCustomOpDomains", there's some expectation of ownership transfer. E.g., how OrtApi::CreateCustomOpDomain() creates a new domain that the user is responsible for releasing.
since this function does not transfer ownership to the caller, maybe a name like GetCustomOpDomains() would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense to me to have a name GetCustomOpDomains(). Changed.
| } | ||
| }; | ||
|
|
||
| struct ExampleEpCustomOp : Ort::CustomOpBase<ExampleEpCustomOp, CustomMulKernel> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CustomMulKernel::Compute() returns OrtStatus*. should this inherit from CustomOpBase<ExampleEpCustomOp, CustomMulKernel, /*WithStatus*/ true> instead so that the OrtStatus will be checked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I also used CreateKernelV2 and ComputeKernerV2 so that OrtStatus will be checked.
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, | ||
| _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _In_ size_t num_domains); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is _Outptr_result_maybenull_ used for domains?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think EP can still implement this function and creates no custom op domain.
Or do you think _Outptr_result_ makes more sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, would _Out_writes_all_(num_domains) make sense here? assuming that num_domains is the value returned by GetNumCustomOpDomains() - perhaps that should also be documented. the implementation is expected to treat domains as a buffer.
| return Status::OK(); | ||
| } | ||
|
|
||
| Status AddEpOptionsToSessionOptions(gsl::span<const OrtEpDevice* const> ep_devices, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not called during auto ep selection path, when the session_options.set_provider_selection_policy(ort.OrtExecutionProviderDevicePolicy.PREFER_GPU) is set?
For auto ep path when it creates plugin ep it goes for inference_session.cc - RegisterExecutionProvider() where it calls GetCustomOpDomainList? Is this understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of EP using GetCustomOpDomains to register custom op, only this EP can run the model contains that custom op.
IMO, the application should explicitly call SessionOptionsAppendExecutionProvider_V2() and specify that ep device, rather than using auto ep selection as other devices might not be able to run that custom op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registering custom ops should be independent of EP , so it should be available in auto ep selectin path as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
| auto* factory = static_cast<ExampleEpFactory*>(this_ptr); | ||
|
|
||
| // Custom Op Domains | ||
| factory->custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be better to only create the custom op domains once, e.g., in the factory ctor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that's good suggestion. Moved to factor ctor.
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, | ||
| _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _In_ size_t num_domains); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, would _Out_writes_all_(num_domains) make sense here? assuming that num_domains is the value returned by GetNumCustomOpDomains() - perhaps that should also be documented. the implementation is expected to treat domains as a buffer.
| * This function is used when running inference on a model that contains EP-specific custom operations. | ||
| * | ||
| * Workflow: | ||
| * 1. The EP implements this function to supply a list of OrtCustomOpDomain instances. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should clarify that the custom op domains are provided by the EP factory, not the EP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, modified to EP factory
| * - In Compile(), the EP executes its compiled bits to perform inference for | ||
| * the fused custom node. | ||
| * | ||
| * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be the EP factory and not the EP that has this responsibility, right? and the instances must be valid while any session is using them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
| */ | ||
| ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); | ||
|
|
||
| /** \brief Returns the number of OrtCustomOpDomains that this factory creates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we renamed it to GetCustomOpDomains, maybe replace "creates" with "provides" or "supplies".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced with "provides"
|
|
||
| supported_nodes.push_back(node); // Only support a single Mul for now. | ||
| break; | ||
| } else if (op_type == "Custom_Mul") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also check the domain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need to. As when the model is loading, ORT already checks the domain.
Also, we can't really get the domain info given only an op, can we? i might be wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the op domain, like node.GetDomain(). domain + op type identifies the op. here we are matching an op from a custom domain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, okay, i added the domain check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Description
The newly added two APIs,
CreateCustomOpDomains()andGetNumCustomOpDomains, are used when running inference on a model that contains EP-specific custom operations.Workflow:
OrtCustomOpDomaininstances.SessionOptionsAppendExecutionProvider_V2()with theOrtEpDevicecontainingthe plugin EP's factory.
SessionOptionsAppendExecutionProvider_V2()internally appends the providedOrtCustomOpDomainlist to thesession options.
As a result, any session created from these session options will have these custom op domains registered
in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded.
Plugin EPs can provide two types of custom ops:
A full OrtCustomOp with a concrete kernel implementation
that the custom node should NOT be fused or compiled. Instead, ORT should invoke
the custom node's Compute() function at runtime.
A "placeholder" OrtCustomOp with an empty kernel implementation
does nothing. The purpose is to satisfy model validation during model loading by
registering the custom op as a valid operator in the session.
notify ORT that this custom node should be fused and compiled by the EP.
the fused custom node.
Motivation and Context
Currently, the provider-bridge TRT RTX EP and TRT EP supports registering custom op domain list in session option so
that it can run model contains TRT specific custom ops.
This PR adds the same feature for plugin EP.