Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4f3d004
[EP ABI] Add weight pre-packing support to kernel-based plugin EPs
adrianlizarraga Dec 8, 2025
e954021
Add comment about sharing of prepacked weights (cpu ep only)
adrianlizarraga Dec 8, 2025
fb2998b
Update Mul kernel to pre-pack input b
adrianlizarraga Dec 9, 2025
5e64f79
Apply suggestions from code review
adrianlizarraga Dec 9, 2025
9b1c6a2
Add comments regarding prepack allocator lifetime
adrianlizarraga Dec 9, 2025
c638a1a
Merge branch 'adrianl/plugin-ep-kernel-prepack' of github.com:microso…
adrianlizarraga Dec 9, 2025
717ed4a
Added support for sharing pre-packed weights for cpu-accessible alloc…
adrianlizarraga Dec 11, 2025
bd8f6f0
Define what should happen if OrtKernelImpl::SetSharedPrePackedWeight(…
adrianlizarraga Dec 12, 2025
5c94ec4
Add some missing OrtKernelInfo APIs for plugin EPs
adrianlizarraga Dec 15, 2025
fc1fd16
Merge branch 'main' into adrianl/plugin-ep-kernel-prepack
adrianlizarraga Dec 16, 2025
8b3f56c
Clean up some exception handling
adrianlizarraga Dec 16, 2025
23503a1
Refactor example kernel classes (no inheritance)
adrianlizarraga Dec 17, 2025
7f37ffb
Merge branch 'main' into adrianl/plugin-ep-kernel-prepack
adrianlizarraga Dec 17, 2025
26eca56
Correct use of output param
adrianlizarraga Dec 17, 2025
7af257b
Add more edge-case handling for PrePack() call
adrianlizarraga Dec 17, 2025
515062e
API version checks
adrianlizarraga Dec 17, 2025
347ce4f
Use correct SAL annotation for array parameters
adrianlizarraga Dec 18, 2025
906187d
Clean up some includes
adrianlizarraga Dec 18, 2025
56074f6
Merge branch 'main' into adrianl/KernelPluginEp_KernelInfoApis
adrianlizarraga Dec 18, 2025
1611fc3
Update onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc
adrianlizarraga Dec 19, 2025
cb75d5a
Rename ExecuteIfCustomOpsApiEnabled to ExecuteifKernelApiEnabled
adrianlizarraga Dec 19, 2025
30ca590
Remove OrtAllocator parameter from SharedPrePackedWeightCache_StoreWe…
adrianlizarraga Dec 19, 2025
a5342b9
Clarify what happens when SharedPrePackedWeightCache_StoreWeightData …
adrianlizarraga Dec 20, 2025
c42d7ac
Merge branch 'main' into adrianl/KernelPluginEp_KernelInfoApis
adrianlizarraga Dec 22, 2025
3323330
Move some kernel info apis to OrtApi; Fix unrelated memory leak for K…
adrianlizarraga Dec 22, 2025
51bc731
Merge branch 'main' into adrianl/plugin-ep-kernel-prepack
adrianlizarraga Dec 22, 2025
edf3f2c
Review comments
adrianlizarraga Dec 23, 2025
e94c0aa
C++ API
adrianlizarraga Dec 23, 2025
c8eb3c9
Improve doc for c++ api convenience class
adrianlizarraga Dec 23, 2025
98e3d13
Add buffer_sizes as a parameter to OrtKernelImpl::SetSharedWeightData
adrianlizarraga Dec 24, 2025
c61ae41
Add comment to implementation of OrtKernelImpl::SetSharedPrePackedWeight
adrianlizarraga Dec 24, 2025
02d75d2
Do not prescribe what the kernel impl should return for a situation t…
adrianlizarraga Dec 24, 2025
0a84eda
Update include/onnxruntime/core/session/onnxruntime_ep_c_api.h
adrianlizarraga Dec 24, 2025
5f80f9d
Adjust comments
adrianlizarraga Dec 24, 2025
c60472d
Tweak comment again
adrianlizarraga Dec 24, 2025
12b8394
Merge main and fix conflicts
adrianlizarraga Dec 26, 2025
52453f2
Update example kernel EP to get the OrtEp from OrtKernelInfo and add …
adrianlizarraga Dec 26, 2025
79eadd8
Address copilot review comments
adrianlizarraga Dec 26, 2025
1dc98b2
Address copilot review comment: update error message when APIs for Or…
adrianlizarraga Dec 26, 2025
441c9e2
Add comments to clarify ownership scenarios
adrianlizarraga Dec 27, 2025
5cf9600
Merge branch 'adrianl/plugin-ep-kernel-prepack' into adrianl/KernelPl…
adrianlizarraga Dec 27, 2025
ec0183e
Merge main and fix conflicts
adrianlizarraga Dec 29, 2025
7e89261
Add support for Add kernel to test kernel-based EP. This used as moti…
adrianlizarraga Dec 29, 2025
1ddc9fb
Rename variable from 'mul_kernel' to 'binary_op_kernel'
adrianlizarraga Dec 29, 2025
9e6b06e
Address review comments and fix compiler warnings on Linux
adrianlizarraga Dec 29, 2025
1d496a2
Move KernelContext_GetAllocator memory leak fix to a different PR
adrianlizarraga Dec 30, 2025
6378bf2
Simplify binary op type check
adrianlizarraga Dec 30, 2025
eadc029
Merge branch 'main' into adrianl/KernelPluginEp_KernelInfoApis
adrianlizarraga Jan 5, 2026
cad57e7
Review comments + add new API to get op domain
adrianlizarraga Jan 6, 2026
4fb2be9
Check for NULL size arg in API call. Update comment
adrianlizarraga Jan 6, 2026
ebe58ed
Update SAL annotations for older KernelInfo_* API functions
adrianlizarraga Jan 6, 2026
3968ec3
Update example BinaryOp class to support Sub instead of Add
adrianlizarraga Jan 6, 2026
df45a2f
Update comment for KernelInfo_GetNodeName
adrianlizarraga Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ class IExecutionProvider {
return InlinedVector<const Node*>();
}

/**
* Returns a the underlying OrtEp instance if this IExecutionProvider wraps a plugin EP.
* Otherwise, returns a nullptr (default implementation).
*/
virtual const OrtEp* GetOrtEp() const {
return nullptr;
}

private:
const std::string type_;

Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2779,6 +2779,10 @@ struct KernelInfoImpl : Base<T> {
Logger GetLogger() const;

KeyValuePairs GetConfigEntries() const;

std::string GetOperatorType() const; ///< Wraps KernelInfo_GetOperatorType
int GetSinceVersion() const; ///< Wraps KernelInfo_GetSinceVersion
const OrtEp* GetEp() const; ///< Wraps KernelInfo_GetEp
};

} // namespace detail
Expand Down
21 changes: 21 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2842,6 +2842,27 @@ inline KeyValuePairs KernelInfoImpl<T>::GetConfigEntries() const {
return KeyValuePairs{out};
}

