Skip to content

Commit

Permalink
[CoreML] ML Program more ops (2/N) (#22480)
Browse files Browse the repository at this point in the history
- cast 
 - argmax
 - gelu 
 - cast 
 - LayerNorm 
 - GroupNorm 
 - InstanceNorm

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Edward Chen <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 1, 2024
1 parent c7ecc08 commit 9daf766
Show file tree
Hide file tree
Showing 27 changed files with 951 additions and 234 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ enum COREMLFlags {
// Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later.
COREML_FLAG_CREATE_MLPROGRAM = 0x010,

// Exclude ANE as sometimes this decrease performance
// https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc
// there are four compute units:
// MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll
// different CU will have different performance and power consumption
COREML_FLAG_USE_CPU_AND_GPU = 0x020,
// Keep COREML_FLAG_LAST at the end of the enum definition
// And assign the last COREMLFlag to it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, con
}

namespace {

template <typename T>
void HandlePReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger,
std::vector<T>& alpha_values) {
// add slope initializer as alpha weight
const auto& slope_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name());
Initializer unpacked_tensor(slope_tensor);
const auto alpha_v = unpacked_tensor.DataAsSpan<T>();

if (alpha_v.size() == 1) {
// expand to number of channels
std::vector<int64_t> x_shape;
GetShape(*node.InputDefs()[0], x_shape, logger);
alpha_values.resize(x_shape[x_shape.size() - 3], alpha_v[0]);
} else {
alpha_values.assign(alpha_v.begin(), alpha_v.end());
}
}

Status AddPReluWeight(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger,
COREML_SPEC::ActivationPReLU& prelu) {
Expand Down Expand Up @@ -84,6 +103,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation
std::string_view coreml_op_type;
bool add_alpha = false;
bool add_gelu_mode = false;
if (op_type == "Sigmoid") {
coreml_op_type = "sigmoid";
} else if (op_type == "Tanh") {
Expand All @@ -93,6 +113,12 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
} else if (op_type == "LeakyRelu") {
coreml_op_type = "leaky_relu";
add_alpha = true;
} else if (op_type == "Gelu") {
coreml_op_type = "gelu";
add_gelu_mode = true;
} else if (op_type == "PRelu") {
coreml_op_type = "prelu";
add_alpha = true;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
Expand All @@ -102,16 +128,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());

if (add_alpha) {
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);

auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));

if ("PRelu" == op_type) {
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
std::vector<float> alpha_values;
HandlePReluWeight(model_builder, node, logger, alpha_values);
AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values));
} else {
std::vector<MLFloat16> alpha_values;
HandlePReluWeight(model_builder, node, logger, alpha_values);
AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values));
}
} else {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha)));
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);

if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));
} else {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha)));
}
}
}
if (add_gelu_mode) {
NodeAttrHelper helper(node);
std::string approximate = helper.Get("approximate", std::string("none"));
if (approximate == "tanh") {
approximate = "TANH_APPROXIMATION";
} else if (approximate == "none") {
approximate = "EXACT";
}
AddOperationInput(*op, "mode", model_builder.AddScalarConstant(op->type(), "mode", std::string(approximate)));
}

AddOperationOutput(*op, *node.OutputDefs()[0]);

Expand Down Expand Up @@ -213,17 +262,11 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
const logging::Logger& logger) const {
const auto& op_type = node.OpType();

#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram) {
if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable
return false;
}
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
}
if (op_type == "Gelu" && !input_params.create_mlprogram) {
return false;
}
if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
}

return true;
Expand All @@ -245,6 +288,7 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration
"Relu",
"PRelu",
"LeakyRelu",
"Gelu",
};

op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/shared/utils/utils.h"

Expand All @@ -15,6 +16,9 @@ class ArgMaxOpBuilder : public BaseOpBuilder {

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

public:
bool SupportsMLProgram() const override { return true; }
};

Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
Expand All @@ -24,41 +28,60 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto& graph_viewer = model_builder.GetGraphViewer();

NodeAttrHelper helper(node);
const auto axis = helper.Get("axis", 0);
const auto keepdims = helper.Get("keepdims", 1);
const int64_t axis = helper.Get("axis", 0);
const int64_t keepdims = helper.Get("keepdims", 1);
const bool removedim = keepdims != 1;

auto* coreml_argmax = layer->mutable_argmax();
coreml_argmax->set_axis(axis);
coreml_argmax->set_removedim(removedim);

