-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Out-Tree EP feature #21450
base: main
Are you sure you want to change the base?
[WIP] Out-Tree EP feature #21450
Changes from 48 commits
0e6a80c
c30a639
7bfe57e
8e7d28d
808bfc3
49e396c
e790105
92f529d
3d83ed1
e29499a
f3678c4
ac5ae0a
0cc78e8
740a687
dad6397
94e9cf7
8698517
3d5d2bf
1f10c28
5e46d0f
85c168d
7bdb36a
7d915b7
4aea94b
865a17f
2811541
c97b19f
36f97b5
2fc7aac
4ad6993
53c736f
5fcb972
c3bb437
d1c657c
3efac97
766fec9
ea2465c
76a9305
330cdb6
6fd50f0
681585f
7db20cb
ff782e0
1d7b2df
a407944
f871b25
e84f00c
5b2de22
b1f8e2a
7acaaab
d150a03
da5b6eb
d280e59
cbe98e7
1529059
fa549f8
a28ad38
aa49805
bc65613
a1a3eea
0fe5f01
6bae1b9
ab75d98
c5510f2
08e3f20
b0b3123
9dbb0b1
5a59803
999e7fd
084f735
2b1cfdf
e337d8f
afe92e1
63f8774
bf359a1
c267ea5
72afdc4
6822206
c8ddc73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/session/onnxruntime_c_api_ep.h" | ||
#include <unordered_map> | ||
#include <string> | ||
#include <set> | ||
|
||
struct OrtTypeConstraints { | ||
bool AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type); | ||
inline const std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>& GetTypeConstraints() const { return type_constraints_; }; | ||
private: | ||
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>> type_constraints_; | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
#include "core/platform/threadpool.h" | ||
#include "core/common/logging/logging.h" | ||
#include "core/framework/allocator.h" | ||
#include "core/session/onnxruntime_c_api_ep.h" | ||
|
||
struct OrtThreadingOptions; | ||
namespace onnxruntime { | ||
|
@@ -88,6 +89,10 @@ class Environment { | |
*/ | ||
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr); | ||
|
||
void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given SessionOptionsAppendOrtExecutionProvider allows the user to register the instance of the EP, when do we need this factory? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is another C API RegisterOrtExecutionProviderLibrary which will load the shared library, create plugin EP factory and save it in the Environment. Please see the implementation of RegisterOrtExecutionProviderLibrary and the usage in test.cpp as examples There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to a new Name. Hope it is more clear now. |
||
|
||
OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const std::string& ep_name); | ||
|
||
private: | ||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); | ||
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager, | ||
|
@@ -99,5 +104,6 @@ class Environment { | |
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_; | ||
bool create_global_thread_pools_{false}; | ||
std::vector<AllocatorPtr> shared_allocators_; | ||
std::unordered_map<std::string, std::unique_ptr<OrtExecutionProviderFactory>> custom_ep_factories_; | ||
}; | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,6 +304,12 @@ ORT_RUNTIME_CLASS(Op); | |
ORT_RUNTIME_CLASS(OpAttr); | ||
ORT_RUNTIME_CLASS(Logger); | ||
ORT_RUNTIME_CLASS(ShapeInferContext); | ||
ORT_RUNTIME_CLASS(KernelInfo); | ||
ORT_RUNTIME_CLASS(KernelContext); | ||
ORT_RUNTIME_CLASS(CustomOp); | ||
ORT_RUNTIME_CLASS(KernelRegistry); | ||
ORT_RUNTIME_CLASS(TypeConstraints); | ||
ORT_RUNTIME_CLASS(Device); | ||
|
||
#ifdef _WIN32 | ||
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; | ||
|
@@ -364,13 +370,6 @@ typedef enum OrtLanguageProjection { | |
ORT_PROJECTION_NODEJS = 6, | ||
} OrtLanguageProjection; | ||
|
||
struct OrtKernelInfo; | ||
typedef struct OrtKernelInfo OrtKernelInfo; | ||
struct OrtKernelContext; | ||
typedef struct OrtKernelContext OrtKernelContext; | ||
struct OrtCustomOp; | ||
typedef struct OrtCustomOp OrtCustomOp; | ||
|
||
typedef enum OrtAllocatorType { | ||
OrtInvalidAllocator = -1, | ||
OrtDeviceAllocator = 0, | ||
|
@@ -395,6 +394,13 @@ typedef enum OrtMemoryInfoDeviceType { | |
OrtMemoryInfoDeviceType_FPGA = 2 | ||
} OrtMemoryInfoDeviceType; | ||
|
||
typedef enum OrtMemoryType { | ||
OrtMemoryType_Default = 0, | ||
OrtMemoryType_CUDA_PINNED = 1, | ||
OrtMemoryType_HIP_PINNED = 2, | ||
OrtMemoryType_CANN_PINNED = 3, | ||
} OrtMemoryType; | ||
|
||
Comment on lines
+397
to
+403
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid having EP specific enum values in the public API? If 'pinned' equates to 'visible by device and host' there's hopefully a better place to plugin device specific info like CUDA/HIP/CANN than the high level memory type. |
||
/** \brief Algorithm to use for cuDNN Convolution Op | ||
*/ | ||
typedef enum OrtCudnnConvAlgoSearch { | ||
|
@@ -658,6 +664,9 @@ typedef struct OrtApi OrtApi; | |
struct OrtTrainingApi; | ||
typedef struct OrtTrainingApi OrtTrainingApi; | ||
|
||
struct OrtGraphApi; | ||
typedef struct OrtGraphApi OrtGraphApi; | ||
|
||
/** \brief The helper interface to get the right version of OrtApi | ||
* | ||
* Get a pointer to this structure through ::OrtGetApiBase | ||
|
@@ -4310,9 +4319,6 @@ struct OrtApi { | |
*/ | ||
const char*(ORT_API_CALL* GetBuildInfoString)(void); | ||
|
||
/// \name OrtROCMProviderOptions | ||
/// @{ | ||
|
||
/** \brief Create an OrtROCMProviderOptions | ||
* | ||
* \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions | ||
|
@@ -4665,7 +4671,32 @@ struct OrtApi { | |
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, | ||
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, | ||
size_t num_external_initializer_files); | ||
}; | ||
|
||
ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); | ||
|
||
ORT_API2_STATUS(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please document all functions including 'since version ...' info. It's easier to review if the intent of the API is documented. nit: Does 'device' need to be in the function name twice? #Resolved |
||
|
||
ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); | ||
|
||
ORT_API2_STATUS(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out); | ||
|
||
ORT_CLASS_RELEASE(Device); | ||
|
||
ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); | ||
|
||
ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would |
||
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); | ||
|
||
ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: slightly more readable if this is after the type constraint functions given it takes OrtTypeConstaints as an input. #Resolved |
||
|
||
ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); | ||
|
||
ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); | ||
|
||
ORT_CLASS_RELEASE(TypeConstraints); | ||
|
||
const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION; | ||
}; // struct OrtApi | ||
|
||
/* | ||
* Steps to use a custom op: | ||
|
@@ -4693,6 +4724,13 @@ typedef enum OrtCustomOpInputOutputCharacteristic { | |
* the implementor of the custom op. | ||
*/ | ||
struct OrtCustomOp { | ||
#ifdef __cplusplus | ||
// TODO(leca): initialize all member function pointers to nullptr? | ||
OrtCustomOp() : CreateKernel{nullptr}, KernelCompute{nullptr}, KernelDestroy{nullptr}, GetInputCharacteristic{nullptr}, | ||
GetOutputCharacteristic{nullptr}, GetVariadicInputMinArity{nullptr}, GetVariadicOutputMinArity{nullptr}, | ||
GetStartVersion{nullptr}, GetEndVersion{nullptr}, GetMayInplace{nullptr}, ReleaseMayInplace{nullptr}, | ||
GetAliasMap{nullptr}, ReleaseAliasMap{nullptr} {} | ||
#endif | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What issue is this trying to address? It seems inconsistent to not initialize these values if __cplusplus is not defined. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the kernel based plugin EP which leverages OrtCustomOp to register the compute function, the member variables of OrtCustomOp will be randomly assigned if we don't use this way to do the initialization. Please check kernel_ep.cc and MemcpyFromHost in samples/tensorRTEp/tensorrt_execution_provider.cc for the detail. |
||
uint32_t version; // Must be initialized to ORT_API_VERSION | ||
|
||
// This callback creates the kernel, which is a user defined | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,185 @@ | ||||||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||||||
// Licensed under the MIT License. | ||||||
|
||||||
#pragma once | ||||||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. |
||||||
#include "onnxruntime_c_api.h" | ||||||
|
||||||
ORT_RUNTIME_CLASS(ExecutionProvider); | ||||||
ORT_RUNTIME_CLASS(ExecutionProviderFactory); | ||||||
ORT_RUNTIME_CLASS(Node); | ||||||
ORT_RUNTIME_CLASS(Graph); | ||||||
ORT_RUNTIME_CLASS(GraphViewer); | ||||||
|
||||||
typedef struct OrtCreateStream { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The names of many things here feel like they have lost a lot of context. They're fine if you're converting an existing internal EP, but if someone was to try and write a new OOT EP I think it would be quite confusing. Things like 'stream', 'meta def' and 'indexed sub graph' don't (at least to me) have a clear intuitive meaning. Would it be better to view this from the perspective of an external EP author instead of always being 1:1 with internal names? |
||||||
int device_type; | ||||||
void*(ORT_API_CALL* CreateStreamFunc)(const OrtDevice*); | ||||||
} OrtCreateStream; | ||||||
|
||||||
typedef struct OrtMetaDef { | ||||||
char* name; | ||||||
char* domain; | ||||||
int since_version; | ||||||
|
||||||
char** inputs; | ||||||
size_t input_len; | ||||||
char** outputs; | ||||||
size_t output_len; | ||||||
char** constant_initializers; | ||||||
size_t initializer_len; | ||||||
|
||||||
char* doc_string; | ||||||
} OrtMetaDef; | ||||||
|
||||||
typedef struct OrtIndexedSubGraph { | ||||||
OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? | ||||||
size_t* node_index; | ||||||
size_t node_index_len; | ||||||
} OrtIndexedSubGraph; | ||||||
|
||||||
typedef struct OrtComputeContext { | ||||||
void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t); | ||||||
void(ORT_API_CALL* DestroyFunc)(void*, void*); | ||||||
void* allocator_handle; | ||||||
const char* node_name; | ||||||
} OrtComputeContext; | ||||||
|
||||||
typedef struct OrtNodeComputeInfo { | ||||||
int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void*, void**); | ||||||
OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, void*, const OrtApi*, OrtKernelContext*); | ||||||
void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); | ||||||
} OrtNodeComputeInfo; | ||||||
|
||||||
typedef struct OrtTensorRef { // TODO(leca): OrtValueInfoRef inside OrtTensorRef? | ||||||
int64_t* shape; | ||||||
size_t shape_len; | ||||||
ONNXTensorElementDataType data_type; | ||||||
const char* data; | ||||||
size_t data_len; | ||||||
} OrtTensorRef; | ||||||
|
||||||
typedef struct OrtValueInfoRef { | ||||||
int64_t* shape; | ||||||
size_t shape_len; | ||||||
ONNXTensorElementDataType data_type; | ||||||
} OrtValueInfoRef; | ||||||
|
||||||
typedef struct OrtExecutionProvider { | ||||||
#ifdef __cplusplus | ||||||
OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, CreatePreferredAllocators{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, | ||||||
extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} | ||||||
#endif | ||||||
void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); | ||||||
OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info); | ||||||
void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); | ||||||
bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); | ||||||
OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); | ||||||
int(ORT_API_CALL* CreatePreferredAllocators)(OrtExecutionProvider* this_, OrtAllocator*** ort_allocators); | ||||||
const char* type; | ||||||
OrtCreateStream* create_stream; | ||||||
const OrtDevice* default_device; | ||||||
void* extra_param_for_create_state_func; | ||||||
void* extra_param_for_compute_func; | ||||||
Comment on lines
+81
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this may be able to be replaced with a single |
||||||
} OrtExecutionProvider; | ||||||
|
||||||
typedef struct OrtExecutionProviderFactory { | ||||||
OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); | ||||||
} OrtExecutionProviderFactory; | ||||||
|
||||||
struct OrtGraphApi { | ||||||
const char*(ORT_API_CALL* OrtGraph_GetName)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
bool(ORT_API_CALL* OrtGraph_IsConstantInitializer)(const OrtGraphViewer* graph, const char* name, bool check_outer_scope)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
size_t(ORT_API_CALL* OrtGraph_GetNodesIndexInTopologicalOrder)(const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order); | ||||||
|
||||||
bool(ORT_API_CALL* OrtGraph_IsSubgraph)(const OrtGraph* graph); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to expose OrtGraph? Can we add GetParentGraph in GridViewer so that we don't need expose OrtGraph? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice if we didn't. Having Graph and GraphViewer is going to be confusing to an EP author. |
||||||
|
||||||
const OrtGraph*(ORT_API_CALL* OrtGraph_GetParentGraph)(const OrtGraph* graph); | ||||||
|
||||||
const OrtNode*(ORT_API_CALL* OrtGraph_GetParenNode)(const OrtGraphViewer* graph); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't get it, what are you suggesting to change? split one line to 3 lines? The signature of the C APIs has been changed, to return OrtStatus* in order to make it consistent with the existing C APIs |
||||||
|
||||||
const void*(ORT_API_CALL* OrtGraph_GetModelPath)(const OrtGraphViewer* graph); | ||||||
|
||||||
const OrtGraph*(ORT_API_CALL* OrtGraph_GetOrtGraph)(const OrtGraphViewer* graph_viewer); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtGraph_GetInputsIncludingInitializers)(const OrtGraphViewer* graph, _Outptr_ const char*** input_names); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we maybe simplify? ONNX IR versions less than 4 had all initializers as part of the graph inputs so there was complexity/confusion around whether getting the graph inputs included the initializers or not, and when an initializer was in the list of graph inputs whether it was allowed to be overridden. That was a long time ago now. Would it be easier if the API was in terms of 'required inputs' (input with no backing initializer) and 'optional inputs' (input with backing initializer that can be overridden - i.e. a non-const initializer). 'all' returns the combined set of required an optional. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the following APIs: OrtGraph_GetRequiredInputs: Gets the Graph inputs with no matching initializers OrtGraph_GetAllInputs: Gets the Graph inputs with matching initializers OrtGraph_GetAllInitializers: Gets all the Graph initializers' name and delete this one. Since the graph viewer class does not expose graph_overridable_initializers_ ('optional inputs'), The caller can invoke 'GetAllInputs' and 'GetRequiredInputs' and do the minus themselves Please review the implementation to check if I understand correctly |
||||||
|
||||||
const OrtNode*(ORT_API_CALL* OrtGraph_GetOrtNode)(const OrtGraphViewer* graph, size_t node_index); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtGraph_GetNodesConsumingInput)(const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ? | ||||||
|
||||||
const OrtNode*(ORT_API_CALL* OrtGraph_GetNodeProducingOutput)(const OrtGraphViewer* graph, const char* output_name); | ||||||
|
||||||
int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
int(ORT_API_CALL* OrtGraph_MaxNodeIndex)(const OrtGraphViewer* graph); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
bool(ORT_API_CALL* OrtGraph_GetInitializerTensor)(const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**); | ||||||
|
||||||
bool(ORT_API_CALL* OrtGraph_GetValueInfo)(const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; // TODO(leca): review and discuss | ||||||
|
||||||
ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add an additional graph api or feature to check whether the input of the subgraph has shape info? For TRT EP, TRT requires onnx model to have input shape provided, so that's why TRT EP checks the input shape of the subgraph, if not it suggests user to run symbolic_shape_infer.py first. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally this is via more basic functions. An API to get the graph input names, and an API to get the NodeArg entry (in some more easily digested format) for that value. The checking whether it has shape info or not seem EP specific and not something the general API needs to directly support. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have that already: OrtGraph_GetInputsIncludingInitializers (naming is subject to change per Scott's comment) and OrtGraph_GetValueInfo |
||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetName)(const OrtNode* node); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetDescription)(const OrtNode* node); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetDomain)(const OrtNode* node); | ||||||
|
||||||
int(ORT_API_CALL* OrtNode_SinceVersion)(const OrtNode* node); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetExecutionProviderType)(const OrtNode* node); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetOpType)(const OrtNode* node); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetImplicitInputSize)(const OrtNode* node); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetIthImplicitInputName)(const OrtNode* node, size_t i); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetInputSize)(const OrtNode* node); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slightly unclear if these are returning the size of a specific input/output. Would GetNumImplicitInputs/GetNumInputs/GetNumOutputs be better? #Resolved |
||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetIthInputName)(const OrtNode* node, size_t i); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetOutputSize)(const OrtNode* node); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetIthOutputName)(const OrtNode* node, size_t i); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetIndex)(const OrtNode* node); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetAttributeNames)(const OrtNode*, _Out_ const char*** names); | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetAttributeSize)(const OrtNode* node); | ||||||
|
||||||
int(ORT_API_CALL* OrtNode_GetAttributeType)(const OrtNode* node, const char* attribute)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; // AttributeProto_AttributeType | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetAttributeKeyCount)(const OrtNode* node, const char* key); | ||||||
|
||||||
int(ORT_API_CALL* OrtNode_GetAttributeIntSize)(const OrtNode* node, const char* key); | ||||||
|
||||||
int(ORT_API_CALL* OrtNode_GetAttributeFloatSize)(const OrtNode* node, const char* key); | ||||||
|
||||||
int(ORT_API_CALL* OrtNode_GetAttributeStringSize)(const OrtNode* node, const char* key); | ||||||
|
||||||
int64_t(ORT_API_CALL* OrtNode_GetAttributeIthInt)(const OrtNode* node, const char* key, int i); | ||||||
|
||||||
float(ORT_API_CALL* OrtNode_GetAttributeIthFloat)(const OrtNode* node, const char* key, int i); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetAttributeIthStr)(const OrtNode* node, const char* key, int i); | ||||||
|
||||||
const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||||||
|
||||||
size_t(ORT_API_CALL* OrtNode_GetSubgraphs)(const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs); | ||||||
}; | ||||||
typedef struct OrtGraphApi OrtGraphApi; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. |
||
// Licensed under the MIT License. | ||
|
||
#include "core/framework/ort_type_constraints.h" | ||
|
||
bool OrtTypeConstraints::AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type) { | ||
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>::iterator iter = type_constraints_.find(type_symbol); | ||
if (iter == type_constraints_.end()) { | ||
std::set<ONNXTensorElementDataType> types{type}; | ||
type_constraints_[type_symbol] = types; | ||
return true; | ||
} | ||
return (iter->second).insert(type).second; | ||
} |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning