Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML] Create EP by AppendExecutionProvider #22675

Merged
merged 22 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
/// <summary>
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
/// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
/// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
/// </summary>
Expand Down
8 changes: 1 addition & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <summary>
/// Append QNN, SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerName">Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}

if (providerOptions == null)
{
providerOptions = new Dictionary<string, string>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
#if USE_COREML
opt.AppendExecutionProvider("CoreML");
#else
ex = Assert.Throws<OnnxRuntimeException>(() => { opt.AppendExecutionProvider("CoreML"); });
Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
#endif

opt.AppendExecutionProvider_CPU(1);
}
Expand Down Expand Up @@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}

// Test hangs on mobile.
#if !(ANDROID || IOS)
#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ enum COREMLFlags {
COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
};

// MLComputeUnits can be one of the following values:
// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll'
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
static const char* const kCoremlProviderOption_MLModelFormat = "MLModelFormat";
static const char* const kCoremlProviderOption_MLAllowStaticInputShapes = "MLAllowStaticInputShapes";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
static const char* const kCoremlProviderOption_MLEnableOnSubgraphs = "MLEnableOnSubgraphs";
static const char* const kCoremlProviderOption_MLModelCacheDir = "MLModelCacheDir";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
12 changes: 12 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,18 @@ public void addQnn(Map<String, String> providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}

/**
* Adds CoreML as an execution backend.
*
* @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addCoreML(Map<String, String> providerOptions) throws OrtException {
String CoreMLProviderName = "CoreML";
addExecutionProvider(CoreMLProviderName, providerOptions);
}

private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;

Expand Down
12 changes: 11 additions & 1 deletion objectivec/include/ort_coreml_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,17 @@ NS_ASSUME_NONNULL_BEGIN
*/
- (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options
error:(NSError**)error;

/**
* Enables the CoreML execution provider in the session configuration options.
* It is appended to the execution provider list which is ordered by
* decreasing priority.
*
* @param provider_options The CoreML execution provider options in dict.
* @param error Optional error information set if an error occurs.
* @return Whether the provider was enabled successfully.
*/
- (BOOL)appendCoreMLExecutionProviderWithOptions_v2:(NSDictionary*)provider_options
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
error:(NSError**)error;
@end

NS_ASSUME_NONNULL_END
15 changes: 15 additions & 0 deletions objectivec/ort_coreml_execution_provider.mm
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti
#endif
}

- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options
error:(NSError**)error {
#if ORT_OBJC_API_COREML_EP_AVAILABLE
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
try {
return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error);

#else // !ORT_OBJC_API_COREML_EP_AVAILABLE
static_cast<void>(provider_options);
ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error);
return NO;
#endif
}

@end

NS_ASSUME_NONNULL_END
22 changes: 22 additions & 0 deletions objectivec/test/ort_session_test.mm
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,28 @@ - (void)testAppendCoreMLEP {
ORTAssertNullableResultSuccessful(session, err);
}

- (void)testAppendCoreMLEP_v2 {
NSError* err = nil;
ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions];
NSDictionary* provider_options = @{@"MLEnableOnSubgraphs" : @"1"}; // set an arbitrary option

BOOL appendResult = [sessionOptions appendCoreMLExecutionProviderWithOptions_v2:provider_options
error:&err];

if (!ORTIsCoreMLExecutionProviderAvailable()) {
ORTAssertBoolResultUnsuccessful(appendResult, err);
return;
}

ORTAssertBoolResultSuccessful(appendResult, err);

ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
modelPath:[ORTSessionTest getAddModelPath]
sessionOptions:sessionOptions
error:&err];
ORTAssertNullableResultSuccessful(session, err);
}