// There are two cases here:
// 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input
// (We still have this special case here because CoreML model does not have Cast)
// 2. Otherwise, we add Argmax layer normally
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
// If Argmax's successive node is a Cast from int64 to int32 output
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
// so we omit the check here
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction

std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "reduce_argmax");
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis));
AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims)));

int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
// the output of ArgMax must be int32
AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype);
model_builder.AddOperation(std::move(op));
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
auto* coreml_argmax = layer->mutable_argmax();
coreml_argmax->set_axis(axis);
coreml_argmax->set_removedim(removedim);

// There are two cases here:
// 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input
// (We still have this special case here because CoreML model does not have Cast)
// 2. Otherwise, we add Argmax layer normally
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
// If Argmax's successive node is a Cast from int64 to int32 output
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
// so we omit the check here
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
}
}
}

*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
model_builder.AddLayer(std::move(layer));
}
return Status::OK();
}

bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node,
[[maybe_unused]] const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
// Attribute `select_last_index` of ArgMax op is not supported
NodeAttrHelper helper(node);
Expand All @@ -68,6 +91,12 @@ bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa
return false;
}

#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram) {
return true;
}
#endif

// If there are multiple downstream nodes and cast (toint32) is one of them
// not supported, exit here
// Otherwise, for general multiple downstream nodes, supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ namespace coreml {
// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to
// filter suppported ones.
static std::set<std::string> Float16Ops = {
"Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal",
"Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool",
"Clip", "DepthToSpace", "Resize", "Slice", "Conv",
"ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul",
"AveragePool", "MaxPool", "Reshape", "Split", "Transpose"};
"Add", "ArgMax", "AveragePool", "BatchNormalization", "Cast", "Clip", "Concat", "Conv", "ConvTranspose",
"DepthToSpace", "Div", "Gelu", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "GridSample", "GroupNormalization",
"InstanceNormalization", "LayerNormalization", "LeakyRelu", "MatMul", "MaxPool", "Mul", "PRelu", "Pow",
"Reciprocal", "Relu", "Reshape", "Resize", "Sigmoid", "Slice", "Split", "Sqrt", "Sub", "Tanh", "Transpose"};

namespace {
// TODO, move this to shared_library
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/shared/utils/utils.h"

#ifdef __APPLE__
#include <TargetConditionals.h>
#endif

namespace onnxruntime {
namespace coreml {

Expand All @@ -24,6 +28,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder {

// BatchNormalization opset 6- has unsupported attributes
int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; }

public:
bool SupportsMLProgram() const override { return true; }
};

void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
Expand All @@ -50,21 +57,46 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
const auto eps = helper.Get("epsilon", 1e-5f);
const auto channels = scale_tensor.dims()[0];

auto* coreml_batch_norm = layer->mutable_batchnorm();
coreml_batch_norm->set_channels(channels);
coreml_batch_norm->set_epsilon(eps);
coreml_batch_norm->set_computemeanvar(false);
coreml_batch_norm->set_instancenormalization(false);

ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var

*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm

std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "batch_norm");
AddOperationInput(*op, "x", input_defs[0]->Name());
AddOperationInput(*op, "mean", model_builder.AddConstant(op->type(), input_defs[3]->Name() + "mean", mean_tensor));
AddOperationInput(*op, "variance", model_builder.AddConstant(op->type(), input_defs[4]->Name() + "variance", var_tensor));
AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name(), scale_tensor));
AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name(), bias_tensor));
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
MLFloat16 epsilon_fp16(eps);
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16));
} else {
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps));
}

AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
auto* coreml_batch_norm = layer->mutable_batchnorm();
coreml_batch_norm->set_channels(channels);
coreml_batch_norm->set_epsilon(eps);
coreml_batch_norm->set_computemeanvar(false);
coreml_batch_norm->set_instancenormalization(false);

ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var

*layer->mutable_input()->Add() = input_defs[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
}
return Status::OK();
}

Expand Down Expand Up @@ -119,6 +151,15 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu
return false;
}

#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
// To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) {
LOGS(logger, VERBOSE) << "float16 input is not supported on the iOS x86_64 simulator"
<< " due to CoreML producing invalid output.";
return false;
}
#endif
return true;
}

Expand Down
Loading

0 comments on commit 9daf766

Please sign in to comment.