Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML] Create EP by AppendExecutionProvider #22675

Merged
merged 22 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
/// <summary>
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
/// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
/// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
/// </summary>
Expand Down
8 changes: 1 addition & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <summary>
/// Append QNN, SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}

if (providerOptions == null)
{
providerOptions = new Dictionary<string, string>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
#if USE_COREML
opt.AppendExecutionProvider("CoreML");
#else
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("CoreML"); });
Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
#endif

opt.AppendExecutionProvider_CPU(1);
}
Expand Down Expand Up @@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}

// Test hangs on mobile.
#if !(ANDROID || IOS)
#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ enum COREMLFlags {
COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
};

// MLComputeUnits can be one of the following values:
// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll'
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
static const char* const kCoremlProviderOption_MLModelFormat = "MLModelFormat";
static const char* const kCoremlProviderOption_MLAllowStaticInputShapes = "MLAllowStaticInputShapes";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
static const char* const kCoremlProviderOption_MLEnableOnSubgraphs = "MLEnableOnSubgraphs";
static const char* const kCoremlProviderOption_MLModelCacheDir = "MLModelCacheDir";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
12 changes: 12 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,18 @@ public void addQnn(Map<String, String> providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}

/**
* Adds CoreML as an execution backend.
*
* @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addCoreML(Map<String, String> providerOptions) throws OrtException {
String CoreMLProviderName = "CoreML";
addExecutionProvider(CoreMLProviderName, providerOptions);
}

private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ namespace onnxruntime {

constexpr const char* COREML = "CoreML";

CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
coreml_flags_(coreml_flags),
coreml_flags_(options.coreml_flags),
coreml_version_(coreml::util::CoreMLVersion()) {
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
}

// check if only one flag is set
if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
if ((coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU)) {
// multiple device options selected
ORT_THROW(
"Multiple device options selected, you should use at most one of the following options:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
class Model;
}

struct CoreMLOptions {
uint32_t coreml_flags = 0;
std::string cache_path;
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
};

class CoreMLExecutionProvider : public IExecutionProvider {
public:
CoreMLExecutionProvider(uint32_t coreml_flags);
CoreMLExecutionProvider(const CoreMLOptions& options);

Check warning on line 22 in onnxruntime/core/providers/coreml/coreml_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_execution_provider.h:22: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
virtual ~CoreMLExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
66 changes: 61 additions & 5 deletions onnxruntime/core/providers/coreml/coreml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,77 @@
using namespace onnxruntime;

namespace onnxruntime {

namespace {
CoreMLOptions ParseProviderOption(const ProviderOptions& options) {
CoreMLOptions coreml_options;
const std::unordered_map<std::string, COREMLFlags> available_computeunits_options = {
{"MLComputeUnitsCPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE},
{"MLComputeUnitsCPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU},
{"MLComputeUnitsCPUOnly", COREML_FLAG_USE_CPU_ONLY},
{"MLComputeUnitsAll", COREML_FLAG_USE_NONE},
};
const std::unordered_map<std::string, COREMLFlags> available_modelformat_options = {

Check warning on line 22 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:22: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
{"NeuralNetwork", COREML_FLAG_USE_NONE},
};
std::unordered_set<std::string> valid_options = {

Check warning on line 26 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:26: Add #include <string> for string [build/include_what_you_use] [4]

Check warning on line 26 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:26: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
kCoremlProviderOption_MLComputeUnits,
kCoremlProviderOption_MLModelFormat,
kCoremlProviderOption_MLAllowStaticInputShapes,
kCoremlProviderOption_MLEnableOnSubgraphs,
kCoremlProviderOption_MLModelCacheDir,
};
// Validate the options
for (const auto& option : options) {
if (valid_options.find(option.first) == valid_options.end()) {
ORT_THROW("Unknown option: ", option.first);
}
if (kCoremlProviderOption_MLComputeUnits == option.first) {
if (available_computeunits_options.find(option.second) == available_computeunits_options.end()) {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
}else {
coreml_options.coreml_flags |= available_computeunits_options.at(option.second);
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
}
} else if (kCoremlProviderOption_MLModelFormat == option.first) {
if (available_modelformat_options.find(option.second) == available_modelformat_options.end()) {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
} else {
coreml_options.coreml_flags |= available_modelformat_options.at(option.second);
}
} else if (okCoremlProviderOption_MLAllowStaticInputShapes == option.first) {
coreml_options.coreml_flags |= COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
} else if (okCoremlProviderOption_MLEnableOnSubgraphs == option.first) {
coreml_options.coreml_flags |= COREML_FLAG_ENABLE_ON_SUBGRAPH;
} else if (okCoremlProviderOption_MLModelCacheDir == option.first) {
coreml_options.cache_path = option.second;
}
}

return coreml_options;
}
} // namespace
struct CoreMLProviderFactory : IExecutionProviderFactory {
CoreMLProviderFactory(uint32_t coreml_flags)
: coreml_flags_(coreml_flags) {}
CoreMLProviderFactory(const CoreMLOptions& options)

Check warning on line 63 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:63: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
: options_(options) {}
~CoreMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
uint32_t coreml_flags_;
CoreMLOptions options_;
};

std::unique_ptr<IExecutionProvider> CoreMLProviderFactory::CreateProvider() {
return std::make_unique<CoreMLExecutionProvider>(coreml_flags_);
return std::make_unique<CoreMLExecutionProvider>(options_);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_flags);
CoreMLOptions coreml_options;
coreml_options.coreml_flags = coreml_flags;
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_options);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(ParseProviderOption(options));

Check warning on line 82 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:82: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
}
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

#include <memory>

#include "core/framework/provider_options.h"
#include "core/providers/providers.h"

namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(uint32_t coreml_flags);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& options);
};
} // namespace onnxruntime
12 changes: 11 additions & 1 deletion onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
#ifdef USE_VITISAI
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
#else
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "CoreML") == 0) {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
#if defined(USE_COREML)
options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'");
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'");
}

return status;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::RknpuProviderFactoryCreator::Create(),
#endif
#ifdef USE_COREML
onnxruntime::CoreMLProviderFactoryCreator::Create(0),
onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}),
#endif
#ifdef USE_XNNPACK
onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr),
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
} else {
// read from provider_options
return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
Expand Down
42 changes: 7 additions & 35 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <core/optimizer/graph_transformer_level.h>

#include "test_configuration.h"
#include "strings_helper.h"

namespace onnxruntime {
namespace perftest {
Expand Down Expand Up @@ -129,8 +130,12 @@ namespace perftest {
"\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n"
"\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n"
"\n"
"\t [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM COREML_FLAG_USE_CPU_ONLY COREML_FLAG_USE_CPU_AND_GPU]: Create an ML Program model instead of Neural Network.\n"
"\t [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n"
"\t [CoreML only] [ModelFormat]:[MLProgram, NeuralNetwork] Create an ML Program model or Neural Network.\n"
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
"\t [CoreML only] [ComputeUnits]:[MLComputeUnitsCPUAndNeuralEngine MLComputeUnitsCPUAndGPU MLComputeUnitsCPUAndGPU MLComputeUnitsCPUOnly] the backend device to run model.\n"
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
"\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n"
"\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n"
"\t [CoreML only] [ModelCacheDir]: a path to cached compiled coreml model.\n"
"\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram ComputeUnits|MLComputeUnitsCPUAndGPU\"\n"
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
"\n"
"\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n"
"\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n"
Expand Down Expand Up @@ -175,39 +180,6 @@ static bool ParseDimensionOverride(std::basic_string<ORTCHAR_T>& dim_identifier,
return true;
}

static bool ParseSessionConfigs(const std::string& configs_string,
std::unordered_map<std::string, std::string>& session_configs) {
std::istringstream ss(configs_string);
std::string token;

while (ss >> token) {
if (token == "") {
continue;
}

std::string_view token_sv(token);

auto pos = token_sv.find("|");
if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) {
// Error: must use a '|' to separate the key and value for session configuration entries.
return false;
}

std::string key(token_sv.substr(0, pos));
std::string value(token_sv.substr(pos + 1));

auto it = session_configs.find(key);
if (it != session_configs.end()) {
// Error: specified duplicate session configuration entry: {key}
return false;
}

session_configs.insert(std::make_pair(std::move(key), std::move(value)));
}

return true;
}

/*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) {
Expand Down
Loading
Loading