Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML] Create EP by AppendExecutionProvider #22675

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
/// <summary>
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
/// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
/// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
/// </summary>
Expand Down
8 changes: 1 addition & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <summary>
/// Append QNN, SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}

if (providerOptions == null)
{
providerOptions = new Dictionary<string, string>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
#if USE_COREML
opt.AppendExecutionProvider("CoreML");
#else
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("CoreML"); });
Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
#endif

opt.AppendExecutionProvider_CPU(1);
}
Expand Down Expand Up @@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}

// Test hangs on mobile.
#if !(ANDROID || IOS)
#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ enum COREMLFlags {
COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
};

// MLComputeUnits can be one of the following values:
// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll'
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
static const char* const kCoremlProviderOption_MLModelFormat = "MLModelFormat";
static const char* const kCoremlProviderOption_MLAllowStaticInputShapes = "MLAllowStaticInputShapes";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
static const char* const kCoremlProviderOption_MLEnableOnSubgraphs = "MLEnableOnSubgraphs";
static const char* const kCoremlProviderOption_MLModelCacheDir = "MLModelCacheDir";
Copy link
Contributor

@skottmckay skottmckay Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the 'ML' prefix on these values? It makes sense to me for MLComputeUnits as that maps to a CoreML enum, but I'm not sure that any of the others do.

We should also provide clear documentation on what is valid/expected for each value even if it's simply copied from the flags doco.

A user might discover this file first so would also be good to mention that these values are intended to be used with Ort.::SessionOptions::AppendExecutionProvider (C++ API)/SessionOptionsAppendExecutionProvider (C API). #Closed


#ifdef __cplusplus
extern "C" {
#endif
Expand Down
12 changes: 12 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,18 @@ public void addQnn(Map<String, String> providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}

/**
* Adds CoreML as an execution backend.
*
* @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addCoreML(Map<String, String> providerOptions) throws OrtException {
String CoreMLProviderName = "CoreML";
addExecutionProvider(CoreMLProviderName, providerOptions);
}

private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;

Expand Down
12 changes: 11 additions & 1 deletion objectivec/include/ort_coreml_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,17 @@ NS_ASSUME_NONNULL_BEGIN
*/
- (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options
error:(NSError**)error;

/**
* Enables the CoreML execution provider in the session configuration options.
* It is appended to the execution provider list which is ordered by
* decreasing priority.
*
* @param provider_options The CoreML execution provider options in dict.
* @param error Optional error information set if an error occurs.
* @return Whether the provider was enabled successfully.
*/
- (BOOL)appendCoreMLExecutionProviderWithOptions_v2:(NSDictionary*)provider_options
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A user of appendCoreMLExecutionProviderWithOptions can easily look at the documentation for ORTCoreMLExecutionProviderOptions to know what is available. How should they discover the provider options that can be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. All avaiable key-valus are listed in the comments.

error:(NSError**)error;
@end

NS_ASSUME_NONNULL_END
15 changes: 15 additions & 0 deletions objectivec/ort_coreml_execution_provider.mm
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti
#endif
}

- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options
error:(NSError**)error {
#if ORT_OBJC_API_COREML_EP_AVAILABLE
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
try {
return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error);

#else // !ORT_OBJC_API_COREML_EP_AVAILABLE
static_cast<void>(provider_options);
ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error);
return NO;
#endif
}

@end

NS_ASSUME_NONNULL_END
22 changes: 22 additions & 0 deletions objectivec/test/ort_session_test.mm
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,28 @@ - (void)testAppendCoreMLEP {
ORTAssertNullableResultSuccessful(session, err);
}

- (void)testAppendCoreMLEP_v2 {
NSError* err = nil;
ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions];
NSDictionary* provider_options = @{@"MLEnableOnSubgraphs" : @"1"}; // set an arbitrary option

BOOL appendResult = [sessionOptions appendCoreMLExecutionProviderWithOptions_v2:provider_options
error:&err];

if (!ORTIsCoreMLExecutionProviderAvailable()) {
ORTAssertBoolResultUnsuccessful(appendResult, err);
return;
}

ORTAssertBoolResultSuccessful(appendResult, err);

ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
modelPath:[ORTSessionTest getAddModelPath]
sessionOptions:sessionOptions
error:&err];
ORTAssertNullableResultSuccessful(session, err);
}

- (void)testAppendXnnpackEP {
NSError* err = nil;
ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ namespace onnxruntime {

constexpr const char* COREML = "CoreML";

CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
coreml_flags_(coreml_flags),
coreml_flags_(options.coreml_flags),
coreml_version_(coreml::util::CoreMLVersion()) {
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
}

// check if only one flag is set
if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
if ((coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU)) {
// multiple device options selected
ORT_THROW(
"Multiple device options selected, you should use at most one of the following options:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ namespace coreml {
class Model;
}

struct CoreMLOptions {
uint32_t coreml_flags = 0;
std::string cache_path;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we keep changes to add supporting a cache path in a separate PR?

Also wondering if this is going to grow better if we keep it separate from coreml_flags. It's a little hard to see what is supported with the current setup. If we define the full set of options in this struct, and we populate those from either the flags, or the provider options, it's clear what the full set of options is.

I'd suggest making CoreMLOptions a class and make it responsible for validation. Can have a ctor for flags and one for provider options. That way all the validation is done in one place and it's clear which class is responsible for validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's good to seperate cache path in another PR.

If we make different flag has its own variable to mark it. How can we capatiable it with current coreml_flags?
Besides, we can easily pass coreml_flags(uint32_t) to object-c. can we pass a class from c++ to obect-c directly?

};

class CoreMLExecutionProvider : public IExecutionProvider {
public:
CoreMLExecutionProvider(uint32_t coreml_flags);
CoreMLExecutionProvider(const CoreMLOptions& options);
virtual ~CoreMLExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
66 changes: 61 additions & 5 deletions onnxruntime/core/providers/coreml/coreml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,77 @@
using namespace onnxruntime;

namespace onnxruntime {

namespace {
CoreMLOptions ParseProviderOption(const ProviderOptions& options) {
CoreMLOptions coreml_options;
const std::unordered_map<std::string, COREMLFlags> available_computeunits_options = {
{"MLComputeUnitsCPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE},
{"MLComputeUnitsCPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU},
{"MLComputeUnitsCPUOnly", COREML_FLAG_USE_CPU_ONLY},
{"MLComputeUnitsAll", COREML_FLAG_USE_NONE},
};
const std::unordered_map<std::string, COREMLFlags> available_modelformat_options = {
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
{"NeuralNetwork", COREML_FLAG_USE_NONE},
};
std::unordered_set<std::string> valid_options = {
kCoremlProviderOption_MLComputeUnits,
kCoremlProviderOption_MLModelFormat,
kCoremlProviderOption_MLAllowStaticInputShapes,
kCoremlProviderOption_MLEnableOnSubgraphs,
kCoremlProviderOption_MLModelCacheDir,
};
// Validate the options
for (const auto& option : options) {
if (valid_options.find(option.first) == valid_options.end()) {
ORT_THROW("Unknown option: ", option.first);
}
if (kCoremlProviderOption_MLComputeUnits == option.first) {
if (available_computeunits_options.find(option.second) == available_computeunits_options.end()) {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
} else {
coreml_options.coreml_flags |= available_computeunits_options.at(option.second);
}
} else if (kCoremlProviderOption_MLModelFormat == option.first) {
if (available_modelformat_options.find(option.second) == available_modelformat_options.end()) {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
} else {
coreml_options.coreml_flags |= available_modelformat_options.at(option.second);
}
} else if (kCoremlProviderOption_MLAllowStaticInputShapes == option.first) {
coreml_options.coreml_flags |= COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
} else if (kCoremlProviderOption_MLEnableOnSubgraphs == option.first) {
coreml_options.coreml_flags |= COREML_FLAG_ENABLE_ON_SUBGRAPH;
} else if (kCoremlProviderOption_MLModelCacheDir == option.first) {
coreml_options.cache_path = option.second;
}
}

return coreml_options;
}
} // namespace
struct CoreMLProviderFactory : IExecutionProviderFactory {
CoreMLProviderFactory(uint32_t coreml_flags)
: coreml_flags_(coreml_flags) {}
CoreMLProviderFactory(const CoreMLOptions& options)
: options_(options) {}
~CoreMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
uint32_t coreml_flags_;
CoreMLOptions options_;
};

std::unique_ptr<IExecutionProvider> CoreMLProviderFactory::CreateProvider() {
return std::make_unique<CoreMLExecutionProvider>(coreml_flags_);
return std::make_unique<CoreMLExecutionProvider>(options_);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_flags);
CoreMLOptions coreml_options;
coreml_options.coreml_flags = coreml_flags;
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_options);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(ParseProviderOption(options));
}
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

#include <memory>

#include "core/framework/provider_options.h"
#include "core/providers/providers.h"

namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(uint32_t coreml_flags);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& options);
};
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/coreml/model/model.mm
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
config.computeUnits = MLComputeUnitsCPUOnly;
} else if (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU) {
config.computeUnits = MLComputeUnitsCPUAndGPU;
} else if (coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine
} else {
config.computeUnits = MLComputeUnitsAll;
}
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
#ifdef USE_VITISAI
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
#else
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "CoreML") == 0) {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
#if defined(USE_COREML)
options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'");
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'");
}

return status;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::RknpuProviderFactoryCreator::Create(),
#endif
#ifdef USE_COREML
onnxruntime::CoreMLProviderFactoryCreator::Create(0),
onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}),
#endif
#ifdef USE_XNNPACK
onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr),
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
} else {
// read from provider_options
return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
Expand Down
Loading
Loading