diff --git a/include/onnxruntime/ep/README.md b/include/onnxruntime/ep/README.md new file mode 100644 index 0000000000000..64d85f80313c0 --- /dev/null +++ b/include/onnxruntime/ep/README.md @@ -0,0 +1,7 @@ +## EP adapter + +This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code. + +### Usage + +Make sure to include "ep/_pch.h" for all source code in the implementation. Using PCH is recommended. diff --git a/include/onnxruntime/ep/_pch.h b/include/onnxruntime/ep/_pch.h new file mode 100644 index 0000000000000..ba9c3278693eb --- /dev/null +++ b/include/onnxruntime/ep/_pch.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "api.h" +#include "common.h" + +// This header is only used when building WebGPU/CUDA EP as a shared library. +// +// This header file is used as a precompiled header so it is always included first. + +#pragma push_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED") +#define ORT_EP_API_ADAPTER_HEADER_INCLUDED + +#include "adapter/allocator.h" +#include "adapter/logging.h" +#include "adapter/ep.h" +#include "adapter/kernel_registry.h" + +#pragma pop_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED") + +// +// EP specific using declarations +// + +#define EP_SPECIFIC_USING_DECLARATIONS \ + using FuncManager = onnxruntime::ep::adapter::FuncManager; \ + using KernelCreatePtrFn = onnxruntime::ep::adapter::KernelCreatePtrFn; \ + using KernelDefBuilder = onnxruntime::ep::adapter::KernelDefBuilder; \ + using KernelRegistry = onnxruntime::ep::adapter::KernelRegistry; \ + using KernelCreateInfo = onnxruntime::ep::adapter::KernelCreateInfo; \ + using BuildKernelCreateInfoFn = onnxruntime::ep::adapter::KernelCreateInfo (*)(); \ + using OpKernelInfo = onnxruntime::ep::adapter::OpKernelInfo; \ + using OpKernelContext = onnxruntime::ep::adapter::OpKernelContext; \ + using OpKernel = onnxruntime::ep::adapter::OpKernel; \ + using DataTransferManager = onnxruntime::ep::adapter::DataTransferManager; \ + namespace logging { \ + using Logger = onnxruntime::ep::adapter::Logger; \ + } + +namespace onnxruntime { +namespace webgpu { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace webgpu +namespace cuda { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +namespace webgpu { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace webgpu +namespace cuda { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace cuda +} // namespace contrib +#endif + +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h new file mode 100644 index 0000000000000..36a051a0e0edc --- /dev/null +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// A bridge class between the EP API OrtAllocator and an IAllocator implementation. +/// +class Allocator : public OrtAllocator { + public: + explicit Allocator(AllocatorPtr impl) : OrtAllocator{}, impl_(impl) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + } + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { + auto* allocator = static_cast(this_ptr); + return allocator->impl_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { + auto* allocator = static_cast(this_ptr); + allocator->impl_->Free(p); + } + + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept { + auto* allocator = static_cast(this_ptr); + return &allocator->impl_->Info(); + } + + AllocatorPtr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/data_transfer_manager.h b/include/onnxruntime/ep/adapter/data_transfer_manager.h new file mode 100644 index 0000000000000..57455b454e288 --- /dev/null +++ b/include/onnxruntime/ep/adapter/data_transfer_manager.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/common/status.h" +#include "core/common/common.h" +#include "core/framework/data_transfer.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// +struct DataTransferManager { + explicit DataTransferManager(std::unique_ptr impl) : impl_{std::move(impl)} {} + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const { + if (src.Shape().Size() != dst.Shape().Size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "Tensor size mismatch: source tensor size is ", + src.Shape().Size(), + ", destination tensor size is ", + dst.Shape().Size()); + } + + if (impl_->CanCopy(src.Location().device, dst.Location().device)) { + return impl_->CopyTensor(src, dst); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "There's no data transfer registered for copying tensors from ", + src.Location().device.ToString(), + " to ", + dst.Location().device.ToString()); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager); + std::unique_ptr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/ep.h b/include/onnxruntime/ep/adapter/ep.h new file mode 100644 index 0000000000000..02a6c2f07b0c3 --- /dev/null +++ b/include/onnxruntime/ep/adapter/ep.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "data_transfer_manager.h" + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Wrapper around IExecutionProvider to expose via OrtEp. +/// +class Ep : public OrtEp { + protected: + explicit Ep(IExecutionProvider* impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator) + : OrtEp{}, + impl_(impl), + data_transfer_manager_{impl->GetDataTransfer()}, + profiler_{impl->GetProfiler()}, + temp_space_cpu_allocator_{temp_space_cpu_allocator}, + temp_space_allocator_{temp_space_allocator} { + } + + public: + inline IExecutionProvider* EpImpl() const noexcept { + return impl_.get(); + } + inline const DataTransferManager& GetDataTransferManager() const noexcept { + return data_transfer_manager_; + } + [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { + *output = temp_space_cpu_allocator_; + return Status::OK(); + } + [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { + *output = temp_space_allocator_; + return Status::OK(); + } + + private: + std::unique_ptr impl_; + DataTransferManager data_transfer_manager_; + std::unique_ptr profiler_; + AllocatorPtr temp_space_cpu_allocator_; + AllocatorPtr temp_space_allocator_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_def_builder.h b/include/onnxruntime/ep/adapter/kernel_def_builder.h new file mode 100644 index 0000000000000..c55cdda7ac7c3 --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_def_builder.h @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "core/framework/data_types.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Gets an OrtMLDataType for a tensor type. Throws on error. +/// +/// +/// +inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { + const OrtEpApi& ep_api = Ort::GetEpApi(); + const OrtDataType* result = nullptr; + + Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result)); + return result; +} + +inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) { + auto tensor_type = ml_type->AsTensorType(); + EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types."); + auto elem_type = tensor_type->GetElementType(); + auto primitive_type = static_cast(elem_type); + auto onnx_type = static_cast(primitive_type->GetDataType()); + return GetTensorType(onnx_type); +} + +struct KernelDefBuilder { + static std::unique_ptr Create() { return std::make_unique(); } + + explicit KernelDefBuilder() {} + + KernelDefBuilder& SetName(const char* op_name) { + builder_.SetOperatorType(op_name); + return *this; + } + + KernelDefBuilder& SetDomain(const char* domain) { + builder_.SetDomain(domain); + return *this; + } + + KernelDefBuilder& SinceVersion(int since_version) { + return SinceVersion(since_version, INT_MAX); + } + + KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) { + builder_.SetSinceVersion(since_version_start, since_version_end); + return *this; + } + + KernelDefBuilder& Provider(const char* provider_type) { + builder_.SetExecutionProvider(provider_type); + return *this; + } + + KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector types) { + std::vector ort_types; + ort_types.reserve(types.size()); + for (const auto& type : types) { + ort_types.push_back(MLDataTypeToOrtDataType(type)); + } + builder_.AddTypeConstraint(arg_name, ort_types); + return *this; + } + + KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) { + builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type)); + return *this; + } + + KernelDefBuilder& MayInplace(const std::vector>& inplaces) { + for (const auto& pair : inplaces) { + builder_.AddInputOutputMutableAlias(pair.first, pair.second); + } + return *this; + } + KernelDefBuilder& MayInplace(int input_index, int output_index) { + builder_.AddInputOutputMutableAlias(input_index, output_index); + return *this; + } + + KernelDefBuilder& Alias(const std::vector>& aliases) { + for (const auto& pair : aliases) { + builder_.AddInputOutputAlias(pair.first, pair.second); + } + return *this; + } + KernelDefBuilder& Alias(int input_index, int output_index) { + builder_.AddInputOutputAlias(input_index, output_index); + return *this; + } + + KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) { + builder_.SetInputMemType(input_index, type); + return *this; + } + + KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector& input_indexes) { + for (int input_index : input_indexes) { + builder_.SetInputMemType(input_index, type); + } + return *this; + } + + KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) { + builder_.SetOutputMemType(output_index, type); + return *this; + } + + KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector& output_indexes) { + for (int output_index : output_indexes) { + builder_.SetOutputMemType(output_index, type); + } + return *this; + } + + KernelDefBuilder& ExecQueueId(int queue_id) { return *this; } + + Ort::KernelDef Build() { return builder_.Build(); } + + private: + Ort::KernelDefBuilder builder_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_registry.h b/include/onnxruntime/ep/adapter/kernel_registry.h new file mode 100644 index 0000000000000..152d956030290 --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_registry.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "kernel_def_builder.h" +#include "op_kernel_info.h" +#include "op_kernel.h" + +#include "core/graph/basic_types.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct FuncManager {}; +using KernelCreatePtrFn = std::add_pointer& out)>::type; + +struct KernelCreateInfo { + Ort::KernelDef kernel_def; + KernelCreatePtrFn kernel_create_func; + Status status; + + KernelCreateInfo(Ort::KernelDef definition, + KernelCreatePtrFn create_func) + : kernel_def(std::move(definition)), + kernel_create_func(create_func) { + assert(kernel_def != nullptr); + } + + KernelCreateInfo(KernelCreateInfo&& other) noexcept + : kernel_def(std::move(other.kernel_def)), + kernel_create_func(std::move(other.kernel_create_func)) {} + + KernelCreateInfo() = default; +}; + +struct KernelRegistry { + KernelRegistry() = default; + + Status Register(KernelCreateInfo&& create_info) { + registry_.AddKernel(create_info.kernel_def, [](void* kernel_create_func_state, const OrtKernelInfo* info, OrtKernelImpl** out) -> OrtStatus* { + FuncManager func_mgr; // not used + std::unique_ptr kernel; + KernelCreatePtrFn create_func = reinterpret_cast(kernel_create_func_state); + Status status = create_func(func_mgr, OpKernelInfo(info), kernel); + if (!status.IsOK()) { + return Ort::GetApi().CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str()); + } + *out = new KernelImpl(std::move(kernel)); + return nullptr; }, static_cast(create_info.kernel_create_func)); + return Status::OK(); + } + + // Implicit conversion to OrtKernelRegistry* for compatibility with C API + operator OrtKernelRegistry*() const noexcept { + return registry_.operator OrtKernelRegistry*(); + } + + // Release ownership of the underlying OrtKernelRegistry* + OrtKernelRegistry* release() { + return registry_.release(); + } + + private: + Ort::KernelRegistry registry_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/logging.h b/include/onnxruntime/ep/adapter/logging.h new file mode 100644 index 0000000000000..b93c06bb3f12e --- /dev/null +++ b/include/onnxruntime/ep/adapter/logging.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/common/logging/logging.h" +#include "core/common/path_string.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct Logger { + Logger(const OrtLogger* logger) : logger_(logger) {} + + bool OutputIsEnabled(logging::Severity severity, logging::DataType /* data_type */) const noexcept { + return ((OrtLoggingLevel)severity >= logger_.GetLoggingSeverityLevel()); + } + + void Log(logging::Severity severity, + const char* file_path, + int line_number, + const char* func_name, + const char* message) const noexcept { + auto path_string = onnxruntime::ToPathString(file_path); + logger_.LogMessage((OrtLoggingLevel)severity, + path_string.c_str(), + line_number, + func_name, + message); + } + + static const Logger& DefaultLogger() { return *instance_; } + static void CreateDefaultLogger(const OrtLogger* logger) { + instance_ = new Logger(logger); + } + static void DestroyDefaultLogger() { + delete instance_; + instance_ = nullptr; + } + + private: + Ort::Logger logger_; + inline static Logger* instance_ = nullptr; +}; + +namespace detail { +struct LoggerCapture { + LoggerCapture(const Logger& logger, + logging::Severity severity, + const char* category, + logging::DataType dataType, + const CodeLocation& location) : logger_{logger}, + severity_{severity}, + category_{category}, + data_type_{dataType}, + location_{location} {} + + ~LoggerCapture() { + logger_.Log(severity_, location_.file_and_path.c_str(), location_.line_num, + location_.function.c_str(), stream_.str().c_str()); + } + + std::ostream& Stream() noexcept { + return stream_; + } + + const Logger& logger_; + logging::Severity severity_; + const char* category_; + logging::DataType data_type_; + const CodeLocation& location_; + std::ostringstream stream_; +}; + +// Helper functions to dispatch to the correct Capture type based on logger type +inline ::onnxruntime::logging::Capture CreateMessageCapture( + const ::onnxruntime::logging::Logger& logger, + ::onnxruntime::logging::Severity severity, + const char* category, + ::onnxruntime::logging::DataType datatype, + const CodeLocation& location) { + return ::onnxruntime::logging::Capture(logger, severity, category, datatype, location); +} + +inline detail::LoggerCapture CreateMessageCapture( + const Logger& logger, + ::onnxruntime::logging::Severity severity, + const char* category, + ::onnxruntime::logging::DataType datatype, + const CodeLocation& location) { + return detail::LoggerCapture(logger, severity, category, datatype, location); +} + +} // namespace detail +} // namespace adapter +} // namespace ep +} // namespace onnxruntime + +// Undefine and redefine LOGS_DEFAULT +#undef LOGS_DEFAULT_CATEGORY +#define LOGS_DEFAULT_CATEGORY(severity, category) \ + LOGS_CATEGORY(::onnxruntime::ep::adapter::Logger::DefaultLogger(), severity, category) + +#undef CREATE_MESSAGE +#define CREATE_MESSAGE(logger, severity, category, datatype) \ + ::onnxruntime::ep::adapter::detail::CreateMessageCapture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h new file mode 100644 index 0000000000000..6d7e1d85acf1f --- /dev/null +++ b/include/onnxruntime/ep/adapter/node.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// +struct Node { + explicit Node(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {} + /** Gets the Node's name. */ + const std::string Name() const noexcept { + return kernel_info_.GetNodeName(); + } + + /** Gets the Node's operator type. */ + const std::string OpType() const noexcept { + return kernel_info_.GetOperatorType(); + } + + /** Gets the since version of the operator. */ + int SinceVersion() const noexcept { + return kernel_info_.GetOperatorSinceVersion(); + } + + private: + const Ort::ConstKernelInfo kernel_info_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h new file mode 100644 index 0000000000000..0fc9159b01a5f --- /dev/null +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include +#include + +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" + +#include "node.h" +#include "op_kernel_info.h" +#include "tensor_helper.h" + +namespace onnxruntime { +struct PrePackedWeights; +struct TensorShape; +} // namespace onnxruntime + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct OpKernelContext; + +struct OpKernel { + explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_{info} {} + virtual ~OpKernel() {} + + const Node& Node() const { + return op_kernel_info_.node(); + } + const OpKernelInfo& Info() const { + return op_kernel_info_; + } + + virtual Status Compute(OpKernelContext* p_op_kernel_context) const = 0; + virtual Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + return Status::OK(); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel); + OpKernelInfo op_kernel_info_; +}; + +struct OpKernelContext { + explicit OpKernelContext(OrtKernelContext* context, const OpKernel& op_kernel) : context_{context}, op_kernel_{op_kernel} { + input_tensors_.resize(InputCount()); + output_tensors_.resize(OutputCount()); + } + + template >> + const T* Input(int index) const { + if (index < 0 || static_cast(index) >= input_tensors_.size()) { + return nullptr; + } + if (input_tensors_[index] != nullptr) { + return static_cast(input_tensors_[index].get()); + } + + auto input = context_.GetInput(index); + if (input == nullptr || !input.IsTensor()) { + return nullptr; + } + + input_tensors_[index] = CreateTensorFromApiValue(input); + return static_cast(input_tensors_[index].get()); + } + Tensor* Output(int index, const TensorShape& shape) { + if (index < 0 || static_cast(index) >= output_tensors_.size()) { + return nullptr; + } + if (output_tensors_[index] != nullptr) { + return output_tensors_[index].get(); + } + + auto output = context_.GetOutput(index, shape.GetDims().data(), shape.GetDims().size()); + if (output == nullptr) { + return nullptr; + } + + output_tensors_[index] = CreateTensorFromApiValue(output); + return output_tensors_[index].get(); + } + Tensor* Output(int index, const std::vector& shape) { + return Output(index, TensorShape{shape}); + } + Tensor* Output(int index, const std::initializer_list& shape) { + return Output(index, TensorShape{shape}); + } + [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { + return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceCPUAllocator(output); + } + [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { + return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceAllocator(output); + } + size_t InputCount() const { + return context_.GetInputCount(); + } + size_t OutputCount() const { + return context_.GetOutputCount(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext); + Ort::KernelContext context_; + const OpKernel& op_kernel_; + mutable std::vector> input_tensors_; + std::vector> output_tensors_; +}; + +struct KernelImpl : OrtKernelImpl { + explicit KernelImpl(std::unique_ptr impl) + : OrtKernelImpl{}, impl_(std::move(impl)) { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; + PrePackWeight = PrePackWeightImpl; + } + + private: + static OrtStatus* ORT_API_CALL ComputeImpl(_In_ OrtKernelImpl* this_ptr, + _In_ OrtKernelContext* context) noexcept { + const auto* kernel_impl = static_cast(this_ptr)->impl_.get(); + OpKernelContext ctx{context, *kernel_impl}; + Status status; + ORT_TRY { + status = kernel_impl->Compute(&ctx); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); + }); + } + if (status.IsOK()) { + return nullptr; + } else { + return Ort::Status{status.ErrorMessage().c_str(), static_cast(status.Code())}.release(); + } + } + + static void ORT_API_CALL ReleaseImpl(_In_ OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ORT_API_CALL PrePackWeightImpl(_In_ OrtKernelImpl* this_ptr, + _In_ const OrtValue* weight, + int input_index, + _In_ OrtAllocator* /* allocator */, + _In_opt_ OrtSharedPrePackedWeightCache* /* prepacked_weight_cache */, + _Out_ bool* is_packed) noexcept { + auto* kernel_impl = static_cast(this_ptr)->impl_.get(); + const auto tensor = CreateTensorFromApiValue(Ort::ConstValue{weight}); + Status status; + ORT_TRY { + status = kernel_impl->PrePack(*tensor.get(), input_index, AllocatorPtr{}, *is_packed, nullptr); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); + }); + } + if (!status.IsOK()) { + return Ort::Status{status.ErrorMessage().c_str(), static_cast(status.Code())}.release(); + } + return nullptr; + } + + ~KernelImpl() = default; + + private: + std::unique_ptr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h new file mode 100644 index 0000000000000..5d0f4328d38dd --- /dev/null +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "core/common/status.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" + +#include "node.h" +#include "tensor_helper.h" + +namespace onnxruntime { +struct DataTransferManager; +struct IExecutionProvider; +} // namespace onnxruntime + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct OpKernelInfo { + // + // A helper struct to cache kernel info data + // + // Because `KernelCreatePtrFn` is defined to use `const OrtKernelInfo&` as parameter type of the kernel creation function, `OpKernelInfo` has to be copyable. + // This means we cannot store cached data like `constant_input_tensors_` in `OpKernelInfo` directly to avoid ownership issues. + // + // As a workaround, we define this struct `KernelInfoCache` here to represent the cached data. We use a shared pointer to `KernelInfoCache` in `OpKernelInfo` + // to manage the lifetime of the cached data. + struct KernelInfoCache { + explicit KernelInfoCache(const OrtKernelInfo* kernel_info) : node(kernel_info) { + Ort::ConstKernelInfo info{kernel_info}; + const int input_count = info.GetInputCount(); + constant_input_tensors.resize(input_count); + for (int i = 0; i < input_count; ++i) { + int is_constant = 0; + Ort::ConstValue const_input = info.GetTensorConstantInput(i, &is_constant); + if (is_constant && const_input != nullptr && const_input.IsTensor()) { + constant_input_tensors[i] = CreateTensorFromApiValue(const_input); + } + } + } + Node node; + std::vector> constant_input_tensors; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelInfoCache); + }; + + explicit OpKernelInfo(const OrtKernelInfo* info) : info_(info), cache_{std::make_shared(info)} { + } + + const DataTransferManager& GetDataTransferManager() const noexcept { + return (static_cast(info_.GetEp()))->GetDataTransferManager(); + } + const Node& node() const { + return cache_->node; + } + const IExecutionProvider* GetExecutionProvider() const noexcept { + return (static_cast(info_.GetEp()))->EpImpl(); + } + + const Ort::ConstKernelInfo GetKernelInfo() const noexcept { + return info_; + } + + int GetInputCount() const noexcept { + return info_.GetInputCount(); + } + + bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const { + if (input_index < 0 || static_cast(input_index) >= cache_->constant_input_tensors.size()) { + return false; + } + const Tensor* tensor = cache_->constant_input_tensors[input_index].get(); + if (tensor != nullptr) { + *constant_input_value = tensor; + return true; + } + return false; + } + + template + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { + T tmp; + return GetAttr(name, &tmp).IsOK() ? tmp : default_value; + } + template + void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const { + if (!GetAttr(name, value).IsOK()) + *value = default_value; + } + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } + template + Status GetAttr(const std::string& name, T* value) const { + try { + *value = info_.GetAttribute(name.c_str()); + return Status::OK(); + } catch (const Ort::Exception& ex) { + return Status(onnxruntime::common::ONNXRUNTIME, ex.GetOrtErrorCode(), ex.what()); + } + } + template + Status GetAttrs(const std::string& name, std::vector& values) const { + try { + values = info_.GetAttributes(name.c_str()); + return Status::OK(); + } catch (const Ort::Exception& ex) { + return Status(onnxruntime::common::ONNXRUNTIME, ex.GetOrtErrorCode(), ex.what()); + } + } + + Status GetAttrs(const std::string& name, TensorShapeVector& out) const { + std::vector shape; + Status status = GetAttrs(name, shape); + if (status.IsOK()) { + out.reserve(shape.size()); + out.assign(shape.begin(), shape.end()); + } + return status; + } + + template + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { + std::vector tmp; + return GetAttrs(name, tmp).IsOK() ? tmp : default_value; + } + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { + TensorShapeVector tmp; + return GetAttrs(name, tmp).IsOK() ? tmp : default_value; + } + + private: + const Ort::ConstKernelInfo info_; + std::shared_ptr cache_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/tensor_helper.h b/include/onnxruntime/ep/adapter/tensor_helper.h new file mode 100644 index 0000000000000..4d8ee078d5836 --- /dev/null +++ b/include/onnxruntime/ep/adapter/tensor_helper.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include +#include + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Create an unowned onnxruntime::Tensor from a tensor OrtValue from C API. +/// +inline std::unique_ptr CreateTensorFromApiValue(const OrtValue* ort_value) { + Ort::ConstValue value{ort_value}; + EP_ENFORCE(value.IsTensor(), "Only tensor OrtValue is supported."); + + auto type_and_shape_info = value.GetTypeInfo().GetTensorTypeAndShapeInfo(); + auto type = type_and_shape_info.GetElementType(); + auto shape_vec = type_and_shape_info.GetShape(); + + auto memory_info = value.GetTensorMemoryInfo(); + MLDataType data_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + + return std::make_unique(data_type, + TensorShape{shape_vec}, + const_cast(value.GetTensorRawData()), + OrtMemoryInfo{ + memory_info.GetAllocatorName(), + memory_info.GetAllocatorType(), + OrtDevice{ + static_cast(memory_info.GetDeviceType()), + static_cast(memory_info.GetMemoryType()), + static_cast(memory_info.GetVendorId()), + static_cast(memory_info.GetDeviceId()), + + }, + memory_info.GetMemoryType()}); +} + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h new file mode 100644 index 0000000000000..b05fb9e6d1cb3 --- /dev/null +++ b/include/onnxruntime/ep/api.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#pragma push_macro("ORT_API_MANUAL_INIT") +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#pragma pop_macro("ORT_API_MANUAL_INIT") + +namespace onnxruntime { +namespace ep { + +struct ApiPtrs { + const OrtApi& ort; + const OrtEpApi& ep; + const OrtModelEditorApi& model_editor; +}; + +namespace detail { +inline std::unique_ptr g_api_ptrs; +} + +/// +/// Get the global instance of ApiPtrs. +/// +inline const ApiPtrs& Api() { + return *detail::g_api_ptrs; +} + +/// +/// Initialize the EP API pointers and global OrtEnv if not already done. +/// +inline void ApiInit(const OrtApiBase* ort_api_base) { + // Manual init for the C++ API + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + Ort::InitApi(ort_api); + + // Initialize the global API instance + if (!detail::g_api_ptrs) { + detail::g_api_ptrs = std::make_unique( + ApiPtrs{*ort_api, *ep_api, *model_editor_api}); + } +} + +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/common.h b/include/onnxruntime/ep/common.h new file mode 100644 index 0000000000000..12118c938820c --- /dev/null +++ b/include/onnxruntime/ep/common.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +// see ORT_ENFORCE for implementations that also capture a stack trace and work in builds with exceptions disabled +// NOTE: In this simplistic implementation you must provide an argument, even it if's an empty string +#define EP_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + std::ostringstream oss; \ + oss << "EP_ENFORCE failed: " << #condition << " "; \ + oss << __VA_ARGS__; \ + throw std::runtime_error(oss.str()); \ + } \ + } while (false) + +// Ignores an OrtStatus* while taking ownership of it so that it does not get leaked. +#define IGNORE_ORTSTATUS(status_expr) \ + do { \ + OrtStatus* _status = (status_expr); \ + Ort::Status _ignored{_status}; \ + } while (false)