diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index be157a0419fc0..828ecaa25e6b8 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -1272,7 +1272,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
///
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
- /// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
+ /// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
///
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 3acd84b3016de..c6e576ca84fb9 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -430,16 +430,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
///
/// Append QNN, SNPE or XNNPACK execution provider
///
- /// Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.
+ /// Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported.
/// Optional key/value pairs to specify execution provider options.
public void AppendExecutionProvider(string providerName, Dictionary providerOptions = null)
{
- if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
- {
- throw new NotSupportedException(
- "Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
- }
-
if (providerOptions == null)
{
providerOptions = new Dictionary();
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
index aa0e6ee62248a..1941ca72d689d 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
@@ -179,6 +179,12 @@ public void TestSessionOptions()
ex = Assert.Throws(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
+#if USE_COREML
+ opt.AppendExecutionProvider("CoreML");
+#else
+ ex = Assert.Throws(() => { opt.AppendExecutionProvider("CoreML"); });
+ Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
+#endif
opt.AppendExecutionProvider_CPU(1);
}
@@ -2041,7 +2047,7 @@ public SkipNonPackageTests()
}
// Test hangs on mobile.
-#if !(ANDROID || IOS)
+#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
index 98fa9e09f1ba8..79e6229e3891c 100644
--- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
+++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
@@ -41,6 +41,10 @@ enum COREMLFlags {
COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
};
+// MLComputeUnits can be one of the following values:
+// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll'
+static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
+
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java
index 7280f3c88e2e8..32dc9d9f84aaa 100644
--- a/java/src/main/java/ai/onnxruntime/OrtSession.java
+++ b/java/src/main/java/ai/onnxruntime/OrtSession.java
@@ -1323,6 +1323,18 @@ public void addQnn(Map providerOptions) throws OrtException {
addExecutionProvider(qnnProviderName, providerOptions);
}
+ /**
+ * Adds CoreML as an execution backend.
+ *
+ * @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
+ * execution provider's documentation.
+ * @throws OrtException If there was an error in native code.
+ */
+ public void addCoreML(Map providerOptions) throws OrtException {
+ String CoreMLProviderName = "CoreML";
+ addExecutionProvider(CoreMLProviderName, providerOptions);
+ }
+
private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;
diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
index f7afbb2f98bd8..f3e8bd9b0e2af 100644
--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
+++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
@@ -23,9 +23,9 @@ namespace onnxruntime {
constexpr const char* COREML = "CoreML";
-CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
+CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
- coreml_flags_(coreml_flags),
+ coreml_flags_(options.coreml_flags),
coreml_version_(coreml::util::CoreMLVersion()) {
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
@@ -33,7 +33,7 @@ CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
}
// check if only one flag is set
- if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
+ if ((coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU)) {
// multiple device options selected
ORT_THROW(
"Multiple device options selected, you should use at most one of the following options:"
diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h
index 24a001280eef5..d37f6bdc2732d 100644
--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h
+++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h
@@ -12,9 +12,14 @@ namespace coreml {
class Model;
}
+struct CoreMLOptions {
+ uint32_t coreml_flags = 0;
+ std::string cache_path;
+};
+
class CoreMLExecutionProvider : public IExecutionProvider {
public:
- CoreMLExecutionProvider(uint32_t coreml_flags);
+ CoreMLExecutionProvider(const CoreMLOptions& options);
virtual ~CoreMLExecutionProvider();
std::vector>
diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc
index fcdf37c446ce7..bcb2927150713 100644
--- a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc
+++ b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc
@@ -9,21 +9,60 @@
using namespace onnxruntime;
namespace onnxruntime {
+
+namespace {
+CoreMLOptions ParseProviderOption(const ProviderOptions& options) {
+ CoreMLOptions coreml_options;
+ const std::unordered_map available_device_options = {
+ {"MLComputeUnitsCPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE},
+ {"MLComputeUnitsCPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU},
+ {"MLComputeUnitsCPUOnly", COREML_FLAG_USE_CPU_ONLY},
+ {"MLComputeUnitsAll", COREML_FLAG_USE_NONE},
+ };
+ const std::unordered_map available_format_options = {
+ {"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
+ {"NeuralNetwork", COREML_FLAG_USE_NONE},
+ };
+ if (options.count("ComputeUnits")) {
+ coreml_options.coreml_flags |= available_device_options.at(options.at("ComputeUnits"));
+ }
+ if (options.count("ModelFormat")) {
+ coreml_options.coreml_flags |= available_format_options.at(options.at("ModelFormat"));
+ }
+ if (options.count("AllowStaticInputShapes")) {
+ coreml_options.coreml_flags |= COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
+ }
+ if (options.count("EnableOnSubgraphs")) {
+ coreml_options.coreml_flags |= COREML_FLAG_ENABLE_ON_SUBGRAPH;
+ }
+ if (options.count("ModelCacheDir")) {
+ coreml_options.cache_path = options.at("ModelCacheDir");
+ }
+
+ return coreml_options;
+}
+} // namespace
struct CoreMLProviderFactory : IExecutionProviderFactory {
- CoreMLProviderFactory(uint32_t coreml_flags)
- : coreml_flags_(coreml_flags) {}
+ CoreMLProviderFactory(const CoreMLOptions& options)
+ : options_(options) {}
~CoreMLProviderFactory() override {}
std::unique_ptr CreateProvider() override;
- uint32_t coreml_flags_;
+ CoreMLOptions options_;
};
std::unique_ptr CoreMLProviderFactory::CreateProvider() {
- return std::make_unique(coreml_flags_);
+ return std::make_unique(options_);
}
std::shared_ptr CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
- return std::make_shared(coreml_flags);
+ CoreMLOptions coreml_options;
+ coreml_options.coreml_flags = coreml_flags;
+ return std::make_shared(coreml_options);
+}
+
+std::shared_ptr CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
+ return std::make_shared(ParseProviderOption(options));
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h
index ba701724c4da9..93ec2af50698d 100644
--- a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h
+++ b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h
@@ -5,10 +5,12 @@
#include
+#include "core/framework/provider_options.h"
#include "core/providers/providers.h"
namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr Create(uint32_t coreml_flags);
+ static std::shared_ptr Create(const ProviderOptions& options);
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc
index 8c512c561ea8c..4ae912b23f80b 100644
--- a/onnxruntime/core/session/provider_registration.cc
+++ b/onnxruntime/core/session/provider_registration.cc
@@ -155,11 +155,25 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
+#ifdef USE_VITISAI
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
+#else
+ status = create_not_supported_status();
+#endif
+ } else if (strcmp(provider_name, "CoreML") == 0) {
+#if defined(USE_COREML)
+ std::string coreml_flags;
+ if (options->value.config_options.TryGetConfigEntry("coreml_flags", coreml_flags)) {
+ provider_options["coreml_flags"] = coreml_flags;
+ }
+ options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
+#else
+ status = create_not_supported_status();
+#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
- "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'");
+ "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'");
}
return status;
diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc
index 1319e8f6fe959..dcd021494c24c 100644
--- a/onnxruntime/python/onnxruntime_pybind_schema.cc
+++ b/onnxruntime/python/onnxruntime_pybind_schema.cc
@@ -73,7 +73,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::RknpuProviderFactoryCreator::Create(),
#endif
#ifdef USE_COREML
- onnxruntime::CoreMLProviderFactoryCreator::Create(0),
+ onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}),
#endif
#ifdef USE_XNNPACK
onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr),
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index 93a1bf9f30651..ddc453f84feb6 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
+ sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index e40544d950ed7..46777eda793e0 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -24,6 +24,7 @@
#include
#include "test_configuration.h"
+#include "strings_helper.h"
namespace onnxruntime {
namespace perftest {
@@ -175,39 +176,6 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier,
return true;
}
-static bool ParseSessionConfigs(const std::string& configs_string,
- std::unordered_map& session_configs) {
- std::istringstream ss(configs_string);
- std::string token;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
-
- std::string_view token_sv(token);
-
- auto pos = token_sv.find("|");
- if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) {
- // Error: must use a '|' to separate the key and value for session configuration entries.
- return false;
- }
-
- std::string key(token_sv.substr(0, pos));
- std::string value(token_sv.substr(pos + 1));
-
- auto it = session_configs.find(key);
- if (it != session_configs.end()) {
- // Error: specified duplicate session configuration entry: {key}
- return false;
- }
-
- session_configs.insert(std::make_pair(std::move(key), std::move(value)));
- }
-
- return true;
-}
-
/*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) {
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 8f2e5282ede9a..85429273409de 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -17,6 +17,7 @@
#include
#include "providers.h"
#include "TestCase.h"
+#include "strings_helper.h"
#ifdef USE_OPENVINO
#include "nlohmann/json.hpp"
@@ -58,6 +59,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
Ort::SessionOptions session_options;
provider_name_ = performance_test_config.machine_config.provider_type_name;
+ std::unordered_map provider_options;
if (provider_name_ == onnxruntime::kDnnlExecutionProvider) {
#ifdef USE_DNNL
// Generate provider options
@@ -72,24 +74,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif // defined(_MSC_VER)
int num_threads = 0;
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW(
- "[ERROR] [OneDNN] Use a '|' to separate the key and value for the "
- "run-time option you are trying to use.\n");
- }
-
- auto key = token.substr(0, pos);
- auto value = token.substr(pos + 1);
-
- if (key == "num_of_threads") {
- std::stringstream sstream(value);
+ if (!ParseSessionConfigs(ov_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
+ if (provider_option.first == "num_of_threads") {
+ std::stringstream sstream(provider_option.second);
sstream >> num_threads;
if (num_threads < 0) {
ORT_THROW(
@@ -144,22 +136,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW(
- "[ERROR] [CUDA] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- buffer.emplace_back(token.substr(0, pos));
- option_keys.push_back(buffer.back().c_str());
- buffer.emplace_back(token.substr(pos + 1));
- option_values.push_back(buffer.back().c_str());
+ if (!ParseSessionConfigs(ov_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
+ option_keys.push_back(provider_option->first.c_str());
+ option_values.push_back(provider_option->first.c_str());
}
Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options,
@@ -192,24 +176,15 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW(
- "[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- buffer.emplace_back(token.substr(0, pos));
- option_keys.push_back(buffer.back().c_str());
- buffer.emplace_back(token.substr(pos + 1));
- option_values.push_back(buffer.back().c_str());
+ if (!ParseSessionConfigs(ov_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
+ option_keys.push_back(provider_option->first.c_str());
+ option_values.push_back(provider_option->first.c_str());
}
-
Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options,
option_keys.data(), option_values.data(), option_keys.size()));
if (!status.IsOK()) {
@@ -239,22 +214,12 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string option_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(option_string);
- std::string token;
- std::unordered_map qnn_options;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.");
- }
-
- std::string key(token.substr(0, pos));
- std::string value(token.substr(pos + 1));
-
+ if (!ParseSessionConfigs(option_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
if (key == "backend_path" || key == "profiling_file_path") {
if (value.empty()) {
ORT_THROW("Please provide the valid file path.");
@@ -317,10 +282,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model',
'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])");
}
-
- qnn_options[key] = value;
}
- session_options.AppendExecutionProvider("QNN", qnn_options);
+ session_options.AppendExecutionProvider("QNN", provider_options);
#else
ORT_THROW("QNN is not supported in this build\n");
#endif
@@ -331,22 +294,12 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string option_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(option_string);
- std::string token;
- std::unordered_map snpe_options;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- std::string key(token.substr(0, pos));
- std::string value(token.substr(pos + 1));
-
+ if (!ParseSessionConfigs(option_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
if (key == "runtime") {
std::set supported_runtime = {"CPU", "GPU_FP32", "GPU", "GPU_FLOAT16", "DSP", "AIP_FIXED_TF"};
if (supported_runtime.find(value) == supported_runtime.end()) {
@@ -368,11 +321,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else {
ORT_THROW("Wrong key type entered. Choose from options: ['runtime', 'priority', 'buffer_type', 'enable_init_cache'] \n");
}
-
- snpe_options[key] = value;
}
- session_options.AppendExecutionProvider("SNPE", snpe_options);
+ session_options.AppendExecutionProvider("SNPE", provider_options);
#else
ORT_THROW("SNPE is not supported in this build\n");
#endif
@@ -448,34 +399,20 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#endif
} else if (provider_name_ == onnxruntime::kDmlExecutionProvider) {
#ifdef USE_DML
- std::unordered_map dml_options;
- dml_options["performance_preference"] = "high_performance";
- dml_options["device_filter"] = "gpu";
- dml_options["disable_metacommands"] = "false";
- dml_options["enable_graph_capture"] = "false";
#ifdef _MSC_VER
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- auto key = token.substr(0, pos);
- auto value = token.substr(pos + 1);
-
+ if (!ParseSessionConfigs(ov_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
if (key == "device_filter") {
std::set ov_supported_device_types = {"gpu", "npu"};
if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. "
@@ -484,7 +421,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (key == "performance_preference") {
std::set ov_supported_values = {"default", "high_performance", "minimal_power"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. "
@@ -493,7 +429,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (key == "disable_metacommands") {
std::set ov_supported_values = {"true", "True", "false", "False"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong value for the key 'disable_metacommands'. "
@@ -502,7 +437,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (key == "enable_graph_capture") {
std::set ov_supported_values = {"true", "True", "false", "False"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong value for the key 'enable_graph_capture'. "
@@ -519,7 +453,19 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
}
}
- session_options.AppendExecutionProvider("DML", dml_options);
+ if (provider_options.find("performance_preference") == provider_options.end()) {
+ provider_options["performance_preference"] = "high_performance";
+ }
+ if (provider_options.find("device_filter") == provider_options.end()) {
+ provider_options["device_filter"] = "gpu";
+ }
+ if (provider_options.find("disable_metacommands") == provider_options.end()) {
+ provider_options["disable_metacommands"] = "false";
+ }
+ if (provider_options.find("enable_graph_capture") == provider_options.end()) {
+ provider_options["enable_graph_capture"] = "false";
+ }
+ session_options.AppendExecutionProvider("DML", provider_options);
#else
ORT_THROW("DML is not supported in this build\n");
#endif
@@ -530,21 +476,13 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif // defined(_MSC_VER)
- std::istringstream ss(ov_string);
- std::string token;
bool enable_fast_math = false;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- auto key = token.substr(0, pos);
- auto value = token.substr(pos + 1);
-
+ if (!ParseSessionConfigs(ov_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
+ }
+ for (const auto& provider_option : provider_options) {
if (key == "enable_fast_math") {
std::set ov_supported_values = {"true", "True", "false", "False"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
@@ -612,24 +550,13 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#else
std::string option_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(option_string);
- std::string token;
- std::unordered_map vitisai_session_options;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("[ERROR] [VitisAI] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- std::string key(token.substr(0, pos));
- std::string value(token.substr(pos + 1));
- vitisai_session_options[key] = value;
+ if (!ParseSessionConfigs(option_string, provider_options)) {
+ ORT_THROW(
+ "[ERROR] Use a '|' to separate the key and value for the "
+ "run-time option you are trying to use.\n");
}
- session_options.AppendExecutionProvider_VitisAI(vitisai_session_options);
+
+ session_options.AppendExecutionProvider_VitisAI(provider_options);
#else
ORT_THROW("VitisAI is not supported in this build\n");
#endif
diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc
new file mode 100644
index 0000000000000..22f682159b924
--- /dev/null
+++ b/onnxruntime/test/perftest/strings_helper.cc
@@ -0,0 +1,47 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) 2023 NVIDIA Corporation.
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+// Licensed under the MIT License.
+
+#include
+#include
+
+#include "strings_helper.h"
+
+namespace onnxruntime {
+namespace perftest {
+
+bool ParseSessionConfigs(const std::string& configs_string,
+ std::unordered_map& session_configs) {
+ std::istringstream ss(configs_string);
+ std::string token;
+
+ while (ss >> token) {
+ if (token == "") {
+ continue;
+ }
+
+ std::string_view token_sv(token);
+
+ auto pos = token_sv.find("|");
+ if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) {
+ // Error: must use a '|' to separate the key and value for session configuration entries.
+ return false;
+ }
+
+ std::string key(token_sv.substr(0, pos));
+ std::string value(token_sv.substr(pos + 1));
+
+ auto it = session_configs.find(key);
+ if (it != session_configs.end()) {
+ // Error: specified duplicate session configuration entry: {key}
+ return false;
+ }
+
+ session_configs.insert(std::make_pair(std::move(key), std::move(value)));
+ }
+
+ return true;
+}
+} // namespace perftest
+} // namespace onnxruntime
diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h
new file mode 100644
index 0000000000000..24feb90a20a61
--- /dev/null
+++ b/onnxruntime/test/perftest/strings_helper.h
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) 2023 NVIDIA Corporation.
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+// Licensed under the MIT License.
+#include
+#include
+
+namespace onnxruntime {
+namespace perftest {
+
+bool ParseSessionConfigs(const std::string& configs_string,
+ std::unordered_map& session_configs);
+} // namespace perftest
+} // namespace onnxruntime
diff --git a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
index 32b4b32e299d6..9bead11109f3b 100644
--- a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
+++ b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
@@ -36,7 +36,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
+ std::unordered_map provider_options = {{"coreml_flags", std::to_string(flags)}};
+
+ session_options.AppendExecutionProvider("CoreML", provider_options);
}
#else
(void)useCoreML;
diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
index 86001b6cb50a5..a7851d078ece4 100644
--- a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
+++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
@@ -36,7 +36,8 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
+ std::unordered_map provider_options = {{"coreml_flags", std::to_string(flags)}};
+ session_options.AppendExecutionProvider("CoreML", provider_options);
}
#else
(void)useCoreML;
diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc
index de647d9e3aa3e..ca89b22cbc088 100644
--- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc
+++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc
@@ -34,7 +34,8 @@ namespace test {
static constexpr uint32_t s_coreml_flags = COREML_FLAG_USE_CPU_ONLY;
static std::unique_ptr MakeCoreMLExecutionProvider(uint32_t flags = s_coreml_flags) {
- return std::make_unique(flags);
+ std::unordered_map provider_options = {{"coreml_flags", std::to_string(flags)}};
+ return std::make_unique(provider_options);
}
#if !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/test/providers/coreml/dynamic_input_test.cc b/onnxruntime/test/providers/coreml/dynamic_input_test.cc
index c91ef23650040..8eecdbcce33c6 100644
--- a/onnxruntime/test/providers/coreml/dynamic_input_test.cc
+++ b/onnxruntime/test/providers/coreml/dynamic_input_test.cc
@@ -20,8 +20,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MatMul) {
auto test = [&](const size_t M) {
SCOPED_TRACE(MakeString("M=", M));
-
- auto coreml_ep = std::make_unique(0);
+ std::unordered_map options;
+ auto coreml_ep = std::make_unique(options);
const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
@@ -54,8 +54,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MobileNetExcerpt) {
auto test = [&](const size_t batch_size) {
SCOPED_TRACE(MakeString("batch_size=", batch_size));
-
- auto coreml_ep = std::make_unique(0);
+ std::unordered_map options;
+ auto coreml_ep = std::make_unique(options);
const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
@@ -87,6 +87,7 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");
ModelTester tester(CurrentTestName(), model_path);
+ std::unordered_map options;
tester.AddInput("A", {0, 2}, {});
tester.AddOutput("Y", {0, 4}, {});
@@ -94,14 +95,14 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) {
tester
.Config(ModelTester::ExpectResult::kExpectFailure,
"the runtime shape ({0,2}) has zero elements. This is not supported by the CoreML EP.")
- .ConfigEp(std::make_unique(0))
+ .ConfigEp(std::make_unique(options))
.RunWithConfig();
}
TEST(CoreMLExecutionProviderDynamicInputShapeTest, OnlyAllowStaticInputShapes) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");
-
- auto coreml_ep = std::make_unique(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES);
+ std::unordered_map options = {{"coreml_flags", std::to_string(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES)}};
+ auto coreml_ep = std::make_unique(options);
TestModelLoad(model_path, std::move(coreml_ep),
// expect no supported nodes because we disable dynamic input shape support
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index 62bdedd833025..32a0560a6bc87 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -259,8 +259,9 @@ std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlpr
if (use_mlprogram) {
coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM;
}
-
- return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
+ auto option = ProviderOptions();
+ option["coreml_flags"] = std::to_string(coreml_flags);
+ return CoreMLProviderFactoryCreator::Create(option)->CreateProvider();
#else
ORT_UNUSED_PARAMETER(use_mlprogram);
return nullptr;