Skip to content

Commit

Permalink
Fix C4459 warning in custom_op_lite.h (microsoft#751)
Browse files Browse the repository at this point in the history
Internal workitem: https://task.ms/aii/29719

Co-authored-by: Xavier Dupré <[email protected]>
  • Loading branch information
skyline75489 and xadupre authored Jun 25, 2024
1 parent 3b275b1 commit 0f1f454
Showing 1 changed file with 73 additions and 74 deletions.
147 changes: 73 additions & 74 deletions include/custom_op/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ namespace Custom {

class OrtKernelContextStorage : public ITensorStorage {
public:
OrtKernelContextStorage(const OrtW::CustomOpApi& api,
OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api.KernelContext_GetInputCount(&ctx);
auto input_count = api_.KernelContext_GetInputCount(&ctx);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
const_value_ = api.KernelContext_GetInput(&ctx, indice);
auto* info = api.GetTensorTypeAndShape(const_value_);
shape_ = api.GetTensorShape(info);
api.ReleaseTensorTypeAndShapeInfo(info);
const_value_ = api_.KernelContext_GetInput(&ctx, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
}
}

Expand Down Expand Up @@ -66,18 +66,18 @@ class OrtKernelContextStorage : public ITensorStorage {
std::optional<std::vector<int64_t>> shape_;
};

static std::string get_mem_type(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input){
static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) {
std::string output = "Cpu";
if (is_input) {
const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice);
const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice);
const OrtMemoryInfo* mem_info = {};
api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
if (mem_info) {
const char* mem_type = nullptr;
api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
if (mem_type) {
output = mem_type;
}
Expand All @@ -88,29 +88,29 @@ static std::string get_mem_type(const OrtW::CustomOpApi& api,

template <typename T>
class OrtTensor : public Tensor<T> {
public:
OrtTensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(api, ctx, indice, is_input)),
mem_type_(get_mem_type(api, ctx, indice, is_input)) {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
}

bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}

private:
private:
std::string mem_type_ = "Cpu";
};

class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
public:
using strings = std::vector<std::string>;
OrtStringTensorStorage(const OrtW::CustomOpApi& api,
OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
Expand Down Expand Up @@ -197,10 +197,10 @@ class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
public:
using strings = std::vector<std::string_view>;
OrtStringViewTensorStorage(const OrtW::CustomOpApi& api,
OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
Expand Down Expand Up @@ -275,57 +275,56 @@ class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view>

// to make the metaprogramming magic happy.
template <>
class OrtTensor<std::string> : public Tensor<std::string>{
public:
OrtTensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(api, ctx, indice, is_input)),
mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
class OrtTensor<std::string> : public Tensor<std::string> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}

bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}

private:
private:
std::string mem_type_ = "Cpu";
};

template <>
class OrtTensor<std::string_view> : public Tensor<std::string_view>{
public:
OrtTensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(api, ctx, indice, is_input)),
mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
class OrtTensor<std::string_view> : public Tensor<std::string_view> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}

bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}

private:
private:
std::string mem_type_ = "Cpu";
};

using TensorPtr = std::unique_ptr<Custom::Arg>;
using TensorPtrs = std::vector<TensorPtr>;


using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
using TensorBasePtrs = std::vector<TensorBasePtr>;

// Represent variadic input or output
struct Variadic : public Arg {
Variadic(const OrtW::CustomOpApi& api,
Variadic(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api, ctx, indice, is_input)) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
#if ORT_API_VERSION < 14
ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
#endif
if (is_input) {
auto input_count = api.KernelContext_GetInputCount(&ctx_);
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
auto* info = api_.GetTensorTypeAndShape(const_value);
Expand All @@ -334,40 +333,40 @@ struct Variadic : public Arg {
TensorBasePtr tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
tensor = std::make_unique<Custom::OrtTensor<bool>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<bool>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
tensor = std::make_unique<Custom::OrtTensor<float>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<float>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
tensor = std::make_unique<Custom::OrtTensor<double>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<double>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_, ctx, ith_input, true);
break;
default:
ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
Expand Down Expand Up @@ -395,7 +394,7 @@ struct Variadic : public Arg {
size_t Size() const {
return tensors_.size();
}

const TensorBasePtr& operator[](size_t indice) const {
return tensors_.at(indice);
}
Expand All @@ -412,11 +411,11 @@ struct Variadic : public Arg {

class OrtGraphKernelContext : public KernelContext {
public:
OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
OrtMemoryInfo* info;
OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_));
api.ReleaseMemoryInfo(info);
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_));
api_.ReleaseMemoryInfo(info);
}

virtual ~OrtGraphKernelContext() {
Expand Down Expand Up @@ -458,31 +457,31 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
public:
static const int cuda_resource_ver = 1;

OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
if (!cuda_stream_) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
}
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
if (!cublas_) {
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
}
void* resource = nullptr;
OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
if (result) {
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
}
memcpy(&device_id_, &resource, sizeof(int));

OrtMemoryInfo* info;
OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
api.ReleaseMemoryInfo(info);
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
api_.ReleaseMemoryInfo(info);

OrtMemoryInfo* cuda_mem_info;
OrtW::ThrowOnError(api, api.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
api.ReleaseMemoryInfo(cuda_mem_info);
OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
api_.ReleaseMemoryInfo(cuda_mem_info);
}

virtual ~OrtGraphCudaKernelContext() {
Expand Down Expand Up @@ -944,7 +943,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {

class OrtAttributeReader {
public:
OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) {
OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) {
}

template <class T>
Expand Down

0 comments on commit 0f1f454

Please sign in to comment.