From 9daf7664fc2436b7a2ffb6334c5fce9647f53ae5 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 1 Nov 2024 08:37:56 +0800 Subject: [PATCH] [CoreML] ML Program more ops (2/N) (#22480) - cast - argmax - gelu - cast - LayerNorm - GroupNorm - InstanceNorm ### Description ### Motivation and Context --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Scott McKay Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../coreml/coreml_provider_factory.h | 2 +- .../builders/impl/activation_op_builder.cc | 78 +++-- .../coreml/builders/impl/argmax_op_builder.cc | 83 ++++-- .../coreml/builders/impl/base_op_builder.cc | 9 +- .../builders/impl/batch_norm_op_builder.cc | 71 ++++- .../coreml/builders/impl/cast_op_builder.cc | 96 +++++- .../builders/impl/normalization_op_builder.cc | 277 ++++++++++++++++++ .../coreml/builders/model_builder.cc | 43 +++ .../providers/coreml/builders/model_builder.h | 6 + .../coreml/builders/op_builder_factory.cc | 8 +- .../coreml/builders/op_builder_factory.h | 1 + .../core/providers/coreml/model/host_utils.h | 4 + .../core/providers/coreml/model/host_utils.mm | 2 + .../test/common/tensor_op_test_utils.h | 18 ++ .../test/contrib_ops/layer_norm_op_test.cc | 77 +++-- .../providers/coreml/coreml_basic_test.cc | 9 + .../cpu/activation/activation_op_test.h | 5 + .../providers/cpu/nn/batch_norm_op_test.cc | 26 +- .../cpu/nn/conv_transpose_op_test.cc | 12 +- .../providers/cpu/nn/group_norm_op_test.cc | 144 +++++++++ .../providers/cpu/nn/instance_norm_op_test.cc | 132 +++++---- .../cpu/reduction/reduction_ops_test.cc | 23 +- .../providers/cpu/tensor/concat_op_test.cc | 12 +- .../providers/cpu/tensor/slice_op.test.cc | 17 +- .../providers/cpu/tensor/split_op_test.cc | 12 +- .../providers/cpu/tensor/transpose_test.cc | 11 - .../apple/coreml_supported_mlprogram_ops.md | 7 + 27 files changed, 951 insertions(+), 234 deletions(-) create mode 100644 onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc create mode 100644 onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 7a6ba3afddce7..98fa9e09f1ba8 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -31,10 +31,10 @@ enum COREMLFlags { // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. COREML_FLAG_CREATE_MLPROGRAM = 0x010, - // Exclude ANE as sometimes this decrease performance // https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc // there are four compute units: // MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll + // different CU will have different performance and power consumption COREML_FLAG_USE_CPU_AND_GPU = 0x020, // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 5389eb5ab7e95..4481a5172966b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -40,6 +40,25 @@ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, con } namespace { + +template +void HandlePReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger, + std::vector& alpha_values) { + // add slope initializer as alpha weight + const auto& slope_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name()); + Initializer unpacked_tensor(slope_tensor); + const auto alpha_v = unpacked_tensor.DataAsSpan(); + + if (alpha_v.size() == 1) { + // expand to number of channels + std::vector x_shape; + GetShape(*node.InputDefs()[0], x_shape, logger); + alpha_values.resize(x_shape[x_shape.size() - 3], alpha_v[0]); + } else { + alpha_values.assign(alpha_v.begin(), alpha_v.end()); + } +} + Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger, COREML_SPEC::ActivationPReLU& prelu) { @@ -84,6 +103,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation std::string_view coreml_op_type; bool add_alpha = false; + bool add_gelu_mode = false; if (op_type == "Sigmoid") { coreml_op_type = "sigmoid"; } else if (op_type == "Tanh") { @@ -93,6 +113,12 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } else if (op_type == "LeakyRelu") { coreml_op_type = "leaky_relu"; add_alpha = true; + } else if (op_type == "Gelu") { + coreml_op_type = "gelu"; + add_gelu_mode = true; + } else if (op_type == "PRelu") { + coreml_op_type = "prelu"; + add_alpha = true; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -102,16 +128,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); if (add_alpha) { - NodeAttrHelper helper(node); - const auto alpha = helper.Get("alpha", 0.01f); - auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + + if ("PRelu" == op_type) { + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector alpha_values; + HandlePReluWeight(model_builder, node, logger, alpha_values); + AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values)); + } else { + std::vector alpha_values; + HandlePReluWeight(model_builder, node, logger, alpha_values); + AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values)); + } } else { - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.01f); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } else { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + } } } + if (add_gelu_mode) { + NodeAttrHelper helper(node); + std::string approximate = helper.Get("approximate", std::string("none")); + if (approximate == "tanh") { + approximate = "TANH_APPROXIMATION"; + } else if (approximate == "none") { + approximate = "EXACT"; + } + AddOperationInput(*op, "mode", model_builder.AddScalarConstant(op->type(), "mode", std::string(approximate))); + } AddOperationOutput(*op, *node.OutputDefs()[0]); @@ -213,17 +262,11 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp const logging::Logger& logger) const { const auto& op_type = node.OpType(); -#if defined(COREML_ENABLE_MLPROGRAM) - if (input_params.create_mlprogram) { - if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable - return false; - } - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { - if (op_type == "PRelu") { - return IsPReluOpSupported(node, input_params, logger); - } + if (op_type == "Gelu" && !input_params.create_mlprogram) { + return false; + } + if (op_type == "PRelu") { + return IsPReluOpSupported(node, input_params, logger); } return true; @@ -245,6 +288,7 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration "Relu", "PRelu", "LeakyRelu", + "Gelu", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index bc8b2d1a3505d..6169090a36014 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -15,6 +16,9 @@ class ArgMaxOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + public: + bool SupportsMLProgram() const override { return true; } }; Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -24,41 +28,60 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); - const auto axis = helper.Get("axis", 0); - const auto keepdims = helper.Get("keepdims", 1); + const int64_t axis = helper.Get("axis", 0); + const int64_t keepdims = helper.Get("keepdims", 1); const bool removedim = keepdims != 1; - auto* coreml_argmax = layer->mutable_argmax(); - coreml_argmax->set_axis(axis); - coreml_argmax->set_removedim(removedim); - - // There are two cases here: - // 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input - // (We still have this special case here because CoreML model does not have Cast) - // 2. Otherwise, we add Argmax layer normally - if (node.GetOutputEdgesCount() == 1) { - auto it = node.OutputEdgesBegin(); - const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); - // If Argmax's successive node is a Cast from int64 to int32 output - // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) - // so we omit the check here - if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { - // Skip the cast's input/argmax's output - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); - return Status::OK(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction + + std::unique_ptr op = model_builder.CreateOperation(node, "reduce_argmax"); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims))); + + int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; + // the output of ArgMax must be int32 + AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + auto* coreml_argmax = layer->mutable_argmax(); + coreml_argmax->set_axis(axis); + coreml_argmax->set_removedim(removedim); + + // There are two cases here: + // 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input + // (We still have this special case here because CoreML model does not have Cast) + // 2. Otherwise, we add Argmax layer normally + if (node.GetOutputEdgesCount() == 1) { + auto it = node.OutputEdgesBegin(); + const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); + // If Argmax's successive node is a Cast from int64 to int32 output + // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) + // so we omit the check here + if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { + // Skip the cast's input/argmax's output + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + return Status::OK(); + } } - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } -bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // Attribute `select_last_index` of ArgMax op is not supported NodeAttrHelper helper(node); @@ -68,6 +91,12 @@ bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + return true; + } +#endif + // If there are multiple downstream nodes and cast (toint32) is one of them // not supported, exit here // Otherwise, for general multiple downstream nodes, supported diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index f185a80de3cbf..70002b6295f5a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -16,11 +16,10 @@ namespace coreml { // Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to // filter suppported ones. static std::set Float16Ops = { - "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", - "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", - "Clip", "DepthToSpace", "Resize", "Slice", "Conv", - "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", - "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; + "Add", "ArgMax", "AveragePool", "BatchNormalization", "Cast", "Clip", "Concat", "Conv", "ConvTranspose", + "DepthToSpace", "Div", "Gelu", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "GridSample", "GroupNormalization", + "InstanceNormalization", "LayerNormalization", "LeakyRelu", "MatMul", "MaxPool", "Mul", "PRelu", "Pow", + "Reciprocal", "Relu", "Reshape", "Resize", "Sigmoid", "Slice", "Split", "Sqrt", "Sub", "Tanh", "Transpose"}; namespace { // TODO, move this to shared_library diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 8da58f659acf1..cc68fa6ec399a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -10,6 +10,10 @@ #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -24,6 +28,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { // BatchNormalization opset 6- has unsupported attributes int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } + + public: + bool SupportsMLProgram() const override { return true; } }; void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -50,21 +57,46 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu const auto eps = helper.Get("epsilon", 1e-5f); const auto channels = scale_tensor.dims()[0]; - auto* coreml_batch_norm = layer->mutable_batchnorm(); - coreml_batch_norm->set_channels(channels); - coreml_batch_norm->set_epsilon(eps); - coreml_batch_norm->set_computemeanvar(false); - coreml_batch_norm->set_instancenormalization(false); - - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var - - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm + + std::unique_ptr op = model_builder.CreateOperation(node, "batch_norm"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + AddOperationInput(*op, "mean", model_builder.AddConstant(op->type(), input_defs[3]->Name() + "mean", mean_tensor)); + AddOperationInput(*op, "variance", model_builder.AddConstant(op->type(), input_defs[4]->Name() + "variance", var_tensor)); + AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name(), scale_tensor)); + AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name(), bias_tensor)); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + auto* coreml_batch_norm = layer->mutable_batchnorm(); + coreml_batch_norm->set_channels(channels); + coreml_batch_norm->set_epsilon(eps); + coreml_batch_norm->set_computemeanvar(false); + coreml_batch_norm->set_instancenormalization(false); + + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } @@ -119,6 +151,15 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu return false; } +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) + // To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) { + LOGS(logger, VERBOSE) << "float16 input is not supported on the iOS x86_64 simulator" + << " due to CoreML producing invalid output."; + return false; + } +#endif return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index fc8879abbefb0..7c7363d4c81ad 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -18,14 +19,62 @@ class CastOpBuilder : public BaseOpBuilder { bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + public: + bool SupportsMLProgram() const override { return true; } }; -Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, - const Node& /* node */, - const logging::Logger& /* logger */) const { - // This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. - // The ArgMax is fused with the Cast node and produces an int32 output. - // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. +Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { +// This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. +// The ArgMax is fused with the Cast node and produces an int32 output. +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.cast + + NodeAttrHelper helper(node); + auto cast_to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto::UNDEFINED); + std::string to_dtype = ""; + if (cast_to_type == ONNX_NAMESPACE::TensorProto::INT32 || cast_to_type == ONNX_NAMESPACE::TensorProto::INT64) { + to_dtype = "int32"; + // CoreML doesn't support int64, while ONNX uses int64 for indices and as well as data values. + // We convert the data inputs/outputs between int64 and int32 when calling onnxruntime::coreml::Model::Predict, + // and when adding int64 initializers to the CoreML model. + // CoreML operators can only produce int32 and not int64 values. + // Due to that there should be no actual int64 values inside the CoreML model and we can infer any + // ONNX_NAMESPACE::TensorProto::INT64 values to be int32. + cast_to_type = ONNX_NAMESPACE::TensorProto::INT32; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT) { + to_dtype = "fp32"; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT16) { + to_dtype = "fp16"; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::BOOL) { + to_dtype = "bool"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported cast type: ", cast_to_type); + } + + std::string_view op_type = "cast"; + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (((input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT32) && + to_dtype == "int32") || + cast_to_type == input_dtype) { + op_type = "identity"; + } + + std::unique_ptr op = model_builder.CreateOperation(node, op_type); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + if (op_type == "cast") { + AddOperationInput(*op, "dtype", model_builder.AddScalarConstant(op->type(), "dtype", std::string(to_dtype))); + } + AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type); + model_builder.AddOperation(std::move(op)); + } +#endif + return Status::OK(); } @@ -36,6 +85,10 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } + if (input_params.create_mlprogram) { + return true; + } + const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax @@ -67,14 +120,39 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; + const auto& output = *node.OutputDefs()[0]; - int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t input_type, output_type; + if (!GetType(input, input_type, logger)) { return false; + } + if (!GetType(output, output_type, logger)) { + return false; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) && + (output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + return true; + } else { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; + return false; + } + } +#endif // only support int64 coming from ArgMax (check for ArgMax is done in IsOpSupportedImpl()) if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc new file mode 100644 index 0000000000000..b4dc8d1647ad0 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/optimizer/initializer.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" +#include + +namespace onnxruntime { +namespace coreml { + +class NormalizationOpBuilder : public BaseOpBuilder { + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; + Status AddGroupNormToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const; + + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& /* node */) const override { return 1; } + + public: + bool SupportsMLProgram() const override { return true; } +}; + +void NormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // skip everything except input0 for Normalization + const auto& input_defs = node.InputDefs(); + model_builder.AddInitializerToSkip(input_defs[1]->Name()); // scale + if (input_defs.size() > 2) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); // B + } +} + +Status NormalizationOpBuilder::AddToModelBuilderImpl( + [[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { + if (node.OpType() == "GroupNormalization") { + return AddGroupNormToModelBuilderImpl(model_builder, node, logger); + } +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + + const auto eps = helper.Get("epsilon", 1e-5f); + + std::vector input_shape; + // GetShape will never fail as we have already verified the input shape in IsOpSupportedImpl + GetShape(*input_defs[0], input_shape, logger); + + const auto rank = input_shape.size(); + auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + + std::vector axes(rank - axis); + std::iota(axes.begin(), axes.end(), axis); + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + std::string_view op_name = (node.OpType() == "InstanceNormalization") ? "instance_norm" : "layer_norm"; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm + + std::unique_ptr op = model_builder.CreateOperation(node, op_name); + AddOperationInput(*op, "x", layer_input_name_x); + if (op_name == "layer_norm") { + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), input_defs[0]->Name() + "axes", axes)); + } + AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name() + "gamma", scale_tensor)); + if (input_defs.size() > 2) { + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name() + "beta", bias_tensor)); + } + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } +#endif // (COREML_ENABLE_MLPROGRAM) + + return Status::OK(); +} + +Status NormalizationOpBuilder::AddGroupNormToModelBuilderImpl( + [[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + // Coreml hasn't supported GroupNorm yet. + // we decompose GroupNorm to sub ops and levrage LayerNorm to implement GroupNorm. + // groupnorm --> reshape [b, num_groups, c // (num_groups), h, w] --> layer_norm --> reshape [b, c, h, w]->mul(scale)->add(bias) + + // scale and bias is required for group-norm by the onnx spec + const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + + const auto eps = helper.Get("epsilon", 1e-5f); + int64_t num_groups = helper.Get("num_groups", 1); // GroupNorm + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + + const auto input_size = input_shape.size(); + int64_t axis = 2; + std::vector axes(input_size + 1 - axis); // Group add one more dim + std::iota(axes.begin(), axes.end(), axis); + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int64_t channel_dims = input_shape[1]; + + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + const int32_t elem_type = static_cast(input_dtype); + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm + // https://github.com/apple/coremltools/blob/9827d424b3c5b5fbb6ddc8891a000d87a188c84f/coremltools/converters/mil/frontend/torch/ops.py#L1354 + // reshape to [b, num_groups, c // (num_groups), h, w] + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + std::vector shape1 = input_shape; + shape1.insert(shape1.begin() + 1, num_groups); + shape1[2] = input_shape[1] / num_groups; + std::vector shape_scale_bias(input_shape.size(), 1); + shape_scale_bias[1] = channel_dims; + AddOperationInput(*reshape1, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape1", shape1)); + layer_input_name_x = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*reshape1, layer_input_name_x, elem_type, shape1); + + std::unique_ptr layer_norm = model_builder.CreateOperation(node, "layer_norm"); + AddOperationInput(*layer_norm, "x", layer_input_name_x); + AddOperationInput(*layer_norm, "axes", model_builder.AddConstant(layer_norm->type(), "axes", axes)); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", eps)); + } + + const auto& ln_output_name = model_builder.GetUniqueName(node, "ln_output_"); + AddIntermediateOperationOutput(*layer_norm, ln_output_name, elem_type, shape1); + + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + AddOperationInput(*reshape2, "x", ln_output_name); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape2", input_shape)); + + const auto& reshape2_output_name = model_builder.GetUniqueName(node, "gn_reshape_output_"); + AddIntermediateOperationOutput(*reshape2, reshape2_output_name, elem_type, input_shape); + + auto mul = model_builder.CreateOperation(node, "mul", "post_mul"); + AddOperationInput(*mul, "x", reshape2_output_name); + AddOperationInput(*mul, "y", model_builder.AddConstant(mul->type(), "mul1", scale_tensor, shape_scale_bias)); + const auto& mul_output_name = model_builder.GetUniqueName(node, "mul_output_"); + AddIntermediateOperationOutput(*mul, mul_output_name, elem_type, input_shape); + + auto add = model_builder.CreateOperation(node, "add", "post_add"); + AddOperationInput(*add, "x", mul_output_name); + AddOperationInput(*add, "y", model_builder.AddConstant(add->type(), "add1", bias_tensor, shape_scale_bias)); + AddOperationOutput(*add, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(reshape1)); + model_builder.AddOperation(std::move(layer_norm)); + model_builder.AddOperation(std::move(reshape2)); + model_builder.AddOperation(std::move(mul)); + model_builder.AddOperation(std::move(add)); + } +#endif // (COREML_ENABLE_MLPROGRAM) + return Status::OK(); +} + +bool NormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // LayerNormalization may have three output in the training mode, but we only support the inference mode + // for InstanceNormalization and GroupNormalization, they only have one output, so this check will always return true + if (node.OutputDefs().size() != 1) { + LOGS(logger, VERBOSE) << "Your onnx model (with LayerNormalization) may be in training mode," + << " please export it for inferencing."; + return false; + } + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + + // groupnorm and layernorm has attribute "stash_type", while InstanceNormalization doesn't have this attribute + // Type of Mean and InvStdDev. This also specifies stage one’s computation precision. + // if stash_type is 1, this operator casts all input variables to 32-bit float, + // perform the computation, and finally cast Normalized back to the original type of X + // coreml didn't have a similiar attribute to stash_type, for now, we support the default value + if (node.OpType() != "InstanceNormalization") { + NodeAttrHelper helper(node); + const auto stash_type = helper.Get("stash_type", 1); + if (stash_type != 1) { + LOGS(logger, VERBOSE) << "stash_type != 1 is not supported"; + return false; + } + } + + const auto& scale_name = input_defs[1]->Name(); + const auto* scale_tensor = input_params.graph_viewer.GetConstantInitializer(scale_name); + if (!scale_tensor) { + LOGS(logger, VERBOSE) << "Scale must be a constant initializer"; + return false; + } + + if (input_defs.size() > 2) { + const auto& b_name = input_defs[2]->Name(); + const auto& b_tensor = input_params.graph_viewer.GetConstantInitializer(b_name); + if (!b_tensor) { + LOGS(logger, VERBOSE) << "Bias must be a constant initializer"; + return false; + } + } + + return true; +} + +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (!input_params.create_mlprogram) { + return false; + } + // We only check the type of input 0,1,2 + const auto& input_0 = *node.InputDefs()[0]; + const auto& input_1 = *node.InputDefs()[1]; + const auto& input_2 = node.InputDefs().size() > 2 ? *node.InputDefs()[2] : input_0; + int32_t input_type_0, input_type_1, input_type_2; + if (!GetType(input_0, input_type_0, logger)) { + return false; + } + if (!GetType(input_1, input_type_1, logger)) { + return false; + } + if (!GetType(input_2, input_type_2, logger)) { + return false; + } + if (input_type_0 != input_type_1 || input_type_0 != input_type_2) { + LOGS(logger, VERBOSE) << "Input types of LayerNorm must be the same"; + return false; + } + + if (input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + LOGS(logger, VERBOSE) << "Input types of LayerNorm must be float or float16"; + return false; + } + return true; +} + +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 50faebf06875d..f12e4dab5b3ec 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -14,6 +14,7 @@ #include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/optimizer/initializer.h" #if defined(COREML_ENABLE_MLPROGRAM) // includes from coremltools-src in _deps @@ -1003,6 +1004,48 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { return model->LoadModel(); // load using CoreML API, including compilation } +#if defined(COREML_ENABLE_MLPROGRAM) +std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string_view value_type, + const ONNX_NAMESPACE::TensorProto& tensor, + std::optional> shape) { + const auto data_type = tensor.data_type(); + Initializer unpacked_tensor(tensor); + std::string_view ret; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + // case ONNX_NAMESPACE::TensorProto_DataType_INT32: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_INT8: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + default: + ORT_THROW("AddConstant: Unsupported data type: ", data_type); + } + + return ret; +} +#endif // static Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, int32_t coreml_version, uint32_t coreml_flags, diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index b3dfec29872a2..c566dbe160b50 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -129,6 +129,12 @@ class ModelBuilder { return AddConstant(op_type, value_type, gsl::span(value), shape); } + // helper to convert a initializer to a constant + // by default, shape is inferred from the tensor.dims(), but can be provided to override if needed + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, + const ONNX_NAMESPACE::TensorProto& tensor, + std::optional> shape = std::nullopt); + /// /// Add a scalar value as a 'const' operation. See AddConstant for details. /// diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index b0006b24e7d75..4fd0c0577a9b8 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -21,6 +21,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateActivationOpBuilder("Relu", op_registrations); CreateActivationOpBuilder("PRelu", op_registrations); CreateActivationOpBuilder("LeakyRelu", op_registrations); + CreateActivationOpBuilder("Gelu", op_registrations); // Unary ops CreateUnaryOpBuilder("Reciprocal", op_registrations); @@ -43,8 +44,13 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateReductionOpBuilder("ReduceMean", op_registrations); CreateReductionOpBuilder("ReduceSum", op_registrations); - CreateArgMaxOpBuilder("ArgMax", op_registrations); + // Normalization ops CreateBatchNormalizationOpBuilder("BatchNormalization", op_registrations); + CreateNormalizationOpBuilder("GroupNormalization", op_registrations); + CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); + CreateNormalizationOpBuilder("LayerNormalization", op_registrations); + + CreateArgMaxOpBuilder("ArgMax", op_registrations); CreateCastOpBuilder("Cast", op_registrations); CreateClipOpBuilder("Clip", op_registrations); CreateConcatOpBuilder("Concat", op_registrations); diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 1990fb6400ce1..9b51b53d73e9e 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -19,6 +19,7 @@ const std::unordered_map& GetOpBuilders(); void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index a9991ccb945ce..145c64e5320d3 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -26,6 +26,8 @@ // - iOS 16 ops // 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) // - iOS 17 ops +// 9 : iOS 18, macOS 15, tvOS 18, watchOS 11 (Core ML 8) +// - iOS 18 ops // // **NOTE** We use the Core ML version not the spec version. // @@ -39,6 +41,7 @@ #define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) #define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) #define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) +#define API_AVAILABLE_COREML8 API_AVAILABLE(macos(15), ios(18)) // @available is used in implementation code // Base required OS to run CoreML Specification Version 4 (Core ML 3) @@ -47,6 +50,7 @@ #define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) #define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) #define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) +#define HAS_COREML8_OR_LATER @available(macOS 15, iOS 18, *) #endif diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 70052f50ae1c2..4239121a42c97 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -16,6 +16,8 @@ bool HasRequiredBaseOS() { } int32_t CoreMLVersion() { + if (HAS_COREML8_OR_LATER) + return 8; if (HAS_COREML7_OR_LATER) return 7; if (HAS_COREML6_OR_LATER) diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index e0891c7ced63e..acb520f894569 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -194,6 +194,24 @@ inline void CheckTensor(const Tensor& expected_tensor, const Tensor& output_tens } } +template +std::vector GetTypedArray(std::vector inputs) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_integral_v, + "Only float, double, MLFloat16, and integral types are supported."); + if constexpr (std::is_same::value) { + return inputs; + } else if constexpr (std::is_integral_v || std::is_same::value) { + std::vector result(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + result[i] = static_cast(inputs[i]); + } + return result; + } else { + return ToFloat16(inputs); + } +} + class ParallelRandomValueGenerator { public: using RandomEngine = std::default_random_engine; diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index f1e0e99a5fb79..52e67bf0616d1 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -162,7 +162,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput_Initializers) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias) { @@ -211,20 +211,31 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { - OpTester test("LayerNormalization"); - test.AddAttribute("epsilon", 1e-05f); - - std::vector dims{1, 3, 2}; - test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); - test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); - test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); - test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); - // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes - test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), is_initializer); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}), is_initializer); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL, OpenVINO and NNAPI don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); + }; + run_test(false); + run_test(true); } +template +class LayerNormTest : public ::testing::Test { +}; + +using LayerNormTestTypes = ::testing::Types; +TYPED_TEST_SUITE(LayerNormTest, LayerNormTestTypes); + TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -237,19 +248,41 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializer // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider}); } // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. -TEST(LayerNormTest, LayerNorm17_float) { - OpTester test("LayerNormalization", 17); - test.AddAttribute("epsilon", 1e-05f); +TYPED_TEST(LayerNormTest, LayerNorm17_opset) { + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, GetTypedArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, GetTypedArray({1.0f, 1.0f, 1.0f}), is_initializer); + test.AddOutput("output", dims, GetTypedArray({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + if (std::is_same::value) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}, + nullptr, &execution_providers); + } else { + test.Run(); + } + }; + // Execution provider entry invalid. + // when other EPs support layer-norm fp16, this test should be updated to include them. + if (std::is_same::value) { +#if !defined(COREML_ENABLE_MLPROGRAM) + return; +#endif + } - std::vector dims{1, 2, 3}; - test.AddInput("x", dims, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - test.AddInput("gamma", {3}, {1.0f, 1.0f, 1.0f}); - test.AddOutput("output", dims, {-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}); - test.Run(); + run_test(false); + run_test(true); } TEST(LayerNormTest, LayerNorm17_double) { diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index daa24db134114..de647d9e3aa3e 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -127,6 +127,10 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { MakeCoreMLExecutionProvider(), feeds, verification_params); + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); #endif @@ -164,6 +168,11 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { MakeCoreMLExecutionProvider(), feeds, verification_params); + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some); #endif diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 8ca0f6d845a09..59813f433dc41 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -105,7 +105,12 @@ class ActivationOpTest : public ::testing::Test { std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution dist(low, high); +#ifdef COREML_ENABLE_MLPROGRAM + // please check onnxruntime/onnxruntime/core/providers/coreml/builders/helper.cc:81 + std::vector batch_size_list = {1, 2, 4, 9, 100}; +#else std::vector batch_size_list = {1, 2, 4, 9, 100000}; +#endif for (auto batch_size : batch_size_list) { std::vector vec(batch_size); for (size_t i = 0; i != batch_size; ++i) { diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index b0d97410ac9b3..08c4e608aada3 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -704,7 +704,7 @@ TEST(BatchNormTest, NonSpatial_Complicated) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(BatchNormTest, BatchNorm2d_fp16) { vector X{-0.91221f, -0.283559f, 0.937637f, 2.09818f, -0.100199f, -0.608113f, 0.444562f, -1.07505f, 0.940591f, -0.922262f, 0.0931303f, 0.69611f, 1.55187f, 0.159808f, 0.914874f, -1.24856f, -1.98928f, -0.331621f, @@ -765,9 +765,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { -0.0989828f, -0.160014f, 0.362077f, 0.0649763f, -0.371465f, 0.727401f, 0.0320011f}; float epsilon = 1e-05f; - OpTester test("BatchNormalization"); - test.AddAttribute("epsilon", epsilon); - vector input_shape{2, 3, 6, 6}; int input_size = 2 * 3 * 6 * 6; @@ -785,13 +782,20 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { ConvertFloatToMLFloat16(var.data(), f_var.data(), 3); ConvertFloatToMLFloat16(expected_output.data(), f_output.data(), input_size); - test.AddInput("X", input_shape, f_X); - test.AddInput("scale", {3}, f_scale); - test.AddInput("B", {3}, f_B); - test.AddInput("mean", {3}, f_mean); - test.AddInput("var", {3}, f_var); - test.AddOutput("output", input_shape, f_output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + auto run_test = [&](bool is_initializer) { + OpTester test("BatchNormalization"); + test.AddAttribute("epsilon", epsilon); + test.AddInput("X", input_shape, f_X); + test.AddInput("scale", {3}, f_scale, is_initializer); + test.AddInput("B", {3}, f_B, is_initializer); + test.AddInput("mean", {3}, f_mean, is_initializer); + test.AddInput("var", {3}, f_var, is_initializer); + test.AddOutput("output", input_shape, f_output, is_initializer); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + // coreml EP requires initializer + run_test(true); } #endif diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 0ce87fb65898b..83b27f10fe04f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -4,6 +4,7 @@ #include "core/providers/xnnpack/xnnpack_init.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" #include "default_providers.h" using namespace std; @@ -130,17 +131,6 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape diff --git a/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc new file mode 100644 index 0000000000000..ac517193a2c77 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/util/include/default_providers.h" + +#ifdef COREML_ENABLE_MLPROGRAM +using namespace std; +namespace onnxruntime { +namespace test { + +template +class GroupNormalizationOpTest : public ::testing::Test { +}; +using GroupNormalizationOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GroupNormalizationOpTest, GroupNormalizationOpTestTypes); + +// GroupSize = channel_dims to simulate InstanceNorm +// Disable TensorRT on some of the tests because its parser doesn't support weight as input +TYPED_TEST(GroupNormalizationOpTest, Equivalent_InstanceNorm_G_C) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 0.3F); + test.AddAttribute("num_groups", int64_t(3)); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, + 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, + 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; + vector input_dims = {2, 3, 4}; + test.AddInput("X", input_dims, GetTypedArray(input)); + + vector scale = {1.F, 1.F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), true); + + vector B = {0.F, 0.F, 0.F}; + vector B_dims = {3}; + test.AddInput("bias", B_dims, GetTypedArray(B), true); + + // expected output is calculated using torch.nn.GroupNorm(3, 3, eps=0.3) + vector expected_output = {-0.56495477f, 1.48930046f, -1.13334329f, 0.20899761f, + 1.46688162f, -0.98600774f, -0.79911913f, 0.31824524f, + 0.57370438f, 0.42193634f, 0.6525492f, -1.64818992f, + + -0.92380346f, -0.60808484f, 0.04711878f, 1.48476953f, + -0.14644464f, -0.82262872f, -0.66852817f, 1.63760153f, + -1.65898662f, 0.27618144f, 0.64840618f, 0.734399f}; + + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GroupSize = 1 to simulate LayerNorm, (LayerNorm) +// expected output is calculated using torch.nn.GroupNorm(1, 3, eps=1e-5f) +TYPED_TEST(GroupNormalizationOpTest, Equivalent_LayerNorm_G_1) { + auto run_test = [](bool is_initializer) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 1e-5f); + test.AddAttribute("num_groups", int64_t(1)); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, GetTypedArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("scale", {2}, GetTypedArray({1.0f, 1.0f}), is_initializer); + test.AddInput("bias", {2}, GetTypedArray({2.0f, 1.0f}), is_initializer); + test.AddOutput("output", dims, GetTypedArray({0.5361f, 1.1216f, 1.7072f, 1.2928f, 1.8783f, 2.4638f})); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + }; + + run_test(true); +} + +// expected output is calculated using torch.nn.GroupNorm(2, 6, eps=0.3) +TYPED_TEST(GroupNormalizationOpTest, GroupSize_N) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 0.3F); + test.AddAttribute("num_groups", int64_t(2)); + + vector input = {-1.1258f, -1.1524f, -0.2506f, -0.4339f, + 0.8487f, 0.6920f, -0.3160f, -2.1152f, + 0.3223f, -1.2633f, 0.3500f, 0.3081f, + 0.1198f, 1.2377f, 1.1168f, -0.2473f, + -1.3527f, -1.6959f, 0.5667f, 0.7935f, + 0.5988f, -1.5551f, -0.3414f, 1.8530f, + + 0.7502f, -0.5855f, -0.1734f, 0.1835f, + 1.3894f, 1.5863f, 0.9463f, -0.8437f, + -0.6136f, 0.0316f, -0.4927f, 0.2484f, + 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, + 0.5229f, 2.3022f, -1.4689f, -1.5867f}; + vector input_dims = {2, 6, 4}; + test.AddInput("X", input_dims, GetTypedArray(input)); + + vector scale = {1.F, 1.F, 1.F, 1.F, 1.F, 1.F}; + vector scale_dims = {6}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), true); + + vector B = {.0F, .0F, .0F, .0F, .0F, .0F}; + vector B_dims = {6}; + test.AddInput("bias", B_dims, GetTypedArray(B), true); + + vector expected_output = { + -0.7590f, -0.7848f, 0.0914f, -0.0867f, + 1.1595f, 1.0073f, 0.0278f, -1.7203f, + 0.6480f, -0.8926f, 0.6749f, 0.6343f, + 0.0232f, 0.9274f, 0.8296f, -0.2738f, + -1.1679f, -1.4456f, 0.3846f, 0.5681f, + 0.4107f, -1.3317f, -0.3499f, 1.4252f, + + 0.5772f, -0.8298f, -0.3957f, -0.0198f, + 1.2505f, 1.4580f, 0.7838f, -1.1017f, + -0.8594f, -0.1798f, -0.7320f, 0.0486f, + 0.2541f, -0.0377f, 0.4334f, 0.2554f, + -0.2291f, 0.5686f, -0.3962f, -0.0911f, + 0.3282f, 1.9145f, -1.4475f, -1.5525f}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + if constexpr (std::is_same::value) { + test.SetOutputTolerance(1e-4f); + } else { + test.SetOutputTolerance(0.005f); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime +#endif diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index 31f119ec6b0e9..341bb8a4fc957 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -3,71 +3,87 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { -// Disable TensorRT on some of the tests because its parser doesn't support weight as input +template +class InstanceNormalizationOpTest : public ::testing::Test { +}; +using InstanceNormalizationOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(InstanceNormalizationOpTest, InstanceNormalizationOpTestTypes); -TEST(InstanceNormalizationOpTest, InstanceNorm) { - OpTester test("InstanceNormalization"); - test.AddAttribute("epsilon", 0.3F); - - vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, - 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, - 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, - - 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, - 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, - 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; - vector input_dims = {2, 3, 4}; - test.AddInput("input", input_dims, input); - - // vector scale = {2.1F, 0.1F, 1.F}; - vector scale = {1.0F, 1.0F, 1.F}; - vector scale_dims = {3}; - test.AddInput("scale", scale_dims, scale); - - // vector B = {2.3F, 1.5F, 0.F}; - vector B = {0.0F, 0.0F, 0.F}; - vector B_dims = {3}; - test.AddInput("B", B_dims, B); - - vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, - 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, - 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, +// Disable TensorRT on some of the tests because its parser doesn't support weight as input - -0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F, - -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, - -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; - test.AddOutput("Y", input_dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +TYPED_TEST(InstanceNormalizationOpTest, InstanceNorm) { + auto run_test = [](bool is_initializer) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, + 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, + 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; + vector input_dims = {2, 3, 4}; + test.AddInput("input", input_dims, GetTypedArray(input)); + + // vector scale = {2.1F, 0.1F, 1.F}; + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), is_initializer); + + // vector B = {2.3F, 1.5F, 0.F}; + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, GetTypedArray(B), is_initializer); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, + + -0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F, + -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, + -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + run_test(true); } -TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { - OpTester test("InstanceNormalization"); - test.AddAttribute("epsilon", 0.3F); - - vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, - 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, - 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; - vector input_dims = {1, 3, 4}; - test.AddInput("input", input_dims, input); - - vector scale = {1.0F, 1.0F, 1.F}; - vector scale_dims = {3}; - test.AddInput("scale", scale_dims, scale); - - vector B = {0.0F, 0.0F, 0.F}; - vector B_dims = {3}; - test.AddInput("B", B_dims, B); - - vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, - 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, - 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; - test.AddOutput("Y", input_dims, expected_output); - - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +TYPED_TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { + auto run_test = [](bool is_initializer) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; + vector input_dims = {1, 3, 4}; + test.AddInput("input", input_dims, GetTypedArray(input)); + + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), is_initializer); + + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, GetTypedArray(B), is_initializer); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + run_test(true); } TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { @@ -105,7 +121,7 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) { OpTester test("InstanceNormalization"); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0968bc32e0de4..bb6d732fccb8f 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3175,19 +3175,26 @@ TEST(ReductionOpTest, ReduceProd0DTensor) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(ReductionOpTest, ArgMax) { +template +class ReductionOpTest : public ::testing::Test { +}; + +using ReductionOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ReductionOpTest, ReductionOpTestTypes); + +TYPED_TEST(ReductionOpTest, ArgMax) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); test.AddAttribute("keepdims", (int64_t)1); - test.AddInput("data", {3, 2, 2}, - {1.0f, 2.0f, - 3.0f, 4.0f, + test.AddInput("data", {3, 2, 2}, + GetTypedArray({1.0f, 2.0f, + 3.0f, 4.0f, - 5.0f, 6.0f, - 7.0f, 8.0f, + 5.0f, 6.0f, + 7.0f, 8.0f, - 9.0f, 10.0f, - 11.0f, 12.0f}); + 9.0f, 10.0f, + 11.0f, 12.0f})); test.AddOutput("reduced", {3, 1, 2}, {1, 1, 1, 1, diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 4a1888a5ca7d6..9e0fb81cbb0fc 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -75,17 +76,6 @@ TEST(ConcatOpTest, Concat1D_2) { kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TYPED_TEST(ConcatOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index a32d43f296250..2169436255727 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -263,22 +264,6 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - std::vector inputs_T(inputs.size()); - if constexpr (std::is_same::value) { - return inputs; - } else if constexpr (std::is_integral_v) { - for (size_t i = 0; i < inputs.size(); i++) { - inputs_T[i] = static_cast(inputs[i]); - } - return inputs_T; - } else { - ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size()); - return inputs_T; - } -} - template static void TestSlice1DIntData() { // static_assert(std::is_integral_v); diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 48872404f08bd..1c2a86bb808b5 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "core/framework/to_tensor_proto_element_type.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -178,17 +179,6 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } -template -std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TEST(SplitOperatorTest, Axis0UnequalSplitString) { constexpr int64_t axis = 0; std::vector outputs; diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 3b46dc3f5d6a2..73a5bce768a2a 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -69,17 +69,6 @@ void TransposeTest(const std::vector& input_shape, } } -template -std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - // Test 2 dimensional transpose, with no permutation attribute specified TYPED_TEST(TransposeOpTest, TwoDimNoAttr) { std::vector input_shape({2, 3}); diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 0c28b272f7fa3..b269026ea02ac 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -4,7 +4,9 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |Operator|Note| |--------|------| |ai.onnx:Add|| +|ai.onnx:Argmax|| |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| +|ai.onnx:Cast|| |ai.onnx:Clip|| |ai.onnx:Concat|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| @@ -12,14 +14,19 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.| |ai.onnx:Div|| |ai.onnx:Gemm|Input B must be constant.| +|ai.onnx:Gelu|| |ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GridSample|4D input.
'mode' of 'linear' or 'zeros'.
(mode==linear && padding_mode==reflection && align_corners==0) is not supported.| +|ai.onnx:GroupNormalization|| +|ai.onnx:InstanceNormalization|| +|ai.onnx:LayerNormalization|| |ai.onnx:LeakyRelu|| |ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.| |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Mul|| |ai.onnx:Pow|Only supports cases when both inputs are fp32.| +|ai.onnx:PRelu|| |ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| |ai.onnx:Relu|| |ai.onnx:Reshape||