Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Out-Tree EP feature #21450

Draft
wants to merge 79 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
0e6a80c
opaque pointer for graph
jslhcl Jul 17, 2024
c30a639
ORT C API RegisterOrtExecutionProviderLibrary work
jslhcl Jul 23, 2024
7bfe57e
ORT C-API SessionOptionsAppendOrtExecutionProvider work
jslhcl Jul 23, 2024
8e7d28d
Test Relu with compile based EP, build work, runtime error of loading…
jslhcl Jul 26, 2024
808bfc3
prototype works with hardcode node_compute_info's index in ExecutionP…
jslhcl Jul 29, 2024
49e396c
prototype works without hardcode
jslhcl Jul 29, 2024
e790105
fix comments for Compile function
jslhcl Jul 31, 2024
92f529d
add provider_factory_adapter.h
jslhcl Aug 1, 2024
3d83ed1
fix crash after introducing kernel based EP
jslhcl Aug 5, 2024
e29499a
kernel based EP work with type constraint check commented out
jslhcl Aug 6, 2024
f3678c4
add kernel type constraints from out tree EP
jslhcl Aug 7, 2024
ac5ae0a
add API ReleaseOrtTypeConstraints
jslhcl Aug 7, 2024
0cc78e8
introduce qnn ep
jslhcl Aug 12, 2024
740a687
more graph/node C API
jslhcl Aug 13, 2024
dad6397
stream support
jslhcl Aug 15, 2024
94e9cf7
support data transfer and OrtDevice in out tree EP API
jslhcl Aug 16, 2024
8698517
change compile return type from void to OrtStatusPtr
jslhcl Aug 20, 2024
3d5d2bf
add TensorRT dependency in tensorRT EP's CMakeLists.txt
jslhcl Aug 20, 2024
1f10c28
Add extra parameters in OrtExecutionProvider to avoid capture variabl…
jslhcl Aug 22, 2024
5e46d0f
add OrtGraph_SerializeToArray
jslhcl Aug 23, 2024
85c168d
finish Compile function
jslhcl Aug 24, 2024
7bdb36a
add override function implementation and cudart dependency for tensorrt
jslhcl Aug 26, 2024
7d915b7
add outOfTree tensorrt ep.1 (#21830)
guyang3532 Aug 27, 2024
4aea94b
GetSupportedList
jslhcl Aug 28, 2024
865a17f
GetSubGraph and TensorrtExecutionProviderInfo
jslhcl Aug 29, 2024
2811541
Add simple CUDA allocators for TRT EP (#21901)
chilo-ms Aug 29, 2024
c97b19f
add constructor for tensorrt ep and refine GetCapability (#21914)
guyang3532 Aug 29, 2024
36f97b5
relu can work on out tree TRT now
jslhcl Aug 29, 2024
2fc7aac
rebuild graph proto from scratch with the information needed from gra…
jslhcl Aug 31, 2024
4ad6993
complete the GetCapability (#21956)
guyang3532 Sep 2, 2024
53c736f
Chi's fix and reorder ep for registering shared resource
jslhcl Sep 4, 2024
5fcb972
complete the GetSubGraph (#21998)
guyang3532 Sep 5, 2024
c3bb437
run resnet18v1_7, crash on GetSubGraph()
jslhcl Sep 6, 2024
d1c657c
Merge branch 'leca/outOfTreeEP' of https://github.com/microsoft/onnxr…
jslhcl Sep 6, 2024
3efac97
resnet18-v1-7 works for TRT EP, with next_nodes_list assignment comme…
jslhcl Sep 6, 2024
766fec9
test cases for decoder and fast_rcnn, delete dynamic_cast in ShouldPo…
jslhcl Sep 9, 2024
ea2465c
add tensorrt home in CMakeLists, add trt and CUDA ep for test, change…
jslhcl Sep 11, 2024
76a9305
[WIP, DONT REVIEW] add initializer to graph proto (#22085)
jslhcl Sep 18, 2024
330cdb6
use parameter ExecutionOrder::PRIORITY_BASED for GraphViewerToProto()…
jslhcl Sep 19, 2024
6fd50f0
can create session with out tree trt ep now. Error:Name:'tensorrtEp_T…
jslhcl Sep 23, 2024
681585f
make trt_node_name_with_precision_ from string to map, to capture the…
jslhcl Sep 23, 2024
7db20cb
fix redundant inputs and outputs in GetSubgraph (#22201)
guyang3532 Sep 24, 2024
ff782e0
RunTinyYolov3()
jslhcl Sep 25, 2024
1d7b2df
fix bugs for run tinyYolo (#22233)
guyang3532 Sep 26, 2024
a407944
sample code to separate graph C API to different files
jslhcl Sep 26, 2024
f871b25
new test control_flow, error: ErrorMessage:Failed to find kernel for …
jslhcl Oct 2, 2024
e84f00c
control flow model works
jslhcl Oct 3, 2024
5b2de22
API refactor
jslhcl Oct 7, 2024
b1f8e2a
Python API
jslhcl Oct 14, 2024
7acaaab
fix memory leak (#22444)
guyang3532 Oct 15, 2024
d150a03
refactor all functions in onnxruntime_c_api_ep with status as return …
guyang3532 Oct 17, 2024
da5b6eb
resolve comments
jslhcl Oct 18, 2024
d280e59
add documents for all functions in c_api_ep (#22502)
guyang3532 Oct 18, 2024
cbe98e7
fix comments
jslhcl Oct 19, 2024
1529059
fix memory leak (#22522)
guyang3532 Oct 21, 2024
fa549f8
add mutex to plugin trt ep (#22581)
guyang3532 Oct 24, 2024
a28ad38
use std::mutex instead of OrtMutex and fix build error in Windows
jslhcl Oct 24, 2024
aa49805
openvino
jslhcl Oct 26, 2024
bc65613
openvino, GetCapability almost ready
jslhcl Oct 31, 2024
a1a3eea
openvino GetCapacity() is done. UnregisterPluginExecutionProviderLibrary
jslhcl Nov 1, 2024
0fe5f01
refine compile of openvino ep (#22689)
guyang3532 Nov 1, 2024
6bae1b9
Add utility files (#22650)
chilo-ms Nov 1, 2024
ab75d98
OpenVino, compile() is done
jslhcl Nov 2, 2024
c5510f2
Merge branch 'leca/outOfTreeEP' of https://github.com/microsoft/onnxr…
jslhcl Nov 2, 2024
08e3f20
Add unit test for TRT EP plugin (#22548)
chilo-ms Nov 2, 2024
b0b3123
add test for openvino plugin ep and fix bugs (#22734)
guyang3532 Nov 5, 2024
9dbb0b1
add missing mutex to plugin trt ep
chilo-ms Nov 6, 2024
5a59803
merge code
jslhcl Nov 6, 2024
999e7fd
Merge branch 'leca/outOfTreeEP' of https://github.com/microsoft/onnxr…
jslhcl Nov 6, 2024
084f735
fix bugs (#22744)
guyang3532 Nov 6, 2024
2b1cfdf
relu and resnet works in OpenVINO plugin
jslhcl Nov 7, 2024
e337d8f
Add OrtGraphApis::OrtNode_GetAttributeStrWithSize to handle case wher…
chilo-ms Nov 13, 2024
afe92e1
Make EP plugin be able to create and update EP Context graph (#22740)
chilo-ms Nov 13, 2024
63f8774
[TensorRT EP Plugin] use new graph api for ep context model generation
chilo-ms Nov 14, 2024
bf359a1
use cuda's preferred allocator for plugin trt and builtin cuda combin…
jslhcl Nov 16, 2024
c267ea5
[TensorRT EP Plugin] Add cuda::Impl_Cast (#22908)
chilo-ms Nov 20, 2024
72afdc4
fix build/compiler error for nvcc 11.8
chilo-ms Nov 22, 2024
6822206
Do not expose OrtGraph
jslhcl Dec 3, 2024
c8ddc73
initial commit for Graph C++ API
jslhcl Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class IExecutionProvider {
*/
const OrtDevice default_device_;

bool intree_ep = true;

public:
virtual ~IExecutionProvider() = default;

Expand Down Expand Up @@ -325,6 +327,8 @@ class IExecutionProvider {
return InlinedVector<const Node*>();
}

bool IsIntreeEp() const { return intree_ep; }

private:
const std::string type_;

Expand Down
15 changes: 15 additions & 0 deletions include/onnxruntime/core/framework/ort_type_constraints.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#pragma once
#include "core/session/onnxruntime_c_api_ep.h"
#include <unordered_map>
#include <string>
#include <set>

struct OrtTypeConstraints {
bool AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type);
inline const std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>& GetTypeConstraints() const { return type_constraints_; };
private:
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>> type_constraints_;
};
6 changes: 6 additions & 0 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/platform/threadpool.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/session/onnxruntime_c_api_ep.h"

struct OrtThreadingOptions;
namespace onnxruntime {
Expand Down Expand Up @@ -88,6 +89,10 @@ class Environment {
*/
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr);

void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory);
Copy link
Contributor

@skottmckay skottmckay Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given SessionOptionsAppendOrtExecutionProvider allows the user to register the instance of the EP, when do we need this factory? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another C API RegisterOrtExecutionProviderLibrary which will load the shared library, create plugin EP factory and save it in the Environment.

Please see the implementation of RegisterOrtExecutionProviderLibrary and the usage in test.cpp as examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to a new Name. Hope it is more clear now.


OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const std::string& ep_name);

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
Expand All @@ -99,5 +104,6 @@ class Environment {
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};
std::vector<AllocatorPtr> shared_allocators_;
std::unordered_map<std::string, std::unique_ptr<OrtExecutionProviderFactory>> custom_ep_factories_;
};
} // namespace onnxruntime
60 changes: 49 additions & 11 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ ORT_RUNTIME_CLASS(Op);
ORT_RUNTIME_CLASS(OpAttr);
ORT_RUNTIME_CLASS(Logger);
ORT_RUNTIME_CLASS(ShapeInferContext);
ORT_RUNTIME_CLASS(KernelInfo);
ORT_RUNTIME_CLASS(KernelContext);
ORT_RUNTIME_CLASS(CustomOp);
ORT_RUNTIME_CLASS(KernelRegistry);
ORT_RUNTIME_CLASS(TypeConstraints);
ORT_RUNTIME_CLASS(Device);

#ifdef _WIN32
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
Expand Down Expand Up @@ -364,13 +370,6 @@ typedef enum OrtLanguageProjection {
ORT_PROJECTION_NODEJS = 6,
} OrtLanguageProjection;

struct OrtKernelInfo;
typedef struct OrtKernelInfo OrtKernelInfo;
struct OrtKernelContext;
typedef struct OrtKernelContext OrtKernelContext;
struct OrtCustomOp;
typedef struct OrtCustomOp OrtCustomOp;

typedef enum OrtAllocatorType {
OrtInvalidAllocator = -1,
OrtDeviceAllocator = 0,
Expand All @@ -395,6 +394,13 @@ typedef enum OrtMemoryInfoDeviceType {
OrtMemoryInfoDeviceType_FPGA = 2
} OrtMemoryInfoDeviceType;

typedef enum OrtMemoryType {
OrtMemoryType_Default = 0,
OrtMemoryType_CUDA_PINNED = 1,
OrtMemoryType_HIP_PINNED = 2,
OrtMemoryType_CANN_PINNED = 3,
} OrtMemoryType;

Comment on lines +397 to +403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid having EP specific enum values in the public API? If 'pinned' equates to 'visible by device and host' there's hopefully a better place to plugin device specific info like CUDA/HIP/CANN than the high level memory type.

/** \brief Algorithm to use for cuDNN Convolution Op
*/
typedef enum OrtCudnnConvAlgoSearch {
Expand Down Expand Up @@ -658,6 +664,9 @@ typedef struct OrtApi OrtApi;
struct OrtTrainingApi;
typedef struct OrtTrainingApi OrtTrainingApi;

struct OrtGraphApi;
typedef struct OrtGraphApi OrtGraphApi;

/** \brief The helper interface to get the right version of OrtApi
*
* Get a pointer to this structure through ::OrtGetApiBase
Expand Down Expand Up @@ -4310,9 +4319,6 @@ struct OrtApi {
*/
const char*(ORT_API_CALL* GetBuildInfoString)(void);

/// \name OrtROCMProviderOptions
/// @{

/** \brief Create an OrtROCMProviderOptions
*
* \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions
Expand Down Expand Up @@ -4665,7 +4671,32 @@ struct OrtApi {
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array,
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths,
size_t num_external_initializer_files);
};

ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out);

ORT_API2_STATUS(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out);
Copy link
Contributor

@skottmckay skottmckay Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document all functions including 'since version ...' info. It's easier to review if the intent of the API is documented.

nit: Does 'device' need to be in the function name twice? #Resolved


ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out);

ORT_API2_STATUS(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out);

ORT_CLASS_RELEASE(Device);

ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name);

ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env,
Copy link
Contributor

@skottmckay skottmckay Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would SessionOptionsAppendPluginExecutionProvider be slightly clearer? #Resolved

_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints);
Copy link
Contributor

@skottmckay skottmckay Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: slightly more readable if this is after the type constraint functions given it takes OrtTypeConstaints as an input. #Resolved


ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints);

ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type);

ORT_CLASS_RELEASE(TypeConstraints);

const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION;
}; // struct OrtApi

/*
* Steps to use a custom op:
Expand Down Expand Up @@ -4693,6 +4724,13 @@ typedef enum OrtCustomOpInputOutputCharacteristic {
* the implementor of the custom op.
*/
struct OrtCustomOp {
#ifdef __cplusplus
// TODO(leca): initialize all member function pointers to nullptr?
OrtCustomOp() : CreateKernel{nullptr}, KernelCompute{nullptr}, KernelDestroy{nullptr}, GetInputCharacteristic{nullptr},
GetOutputCharacteristic{nullptr}, GetVariadicInputMinArity{nullptr}, GetVariadicOutputMinArity{nullptr},
GetStartVersion{nullptr}, GetEndVersion{nullptr}, GetMayInplace{nullptr}, ReleaseMayInplace{nullptr},
GetAliasMap{nullptr}, ReleaseAliasMap{nullptr} {}
#endif
Copy link
Contributor

@skottmckay skottmckay Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What issue is this trying to address? It seems inconsistent to not initialize these values if __cplusplus is not defined. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the kernel based plugin EP which leverages OrtCustomOp to register the compute function, the member variables of OrtCustomOp will be randomly assigned if we don't use this way to do the initialization.

Please check kernel_ep.cc and MemcpyFromHost in samples/tensorRTEp/tensorrt_execution_provider.cc for the detail.

uint32_t version; // Must be initialized to ORT_API_VERSION

// This callback creates the kernel, which is a user defined
Expand Down
185 changes: 185 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api_ep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
#include "onnxruntime_c_api.h"

ORT_RUNTIME_CLASS(ExecutionProvider);
ORT_RUNTIME_CLASS(ExecutionProviderFactory);
ORT_RUNTIME_CLASS(Node);
ORT_RUNTIME_CLASS(Graph);
ORT_RUNTIME_CLASS(GraphViewer);

typedef struct OrtCreateStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of many things here feel like they have lost a lot of context. They're fine if you're converting an existing internal EP, but if someone was to try and write a new OOT EP I think it would be quite confusing.

Things like 'stream', 'meta def' and 'indexed sub graph' don't (at least to me) have a clear intuitive meaning.

Would it be better to view this from the perspective of an external EP author instead of always being 1:1 with internal names?

int device_type;
void*(ORT_API_CALL* CreateStreamFunc)(const OrtDevice*);
} OrtCreateStream;

typedef struct OrtMetaDef {
char* name;
char* domain;
int since_version;

char** inputs;
size_t input_len;
char** outputs;
size_t output_len;
char** constant_initializers;
size_t initializer_len;

char* doc_string;
} OrtMetaDef;

typedef struct OrtIndexedSubGraph {
OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer?
size_t* node_index;
size_t node_index_len;
} OrtIndexedSubGraph;

typedef struct OrtComputeContext {
void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t);
void(ORT_API_CALL* DestroyFunc)(void*, void*);
void* allocator_handle;
const char* node_name;
} OrtComputeContext;

typedef struct OrtNodeComputeInfo {
int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void*, void**);
OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, void*, const OrtApi*, OrtKernelContext*);
void(ORT_API_CALL* DestroyFunctionStateFunc)(void*);
} OrtNodeComputeInfo;

typedef struct OrtTensorRef { // TODO(leca): OrtValueInfoRef inside OrtTensorRef?
int64_t* shape;
size_t shape_len;
ONNXTensorElementDataType data_type;
const char* data;
size_t data_len;
} OrtTensorRef;

typedef struct OrtValueInfoRef {
int64_t* shape;
size_t shape_len;
ONNXTensorElementDataType data_type;
} OrtValueInfoRef;

typedef struct OrtExecutionProvider {
#ifdef __cplusplus
OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, CreatePreferredAllocators{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr},
extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {}
#endif
void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***);
OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo* node_compute_info);
void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry);
bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target);
OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream);
int(ORT_API_CALL* CreatePreferredAllocators)(OrtExecutionProvider* this_, OrtAllocator*** ort_allocators);
const char* type;
OrtCreateStream* create_stream;
const OrtDevice* default_device;
void* extra_param_for_create_state_func;
void* extra_param_for_compute_func;
Comment on lines +81 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may be able to be replaced with a single void* state;. See https://github.com/microsoft/onnxruntime/pull/21450/files#r1868553280

} OrtExecutionProvider;

typedef struct OrtExecutionProviderFactory {
OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size);
} OrtExecutionProviderFactory;

struct OrtGraphApi {
const char*(ORT_API_CALL* OrtGraph_GetName)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

bool(ORT_API_CALL* OrtGraph_IsConstantInitializer)(const OrtGraphViewer* graph, const char* name, bool check_outer_scope)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

size_t(ORT_API_CALL* OrtGraph_GetNodesIndexInTopologicalOrder)(const OrtGraphViewer* graph, int execution_order, _Out_ const size_t** nodes_index_in_topological_order);

bool(ORT_API_CALL* OrtGraph_IsSubgraph)(const OrtGraph* graph);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to expose OrtGraph? Can we add GetParentGraph in GridViewer so that we don't need expose OrtGraph?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we didn't. Having Graph and GraphViewer is going to be confusing to an EP author.


const OrtGraph*(ORT_API_CALL* OrtGraph_GetParentGraph)(const OrtGraph* graph);

const OrtNode*(ORT_API_CALL* OrtGraph_GetParenNode)(const OrtGraphViewer* graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const OrtNode*(ORT_API_CALL* OrtGraph_GetParenNode)(const OrtGraphViewer* graph);
const OrtNode*(ORT_API_CALL* OrtGraph_GetParentNode)(const OrtGraphViewer* graph);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't get it, what are you suggesting to change? split one line to 3 lines?

The signature of the C APIs has been changed, to return OrtStatus* in order to make it consistent with the existing C APIs


const void*(ORT_API_CALL* OrtGraph_GetModelPath)(const OrtGraphViewer* graph);

const OrtGraph*(ORT_API_CALL* OrtGraph_GetOrtGraph)(const OrtGraphViewer* graph_viewer);

size_t(ORT_API_CALL* OrtGraph_GetInputsIncludingInitializers)(const OrtGraphViewer* graph, _Outptr_ const char*** input_names);
Copy link
Contributor

@skottmckay skottmckay Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe simplify? ONNX IR versions less than 4 had all initializers as part of the graph inputs so there was complexity/confusion around whether getting the graph inputs included the initializers or not, and when an initializer was in the list of graph inputs whether it was allowed to be overridden. That was a long time ago now.

Would it be easier if the API was in terms of 'required inputs' (input with no backing initializer) and 'optional inputs' (input with backing initializer that can be overridden - i.e. a non-const initializer). 'all' returns the combined set of required an optional.

#Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the following APIs:

OrtGraph_GetRequiredInputs: Gets the Graph inputs with no matching initializers

OrtGraph_GetAllInputs: Gets the Graph inputs with matching initializers

OrtGraph_GetAllInitializers: Gets all the Graph initializers' name

and delete this one.

Since the graph viewer class does not expose graph_overridable_initializers_ ('optional inputs'), The caller can invoke 'GetAllInputs' and 'GetRequiredInputs' and do the minus themselves

Please review the implementation to check if I understand correctly


const OrtNode*(ORT_API_CALL* OrtGraph_GetOrtNode)(const OrtGraphViewer* graph, size_t node_index);

size_t(ORT_API_CALL* OrtGraph_GetNodesConsumingInput)(const OrtGraphViewer* graph, const char* input_name, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ?

const OrtNode*(ORT_API_CALL* OrtGraph_GetNodeProducingOutput)(const OrtGraphViewer* graph, const char* output_name);

int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

int(ORT_API_CALL* OrtGraph_MaxNodeIndex)(const OrtGraphViewer* graph);

size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

bool(ORT_API_CALL* OrtGraph_GetInitializerTensor)(const OrtGraphViewer* graph, const char* initializer_name, _Outptr_ OrtTensorRef**);

bool(ORT_API_CALL* OrtGraph_GetValueInfo)(const OrtGraphViewer* graph, const char* name, _Outptr_ OrtValueInfoRef**);

size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; // TODO(leca): review and discuss

ORT_API2_STATUS(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int node_num, const size_t* node_indices, _Outptr_ const OrtGraphViewer** subgraph); // TODO(yang): review and discuss
Copy link
Contributor

@chilo-ms chilo-ms Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add an additional graph api or feature to check whether the input of the subgraph has shape info?

For TRT EP, TRT requires onnx model to have input shape provided, so that's why TRT EP checks the input shape of the subgraph, if not it suggests user to run symbolic_shape_infer.py first. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this is via more basic functions. An API to get the graph input names, and an API to get the NodeArg entry (in some more easily digested format) for that value.

The checking whether it has shape info or not seem EP specific and not something the general API needs to directly support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have that already: OrtGraph_GetInputsIncludingInitializers (naming is subject to change per Scott's comment) and OrtGraph_GetValueInfo


const char*(ORT_API_CALL* OrtNode_GetName)(const OrtNode* node);

const char*(ORT_API_CALL* OrtNode_GetDescription)(const OrtNode* node);

const char*(ORT_API_CALL* OrtNode_GetDomain)(const OrtNode* node);

int(ORT_API_CALL* OrtNode_SinceVersion)(const OrtNode* node);

const char*(ORT_API_CALL* OrtNode_GetExecutionProviderType)(const OrtNode* node);

const char*(ORT_API_CALL* OrtNode_GetOpType)(const OrtNode* node);

size_t(ORT_API_CALL* OrtNode_GetImplicitInputSize)(const OrtNode* node);

const char*(ORT_API_CALL* OrtNode_GetIthImplicitInputName)(const OrtNode* node, size_t i);

size_t(ORT_API_CALL* OrtNode_GetInputSize)(const OrtNode* node);
Copy link
Contributor

@skottmckay skottmckay Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly unclear if these are returning the size of a specific input/output.

Would GetNumImplicitInputs/GetNumInputs/GetNumOutputs be better? #Resolved


const char*(ORT_API_CALL* OrtNode_GetIthInputName)(const OrtNode* node, size_t i);

size_t(ORT_API_CALL* OrtNode_GetOutputSize)(const OrtNode* node);

const char*(ORT_API_CALL* OrtNode_GetIthOutputName)(const OrtNode* node, size_t i);

size_t(ORT_API_CALL* OrtNode_GetIndex)(const OrtNode* node);

size_t(ORT_API_CALL* OrtNode_GetAttributeNames)(const OrtNode*, _Out_ const char*** names);

size_t(ORT_API_CALL* OrtNode_GetAttributeSize)(const OrtNode* node);

int(ORT_API_CALL* OrtNode_GetAttributeType)(const OrtNode* node, const char* attribute)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; // AttributeProto_AttributeType

size_t(ORT_API_CALL* OrtNode_GetAttributeKeyCount)(const OrtNode* node, const char* key);

int(ORT_API_CALL* OrtNode_GetAttributeIntSize)(const OrtNode* node, const char* key);

int(ORT_API_CALL* OrtNode_GetAttributeFloatSize)(const OrtNode* node, const char* key);

int(ORT_API_CALL* OrtNode_GetAttributeStringSize)(const OrtNode* node, const char* key);

int64_t(ORT_API_CALL* OrtNode_GetAttributeIthInt)(const OrtNode* node, const char* key, int i);

float(ORT_API_CALL* OrtNode_GetAttributeIthFloat)(const OrtNode* node, const char* key, int i);

const char*(ORT_API_CALL* OrtNode_GetAttributeIthStr)(const OrtNode* node, const char* key, int i);

const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

float(ORT_API_CALL* OrtNode_GetAttributeFloat)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

size_t(ORT_API_CALL* OrtNode_GetSubgraphs)(const OrtNode* node, _Outptr_ const OrtGraphViewer*** subgraphs);
};
typedef struct OrtGraphApi OrtGraphApi;
14 changes: 14 additions & 0 deletions onnxruntime/core/framework/ort_type_constraints.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#include "core/framework/ort_type_constraints.h"

bool OrtTypeConstraints::AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type) {
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>::iterator iter = type_constraints_.find(type_symbol);
if (iter == type_constraints_.end()) {
std::set<ONNXTensorElementDataType> types{type};
type_constraints_[type_symbol] = types;
return true;
}
return (iter->second).insert(type).second;
}
Loading
Loading