diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4d202af13165b..01c7fa3ac1126 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -2133,8 +2133,8 @@ if (onnxruntime_BUILD_SHARED_LIB AND "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc" - "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h" - "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc") + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc") onnxruntime_add_shared_library_module(example_plugin_ep_kernel_registry ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src}) target_include_directories(example_plugin_ep_kernel_registry PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep_kernel_registry PRIVATE onnxruntime ${GSL_TARGET}) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index f54f4a5a6f1ef..923174cbfe488 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -416,6 +416,15 @@ class IExecutionProvider { return InlinedVector(); } + /** + * Returns the underlying OrtEp instance if this IExecutionProvider wraps a plugin EP. + * Otherwise, returns a nullptr (default implementation). + * This is used to retrieve the OrtEp instance from a OrtKernelInfo instance in a plugin EP's kernel implementation. + */ + virtual const OrtEp* GetOrtEp() const { + return nullptr; + } + private: const std::string type_; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 68899d75e9294..3c85491837793 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4204,7 +4204,7 @@ struct OrtApi { * If `out` is nullptr, the value of `size` is set to the size of the name * string (including null-terminator), and a success status is returned. * - * If the `size` parameter is greater than or equal to the name string's size, + * If the `size` parameter is greater than or equal to the name string's size and `out` is not nullptr, * the value of `size` is set to the true size of the string (including null-terminator), * the provided memory is filled with the string's contents, and a success status is returned. * @@ -4220,7 +4220,7 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value * \since Version 1.14 */ - ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_opt_ char* out, _Inout_ size_t* size); /** \brief Get the name of a ::OrtKernelInfo's output. @@ -4231,7 +4231,7 @@ struct OrtApi { * If `out` is nullptr, the value of `size` is set to the size of the name * string (including null-terminator), and a success status is returned. * - * If the `size` parameter is greater than or equal to the name string's size, + * If the `size` parameter is greater than or equal to the name string's size and `out` is not nullptr, * the value of `size` is set to the true size of the string (including null-terminator), * the provided memory is filled with the string's contents, and a success status is returned. * @@ -4248,7 +4248,7 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value * \since Version 1.14 */ - ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_opt_ char* out, _Inout_ size_t* size); /** \brief Get the type information for a ::OrtKernelInfo's input. @@ -4428,7 +4428,7 @@ struct OrtApi { * If `out` is nullptr, the value of `size` is set to the size of the name * string (including null-terminator), and a success status is returned. * - * If the `size` parameter is greater than or equal to the name string's size, + * If the `size` parameter is greater than or equal to the name string's size and `out` is not nullptr, * the value of `size` is set to the true size of the string (including null-terminator), * the provided memory is filled with the string's contents, and a success status is returned. * @@ -4445,7 +4445,7 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value * \since Version 1.15 */ - ORT_API2_STATUS(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size); + ORT_API2_STATUS(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, _Inout_ size_t* size); /** \brief Get the session logger from ::OrtKernelInfo. * @@ -6608,6 +6608,65 @@ struct OrtApi { * \since Version 1.24 */ ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Get the graph node's operator domain from ::OrtKernelInfo. + * + * If `out` is nullptr, the value of `size` is set to the size of the operator domain + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the string's size and `out` is not nullptr, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status with error code ORT_INVALID_ARGUMENT is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the + * operator domain. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetOperatorDomain, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, + _Inout_ size_t* size); + + /** \brief Get the graph node's operator type from ::OrtKernelInfo. + * + * If `out` is nullptr, the value of `size` is set to the size of the operator type + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the string's size and `out` is not nullptr, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status with error code ORT_INVALID_ARGUMENT is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the + * operator type. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, + _Inout_ size_t* size); + + /** \brief Get the opset version in which the given node's operator type was first defined from ::OrtKernelInfo. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] since_version The opset version in which the node's operator type was first defined. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetOperatorSinceVersion, _In_ const OrtKernelInfo* info, + _Out_ int* since_version); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index fd4d9a683b7cd..d98757b2379a8 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2782,6 +2782,11 @@ struct KernelInfoImpl : Base { Logger GetLogger() const; KeyValuePairs GetConfigEntries() const; + + std::string GetOperatorDomain() const; ///< Wraps OrtApi::KernelInfo_GetOperatorDomain + std::string GetOperatorType() const; ///< Wraps OrtApi::KernelInfo_GetOperatorType + int GetOperatorSinceVersion() const; ///< Wraps OrtApi::KernelInfo_GetOperatorSinceVersion + const OrtEp* GetEp() const; ///< Wraps OrtEpApi::KernelInfo_GetEp }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 63dfc85560a39..b7e1156f38a34 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2842,6 +2842,50 @@ inline KeyValuePairs KernelInfoImpl::GetConfigEntries() const { return KeyValuePairs{out}; } +template +inline std::string KernelInfoImpl::GetOperatorDomain() const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetOperatorDomain(this->p_, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetOperatorDomain(this->p_, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline std::string KernelInfoImpl::GetOperatorType() const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetOperatorType(this->p_, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetOperatorType(this->p_, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline int KernelInfoImpl::GetOperatorSinceVersion() const { + int out = 0; + Ort::ThrowOnError(GetApi().KernelInfo_GetOperatorSinceVersion(this->p_, &out)); + return out; +} + +template +inline const OrtEp* KernelInfoImpl::GetEp() const { + const OrtEp* ep = nullptr; + Ort::ThrowOnError(GetEpApi().KernelInfo_GetEp(this->p_, &ep)); + return ep; +} + inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); } diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 85981d875fd70..bd7f63f53ed8c 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -971,6 +971,18 @@ struct OrtEpApi { _In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes, _In_ size_t num_buffers); + + /** \brief Get the OrtEp instance to which the node is assigned from the OrtKernelInfo. + * + * \note Used within OrtKernelImpl implementations to obtain a reference to the OrtEp. + * + * \param[in] info The ::OrtKernelInfo instance. + * \param[out] ep Output parameter set to the OrtEp instance associated with the OrtKernelInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep); }; /** diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 612497ccfd845..bafa1b88599d0 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -27,11 +27,12 @@ #include "core/session/ort_apis.h" #include "core/platform/threadpool.h" -// NOTE: OrtKernelContext is used by both custom ops and compiled kernels. -// In a minimal build, ORT_EXTENDED_MINIMAL_BUILD is used to enable EPs like CoreML/NNAPI which use compiled kernels, -// and ORT_MINIMAL_BUILD_CUSTOM_OPS is used to allow external custom op libraries to be used. +// NOTE: OrtKernelContext/OrtKernelInfo are used by both custom ops and kernels for both plugin and provider-bridge EPs. +// In a minimal build, ORT_EXTENDED_MINIMAL_BUILD is used to enable EPs like CoreML/NNAPI which use compiled kernels. +// ORT_MINIMAL_BUILD_CUSTOM_OPS is used to allow external custom op libraries to be used. +// Plugin EPs (with registered and compiled kernels) are enabled in non-minimal builds. #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -#define ENABLE_ORT_KERNEL_CONTEXT_API 1 +#define ENABLE_ORT_KERNEL_API 1 #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -143,29 +144,30 @@ struct OrtShapeInferContext { }; #endif -#if ENABLE_ORT_KERNEL_CONTEXT_API +#if ENABLE_ORT_KERNEL_API template -static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T& fn) { +static OrtStatusPtr ExecuteIfKernelApiEnabled(const T& fn) { API_IMPL_BEGIN return fn(); API_IMPL_END } #else template -static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T&) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "OrtKernelContext API is not enabled in this build"); +static OrtStatusPtr ExecuteIfKernelApiEnabled(const T&) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "APIs for OrtKernelContext and OrtKernelInfo are not enabled in this build"); } #endif ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { *out = reinterpret_cast(context)->InputCount(); return nullptr; }); }; ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { *out = reinterpret_cast(context)->OutputCount(); return nullptr; }); @@ -173,7 +175,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelC ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* ctx = reinterpret_cast(context); *out = reinterpret_cast(ctx->GetInputMLValue(onnxruntime::narrow(index))); return nullptr; @@ -182,7 +184,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { onnxruntime::TensorShape shape(dim_values, dim_count); auto* ctx = reinterpret_cast(context); *out = reinterpret_cast(ctx->OutputMLValue(onnxruntime::narrow(index), shape)); @@ -192,7 +194,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { auto* stream = reinterpret_cast(context)->GetComputeStream(); if (stream) *out = stream->GetHandle(); @@ -204,7 +206,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKe ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* ctx = reinterpret_cast(context); onnxruntime::AllocatorPtr allocator = ctx->GetAllocator(mem_info->device); if (!allocator) { @@ -219,7 +221,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelCon ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, _In_ int resource_id, _Outptr_ void** resource) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { *resource = {}; const auto* ctx = reinterpret_cast(context); auto* stream = reinterpret_cast(ctx->GetComputeStream()); @@ -233,7 +235,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelCont ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { if (!context) { return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context"); } @@ -259,7 +261,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelCont ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto& kernel_ctx_logger = reinterpret_cast(context)->Logger(); *logger = kernel_ctx_logger.ToExternal(); @@ -267,11 +269,11 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContex }); } -// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger +// Enabled via ExecuteIfKernelApiEnabled due to KernelContext_GetLogger ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, _In_z_ const char* func_name) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto& actual_logger = *reinterpret_cast(logger); const auto severity = static_cast(log_severity_level); const auto log_data_type = onnxruntime::logging::DataType::SYSTEM; @@ -298,10 +300,10 @@ ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, Or }); } -// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger +// Enabled via ExecuteIfKernelApiEnabled due to KernelContext_GetLogger ORT_API_STATUS_IMPL(OrtApis::Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out) { - return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto& actual_logger = *reinterpret_cast(logger); *out = static_cast(actual_logger.GetSeverity()); return nullptr; @@ -486,7 +488,7 @@ ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ Ort ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { auto status = reinterpret_cast(info)->GetAttr(name, out); if (status.IsOK()) return nullptr; @@ -496,7 +498,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelI ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { auto status = reinterpret_cast(info)->GetAttr(name, out); if (status.IsOK()) return nullptr; @@ -506,7 +508,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelI ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { std::string value; auto status = reinterpret_cast(info)->GetAttr(name, &value); if (status.IsOK()) { @@ -545,7 +547,7 @@ static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, s ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { std::vector values; auto status = reinterpret_cast(info)->GetAttrs(name, values); if (status.IsOK()) { @@ -557,7 +559,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKe ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { std::vector values; auto status = reinterpret_cast(info)->GetAttrs(name, values); if (status.IsOK()) { @@ -569,7 +571,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKe ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_kinfo = reinterpret_cast(info); // Get TensorProto attribute @@ -609,22 +611,22 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernel } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { *out = reinterpret_cast(info)->GetInputCount(); return nullptr; }); }; ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { *out = reinterpret_cast(info)->GetOutputCount(); return nullptr; }); }; ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, - _Out_ char* out, _Inout_ size_t* size) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + _Out_opt_ char* out, _Inout_ size_t* size) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); const auto input_defs = op_info->node().InputDefs(); @@ -639,9 +641,9 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* }); } -ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, - _Inout_ size_t* size) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, + _Out_opt_ char* out, _Inout_ size_t* size) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); const auto output_defs = op_info->node().OutputDefs(); @@ -659,7 +661,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputName, _In_ const OrtKernelInfo* ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, _Outptr_ OrtTypeInfo** type_info) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); const auto input_defs = op_info->node().InputDefs(); @@ -682,7 +684,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelIn ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, _Outptr_ OrtTypeInfo** type_info) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); const auto output_defs = op_info->node().OutputDefs(); @@ -705,16 +707,16 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelI ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, _In_ size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); *is_constant = static_cast(op_info->TryGetConstantInput(gsl::narrow_cast(index), out)); return nullptr; }); }; -ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, _Inout_ size_t* size) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); auto status = CopyStringToOutputArg(op_info->node().Name(), @@ -724,8 +726,49 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetNodeName, _In_ const OrtKernelInfo* i }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOperatorDomain, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, + _Inout_ size_t* size) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + + auto status = CopyStringToOutputArg(op_info->node().Domain(), + "Output buffer is not large enough for ::OrtKernelInfo's operator domain", + out, size); + + return onnxruntime::ToOrtStatus(status); + }); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, + _Inout_ size_t* size) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + + auto status = CopyStringToOutputArg(op_info->node().OpType(), + "Output buffer is not large enough for ::OrtKernelInfo's operator type", + out, size); + + return onnxruntime::ToOrtStatus(status); + }); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOperatorSinceVersion, _In_ const OrtKernelInfo* info, + _Out_ int* since_version) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + + if (since_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter for ::OrtKernelInfo's since version is NULL"); + } + + *since_version = op_info->node().SinceVersion(); + return nullptr; + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* ep = reinterpret_cast(info)->GetExecutionProvider(); if (ep == nullptr) { @@ -746,7 +789,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf } ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { onnxruntime::AllocatorPtr allocator = reinterpret_cast(info)->GetAllocator(mem_type); if (!allocator) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); @@ -758,7 +801,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* i } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out) { - return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { const auto* op_info = reinterpret_cast(info); const auto& config_options_map = op_info->GetConfigOptions().GetConfigOptionsMap(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 82f7cef4aec49..b1c2e07b9ffb7 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4238,6 +4238,9 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::TensorTypeAndShape_HasShape, &OrtApis::KernelInfo_GetConfigEntries, + &OrtApis::KernelInfo_GetOperatorDomain, + &OrtApis::KernelInfo_GetOperatorType, + &OrtApis::KernelInfo_GetOperatorSinceVersion, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f3525d8de7b95..ccdfa53e1b225 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -428,9 +428,9 @@ ORT_API_STATUS_IMPL(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* o ORT_API_STATUS_IMPL(KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); -ORT_API_STATUS_IMPL(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, +ORT_API_STATUS_IMPL(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_opt_ char* out, _Inout_ size_t* size); -ORT_API_STATUS_IMPL(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, +ORT_API_STATUS_IMPL(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_opt_ char* out, _Inout_ size_t* size); ORT_API_STATUS_IMPL(KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, _Outptr_ OrtTypeInfo** type_info); @@ -455,7 +455,7 @@ ORT_API_STATUS_IMPL(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOp _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); ORT_API(void, ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProviderOptions*); -ORT_API_STATUS_IMPL(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size); +ORT_API_STATUS_IMPL(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, _Inout_ size_t* size); ORT_API_STATUS_IMPL(KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger); ORT_API_STATUS_IMPL(KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger); @@ -754,5 +754,11 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_ size_t num_tensors); ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); +ORT_API_STATUS_IMPL(KernelInfo_GetOperatorDomain, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, + _Inout_ size_t* size); +ORT_API_STATUS_IMPL(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Out_opt_ char* out, + _Inout_ size_t* size); +ORT_API_STATUS_IMPL(KernelInfo_GetOperatorSinceVersion, _In_ const OrtKernelInfo* info, + _Out_ int* since_version); } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index b0059f87da207..47b44b8eeba64 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -623,6 +623,38 @@ ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData, API_IMPL_END } +ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep) { + API_IMPL_BEGIN + if (info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null OrtKernelInfo instance from which to obtain an OrtEp"); + } + + if (ep == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null output parameter in which to store the OrtEp instance"); + } + + auto* op_info = reinterpret_cast(info); + auto internal_ep = op_info->GetExecutionProvider(); + + if (internal_ep == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, + "OrtKernelInfo does not have a valid reference to an execution provider instance"); + } + + const OrtEp* ort_ep = internal_ep->GetOrtEp(); + + if (ort_ep == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, + "OrtKernelInfo is not associated with a plugin EP (OrtEp) instance."); + } + + *ep = ort_ep; + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -678,6 +710,7 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::GetTensorDataType, &OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel, &OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData, + &OrtExecutionProviderApi::KernelInfo_GetEp, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index b2abad622c9a6..2d504f5ad2a64 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -105,4 +105,7 @@ ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData, _In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes, _In_ size_t num_buffers); + +// KernelInfo +ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 9ada57dca7f34..4db8bb05f94de 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -776,4 +776,8 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std return Status::OK(); } +const OrtEp* PluginExecutionProvider::GetOrtEp() const { + return ort_ep_.get(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 5b909df4cac1e..8d94607cdace8 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -133,6 +133,8 @@ class PluginExecutionProvider : public IExecutionProvider { Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, OrtCompiledModelCompatibility& model_compatibility) const override; + const OrtEp* GetOrtEp() const override; + private: struct FusedNodeState { FusedNodeState() = default; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 7257b27a48cf3..944e83d8cad66 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -102,6 +102,10 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con #endif // !defined(ORT_MINIMAL_BUILD) common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { + if (size == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "`size` argument is NULL"); + } + const size_t str_len = str.size(); const size_t req_size = str_len + 1; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc index 7b939c0685237..9ce5dbbe91d75 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc @@ -14,12 +14,13 @@ #include "ep_factory.h" #include "../plugin_ep_utils.h" -ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger) +ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const Config& config, const OrtLogger& logger) : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized factory_{factory}, ort_api_{factory.GetOrtApi()}, ep_api_{factory.GetEpApi()}, name_{factory.GetEpName()}, + config_{config}, logger_{logger} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. @@ -67,10 +68,10 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons if (op_type == "Relu" || op_type == "Squeeze") { candidate_nodes.push_back(node); - } else if (op_type == "Mul") { + } else if (op_type == "Mul" || op_type == "Sub") { std::vector inputs = node.GetInputs(); - // Note: ONNX shape inference should ensure Mul has two inputs. + // Note: ONNX shape inference should ensure Mul/Sub has two inputs. std::optional> input_0_shape = GetTensorShape(inputs[0]); std::optional> input_1_shape = GetTensorShape(inputs[1]); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h index 35357ddf3f5e2..bb283cee35af8 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h @@ -14,11 +14,16 @@ class ExampleKernelEpFactory; /// class ExampleKernelEp : public OrtEp { public: - ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger); + struct Config { + bool enable_prepack_weight_sharing = false; + }; + + ExampleKernelEp(ExampleKernelEpFactory& factory, const Config& config, const OrtLogger& logger); ~ExampleKernelEp(); const OrtApi& GetOrtApi() const { return ort_api_; } const OrtEpApi& GetEpApi() const { return ep_api_; } + const Config& GetConfig() const { return config_; } private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; @@ -34,5 +39,6 @@ class ExampleKernelEp : public OrtEp { const OrtApi& ort_api_; const OrtEpApi& ep_api_; std::string name_; + Config config_; const OrtLogger& logger_; }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc index a520b02c20cba..bd0afcd9fc46d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc @@ -176,7 +176,7 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateEpImpl(OrtEpFactory* this_ const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata*/, size_t num_devices, - const OrtSessionOptions* /*session_options*/, + const OrtSessionOptions* session_options, const OrtLogger* logger, OrtEp** ep) noexcept { auto* factory = static_cast(this_ptr); @@ -187,7 +187,14 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateEpImpl(OrtEpFactory* this_ "ExampleKernelEpFactory only supports selection for one device."); } - auto actual_ep = std::make_unique(*factory, *logger); + std::string enable_prepack_weight_sharing; + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, "ep.examplekernelep.enable_prepack_weight_sharing", + "0", enable_prepack_weight_sharing)); + + ExampleKernelEp::Config config = {}; + config.enable_prepack_weight_sharing = enable_prepack_weight_sharing == "1"; + + auto actual_ep = std::make_unique(*factory, config, *logger); *ep = actual_ep.release(); return nullptr; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc index b9518786f3a04..8b8cc35afe1ae 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc @@ -11,6 +11,9 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { // Mul version 14 BuildKernelCreateInfo, + // Sub version 14 + BuildKernelCreateInfo, + // Relu version 14 BuildKernelCreateInfo, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc similarity index 54% rename from onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc rename to onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc index 046ee04f37786..fb3bf6cfdb347 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc @@ -1,10 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include -#include "mul.h" +#include "binary_op.h" #include "utils.h" +#include "../ep.h" // Defines a kernel creation function for version 14 of Mul. ONNX_OPERATOR_KERNEL_EX( @@ -13,9 +15,18 @@ ONNX_OPERATOR_KERNEL_EX( /*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive) (Ort::KernelDefBuilder() .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), - Mul) + BinaryOp) -Mul::Mul(const OrtKernelInfo* info, void* state, PrivateTag) +// Defines a kernel creation function for version 14 of Sub. +ONNX_OPERATOR_KERNEL_EX( + Sub, + kOnnxDomain, + /*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive) + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + BinaryOp) + +BinaryOp::BinaryOp(Ort::ConstKernelInfo info, void* state, PrivateTag) : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL info_{info}, data_transfer_impl_{reinterpret_cast(state)} { @@ -23,32 +34,44 @@ Mul::Mul(const OrtKernelInfo* info, void* state, PrivateTag) Compute = ComputeImpl; Release = ReleaseImpl; - // Optional functions that are only needed to pre-pack weights. This Mul kernel pre-packs - // input[1] weights as an example (not typically done by an actual implementation of Mul). + // Optional functions that are only needed to pre-pack weights. This BinaryOp kernel pre-packs + // input[1] weights as an example (not typically done by an actual implementations of Mul/Sub). PrePackWeight = PrePackWeightImpl; SetSharedPrePackedWeight = SetSharedPrePackedWeightImpl; } /*static*/ -OrtStatus* Mul::Create(const OrtKernelInfo* info, void* state, - /*out*/ std::unique_ptr& result) noexcept { +OrtStatus* BinaryOp::Create(const OrtKernelInfo* info, void* state, + /*out*/ std::unique_ptr& result) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN + Ort::ConstKernelInfo kernel_info(info); + // Note: can do basic validation or preprocessing via the OrtKernelInfo APIs. - result = std::make_unique(info, state, PrivateTag{}); + // Here, we check that this BinaryOp class is only instantiated for an onnx Mul or Sub operator. + std::string op_domain = kernel_info.GetOperatorDomain(); + std::string op_type = kernel_info.GetOperatorType(); + + if ((!op_domain.empty() && op_domain != "ai.onnx") || (op_type != "Sub" && op_type != "Mul")) { + std::ostringstream oss; + oss << "ExampleKernelEp's BinaryOp class does not support operator with domain '" << op_domain << "' and " + << " type '" << op_type << "'."; + return Ort::GetApi().CreateStatus(ORT_EP_FAIL, oss.str().c_str()); + } + + result = std::make_unique(kernel_info, state, PrivateTag{}); return nullptr; EXCEPTION_TO_RETURNED_STATUS_END } /*static*/ -void ORT_API_CALL Mul::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { - delete static_cast(this_ptr); +void ORT_API_CALL BinaryOp::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); } /*static*/ -OrtStatus* ORT_API_CALL Mul::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { +OrtStatus* ORT_API_CALL BinaryOp::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN - Mul* mul_kernel = static_cast(this_ptr); - static_cast(mul_kernel->info_); // NOTE: Unused in this example. + BinaryOp* binary_op_kernel = static_cast(this_ptr); Ort::KernelContext kernel_context(kernel_ctx); @@ -62,8 +85,8 @@ OrtStatus* ORT_API_CALL Mul::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelConte gsl::span input1; std::vector shape1; - if (mul_kernel->packed_weight_1_info_.has_value()) { - const PackedWeightInfo& packed_weight_info = *mul_kernel->packed_weight_1_info_; + if (binary_op_kernel->packed_weight_1_info_.has_value()) { + const PackedWeightInfo& packed_weight_info = *binary_op_kernel->packed_weight_1_info_; shape1 = packed_weight_info.shape; size_t num_elems = 1; @@ -80,13 +103,22 @@ OrtStatus* ORT_API_CALL Mul::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelConte RETURN_IF_ERROR(GetValueDataAndShape(kernel_context.GetInput(1), input1, shape1)); } - RETURN_IF(shape0 != shape1, Ort::GetApi(), "Mul kernel doesn't support broadcasting."); // Checked by GetCapability + // Equal input shapes is checked by GetCapability, but verify here. + RETURN_IF(shape0 != shape1, Ort::GetApi(), "BinaryOp kernel does not support broadcasting."); Ort::UnownedValue output = kernel_context.GetOutput(0, shape0); float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; + std::string op_type = binary_op_kernel->info_.GetOperatorType(); + if (op_type == "Sub") { + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] - input1[i]; + } + } else { + assert(op_type == "Mul"); // Checked by BinaryOp::Create + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; + } } return nullptr; @@ -94,14 +126,14 @@ OrtStatus* ORT_API_CALL Mul::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelConte } /*static*/ -OrtStatus* ORT_API_CALL Mul::PrePackWeightImpl(OrtKernelImpl* this_ptr, const OrtValue* tensor, - int input_index, OrtAllocator* allocator, - OrtSharedPrePackedWeightCache* prepacked_weight_cache, - /*out*/ bool* is_packed) noexcept { +OrtStatus* ORT_API_CALL BinaryOp::PrePackWeightImpl(OrtKernelImpl* this_ptr, const OrtValue* tensor, + int input_index, OrtAllocator* allocator, + OrtSharedPrePackedWeightCache* prepacked_weight_cache, + /*out*/ bool* is_packed) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN - Mul* mul_kernel = static_cast(this_ptr); + BinaryOp* binary_op_kernel = static_cast(this_ptr); - // This example Mul kernel does not really need to pre-pack mul initializers, but we show it here as an example. + // This example BinaryOp kernel does not really need to pre-pack mul initializers, but we show it here as an example. // This implementation just copies original tensor without modification. An actual implementation would, for example, // transform to an appropriate data layout. @@ -126,11 +158,13 @@ OrtStatus* ORT_API_CALL Mul::PrePackWeightImpl(OrtKernelImpl* this_ptr, const Or weight_info.num_bytes, weight_info.shape.data(), weight_info.shape.size(), weight_info.elem_type); - RETURN_IF_ERROR(CopyTensor(*mul_kernel->data_transfer_impl_, original_weight, packed_weight.GetUnowned())); + RETURN_IF_ERROR(CopyTensor(*binary_op_kernel->data_transfer_impl_, original_weight, packed_weight.GetUnowned())); - const bool sharing_allowed = prepacked_weight_cache != nullptr; + const ExampleKernelEp* ep = static_cast(binary_op_kernel->info_.GetEp()); + const bool ep_sharing_enabled = ep->GetConfig().enable_prepack_weight_sharing; + const bool ort_sharing_allowed = prepacked_weight_cache != nullptr; - if (sharing_allowed) { + if (ort_sharing_allowed && ep_sharing_enabled) { std::array buffer_data_ptrs = {weight_info.owned_data.get()}; std::array buffer_data_sizes = {weight_info.num_bytes}; @@ -147,7 +181,7 @@ OrtStatus* ORT_API_CALL Mul::PrePackWeightImpl(OrtKernelImpl* this_ptr, const Or weight_info.owned_data.release(); } - mul_kernel->packed_weight_1_info_ = std::move(weight_info); + binary_op_kernel->packed_weight_1_info_ = std::move(weight_info); *is_packed = true; return nullptr; @@ -155,33 +189,34 @@ OrtStatus* ORT_API_CALL Mul::PrePackWeightImpl(OrtKernelImpl* this_ptr, const Or } /*static*/ -OrtStatus* ORT_API_CALL Mul::SetSharedPrePackedWeightImpl(OrtKernelImpl* this_ptr, - const void* const* buffer_data_ptrs, - const size_t* buffer_data_sizes, - size_t num_buffers, int input_index) noexcept { +OrtStatus* ORT_API_CALL BinaryOp::SetSharedPrePackedWeightImpl(OrtKernelImpl* this_ptr, + const void* const* buffer_data_ptrs, + const size_t* buffer_data_sizes, + size_t num_buffers, int input_index) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN - Mul* mul_kernel = static_cast(this_ptr); + BinaryOp* binary_op_kernel = static_cast(this_ptr); if (input_index != 1) { std::ostringstream oss; oss << "ExampleKernelEp did not expect a call to OrtKernelImpl::SetSharedPrePackedWeight for input index " - << input_index << " of the Mul kernel."; + << input_index << " of the BinaryOp kernel."; return Ort::GetApi().CreateStatus(ORT_EP_FAIL, oss.str().c_str()); } - RETURN_IF(num_buffers != 1, Ort::GetApi(), "Invalid number of pre-packed data buffers for Mul kernel's 2nd input"); - RETURN_IF(!mul_kernel->packed_weight_1_info_.has_value(), Ort::GetApi(), + RETURN_IF(num_buffers != 1, Ort::GetApi(), + "Invalid number of pre-packed data buffers for BinaryOp kernel's 2nd input"); + RETURN_IF(!binary_op_kernel->packed_weight_1_info_.has_value(), Ort::GetApi(), "ERROR! OrtKernelImpl::PrePackWeight should have " "initialized a valid PackedWeightInfo struct for use in SetSharedPrePackedWeight."); // Check that the buffer size is what we expect. - RETURN_IF(buffer_data_sizes[0] != mul_kernel->packed_weight_1_info_->num_bytes, Ort::GetApi(), + RETURN_IF(buffer_data_sizes[0] != binary_op_kernel->packed_weight_1_info_->num_bytes, Ort::GetApi(), "ExampleKernelEp received an unexpected buffer size in a call to OrtKernelImpl::SetSharedPrePackedWeight " - "for the Mul kernel."); + "for the BinaryOp kernel."); // Update buffer data pointer because the shared memory could potentially originate from a different // kernel instance. - mul_kernel->packed_weight_1_info_->shared_data = buffer_data_ptrs[0]; + binary_op_kernel->packed_weight_1_info_->shared_data = buffer_data_ptrs[0]; return nullptr; EXCEPTION_TO_RETURNED_STATUS_END diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h similarity index 85% rename from onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h rename to onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h index f84fda6a8b0ec..b6cddccb22290 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h @@ -7,7 +7,11 @@ #include "../../plugin_ep_utils.h" #include "../ep_allocator.h" -class Mul : public OrtKernelImpl { +/// +/// An OrtKernelImpl class for binary element-wise operations. +/// Only Sub and Mul are supported currently. +/// +class BinaryOp : public OrtKernelImpl { private: struct PrivateTag {}; @@ -25,8 +29,8 @@ class Mul : public OrtKernelImpl { }; public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; - Mul(const OrtKernelInfo* info, void* state, PrivateTag); + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; + BinaryOp(Ort::ConstKernelInfo info, void* state, PrivateTag); // Static functions assigned to the OrtKernelImpl fields: static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; @@ -41,7 +45,7 @@ class Mul : public OrtKernelImpl { size_t num_buffers, int input_index) noexcept; private: - const OrtKernelInfo* info_; + Ort::ConstKernelInfo info_; OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp std::optional packed_weight_1_info_ = std::nullopt; }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index bb391bb0bca23..437ca37c1a7b6 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -47,8 +47,71 @@ void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } +void RunSqueezeMulReluModel(const Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::array a_shape = {3, 1, 2}; + std::array b_shape = {3, 2}; + + std::array a_data = {1.f, -2.f, 3.f, 4.f, -5.f, 6.f}; + std::array b_data = {2.f, 3.f, 4.f, -5.f, 6.f, 7.f}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), b_shape.data(), b_shape.size())); + + std::array ort_input_names{"A", "B"}; + + // Run session and get outputs + std::array output_names{"C"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84)); +} + +void RunSubMulSubModel(const Ort::SessionOptions& session_options) { + // This model has Sub -> Mul -> Sub: (A - B) * B - A + // The example plugin EP supports all ops. + Ort::Session session(*ort_env, ORT_TSTR("testdata/sub_mul_sub.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + + std::vector a_data{1, 2, 3, 4, 5, 6}; + std::vector b_data{2, 3, 4, 5, 6, 7}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), shape.data(), shape.size())); + + std::array ort_input_names{"A", "B"}; + + // Run session and get outputs + std::array output_names{"C"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(-3, -5, -7, -9, -11, -13)); +} + void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) { - // This model has Add -> Mul -> Add. The example plugin EP only supports Mul. + // This model has Add -> Mul -> Add. The example plugin EP supports Mul but not Add. Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options); // Create inputs @@ -245,42 +308,49 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { example_kernel_ep)); Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); - // Create session with example kernel-based plugin EP - Ort::SessionOptions session_options; - session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP. - - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + // Run model with squeeze, mul, and relu nodes. + // No sharing of pre-packed weights. + { + std::unordered_map ep_options; + Ort::SessionOptions session_options; - // This model has Squeeze, Mul, and Relu nodes. The example plugin EP supports all nodes using registered kernels. - Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options); + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP. + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunSqueezeMulReluModel(session_options)); + } - // Create inputs - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::array a_shape = {3, 1, 2}; - std::array b_shape = {3, 2}; + // Run model with squeeze, mul, and relu nodes. + // Enable sharing of pre-packed weights. + { + std::unordered_map ep_options = {{"enable_prepack_weight_sharing", "1"}}; + Ort::SessionOptions session_options; - std::array a_data = {1.f, -2.f, 3.f, 4.f, -5.f, 6.f}; - std::array b_data = {2.f, 3.f, 4.f, -5.f, 6.f, 7.f}; + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunSqueezeMulReluModel(session_options)); + } - std::vector ort_inputs{}; - ort_inputs.emplace_back( - Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size())); - ort_inputs.emplace_back( - Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), b_shape.data(), b_shape.size())); + // Run model with sub, mul, sub. + // No sharing of pre-packed weights. + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; - std::array ort_input_names{"A", "B"}; + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunSubMulSubModel(session_options)); + } - // Run session and get outputs - std::array output_names{"C"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); + // Run model with sub, mul, sub. + // Enable sharing of pre-packed weights. + { + std::unordered_map ep_options = {{"enable_prepack_weight_sharing", "1"}}; + Ort::SessionOptions session_options; - // Check expected output values - Ort::Value& ort_output = ort_outputs[0]; - const float* output_data = ort_output.GetTensorData(); - gsl::span output_span(output_data, 6); - EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84)); + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunSubMulSubModel(session_options)); + } } } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index 66ce6a0838713..01d93aac6d4c7 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -18,6 +18,14 @@ template void cuda_slice(const T*, int64_t, int64_t, T*, cudaStream_t compute_stream); #endif +MyCustomKernel::MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* info) + : ort_(ort_api) { + Ort::ConstKernelInfo kernel_info(info); + EXPECT_EQ(kernel_info.GetOperatorDomain(), "test"); + EXPECT_EQ(kernel_info.GetOperatorType(), "Foo"); + EXPECT_EQ(kernel_info.GetOperatorSinceVersion(), 1); +} + void MyCustomKernel::Compute(OrtKernelContext* context) { // Setup inputs Ort::KernelContext ctx(context); diff --git a/onnxruntime/test/shared_lib/custom_op_utils.h b/onnxruntime/test/shared_lib/custom_op_utils.h index 424c2e2fe3a08..45f522eefe084 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.h +++ b/onnxruntime/test/shared_lib/custom_op_utils.h @@ -9,10 +9,7 @@ #endif struct MyCustomKernel { - MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* /*info*/) - : ort_(ort_api) { - } - + MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* info); void Compute(OrtKernelContext* context); private: diff --git a/onnxruntime/test/testdata/sub_mul_sub.onnx b/onnxruntime/test/testdata/sub_mul_sub.onnx new file mode 100644 index 0000000000000..ce7a87ed1215f --- /dev/null +++ b/onnxruntime/test/testdata/sub_mul_sub.onnx @@ -0,0 +1,27 @@ + :´ + +A +B +sub_outputsub_0"Sub +' + +sub_output +B +mul_outputmul_0"Mul + + +mul_output +ACsub_1"Sub +Main_graphZ +A +  + +Z +B +  + +b +C +  + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/sub_mul_sub.py b/onnxruntime/test/testdata/sub_mul_sub.py new file mode 100644 index 0000000000000..d76f89c8fe38c --- /dev/null +++ b/onnxruntime/test/testdata/sub_mul_sub.py @@ -0,0 +1,37 @@ +from onnx import TensorProto, checker, helper, save + +# (A - B) * B - A +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Sub", + inputs=["A", "B"], + outputs=["sub_output"], + name="sub_0", + ), + helper.make_node( + "Mul", + inputs=["sub_output", "B"], + outputs=["mul_output"], + name="mul_0", + ), + helper.make_node( + "Sub", + inputs=["mul_output", "A"], + outputs=["C"], + name="sub_1", + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + outputs=[ + helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]), + ], +) + +model = helper.make_model(graph_proto) +checker.check_model(model, True) +save(model, "sub_mul_sub.onnx")