template <typename T>
inline std::string KernelInfoImpl<T>::GetOperatorType() const {
const char* op_type = nullptr;
Ort::ThrowOnError(GetEpApi().KernelInfo_GetOperatorType(this->p_, &op_type));
return std::string{op_type};
}

template <typename T>
inline int KernelInfoImpl<T>::GetSinceVersion() const {
int out = 0;
ThrowOnError(GetEpApi().KernelInfo_GetSinceVersion(this->p_, &out));
return out;
}

template <typename T>
inline const OrtEp* KernelInfoImpl<T>::GetEp() const {
const OrtEp* ep = nullptr;
ThrowOnError(GetEpApi().KernelInfo_GetEp(this->p_, &ep));
return ep;
}

inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
}
Expand Down
36 changes: 36 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,42 @@ struct OrtEpApi {
*/
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);

/** \brief Get the graph node operator type from OrtKernelInfo.
*
* \note Used within OrtKernelImpl implementations to obtain operator information.
*
* \param[in] info An instance of ::OrtKernelInfo.
* \param[out] operator_type Output parameter set to the name of the node's operator type.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Outptr_ const char** operator_type);

/** \brief Get the opset version in which the given node's operator type was first defined from OrtKernelInfo.
*
* \note Used within OrtKernelImpl implementations to obtain operator information.
*
* \param[in] info The ::OrtKernelInfo instance.
* \param[out] since_version The opset version in which the node's operator type was first defined.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version);

/** \brief Get the OrtEp instance to which the node is assigned from the OrtKernelInfo.
*
* \note Used within OrtKernelImpl implementations to obtain a reference to the OrtEp.
*
* \param[in] info The ::OrtKernelInfo instance.
* \param[out] ep Output parameter set to the OrtEp instance associated with the OrtKernelInfo.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);
};

/**
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,59 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo*
API_IMPL_END
}

ORT_API_STATUS_IMPL(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Outptr_ const char** operator_type) {
API_IMPL_BEGIN
if (operator_type == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter for the operator type");
}

auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
*operator_type = op_info->node().OpType().c_str();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(KernelInfo_GetSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version) {
API_IMPL_BEGIN
if (since_version == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter for the operator type");
}

auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
*since_version = op_info->node().SinceVersion();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep) {
API_IMPL_BEGIN
if (ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter in which to store the OrtEp instance");
}

auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
auto internal_ep = op_info->GetExecutionProvider();

if (internal_ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtKernelInfo does not have a valid reference to an execution provider instance");
}

const OrtEp* ort_ep = internal_ep->GetOrtEp();

if (ort_ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtKernelInfo is not associated with a plugin EP (OrtEp) instance.");
}

*ep = ort_ep;
return nullptr;
API_IMPL_END
}

static constexpr OrtEpApi ort_ep_api = {
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
// and no functions can be removed (the implementation needs to change to return an error).
Expand Down Expand Up @@ -636,6 +689,9 @@ static constexpr OrtEpApi ort_ep_api = {
&OrtExecutionProviderApi::KernelDef_GetOutputMemType,
&OrtExecutionProviderApi::GetTensorDataType,
&OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel,
&OrtExecutionProviderApi::KernelInfo_GetOperatorType,
&OrtExecutionProviderApi::KernelInfo_GetSinceVersion,
&OrtExecutionProviderApi::KernelInfo_GetEp,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,9 @@ ORT_API_STATUS_IMPL(GetTensorDataType, _In_ ONNXTensorElementDataType elem_type,
_Outptr_ const OrtDataType** out);
ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);

// KernelInfo
ORT_API_STATUS_IMPL(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Outptr_ const char** operator_type);
ORT_API_STATUS_IMPL(KernelInfo_GetSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version);
ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);
} // namespace OrtExecutionProviderApi
Original file line number Diff line number Diff line change
Expand Up @@ -765,4 +765,8 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std
return Status::OK();
}

const OrtEp* PluginExecutionProvider::GetOrtEp() const {
return ort_ep_.get();
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class PluginExecutionProvider : public IExecutionProvider {
Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const override;

const OrtEp* GetOrtEp() const override;

private:
struct FusedNodeState {
FusedNodeState() = default;
Expand Down
Loading