diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b4be501d3f00a..b130b0bdcedb2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4656,6 +4656,21 @@ struct OrtApi { _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + + /** \brief Get a sequence of ::OrtValue tensors stored as an attribute in the graph node. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a sequence of tensors attribute. + * + * \param[in] info ::OrtKernelInfo instance. + * \param[in] name UTF-8 null-terminated string representing the attribute's name. + * \param[in] allocator Allocator used to allocate the internal tensor state. + * \param[out] out Pointer of memory where the pointers of ::OrtValue tensors should be stored. + * If out is nullptr, the function will just return the number of tensors in out_length. + * \param[out] out_length Number of ::OrtValue tensors stored in out. + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out, _Out_ size_t* out_length); + }; /* diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index d0c46142ac060..733327f20841c 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -12,6 +12,7 @@ #include "core/common/gsl.h" #include "core/framework/data_types.h" +#include "core/framework/TensorSeq.h" #include "core/framework/error_code_helper.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/op_kernel_context_internal.h" @@ -599,6 +600,61 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernel }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out, _Out_ size_t* out_length) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_kinfo = reinterpret_cast(info); + + // Get TensorProto attribute + std::vector tensor_protos; + auto status = op_kinfo->GetAttrs(name, tensor_protos); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + + // Initialize out_length + *out_length = tensor_protos.size(); + + // If no pointers are provided, return the length of the array + if (out == nullptr) { + return nullptr; + } + + int i = 0; + + for (const auto& tensor_proto : tensor_protos) { + // Determine the tensor's size in bytes. + size_t req_size = 0; + status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + + // Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator. + onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); + const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); + auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); + + // Deserialize TensorProto into pre-allocated, empty Tensor. + status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + + // Initialize OrtValue from Tensor. + auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); + auto value = std::make_unique(); + value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + + out[i++] = value.release(); + } + + return nullptr; + + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { *out = reinterpret_cast(info)->GetInputCount(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ec0d962d16ee4..ff3dbc1b2da5d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2730,6 +2730,9 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::KernelInfoGetAttributeArray_tensor, + // End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2761,6 +2764,7 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2) / sizeof(void*) == 275, "Size of version 17 API cannot change"); static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change"); +static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_tensor) / sizeof(void*) == 280, "Size of version 19 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.19.0.dev8", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fcae173e6c162..45f268c73e953 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -523,4 +523,7 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); + +ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out, _Out_ size_t* out_length); } // namespace OrtApis