- (void)testAppendXnnpackEP {
NSError* err = nil;
ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ namespace onnxruntime {

constexpr const char* COREML = "CoreML";

CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
coreml_flags_(coreml_flags),
coreml_flags_(options.coreml_flags),
coreml_version_(coreml::util::CoreMLVersion()) {
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
}

// check if only one flag is set
if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
if ((coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU)) {
// multiple device options selected
ORT_THROW(
"Multiple device options selected, you should use at most one of the following options:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ namespace coreml {
class Model;
}

struct CoreMLOptions {
uint32_t coreml_flags = 0;
std::string cache_path;
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
};

class CoreMLExecutionProvider : public IExecutionProvider {
public:
CoreMLExecutionProvider(uint32_t coreml_flags);
CoreMLExecutionProvider(const CoreMLOptions& options);
virtual ~CoreMLExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
66 changes: 61 additions & 5 deletions onnxruntime/core/providers/coreml/coreml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,77 @@
using namespace onnxruntime;

namespace onnxruntime {

namespace {
CoreMLOptions ParseProviderOption(const ProviderOptions& options) {
CoreMLOptions coreml_options;
const std::unordered_map<std::string, COREMLFlags> available_computeunits_options = {
{"MLComputeUnitsCPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE},
{"MLComputeUnitsCPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU},
{"MLComputeUnitsCPUOnly", COREML_FLAG_USE_CPU_ONLY},
{"MLComputeUnitsAll", COREML_FLAG_USE_NONE},
};
const std::unordered_map<std::string, COREMLFlags> available_modelformat_options = {
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
{"NeuralNetwork", COREML_FLAG_USE_NONE},
};
std::unordered_set<std::string> valid_options = {
kCoremlProviderOption_MLComputeUnits,
kCoremlProviderOption_MLModelFormat,
kCoremlProviderOption_MLAllowStaticInputShapes,
kCoremlProviderOption_MLEnableOnSubgraphs,
kCoremlProviderOption_MLModelCacheDir,
};
// Validate the options
for (const auto& option : options) {
if (valid_options.find(option.first) == valid_options.end()) {
ORT_THROW("Unknown option: ", option.first);
}
if (kCoremlProviderOption_MLComputeUnits == option.first) {
if (available_computeunits_options.find(option.second) == available_computeunits_options.end()) {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
} else {
coreml_options.coreml_flags |= available_computeunits_options.at(option.second);
}
} else if (kCoremlProviderOption_MLModelFormat == option.first) {
if (available_modelformat_options.find(option.second) == available_modelformat_options.end()) {
ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
} else {
coreml_options.coreml_flags |= available_modelformat_options.at(option.second);
}
} else if (kCoremlProviderOption_MLAllowStaticInputShapes == option.first) {
coreml_options.coreml_flags |= COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
} else if (kCoremlProviderOption_MLEnableOnSubgraphs == option.first) {
coreml_options.coreml_flags |= COREML_FLAG_ENABLE_ON_SUBGRAPH;
} else if (kCoremlProviderOption_MLModelCacheDir == option.first) {
coreml_options.cache_path = option.second;
}
}

return coreml_options;
}
} // namespace
struct CoreMLProviderFactory : IExecutionProviderFactory {
CoreMLProviderFactory(uint32_t coreml_flags)
: coreml_flags_(coreml_flags) {}
CoreMLProviderFactory(const CoreMLOptions& options)
: options_(options) {}
~CoreMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
uint32_t coreml_flags_;
CoreMLOptions options_;
};

std::unique_ptr<IExecutionProvider> CoreMLProviderFactory::CreateProvider() {
return std::make_unique<CoreMLExecutionProvider>(coreml_flags_);
return std::make_unique<CoreMLExecutionProvider>(options_);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_flags);
CoreMLOptions coreml_options;
coreml_options.coreml_flags = coreml_flags;
return std::make_shared<onnxruntime::CoreMLProviderFactory>(coreml_options);
}

std::shared_ptr<IExecutionProviderFactory> CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
return std::make_shared<onnxruntime::CoreMLProviderFactory>(ParseProviderOption(options));
}
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

#include <memory>

#include "core/framework/provider_options.h"
#include "core/providers/providers.h"

namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(uint32_t coreml_flags);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& options);
};
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/coreml/model/model.mm
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
config.computeUnits = MLComputeUnitsCPUOnly;
} else if (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU) {
config.computeUnits = MLComputeUnitsCPUAndGPU;
} else if (coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine
} else {
config.computeUnits = MLComputeUnitsAll;
}
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
#ifdef USE_VITISAI
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
#else
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "CoreML") == 0) {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
#if defined(USE_COREML)
options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'");
"Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'");
}

return status;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::RknpuProviderFactoryCreator::Create(),
#endif
#ifdef USE_COREML
onnxruntime::CoreMLProviderFactoryCreator::Create(0),
onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}),
#endif
#ifdef USE_XNNPACK
onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr),
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
} else {
// read from provider_options
return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
Expand Down
Loading
Loading