Skip to content

Commit

Permalink
genenral appendEP for coreml
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Nov 1, 2024
1 parent 9daf766 commit 4b48a38
Show file tree
Hide file tree
Showing 14 changed files with 68 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
/// <summary>
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
/// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
/// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
/// </summary>
Expand Down
5 changes: 3 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,12 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <summary>
/// Append QNN, SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' 'XNNPACK' or 'CoreML are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" &&
providerName != "AZURE" && providerName != "CoreML")
{
throw new NotSupportedException(
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
#if USE_COREML
opt.AppendExecutionProvider("CoreML");
#else
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("CoreML"); });
Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
#endif

opt.AppendExecutionProvider_CPU(1);
}
Expand Down Expand Up @@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}

// Test hangs on mobile.
#if !(ANDROID || IOS)
#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
Expand Down
13 changes: 13 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,19 @@ public void addQnn(Map<String, String> providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}

/**
* Adds CoreML as an execution backend.
*
* @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addCoreML(Map<String, String> providerOptions) throws OrtException {
String CoreMLProviderName = "CoreML";
addExecutionProvider(CoreMLProviderName, providerOptions);
}


private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;

Expand Down
4 changes: 3 additions & 1 deletion js/node/src/session_options_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
#endif
#ifdef USE_COREML
} else if (name == "coreml") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags));
std::unordered_map<std::string, std::string> options;
options["coreml_flags"] = std::string(coreMlFlags);
sessionOptions.AppendExecutionProvider("CoreML", options);
#endif
#ifdef USE_QNN
} else if (name == "qnn") {
Expand Down
25 changes: 4 additions & 21 deletions js/react_native/ios/OnnxruntimeModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param reject callback for returning an error back to react native js
* @note when run() is called, the same modelPath must be passed into the first parameter.
*/
RCT_EXPORT_METHOD(loadModel
: (NSString*)modelPath options
: (NSDictionary*)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(loadModel : (NSString*)modelPath options : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
NSDictionary* resultMap = [self loadModel:modelPath options:options];
resolve(resultMap);
Expand All @@ -95,11 +91,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param reject callback for returning an error back to react native js
* @note when run() is called, the same modelPath must be passed into the first parameter.
*/
RCT_EXPORT_METHOD(loadModelFromBlob
: (NSDictionary*)modelDataBlob options
: (NSDictionary*)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(loadModelFromBlob : (NSDictionary*)modelDataBlob options : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
[self checkBlobManager];
NSString* blobId = [modelDataBlob objectForKey:@"blobId"];
Expand All @@ -121,10 +113,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param resolve callback for returning output back to react native js
* @param reject callback for returning an error back to react native js
*/
RCT_EXPORT_METHOD(dispose
: (NSString*)key resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(dispose : (NSString*)key resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
[self dispose:key];
resolve(nil);
Expand All @@ -143,13 +132,7 @@ - (void)setBlobManager:(RCTBlobManager*)manager {
* @param resolve callback for returning an inference result back to react native js
* @param reject callback for returning an error back to react native js
*/
RCT_EXPORT_METHOD(run
: (NSString*)url input
: (NSDictionary*)input output
: (NSArray*)output options
: (NSDictionary*)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
RCT_EXPORT_METHOD(run : (NSString*)url input : (NSDictionary*)input output : (NSArray*)output options : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) {
@try {
NSDictionary* resultMap = [self run:url input:input output:output options:options];
resolve(resultMap);
Expand Down
25 changes: 14 additions & 11 deletions objectivec/ort_coreml_execution_provider.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ @implementation ORTSessionOptions (ORTSessionOptionsCoreMLEP)
- (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options
error:(NSError**)error {
#if ORT_OBJC_API_COREML_EP_AVAILABLE
const uint32_t flags =
(options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) |
(options.useCPUAndGPU ? COREML_FLAG_USE_CPU_AND_GPU : 0) |
(options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) |
(options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) |
(options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) |
(options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0);

try {
const uint32_t flags =
(options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) |
(options.useCPUAndGPU ? COREML_FLAG_USE_CPU_AND_GPU : 0) |
(options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) |
(options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) |
(options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) |
(options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0);

Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(
[self CXXAPIOrtSessionOptions], flags));
return YES;
NSDictionary* provider_options = @{
@"coreml_flags" : [NSString stringWithFormat:@"%d", flags]
};
return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error);
return YES;

#else // !ORT_OBJC_API_COREML_EP_AVAILABLE
static_cast<void>(options);
ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error);
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/coreml/coreml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ std::unique_ptr<IExecutionProvider> CoreMLProviderFactory::CreateProvider() {
return std::make_unique<CoreMLExecutionProvider>(coreml_flags_);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
uint32_t coreml_flags = 0;
coreml_flags |= options.count("coreml_flags")
? std::stoi(options.at("coreml_flags"))
: 0;
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_flags);
}
} // namespace onnxruntime

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CoreML,
_In_ OrtSessionOptions* options, uint32_t coreml_flags) {
options->provider_factories.push_back(onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags));
options->provider_factories.push_back(onnxruntime::CoreMLProviderFactoryCreator::Create(
{{"coreml_flags", std::to_string(coreml_flags)}}));
return nullptr;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

#include <memory>

#include "core/framework/provider_options.h"
#include "core/providers/providers.h"

namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(uint32_t coreml_flags);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& options);
};
} // namespace onnxruntime
10 changes: 10 additions & 0 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
} else if (strcmp(provider_name, "CoreML") == 0) {
#if defined(USE_COREML)
std::string coreml_flags;
if (options->value.config_options.TryGetConfigEntry("coreml_flags", coreml_flags)) {
provider_options["coreml_flags"] = coreml_flags;
}
options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
}
// COREML_FLAG_CREATE_MLPROGRAM
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags));
session_options.AppendExecutionProvider("Coreml", {{"coreml_flags", std::to_string(coreml_flags)}});
#else
ORT_THROW("CoreML is not supported in this build\n");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
session_options.AppendExecutionProvider("CoreML", {"coreml_flags", std::to_string(flags)});
}
#else
(void)useCoreML;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
session_options.AppendExecutionProvider("CoreML", {{"coreml_flags", std::to_string(flags)}});
}
#else
(void)useCoreML;
Expand Down

0 comments on commit 4b48a38

Please sign in to comment.