-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Out-Tree EP feature #21450
base: main
Are you sure you want to change the base?
[WIP] Out-Tree EP feature #21450
Changes from all commits
0e6a80c
c30a639
7bfe57e
8e7d28d
808bfc3
49e396c
e790105
92f529d
3d83ed1
e29499a
f3678c4
ac5ae0a
0cc78e8
740a687
dad6397
94e9cf7
8698517
3d5d2bf
1f10c28
5e46d0f
85c168d
7bdb36a
7d915b7
4aea94b
865a17f
2811541
c97b19f
36f97b5
2fc7aac
4ad6993
53c736f
5fcb972
c3bb437
d1c657c
3efac97
766fec9
ea2465c
76a9305
330cdb6
6fd50f0
681585f
7db20cb
ff782e0
1d7b2df
a407944
f871b25
e84f00c
5b2de22
b1f8e2a
7acaaab
d150a03
da5b6eb
d280e59
cbe98e7
1529059
fa549f8
a28ad38
aa49805
bc65613
a1a3eea
0fe5f01
6bae1b9
ab75d98
c5510f2
08e3f20
b0b3123
9dbb0b1
5a59803
999e7fd
084f735
2b1cfdf
e337d8f
afe92e1
63f8774
bf359a1
c267ea5
72afdc4
6822206
c8ddc73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/session/onnxruntime_c_api_ep.h" | ||
#include <unordered_map> | ||
#include <string> | ||
#include <set> | ||
|
||
struct OrtTypeConstraints { | ||
bool AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type); | ||
inline const std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>& GetTypeConstraints() const { return type_constraints_; }; | ||
private: | ||
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>> type_constraints_; | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,6 +304,12 @@ ORT_RUNTIME_CLASS(Op); | |
ORT_RUNTIME_CLASS(OpAttr); | ||
ORT_RUNTIME_CLASS(Logger); | ||
ORT_RUNTIME_CLASS(ShapeInferContext); | ||
ORT_RUNTIME_CLASS(KernelInfo); | ||
ORT_RUNTIME_CLASS(KernelContext); | ||
ORT_RUNTIME_CLASS(CustomOp); | ||
ORT_RUNTIME_CLASS(KernelRegistry); | ||
ORT_RUNTIME_CLASS(TypeConstraints); | ||
ORT_RUNTIME_CLASS(Device); | ||
|
||
#ifdef _WIN32 | ||
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; | ||
|
@@ -364,13 +370,6 @@ typedef enum OrtLanguageProjection { | |
ORT_PROJECTION_NODEJS = 6, | ||
} OrtLanguageProjection; | ||
|
||
struct OrtKernelInfo; | ||
typedef struct OrtKernelInfo OrtKernelInfo; | ||
struct OrtKernelContext; | ||
typedef struct OrtKernelContext OrtKernelContext; | ||
struct OrtCustomOp; | ||
typedef struct OrtCustomOp OrtCustomOp; | ||
|
||
typedef enum OrtAllocatorType { | ||
OrtInvalidAllocator = -1, | ||
OrtDeviceAllocator = 0, | ||
|
@@ -395,6 +394,13 @@ typedef enum OrtMemoryInfoDeviceType { | |
OrtMemoryInfoDeviceType_FPGA = 2 | ||
} OrtMemoryInfoDeviceType; | ||
|
||
typedef enum OrtMemoryType { | ||
OrtMemoryType_Default = 0, | ||
OrtMemoryType_CUDA_PINNED = 1, | ||
OrtMemoryType_HIP_PINNED = 2, | ||
OrtMemoryType_CANN_PINNED = 3, | ||
} OrtMemoryType; | ||
|
||
Comment on lines
+397
to
+403
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid having EP specific enum values in the public API? If 'pinned' equates to 'visible by device and host' there's hopefully a better place to plugin device specific info like CUDA/HIP/CANN than the high level memory type. |
||
/** \brief Algorithm to use for cuDNN Convolution Op | ||
*/ | ||
typedef enum OrtCudnnConvAlgoSearch { | ||
|
@@ -658,6 +664,9 @@ typedef struct OrtApi OrtApi; | |
struct OrtTrainingApi; | ||
typedef struct OrtTrainingApi OrtTrainingApi; | ||
|
||
struct OrtGraphApi; | ||
typedef struct OrtGraphApi OrtGraphApi; | ||
|
||
/** \brief The helper interface to get the right version of OrtApi | ||
* | ||
* Get a pointer to this structure through ::OrtGetApiBase | ||
|
@@ -4310,9 +4319,6 @@ struct OrtApi { | |
*/ | ||
const char*(ORT_API_CALL* GetBuildInfoString)(void); | ||
|
||
/// \name OrtROCMProviderOptions | ||
/// @{ | ||
|
||
/** \brief Create an OrtROCMProviderOptions | ||
* | ||
* \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions | ||
|
@@ -4665,7 +4671,128 @@ struct OrtApi { | |
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, | ||
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, | ||
size_t num_external_initializer_files); | ||
}; | ||
|
||
/** \brief Create OrtDevice object. | ||
* | ||
* \param[in] device_type | ||
* \param[in] memory_type | ||
* \param[in] device_id | ||
* \param[out] out OrtDevice object | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); | ||
|
||
/** \brief Get OrtMemoryInfoDeviceType property from OrtDevice object. | ||
* | ||
* \param[in] device OrtDevice object | ||
* \param[out] out OrtMemoryInfoDeviceType property | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(DeviceGetType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); | ||
|
||
/** \brief Get OrtMemoryType property from OrtDevice object. | ||
* | ||
* \param[in] device OrtDevice object | ||
* \param[out] out OrtMemoryType property | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); | ||
|
||
/** \brief Get device id property from OrtDevice object. | ||
* | ||
* \param[in] device OrtDevice object | ||
* \param[out] out device id property | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(DeviceGetId, _In_ const OrtDevice* device, _Out_ int16_t* out); | ||
|
||
/** \brief Release OrtDevice object. | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_CLASS_RELEASE(Device); | ||
|
||
/** \brief Register the plugin ExecutionProvider library | ||
* | ||
* The plugin ExecutionProvider library will be loaded and EP factory object will be created and saved in OrtEnv object | ||
* | ||
* \param[in] lib_path the path of the plugin ExecutionProvider library | ||
* \param[in] env OrtEnv object | ||
* \param[in] ep_name the plugin ExecutionProvider name | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(RegisterPluginExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); | ||
|
||
/** \brief Unregister the plugin ExecutionProvider library | ||
* | ||
* The plugin ExecutionProvider factory will be removed from OrtEnv object | ||
* | ||
* \param[in] env OrtEnv object | ||
* \param[in] ep_name the plugin ExecutionProvider name | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(UnregisterPluginExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name); | ||
|
||
/** \brief Append the plugin ExecutionProvider factory into the session option with provider options | ||
* | ||
* \param[in] options OrtSessionOptions object | ||
* \param[in] ep_name the plugin ExecutionProvider name | ||
* \param[in] env OrtEnv object | ||
* \param[in] provider_options_keys provider options' keys | ||
* \param[in] provider_options_values provider options' values | ||
* \param[in] num_keys the number of the provider options' key-value pairs | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this can be removed? We may be able to use the existing C API function called SessionOptionsAppendExecutionProvider. The existing C API does not take an OrtEnv parameter, but we can just get the default OrtEnv since there is only one per process. |
||
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); | ||
|
||
/** \brief Create OrtTypeConstraints object | ||
* | ||
* \param[out] OrtTypeConstraints object | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); | ||
|
||
/** \brief Add a specific type constraint into OrtTypeConstraints object | ||
* | ||
* \param[in] type_constraints OrtTypeConstraints object | ||
* \param[in] type_symbol symbol string to represent a specific type | ||
* \param[in] type a specific type | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); | ||
|
||
/** \brief Release OrtTypeConstraints object. | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_CLASS_RELEASE(TypeConstraints); | ||
|
||
/** \brief Create KernelCreateInfo with custom op and type constraints, and register it | ||
* | ||
* \param[in] kernel_registry Opaque pointer of KernelRegistry object | ||
* \param[in] custom_op Custom Op where the kernel compute function is defined | ||
* \param[in] type_constraints | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, _In_ OrtKernelRegistry* kernel_registry, _In_ OrtCustomOp* custom_op, _In_ OrtTypeConstraints* type_constraints); | ||
|
||
/** \brief Get Graph API | ||
* | ||
* \since Version 1.xx. | ||
*/ | ||
const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION; | ||
}; // struct OrtApi | ||
|
||
/* | ||
* Steps to use a custom op: | ||
|
@@ -4693,6 +4820,15 @@ typedef enum OrtCustomOpInputOutputCharacteristic { | |
* the implementor of the custom op. | ||
*/ | ||
struct OrtCustomOp { | ||
#ifdef __cplusplus | ||
OrtCustomOp() : CreateKernel{nullptr}, GetName{nullptr}, GetExecutionProviderType{nullptr}, GetInputType{nullptr}, | ||
GetInputTypeCount{nullptr}, GetOutputType{nullptr}, GetOutputTypeCount{nullptr}, KernelCompute{nullptr}, | ||
KernelDestroy{nullptr}, GetInputCharacteristic{nullptr}, GetOutputCharacteristic{nullptr}, | ||
GetInputMemoryType{nullptr}, GetVariadicInputMinArity{nullptr}, GetVariadicInputHomogeneity{nullptr}, | ||
GetVariadicOutputMinArity{nullptr}, GetVariadicOutputHomogeneity{nullptr}, CreateKernelV2{nullptr}, | ||
KernelComputeV2{nullptr}, InferOutputShapeFn{nullptr}, GetStartVersion{nullptr}, GetEndVersion{nullptr}, | ||
GetMayInplace{nullptr}, ReleaseMayInplace{nullptr}, GetAliasMap{nullptr}, ReleaseAliasMap{nullptr} {} | ||
#endif | ||
uint32_t version; // Must be initialized to ORT_API_VERSION | ||
|
||
// This callback creates the kernel, which is a user defined | ||
|
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning