Skip to content

Commit

Permalink
genenral appendEP for coreml
Browse files Browse the repository at this point in the history
accept provider-options
  • Loading branch information
wejoncy committed Nov 4, 2024
1 parent c64459f commit 530000a
Show file tree
Hide file tree
Showing 19 changed files with 86 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
/// <summary>
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
/// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
/// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
/// </summary>
Expand Down
5 changes: 3 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,12 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <summary>
/// Append QNN, SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' 'XNNPACK' or 'CoreML are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" &&
providerName != "AZURE" && providerName != "CoreML")
{
throw new NotSupportedException(
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
#if USE_COREML
opt.AppendExecutionProvider("CoreML");
#else
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("CoreML"); });
Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
#endif

opt.AppendExecutionProvider_CPU(1);
}
Expand Down Expand Up @@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}

// Test hangs on mobile.
#if !(ANDROID || IOS)
#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
Expand Down
13 changes: 13 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,19 @@ public void addQnn(Map<String, String> providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}

/**
* Adds CoreML as an execution backend.
*
* @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addCoreML(Map<String, String> providerOptions) throws OrtException {
String CoreMLProviderName = "CoreML";
addExecutionProvider(CoreMLProviderName, providerOptions);
}


private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;

Expand Down
4 changes: 3 additions & 1 deletion js/node/src/session_options_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
#endif
#ifdef USE_COREML
} else if (name == "coreml") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags));
std::unordered_map<std::string, std::string> options;
options["coreml_flags"] = std::string(coreMlFlags);
sessionOptions.AppendExecutionProvider("CoreML", options);
#endif
#ifdef USE_QNN
} else if (name == "qnn") {
Expand Down
25 changes: 4 additions & 21 deletions js/react_native/ios/OnnxruntimeModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param reject callback for returning an error back to react native js
* @note when run() is called, the same modelPath must be passed into the first parameter.
*/
RCT_EXPORT_METHOD(loadModel
: (NSString*)modelPath options
: (NSDictionary*)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(loadModel : (NSString*)modelPath options : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
NSDictionary* resultMap = [self loadModel:modelPath options:options];
resolve(resultMap);
Expand All @@ -95,11 +91,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param reject callback for returning an error back to react native js
* @note when run() is called, the same modelPath must be passed into the first parameter.
*/
RCT_EXPORT_METHOD(loadModelFromBlob
: (NSDictionary*)modelDataBlob options
: (NSDictionary*)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(loadModelFromBlob : (NSDictionary*)modelDataBlob options : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
[self checkBlobManager];
NSString* blobId = [modelDataBlob objectForKey:@"blobId"];
Expand All @@ -121,10 +113,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param resolve callback for returning output back to react native js
* @param reject callback for returning an error back to react native js
*/
RCT_EXPORT_METHOD(dispose
: (NSString*)key resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(dispose : (NSString*)key resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
[self dispose:key];
resolve(nil);
Expand All @@ -143,13 +132,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param resolve callback for returning an inference result back to react native js
* @param reject callback for returning an error back to react native js
*/
RCT_EXPORT_METHOD(run
: (NSString*)url input
: (NSDictionary*)input output
: (NSArray*)output options
: (NSDictionary*)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(run : (NSString*)url input : (NSDictionary*)input output : (NSArray*)output options : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
NSDictionary* resultMap = [self run:url input:input output:output options:options];
resolve(resultMap);
Expand Down
25 changes: 14 additions & 11 deletions objectivec/ort_coreml_execution_provider.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ @implementation ORTSessionOptions (ORTSessionOptionsCoreMLEP)
- (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options
error:(NSError**)error {
#if ORT_OBJC_API_COREML_EP_AVAILABLE
const uint32_t flags =
(options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) |
(options.useCPUAndGPU ? COREML_FLAG_USE_CPU_AND_GPU : 0) |
(options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) |
(options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) |
(options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) |
(options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0);

try {
const uint32_t flags =
(options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) |
(options.useCPUAndGPU ? COREML_FLAG_USE_CPU_AND_GPU : 0) |
(options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) |
(options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) |
(options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) |
(options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0);

Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(
[self CXXAPIOrtSessionOptions], flags));
return YES;
NSDictionary* provider_options = @{
@"coreml_flags" : [NSString stringWithFormat:@"%d", flags]
};
return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error);
return YES;

#else // !ORT_OBJC_API_COREML_EP_AVAILABLE
static_cast<void>(options);
ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ namespace onnxruntime {

constexpr const char* COREML = "CoreML";

CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
CoreMLExecutionProvider::CoreMLExecutionProvider(const ProviderOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
coreml_flags_(coreml_flags),
coreml_version_(coreml::util::CoreMLVersion()) {
coreml_flags_ = options.count("coreml_flags")? std::stoi(options.at("coreml_flags")) : 0;
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
}

// check if only one flag is set
if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
if ((coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU)) {
// multiple device options selected
ORT_THROW(
"Multiple device options selected, you should use at most one of the following options:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Model;

class CoreMLExecutionProvider : public IExecutionProvider {
public:
CoreMLExecutionProvider(uint32_t coreml_flags);
CoreMLExecutionProvider(const ProviderOptions& options);

Check warning on line 17 in onnxruntime/core/providers/coreml/coreml_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_execution_provider.h:17: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
virtual ~CoreMLExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/core/providers/coreml/coreml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,26 @@ using namespace onnxruntime;

namespace onnxruntime {
struct CoreMLProviderFactory : IExecutionProviderFactory {
CoreMLProviderFactory(uint32_t coreml_flags)
: coreml_flags_(coreml_flags) {}
CoreMLProviderFactory(const ProviderOptions& options)

Check warning on line 13 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:13: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
: options_(options) {}
~CoreMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
uint32_t coreml_flags_;
const ProviderOptions& options_;
};

std::unique_ptr<IExecutionProvider> CoreMLProviderFactory::CreateProvider() {
return std::make_unique<CoreMLExecutionProvider>(coreml_flags_);
return std::make_unique<CoreMLExecutionProvider>(options_);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_flags);
std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(options);

Check warning on line 26 in onnxruntime/core/providers/coreml/coreml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_provider_factory.cc:26: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
}
} // namespace onnxruntime

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CoreML,
_In_ OrtSessionOptions* options, uint32_t coreml_flags) {
options->provider_factories.push_back(onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags));
options->provider_factories.push_back(onnxruntime::CoreMLProviderFactoryCreator::Create(
{{"coreml_flags", std::to_string(coreml_flags)}}));
return nullptr;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

#include <memory>

#include "core/framework/provider_options.h"
#include "core/providers/providers.h"

namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(uint32_t coreml_flags);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& options);
};
} // namespace onnxruntime
10 changes: 10 additions & 0 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
} else if (strcmp(provider_name, "CoreML") == 0) {
#if defined(USE_COREML)
std::string coreml_flags;
if (options->value.config_options.TryGetConfigEntry("coreml_flags", coreml_flags)) {
provider_options["coreml_flags"] = coreml_flags;
}
options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
}
// COREML_FLAG_CREATE_MLPROGRAM
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags));
session_options.AppendExecutionProvider("Coreml", {{"coreml_flags", std::to_string(coreml_flags)}});
#else
ORT_THROW("CoreML is not supported in this build\n");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
session_options.AppendExecutionProvider("CoreML", {"coreml_flags", std::to_string(flags)});
}
#else
(void)useCoreML;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
session_options.AppendExecutionProvider("CoreML", {{"coreml_flags", std::to_string(flags)}});
}
#else
(void)useCoreML;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/providers/coreml/coreml_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace test {
static constexpr uint32_t s_coreml_flags = COREML_FLAG_USE_CPU_ONLY;

static std::unique_ptr<IExecutionProvider> MakeCoreMLExecutionProvider(uint32_t flags = s_coreml_flags) {
return std::make_unique<CoreMLExecutionProvider>(flags);
std::unordered_map<std::string, std::string> provider_options= {{"coreml_flags",std::to_string(flags)}};
return std::make_unique<CoreMLExecutionProvider>(provider_options);
}

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/test/providers/coreml/dynamic_input_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MatMul) {

auto test = [&](const size_t M) {
SCOPED_TRACE(MakeString("M=", M));

auto coreml_ep = std::make_unique<CoreMLExecutionProvider>(0);
std::unordered_map<std::string, std::string> options;
auto coreml_ep = std::make_unique<CoreMLExecutionProvider>(options);

const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
Expand Down Expand Up @@ -54,8 +54,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MobileNetExcerpt) {

auto test = [&](const size_t batch_size) {
SCOPED_TRACE(MakeString("batch_size=", batch_size));

auto coreml_ep = std::make_unique<CoreMLExecutionProvider>(0);
std::unordered_map<std::string, std::string> options;
auto coreml_ep = std::make_unique<CoreMLExecutionProvider>(options);

const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
Expand Down Expand Up @@ -87,21 +87,22 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");

ModelTester tester(CurrentTestName(), model_path);
std::unordered_map<std::string, std::string> options;

tester.AddInput<float>("A", {0, 2}, {});
tester.AddOutput<float>("Y", {0, 4}, {});

tester
.Config(ModelTester::ExpectResult::kExpectFailure,
"the runtime shape ({0,2}) has zero elements. This is not supported by the CoreML EP.")
.ConfigEp(std::make_unique<CoreMLExecutionProvider>(0))
.ConfigEp(std::make_unique<CoreMLExecutionProvider>(options))
.RunWithConfig();
}

TEST(CoreMLExecutionProviderDynamicInputShapeTest, OnlyAllowStaticInputShapes) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");

auto coreml_ep = std::make_unique<CoreMLExecutionProvider>(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES);
std::unordered_map<std::string, std::string> options = {{"coreml_flags", std::to_string(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES)}};
auto coreml_ep = std::make_unique<CoreMLExecutionProvider>(options);

TestModelLoad(model_path, std::move(coreml_ep),
// expect no supported nodes because we disable dynamic input shape support
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ std::unique_ptr<IExecutionProvider> DefaultCoreMLExecutionProvider(bool use_mlpr
if (use_mlprogram) {
coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM;
}

return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
auto option = ProviderOptions();
option["coreml_flags"] = std::to_string(coreml_flags);
return CoreMLProviderFactoryCreator::Create(option)->CreateProvider();
#else
ORT_UNUSED_PARAMETER(use_mlprogram);
return nullptr;
Expand Down

0 comments on commit 530000a

Please sign in to comment.