Skip to content

Commit

Permalink
Add reading of tensor arrays in C API
Browse files Browse the repository at this point in the history
  • Loading branch information
amancini-N committed Nov 13, 2024
1 parent 5f740af commit 864e4f3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 0 deletions.
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4656,6 +4656,21 @@ struct OrtApi {
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array,
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths,
size_t num_external_initializer_files);

/** \brief Get a sequence of ::OrtValue tensors stored as an attribute in the graph node.
*
* Used in the CreateKernel callback of an OrtCustomOp to get a sequence of tensors attribute.
*
* \param[in] info ::OrtKernelInfo instance.
* \param[in] name UTF-8 null-terminated string representing the attribute's name.
* \param[in] allocator Allocator used to allocate the internal tensor state.
* \param[out] out Pointer of memory where the pointers of ::OrtValue tensors should be stored.
* If out is nullptr, the function will just return the number of tensors in out_length.
* \param[out] out_length Number of ::OrtValue tensors stored in out.
*/
ORT_API2_STATUS(KernelInfoGetAttributeArray_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out, _Out_ size_t* out_length);

};

/*
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "core/common/gsl.h"
#include "core/framework/data_types.h"
#include "core/framework/TensorSeq.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/framework/op_kernel_context_internal.h"
Expand Down Expand Up @@ -599,6 +600,61 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernel
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out, _Out_ size_t* out_length) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_kinfo = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);

// Get TensorProto attribute
std::vector<onnx::TensorProto> tensor_protos;
auto status = op_kinfo->GetAttrs<onnx::TensorProto>(name, tensor_protos);
if (!status.IsOK()) {
return onnxruntime::ToOrtStatus(status);
}

// Initialize out_length
*out_length = tensor_protos.size();

// If no pointers are provided, return the length of the array
if (out == nullptr) {
return nullptr;
}

int i = 0;

for (const auto& tensor_proto : tensor_protos) {
// Determine the tensor's size in bytes.
size_t req_size = 0;
status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size);
if (!status.IsOK()) {
return onnxruntime::ToOrtStatus(status);
}

// Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator.
onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto);
const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType();
onnxruntime::AllocatorPtr alloc_ptr = std::make_shared<onnxruntime::IAllocatorImplWrappingOrtAllocator>(allocator);
auto tensorp = std::make_unique<onnxruntime::Tensor>(type, tensor_shape, std::move(alloc_ptr));

// Deserialize TensorProto into pre-allocated, empty Tensor.
status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp);
if (!status.IsOK()) {
return onnxruntime::ToOrtStatus(status);
}

// Initialize OrtValue from Tensor.
auto ml_tensor = onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>();
auto value = std::make_unique<OrtValue>();
value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc());

out[i++] = value.release();
}

return nullptr;

});
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
*out = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetInputCount();
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2730,6 +2730,9 @@ static constexpr OrtApi ort_api_1_to_19 = {
&OrtApis::KernelInfoGetAllocator,
&OrtApis::AddExternalInitializersFromFilesInMemory,
// End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information)

&OrtApis::KernelInfoGetAttributeArray_tensor,
// End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information)
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down Expand Up @@ -2761,6 +2764,7 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size
static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change");
static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2) / sizeof(void*) == 275, "Size of version 17 API cannot change");
static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change");
static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_tensor) / sizeof(void*) == 280, "Size of version 19 API cannot change");

// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.19.0.dev8",
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,4 +523,7 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);

ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);

ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out, _Out_ size_t* out_length);
} // namespace OrtApis

0 comments on commit 864e4f3

Please sign in to comment.