Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/onnxruntime/ep/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## EP adapter

This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code.

### Usage

Make sure to include "ep/_pch.h" for all source code in the implementation. Using PCH is recommended.
61 changes: 61 additions & 0 deletions include/onnxruntime/ep/_pch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "api.h"

Check warning on line 6 in include/onnxruntime/ep/_pch.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: include/onnxruntime/ep/_pch.h:6: Include the directory when naming header files [build/include_subdir] [4]
#include "common.h"

Check warning on line 7 in include/onnxruntime/ep/_pch.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: include/onnxruntime/ep/_pch.h:7: Include the directory when naming header files [build/include_subdir] [4]

// This header is only used when building WebGPU/CUDA EP as a shared library.
//
// This header file is used as a precompiled header so it is always included first.

#pragma push_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED")
#define ORT_EP_API_ADAPTER_HEADER_INCLUDED

#include "adapter/allocator.h"
#include "adapter/logging.h"
#include "adapter/ep.h"
#include "adapter/kernel_registry.h"

#pragma pop_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED")

//
// EP specific using declarations
//

#define EP_SPECIFIC_USING_DECLARATIONS \
using FuncManager = onnxruntime::ep::adapter::FuncManager; \
using KernelCreatePtrFn = onnxruntime::ep::adapter::KernelCreatePtrFn; \
using KernelDefBuilder = onnxruntime::ep::adapter::KernelDefBuilder; \
using KernelRegistry = onnxruntime::ep::adapter::KernelRegistry; \
using KernelCreateInfo = onnxruntime::ep::adapter::KernelCreateInfo; \
using BuildKernelCreateInfoFn = onnxruntime::ep::adapter::KernelCreateInfo (*)(); \
using OpKernelInfo = onnxruntime::ep::adapter::OpKernelInfo; \
using OpKernelContext = onnxruntime::ep::adapter::OpKernelContext; \
using OpKernel = onnxruntime::ep::adapter::OpKernel; \
using DataTransferManager = onnxruntime::ep::adapter::DataTransferManager; \
namespace logging { \
using Logger = onnxruntime::ep::adapter::Logger; \
}

namespace onnxruntime {
namespace webgpu {
EP_SPECIFIC_USING_DECLARATIONS
} // namespace webgpu
namespace cuda {
EP_SPECIFIC_USING_DECLARATIONS
} // namespace cuda

#ifndef DISABLE_CONTRIB_OPS
namespace contrib {
namespace webgpu {
EP_SPECIFIC_USING_DECLARATIONS
} // namespace webgpu
namespace cuda {
EP_SPECIFIC_USING_DECLARATIONS
} // namespace cuda
} // namespace contrib
#endif

} // namespace onnxruntime
45 changes: 45 additions & 0 deletions include/onnxruntime/ep/adapter/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/allocator.h"

namespace onnxruntime {
namespace ep {
namespace adapter {

/// <summary>
/// A bridge class between the EP API OrtAllocator and an IAllocator implementation.
/// </summary>
class Allocator : public OrtAllocator {
public:
explicit Allocator(AllocatorPtr impl) : OrtAllocator{}, impl_(impl) {
version = ORT_API_VERSION;
Alloc = AllocImpl;
Free = FreeImpl;
Info = InfoImpl;
}

private:
static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept {
auto* allocator = static_cast<Allocator*>(this_ptr);
return allocator->impl_->Alloc(size);
}

static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept {
auto* allocator = static_cast<Allocator*>(this_ptr);
allocator->impl_->Free(p);
}

static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept {
auto* allocator = static_cast<const Allocator*>(this_ptr);
return &allocator->impl_->Info();
}

AllocatorPtr impl_;
};

} // namespace adapter
} // namespace ep
} // namespace onnxruntime
53 changes: 53 additions & 0 deletions include/onnxruntime/ep/adapter/data_transfer_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
#error "This header should not be included directly. Include ep/_pch.h instead."
#endif

#include "core/common/status.h"
#include "core/common/common.h"
#include "core/framework/data_transfer.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
namespace ep {
namespace adapter {

/// <summary>
/// </summary>
struct DataTransferManager {
explicit DataTransferManager(std::unique_ptr<IDataTransfer> impl) : impl_{std::move(impl)} {}

Check warning on line 22 in include/onnxruntime/ep/adapter/data_transfer_manager.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: include/onnxruntime/ep/adapter/data_transfer_manager.h:22: Add #include <utility> for move [build/include_what_you_use] [4]

common::Status CopyTensor(const Tensor& src, Tensor& dst) const {
if (src.Shape().Size() != dst.Shape().Size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME,
FAIL,
"Tensor size mismatch: source tensor size is ",
src.Shape().Size(),
", destination tensor size is ",
dst.Shape().Size());
}

if (impl_->CanCopy(src.Location().device, dst.Location().device)) {
return impl_->CopyTensor(src, dst);
}

return ORT_MAKE_STATUS(ONNXRUNTIME,
FAIL,
"There's no data transfer registered for copying tensors from ",
src.Location().device.ToString(),
" to ",
dst.Location().device.ToString());
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);
std::unique_ptr<IDataTransfer> impl_;

Check warning on line 48 in include/onnxruntime/ep/adapter/data_transfer_manager.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/ep/adapter/data_transfer_manager.h:48: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
};

} // namespace adapter
} // namespace ep
} // namespace onnxruntime
58 changes: 58 additions & 0 deletions include/onnxruntime/ep/adapter/ep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
#error "This header should not be included directly. Include ep/_pch.h instead."
#endif

