Skip to content

Commit

Permalink
refactor coreml EP creater
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Nov 7, 2024
1 parent c64459f commit 43c882f
Show file tree
Hide file tree
Showing 21 changed files with 248 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
/// <summary>
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
/// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
/// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
/// </summary>
Expand Down
8 changes: 1 addition & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <summary>
/// Append QNN, SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}

if (providerOptions == null)
{
providerOptions = new Dictionary<string, string>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
#if USE_COREML
opt.AppendExecutionProvider("CoreML");
#else
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("CoreML"); });
Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
#endif

opt.AppendExecutionProvider_CPU(1);
}
Expand Down Expand Up @@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}

// Test hangs on mobile.
#if !(ANDROID || IOS)
#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ enum COREMLFlags {
COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
};

// MLComputeUnits can be one of the following values:
// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll'
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
12 changes: 12 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,18 @@ public void addQnn(Map<String, String> providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}

/**
* Adds CoreML as an execution backend.
*
* @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addCoreML(Map<String, String> providerOptions) throws OrtException {
String CoreMLProviderName = "CoreML";
addExecutionProvider(CoreMLProviderName, providerOptions);
}

private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ namespace onnxruntime {

constexpr const char* COREML = "CoreML";

CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
coreml_flags_(coreml_flags),
coreml_flags_(options.coreml_flags),
coreml_version_(coreml::util::CoreMLVersion()) {
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
}

// check if only one flag is set
if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
if ((coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU)) {
// multiple device options selected
ORT_THROW(
"Multiple device options selected, you should use at most one of the following options:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ namespace coreml {
class Model;
}

struct CoreMLOptions {
uint32_t coreml_flags = 0;
std::string cache_path;
};

class CoreMLExecutionProvider : public IExecutionProvider {
public:
CoreMLExecutionProvider(uint32_t coreml_flags);
CoreMLExecutionProvider(const CoreMLOptions& options);
virtual ~CoreMLExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
49 changes: 44 additions & 5 deletions onnxruntime/core/providers/coreml/coreml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,60 @@
using namespace onnxruntime;

namespace onnxruntime {

namespace {
CoreMLOptions ParseProviderOption(const ProviderOptions& options) {
CoreMLOptions coreml_options;
const std::unordered_map<std::string, COREMLFlags> available_device_options = {
{"MLComputeUnitsCPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE},
{"MLComputeUnitsCPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU},
{"MLComputeUnitsCPUOnly", COREML_FLAG_USE_CPU_ONLY},
{"MLComputeUnitsAll", COREML_FLAG_USE_NONE},
};
const std::unordered_map<std::string, COREMLFlags> available_format_options = {
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
{"NeuralNetwork", COREML_FLAG_USE_NONE},
};
if (options.count("ComputeUnits")) {
coreml_options.coreml_flags |= available_device_options.at(options.at("ComputeUnits"));
}
if (options.count("ModelFormat")) {
coreml_options.coreml_flags |= available_format_options.at(options.at("ModelFormat"));
}
if (options.count("AllowStaticInputShapes")) {
coreml_options.coreml_flags |= COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
}
if (options.count("EnableOnSubgraphs")) {
coreml_options.coreml_flags |= COREML_FLAG_ENABLE_ON_SUBGRAPH;
}
if (options.count("ModelCacheDir")) {
coreml_options.cache_path = options.at("ModelCacheDir");
}

return coreml_options;
}
} // namespace
struct CoreMLProviderFactory : IExecutionProviderFactory {
CoreMLProviderFactory(uint32_t coreml_flags)
: coreml_flags_(coreml_flags) {}
CoreMLProviderFactory(const CoreMLOptions& options)
: options_(options) {}
~CoreMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
uint32_t coreml_flags_;
CoreMLOptions options_;
};

std::unique_ptr<IExecutionProvider> CoreMLProviderFactory::CreateProvider() {
return std::make_unique<CoreMLExecutionProvider>(coreml_flags_);
return std::make_unique<CoreMLExecutionProvider>(options_);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_flags);
CoreMLOptions coreml_options;
coreml_options.coreml_flags = coreml_flags;
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_options);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(ParseProviderOption(options));
}
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

#include <memory>

#include "core/framework/provider_options.h"
#include "core/providers/providers.h"

namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(uint32_t coreml_flags);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& options);
};
} // namespace onnxruntime
16 changes: 15 additions & 1 deletion onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,25 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
#ifdef USE_VITISAI
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
#else
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "CoreML") == 0) {
#if defined(USE_COREML)
std::string coreml_flags;
if (options->value.config_options.TryGetConfigEntry("coreml_flags", coreml_flags)) {
provider_options["coreml_flags"] = coreml_flags;
}
options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'");
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'");
}

return status;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::RknpuProviderFactoryCreator::Create(),
#endif
#ifdef USE_COREML
onnxruntime::CoreMLProviderFactoryCreator::Create(0),
onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}),
#endif
#ifdef USE_XNNPACK
onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
Expand Down
34 changes: 1 addition & 33 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <core/optimizer/graph_transformer_level.h>

#include "test_configuration.h"
#include "strings_helper.h"

namespace onnxruntime {
namespace perftest {
Expand Down Expand Up @@ -175,39 +176,6 @@ static bool ParseDimensionOverride(std::basic_string<ORTCHAR_T>& dim_identifier,
return true;
}

static bool ParseSessionConfigs(const std::string& configs_string,
std::unordered_map<std::string, std::string>& session_configs) {
std::istringstream ss(configs_string);
std::string token;

while (ss >> token) {
if (token == "") {
continue;
}

std::string_view token_sv(token);

auto pos = token_sv.find("|");
if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) {
// Error: must use a '|' to separate the key and value for session configuration entries.
return false;
}

std::string key(token_sv.substr(0, pos));
std::string value(token_sv.substr(pos + 1));

auto it = session_configs.find(key);
if (it != session_configs.end()) {
// Error: specified duplicate session configuration entry: {key}
return false;
}

session_configs.insert(std::make_pair(std::move(key), std::move(value)));
}

return true;
}

/*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) {
Expand Down
Loading

0 comments on commit 43c882f

Please sign in to comment.