#include "data_transfer_manager.h"

Check warning on line 10 in include/onnxruntime/ep/adapter/ep.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: include/onnxruntime/ep/adapter/ep.h:10: Include the directory when naming header files [build/include_subdir] [4]

#include "core/framework/execution_provider.h"

namespace onnxruntime {
namespace ep {
namespace adapter {

/// <summary>
/// Wrapper around IExecutionProvider to expose via OrtEp.
/// </summary>
class Ep : public OrtEp {
protected:
explicit Ep(IExecutionProvider* impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator)
: OrtEp{},
impl_(impl),
data_transfer_manager_{impl->GetDataTransfer()},
profiler_{impl->GetProfiler()},
temp_space_cpu_allocator_{temp_space_cpu_allocator},
temp_space_allocator_{temp_space_allocator} {
}

public:
inline IExecutionProvider* EpImpl() const noexcept {
return impl_.get();
}
inline const DataTransferManager& GetDataTransferManager() const noexcept {
return data_transfer_manager_;
}
[[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const {
*output = temp_space_cpu_allocator_;
return Status::OK();
}
[[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const {
*output = temp_space_allocator_;
return Status::OK();
}

private:
std::unique_ptr<IExecutionProvider> impl_;
DataTransferManager data_transfer_manager_;
std::unique_ptr<profiling::EpProfiler> profiler_;

Check warning on line 51 in include/onnxruntime/ep/adapter/ep.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/ep/adapter/ep.h:51: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
AllocatorPtr temp_space_cpu_allocator_;
AllocatorPtr temp_space_allocator_;
};

} // namespace adapter
} // namespace ep
} // namespace onnxruntime
140 changes: 140 additions & 0 deletions include/onnxruntime/ep/adapter/kernel_def_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
#error "This header should not be included directly. Include ep/_pch.h instead."
#endif

#include <memory>

#include "core/framework/data_types.h"

namespace onnxruntime {
namespace ep {
namespace adapter {

/// <summary>
/// Gets an OrtMLDataType for a tensor type. Throws on error.
/// </summary>
/// <param name="elem_type"></param>
/// <returns></returns>
inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) {
const OrtEpApi& ep_api = Ort::GetEpApi();
const OrtDataType* result = nullptr;

Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result));
return result;
}

inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) {
auto tensor_type = ml_type->AsTensorType();
EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types.");
auto elem_type = tensor_type->GetElementType();
auto primitive_type = static_cast<const PrimitiveDataTypeBase*>(elem_type);
auto onnx_type = static_cast<ONNXTensorElementDataType>(primitive_type->GetDataType());
return GetTensorType(onnx_type);
}

struct KernelDefBuilder {
static std::unique_ptr<KernelDefBuilder> Create() { return std::make_unique<KernelDefBuilder>(); }

explicit KernelDefBuilder() {}

KernelDefBuilder& SetName(const char* op_name) {
builder_.SetOperatorType(op_name);
return *this;
}

KernelDefBuilder& SetDomain(const char* domain) {
builder_.SetDomain(domain);
return *this;
}

KernelDefBuilder& SinceVersion(int since_version) {
return SinceVersion(since_version, INT_MAX);
}

KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) {
builder_.SetSinceVersion(since_version_start, since_version_end);
return *this;
}

KernelDefBuilder& Provider(const char* provider_type) {
builder_.SetExecutionProvider(provider_type);
return *this;
}

KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector<MLDataType> types) {
std::vector<const OrtDataType*> ort_types;
ort_types.reserve(types.size());
for (const auto& type : types) {
ort_types.push_back(MLDataTypeToOrtDataType(type));
}
builder_.AddTypeConstraint(arg_name, ort_types);
return *this;
}

KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) {
builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type));
return *this;
}

KernelDefBuilder& MayInplace(const std::vector<std::pair<int, int>>& inplaces) {
for (const auto& pair : inplaces) {
builder_.AddInputOutputMutableAlias(pair.first, pair.second);
}
return *this;
}
KernelDefBuilder& MayInplace(int input_index, int output_index) {
builder_.AddInputOutputMutableAlias(input_index, output_index);
return *this;
}

KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases) {
for (const auto& pair : aliases) {
builder_.AddInputOutputAlias(pair.first, pair.second);
}
return *this;
}
KernelDefBuilder& Alias(int input_index, int output_index) {
builder_.AddInputOutputAlias(input_index, output_index);
return *this;
}

KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) {
builder_.SetInputMemType(input_index, type);
return *this;
}

KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector<int>& input_indexes) {
for (int input_index : input_indexes) {
builder_.SetInputMemType(input_index, type);
}
return *this;
}

KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) {
builder_.SetOutputMemType(output_index, type);
return *this;
}

KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector<int>& output_indexes) {

Check warning on line 123 in include/onnxruntime/ep/adapter/kernel_def_builder.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/ep/adapter/kernel_def_builder.h:123: Add #include <vector> for vector<> [build/include_what_you_use] [4]
for (int output_index : output_indexes) {
builder_.SetOutputMemType(output_index, type);
}
return *this;
}

KernelDefBuilder& ExecQueueId(int queue_id) { return *this; }

Ort::KernelDef Build() { return builder_.Build(); }

private:
Ort::KernelDefBuilder builder_;
};

} // namespace adapter
} // namespace ep
} // namespace onnxruntime
Loading
Loading