diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index d6c46833f1..d628b065ce 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -1269,7 +1269,7 @@ namespace Microsoft.ML.OnnxRuntime
///
/// Append an execution provider instance to the native OrtSessionOptions instance.
///
- /// 'SNPE' and 'XNNPACK' are currently supported as providerName values.
+ /// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values.
///
/// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys.
///
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 9841d972fa..bd450451a1 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -395,16 +395,10 @@ namespace Microsoft.ML.OnnxRuntime
///
/// Append QNN, SNPE or XNNPACK execution provider
///
- /// Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported.
+ /// Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported.
/// Optional key/value pairs to specify execution provider options.
public void AppendExecutionProvider(string providerName, Dictionary providerOptions = null)
{
- if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
- {
- throw new NotSupportedException(
- "Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
- }
-
if (providerOptions == null)
{
providerOptions = new Dictionary();
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
index b2a863a48e..17738da515 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
@@ -175,6 +175,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
ex = Assert.Throws(() => { opt.AppendExecutionProvider("QNN"); });
Assert.Contains("QNN execution provider is not supported in this build", ex.Message);
#endif
+#if USE_COREML
+ opt.AppendExecutionProvider("CoreML");
+#else
+ ex = Assert.Throws(() => { opt.AppendExecutionProvider("CoreML"); });
+ Assert.Contains("CoreML execution provider is not supported in this build", ex.Message);
+#endif
opt.AppendExecutionProvider_CPU(1);
}
@@ -2037,7 +2043,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
// Test hangs on mobile.
-#if !(ANDROID || IOS)
+#if !(ANDROID || IOS)
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async Task TestModelRunAsyncTask()
{
diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
index 98fa9e09f1..3963b80de5 100644
--- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
+++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
@@ -41,6 +41,15 @@ enum COREMLFlags {
COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
};
+// MLComputeUnits can be one of the following values:
+// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll'
+// these values are intended to be used with Ort::SessionOptions::AppendExecutionProvider (C++ API)
+// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead.
+static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
+static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat";
+static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes";
+static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs";
+
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java
index 7280f3c88e..32dc9d9f84 100644
--- a/java/src/main/java/ai/onnxruntime/OrtSession.java
+++ b/java/src/main/java/ai/onnxruntime/OrtSession.java
@@ -1323,6 +1323,18 @@ public class OrtSession implements AutoCloseable {
addExecutionProvider(qnnProviderName, providerOptions);
}
+ /**
+ * Adds CoreML as an execution backend.
+ *
+ * @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML
+ * execution provider's documentation.
+ * @throws OrtException If there was an error in native code.
+ */
+ public void addCoreML(Map providerOptions) throws OrtException {
+ String CoreMLProviderName = "CoreML";
+ addExecutionProvider(CoreMLProviderName, providerOptions);
+ }
+
private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
throws OrtException;
diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h
index d7d873f5eb..41d15aa394 100644
--- a/objectivec/include/ort_coreml_execution_provider.h
+++ b/objectivec/include/ort_coreml_execution_provider.h
@@ -70,7 +70,22 @@ NS_ASSUME_NONNULL_BEGIN
*/
- (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options
error:(NSError**)error;
-
+/**
+ * Enables the CoreML execution provider in the session configuration options.
+ * It is appended to the execution provider list which is ordered by
+ * decreasing priority.
+ *
+ * @param provider_options The CoreML execution provider options in dict.
+ * available keys-values: more detail in core/providers/coreml/coreml_execution_provider.h
+ * kCoremlProviderOption_MLComputeUnits: one of "CPUAndNeuralEngine", "CPUAndGPU", "CPUOnly", "All"
+ * kCoremlProviderOption_ModelFormat: one of "MLProgram", "NeuralNetwork"
+ * kCoremlProviderOption_RequireStaticInputShapes: "1" or "0"
+ * kCoremlProviderOption_EnableOnSubgraphs: "1" or "0"
+ * @param error Optional error information set if an error occurs.
+ * @return Whether the provider was enabled successfully.
+ */
+- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options
+ error:(NSError**)error;
@end
NS_ASSUME_NONNULL_END
diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm
index 6cb5026b93..0c790a91fb 100644
--- a/objectivec/ort_coreml_execution_provider.mm
+++ b/objectivec/ort_coreml_execution_provider.mm
@@ -43,6 +43,21 @@ BOOL ORTIsCoreMLExecutionProviderAvailable() {
#endif
}
+- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options
+ error:(NSError**)error {
+#if ORT_OBJC_API_COREML_EP_AVAILABLE
+ try {
+ return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error];
+ }
+ ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error);
+
+#else // !ORT_OBJC_API_COREML_EP_AVAILABLE
+ static_cast(provider_options);
+ ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error);
+ return NO;
+#endif
+}
+
@end
NS_ASSUME_NONNULL_END
diff --git a/objectivec/test/ort_session_test.mm b/objectivec/test/ort_session_test.mm
index 508289f7bc..409ee7e158 100644
--- a/objectivec/test/ort_session_test.mm
+++ b/objectivec/test/ort_session_test.mm
@@ -223,6 +223,28 @@ NS_ASSUME_NONNULL_BEGIN
ORTAssertNullableResultSuccessful(session, err);
}
+- (void)testAppendCoreMLEP_v2 {
+ NSError* err = nil;
+ ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions];
+ NSDictionary* provider_options = @{@"EnableOnSubgraphs" : @"1"}; // set an arbitrary option
+
+ BOOL appendResult = [sessionOptions appendCoreMLExecutionProviderWithOptionsV2:provider_options
+ error:&err];
+
+ if (!ORTIsCoreMLExecutionProviderAvailable()) {
+ ORTAssertBoolResultUnsuccessful(appendResult, err);
+ return;
+ }
+
+ ORTAssertBoolResultSuccessful(appendResult, err);
+
+ ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
+ modelPath:[ORTSessionTest getAddModelPath]
+ sessionOptions:sessionOptions
+ error:&err];
+ ORTAssertNullableResultSuccessful(session, err);
+}
+
- (void)testAppendXnnpackEP {
NSError* err = nil;
ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions];
diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc
index e1f148fa93..38ac629331 100644
--- a/onnxruntime/core/providers/coreml/builders/helper.cc
+++ b/onnxruntime/core/providers/coreml/builders/helper.cc
@@ -24,11 +24,12 @@ namespace coreml {
OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer,
int32_t coreml_version,
- uint32_t coreml_flags) {
+ bool only_allow_static_input_shapes,
+ bool create_mlprogram) {
return OpBuilderInputParams{graph_viewer,
coreml_version,
- (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0,
- (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0};
+ only_allow_static_input_shapes,
+ create_mlprogram};
}
const IOpBuilder* GetOpBuilder(const Node& node) {
@@ -133,13 +134,13 @@ bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& grap
return true;
}
-bool HasNeuralEngine(const logging::Logger& logger) {
+bool HasNeuralEngine() {
bool has_neural_engine = false;
#ifdef __APPLE__
struct utsname system_info;
uname(&system_info);
- LOGS(logger, VERBOSE) << "Current Apple hardware info: " << system_info.machine;
+ LOGS_DEFAULT(VERBOSE) << "Current Apple hardware info: " << system_info.machine;
#if TARGET_OS_IPHONE
// utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2".
@@ -163,7 +164,7 @@ bool HasNeuralEngine(const logging::Logger& logger) {
#else
// In this case, we are running the EP on non-apple platform, which means we are running the model
// conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine
- LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. "
+ LOGS_DEFAULT(INFO) << "HasNeuralEngine running on non-Apple hardware. "
"Returning true to enable model conversion and local testing of CoreML EP implementation. "
"No CoreML model will be compiled or run.";
has_neural_engine = true;
diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h
index 0acaa0dd8a..ae7f3bdbc3 100644
--- a/onnxruntime/core/providers/coreml/builders/helper.h
+++ b/onnxruntime/core/providers/coreml/builders/helper.h
@@ -25,7 +25,8 @@ namespace coreml {
OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer,
int32_t coreml_version,
- uint32_t coreml_flags);
+ bool only_allow_static_input_shapes,
+ bool create_mlprogram);
const IOpBuilder* GetOpBuilder(const Node& node);
@@ -45,7 +46,7 @@ bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& grap
// CoreML is more efficient running using Apple Neural Engine
// This is to detect if the current system has Apple Neural Engine
-bool HasNeuralEngine(const logging::Logger& logger);
+bool HasNeuralEngine();
} // namespace coreml
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc
index f12e4dab5b..2a02c1f412 100644
--- a/onnxruntime/core/providers/coreml/builders/model_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc
@@ -8,6 +8,7 @@
#include "core/platform/env.h"
#include "core/providers/common.h"
#include "core/providers/coreml/builders/model_builder.h"
+#include "core/providers/coreml/coreml_execution_provider.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
@@ -401,14 +402,14 @@ std::string GetModelOutputPath(bool create_ml_program) {
} // namespace
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
- int32_t coreml_version, uint32_t coreml_flags,
+ int32_t coreml_version, const CoreMLOptions& coreml_options,
std::vector&& onnx_input_names,
std::vector&& onnx_output_names)
: graph_viewer_(graph_viewer),
logger_(logger),
coreml_version_(coreml_version),
- coreml_flags_(coreml_flags),
- create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0),
+ coreml_compute_unit_(coreml_options.ComputeUnits()),
+ create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
@@ -988,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) {
get_sanitized_io_info(std::move(input_output_info_)),
std::move(scalar_outputs_),
std::move(int64_outputs_),
- logger_, coreml_flags_);
+ logger_, coreml_compute_unit_);
} else
#endif
{
@@ -998,7 +999,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) {
std::move(input_output_info_),
std::move(scalar_outputs_),
std::move(int64_outputs_),
- logger_, coreml_flags_);
+ logger_, coreml_compute_unit_);
}
return model->LoadModel(); // load using CoreML API, including compilation
@@ -1048,11 +1049,11 @@ std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string
#endif
// static
Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger,
- int32_t coreml_version, uint32_t coreml_flags,
+ int32_t coreml_version, const CoreMLOptions& coreml_options,
std::vector&& onnx_input_names,
std::vector&& onnx_output_names,
std::unique_ptr& model) {
- ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags,
+ ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options,
std::move(onnx_input_names), std::move(onnx_output_names));
ORT_RETURN_IF_ERROR(builder.CreateModel());
diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h
index c566dbe160..af47869f7e 100644
--- a/onnxruntime/core/providers/coreml/builders/model_builder.h
+++ b/onnxruntime/core/providers/coreml/builders/model_builder.h
@@ -22,6 +22,8 @@ class StorageWriter;
#endif
namespace onnxruntime {
+class CoreMLOptions;
+
namespace coreml {
class IOpBuilder;
@@ -29,14 +31,14 @@ class IOpBuilder;
class ModelBuilder {
private:
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
- int32_t coreml_version, uint32_t coreml_flags,
+ int32_t coreml_version, const CoreMLOptions& coreml_options,
std::vector&& onnx_input_names,
std::vector&& onnx_output_names);
public:
// Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model`
static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger,
- int32_t coreml_version, uint32_t coreml_flags,
+ int32_t coreml_version, const CoreMLOptions& coreml_options,
std::vector&& onnx_input_names,
std::vector&& onnx_output_names,
std::unique_ptr& model);
@@ -216,7 +218,7 @@ class ModelBuilder {
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;
const int32_t coreml_version_;
- const uint32_t coreml_flags_;
+ const uint32_t coreml_compute_unit_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
index f7afbb2f98..5a2867e552 100644
--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
+++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
@@ -23,35 +23,14 @@ namespace onnxruntime {
constexpr const char* COREML = "CoreML";
-CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
+CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
- coreml_flags_(coreml_flags),
+ coreml_options_(options),
coreml_version_(coreml::util::CoreMLVersion()) {
LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_;
if (coreml_version_ < MINIMUM_COREML_VERSION) {
- LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
+ ORT_THROW("CoreML EP is not supported on this platform.");
}
-
- // check if only one flag is set
- if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) {
- // multiple device options selected
- ORT_THROW(
- "Multiple device options selected, you should use at most one of the following options:"
- "COREML_FLAG_USE_CPU_ONLY or COREML_FLAG_USE_CPU_AND_GPU or not set");
- }
-
-#if defined(COREML_ENABLE_MLPROGRAM)
- if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION &&
- (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) {
- LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork.";
- coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM;
- }
-#else
- if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) {
- LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork.";
- coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM;
- }
-#endif
}
CoreMLExecutionProvider::~CoreMLExecutionProvider() {}
@@ -61,26 +40,17 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
const IKernelLookup& /*kernel_lookup*/) const {
std::vector> result;
- if (coreml_version_ < MINIMUM_COREML_VERSION) {
- return result;
- }
-
const auto& logger = *GetLogger();
// We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes
// TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the
// implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate.
- if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) {
+ if (graph_viewer.IsSubgraph() && !coreml_options_.EnableOnSubgraph()) {
return result;
}
- const bool has_neural_engine = coreml::HasNeuralEngine(logger);
- if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) {
- LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used.";
- return result;
- }
-
- const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_);
+ const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_,
+ coreml_options_.RequireStaticShape(), coreml_options_.CreateMLProgram());
const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger);
const auto gen_metadef_name =
@@ -143,7 +113,7 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector onnx_output_names = get_names(fused_node.OutputDefs());
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
- ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_,
+ ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_options_,
std::move(onnx_input_names), std::move(onnx_output_names),
coreml_model));
}
diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h
index 24a001280e..650d81a4fe 100644
--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h
+++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h
@@ -3,7 +3,7 @@
#pragma once
-#include "core/common/inlined_containers.h"
+#include "core/providers/coreml/coreml_options.h"
#include "core/framework/execution_provider.h"
#include "core/framework/model_metadef_id_generator.h"
@@ -14,7 +14,7 @@ class Model;
class CoreMLExecutionProvider : public IExecutionProvider {
public:
- CoreMLExecutionProvider(uint32_t coreml_flags);
+ CoreMLExecutionProvider(const CoreMLOptions& options);
virtual ~CoreMLExecutionProvider();
std::vector>
@@ -29,7 +29,7 @@ class CoreMLExecutionProvider : public IExecutionProvider {
private:
// The bit flags which define bool options for COREML EP, bits are defined as
// COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
- uint32_t coreml_flags_;
+ CoreMLOptions coreml_options_;
const int32_t coreml_version_;
ModelMetadefIdGenerator metadef_id_generator_;
diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc
new file mode 100644
index 0000000000..df78f74383
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/coreml_options.cc
@@ -0,0 +1,96 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/coreml/coreml_execution_provider.h"
+#include "core/providers/coreml/coreml_provider_factory.h" // defines flags
+#include "core/providers/coreml/model/host_utils.h"
+#include "core/providers/coreml/builders/helper.h"
+
+namespace onnxruntime {
+
+CoreMLOptions::CoreMLOptions(uint32_t coreml_flags) {
+ // validate the flags and populate the members. should be moving code from ctor to here
+ require_static_shape_ = (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0;
+ create_mlprogram_ = (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0;
+ enable_on_subgraph_ = (coreml_flags & COREML_FLAG_ENABLE_ON_SUBGRAPH) != 0;
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+ if (coreml::util::CoreMLVersion() < MINIMUM_COREML_MLPROGRAM_VERSION && create_mlprogram_ != 0) {
+ LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork.";
+ create_mlprogram_ = false;
+ }
+#else
+ if (create_mlprogram_ != 0) {
+ LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork.";
+ create_mlprogram_ = false;
+ }
+#endif
+
+ compute_units_ = 0; // 0 for all
+
+ if (coreml_flags & COREML_FLAG_USE_CPU_ONLY) {
+ compute_units_ |= COREML_FLAG_USE_CPU_ONLY;
+ }
+ if (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU) {
+ compute_units_ |= COREML_FLAG_USE_CPU_AND_GPU;
+ }
+ if (coreml_flags & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
+ compute_units_ |= COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE;
+ }
+
+ // assure only one device option is selected
+ if (compute_units_ & (compute_units_ - 1)) {
+ // multiple device options selected
+ ORT_THROW(
+ "Multiple device options selected, you should use at most one of the following options:"
+ "[COREML_FLAG_USE_CPU_ONLY, COREML_FLAG_USE_CPU_AND_GPU, COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE]");
+ }
+
+ const bool has_neural_engine = coreml::HasNeuralEngine();
+ if (ComputeUnits(COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) {
+ ORT_THROW("The current system does not have Apple Neural Engine.");
+ }
+}
+
+void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& options) {
+ const std::unordered_map available_computeunits_options = {
+ {"CPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE},
+ {"CPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU},
+ {"CPUOnly", COREML_FLAG_USE_CPU_ONLY},
+ {"ALL", COREML_FLAG_USE_NONE},
+ };
+ const std::unordered_map available_modelformat_options = {
+ {"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
+ {"NeuralNetwork", COREML_FLAG_USE_NONE},
+ };
+ std::unordered_set valid_options = {
+ kCoremlProviderOption_MLComputeUnits,
+ kCoremlProviderOption_ModelFormat,
+ kCoremlProviderOption_RequireStaticInputShapes,
+ kCoremlProviderOption_EnableOnSubgraphs,
+ };
+ // Validate the options
+ for (const auto& option : options) {
+ if (valid_options.find(option.first) == valid_options.end()) {
+ ORT_THROW("Unknown option: ", option.first);
+ }
+ if (kCoremlProviderOption_MLComputeUnits == option.first) {
+ if (available_computeunits_options.find(option.second) == available_computeunits_options.end()) {
+ ORT_THROW("Invalid value for option `", option.first, "`: ", option.second);
+ } else {
+ compute_units_ = available_computeunits_options.at(option.second);
+ }
+ } else if (kCoremlProviderOption_ModelFormat == option.first) {
+ if (available_modelformat_options.find(option.second) == available_modelformat_options.end()) {
+ ORT_THROW("Invalid value for option ", option.first, ": ", option.second);
+ } else {
+ create_mlprogram_ = available_modelformat_options.at(option.second) & COREML_FLAG_CREATE_MLPROGRAM;
+ }
+ } else if (kCoremlProviderOption_RequireStaticInputShapes == option.first) {
+ require_static_shape_ = option.second == "1";
+ } else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) {
+ enable_on_subgraph_ = option.second == "1";
+ }
+ }
+}
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h
new file mode 100644
index 0000000000..8bb748fcd6
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/coreml_options.h
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/inlined_containers.h"
+#include "core/framework/execution_provider.h"
+
+namespace onnxruntime {
+
+class CoreMLOptions {
+ private:
+ bool require_static_shape_{false};
+ bool create_mlprogram_{false};
+ bool enable_on_subgraph_{false};
+ uint32_t compute_units_{0};
+
+ public:
+ explicit CoreMLOptions(uint32_t coreml_flags);
+
+ CoreMLOptions(const ProviderOptions& options) {
+ ValidateAndParseProviderOption(options);
+ }
+ bool RequireStaticShape() const { return require_static_shape_; }
+ bool CreateMLProgram() const { return create_mlprogram_; }
+ bool EnableOnSubgraph() const { return enable_on_subgraph_; }
+ uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; }
+
+ private:
+ void ValidateAndParseProviderOption(const ProviderOptions& options);
+};
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc
index fcdf37c446..bc8702d329 100644
--- a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc
+++ b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc
@@ -9,21 +9,28 @@
using namespace onnxruntime;
namespace onnxruntime {
+
struct CoreMLProviderFactory : IExecutionProviderFactory {
- CoreMLProviderFactory(uint32_t coreml_flags)
- : coreml_flags_(coreml_flags) {}
+ CoreMLProviderFactory(const CoreMLOptions& options)
+ : options_(options) {}
~CoreMLProviderFactory() override {}
std::unique_ptr CreateProvider() override;
- uint32_t coreml_flags_;
+ CoreMLOptions options_;
};
std::unique_ptr CoreMLProviderFactory::CreateProvider() {
- return std::make_unique(coreml_flags_);
+ return std::make_unique(options_);
}
std::shared_ptr CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) {
- return std::make_shared(coreml_flags);
+ CoreMLOptions coreml_options(coreml_flags);
+ return std::make_shared(coreml_options);
+}
+
+std::shared_ptr CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) {
+ CoreMLOptions coreml_options(options);
+ return std::make_shared(coreml_options);
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h
index ba701724c4..93ec2af506 100644
--- a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h
+++ b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h
@@ -5,10 +5,12 @@
#include
+#include "core/framework/provider_options.h"
#include "core/providers/providers.h"
namespace onnxruntime {
struct CoreMLProviderFactoryCreator {
static std::shared_ptr Create(uint32_t coreml_flags);
+ static std::shared_ptr Create(const ProviderOptions& options);
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h
index 7fdd6b25bc..68ecbe5fb8 100644
--- a/onnxruntime/core/providers/coreml/model/model.h
+++ b/onnxruntime/core/providers/coreml/model/model.h
@@ -53,7 +53,7 @@ class Model {
std::unordered_map&& input_output_info,
std::unordered_set&& scalar_outputs,
std::unordered_set&& int64_outputs,
- const logging::Logger& logger, uint32_t coreml_flags);
+ const logging::Logger& logger, uint32_t coreml_compute_unit);
~Model();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model);
diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm
index ff32c52f94..c8edb64ff5 100644
--- a/onnxruntime/core/providers/coreml/model/model.mm
+++ b/onnxruntime/core/providers/coreml/model/model.mm
@@ -320,13 +320,13 @@ class Execution {
NSString* coreml_model_path_{nil};
NSString* compiled_model_path_{nil};
const logging::Logger& logger_;
- uint32_t coreml_flags_{0};
+ uint32_t coreml_compute_unit_{0};
MLModel* model_{nil};
};
-Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags)
+Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_compute_unit)
: logger_(logger),
- coreml_flags_(coreml_flags) {
+ coreml_compute_unit_(coreml_compute_unit) {
@autoreleasepool {
coreml_model_path_ = util::Utf8StringToNSString(path.c_str());
}
@@ -396,10 +396,12 @@ Status Execution::LoadModel() {
MLModelConfiguration* config = [[MLModelConfiguration alloc] init];
- if (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) {
+ if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_ONLY) {
config.computeUnits = MLComputeUnitsCPUOnly;
- } else if (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU) {
+ } else if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_AND_GPU) {
config.computeUnits = MLComputeUnitsCPUAndGPU;
+ } else if (coreml_compute_unit_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
+ config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine
} else {
config.computeUnits = MLComputeUnitsAll;
}
diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc
index 8bea347c85..7fb518cdc0 100644
--- a/onnxruntime/core/session/provider_registration.cc
+++ b/onnxruntime/core/session/provider_registration.cc
@@ -155,11 +155,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
status = create_not_supported_status();
#endif
} else if (strcmp(provider_name, "VitisAI") == 0) {
+#ifdef USE_VITISAI
status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys);
+#else
+ status = create_not_supported_status();
+#endif
+ } else if (strcmp(provider_name, "CoreML") == 0) {
+#if defined(USE_COREML)
+ options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options));
+#else
+ status = create_not_supported_status();
+#endif
} else {
ORT_UNUSED_PARAMETER(options);
status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
- "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'");
+ "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'");
}
return status;
diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc
index 1319e8f6fe..dcd021494c 100644
--- a/onnxruntime/python/onnxruntime_pybind_schema.cc
+++ b/onnxruntime/python/onnxruntime_pybind_schema.cc
@@ -73,7 +73,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) {
onnxruntime::RknpuProviderFactoryCreator::Create(),
#endif
#ifdef USE_COREML
- onnxruntime::CoreMLProviderFactoryCreator::Create(0),
+ onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}),
#endif
#ifdef USE_XNNPACK
onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr),
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 54accf7ed8..c20d0e64bf 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -1212,6 +1212,9 @@ std::unique_ptr CreateExecutionProviderInstance(
if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
+ } else {
+ // read from provider_options
+ return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
}
}
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index 93a1bf9f30..ddc453f84f 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -631,7 +631,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_coreml) {
#ifdef USE_COREML
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0));
+ sf.AppendExecutionProvider("CoreML", {});
#else
fprintf(stderr, "CoreML is not supported in this build");
return -1;
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index 040355d5e0..e406405464 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -24,6 +24,7 @@
#include
#include "test_configuration.h"
+#include "strings_helper.h"
namespace onnxruntime {
namespace perftest {
@@ -129,8 +130,11 @@ namespace perftest {
"\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n"
"\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n"
"\n"
- "\t [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM COREML_FLAG_USE_CPU_ONLY COREML_FLAG_USE_CPU_AND_GPU]: Create an ML Program model instead of Neural Network.\n"
- "\t [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n"
+ "\t [CoreML only] [ModelFormat]:[MLProgram, NeuralNetwork] Create an ML Program model or Neural Network. Default is NeuralNetwork.\n"
+ "\t [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n"
+ "\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n"
+ "\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n"
+ "\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n"
"\n"
"\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n"
"\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n"
@@ -175,39 +179,6 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier,
return true;
}
-static bool ParseSessionConfigs(const std::string& configs_string,
- std::unordered_map& session_configs) {
- std::istringstream ss(configs_string);
- std::string token;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
-
- std::string_view token_sv(token);
-
- auto pos = token_sv.find("|");
- if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) {
- // Error: must use a '|' to separate the key and value for session configuration entries.
- return false;
- }
-
- std::string key(token_sv.substr(0, pos));
- std::string value(token_sv.substr(pos + 1));
-
- auto it = session_configs.find(key);
- if (it != session_configs.end()) {
- // Error: specified duplicate session configuration entry: {key}
- return false;
- }
-
- session_configs.insert(std::make_pair(std::move(key), std::move(value)));
- }
-
- return true;
-}
-
/*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) {
@@ -382,7 +353,13 @@ static bool ParseSessionConfigs(const std::string& configs_string,
test_config.run_config.intra_op_thread_affinities = ToUTF8String(optarg);
break;
case 'C': {
- if (!ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries)) {
+ ORT_TRY {
+ ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries);
+ }
+ ORT_CATCH(const std::exception& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ fprintf(stderr, "Error parsing session configuration entries: %s\n", ex.what());
+ });
return false;
}
break;
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 8f2e5282ed..02768b8c08 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -17,6 +17,7 @@
#include
#include "providers.h"
#include "TestCase.h"
+#include "strings_helper.h"
#ifdef USE_OPENVINO
#include "nlohmann/json.hpp"
@@ -58,6 +59,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
Ort::SessionOptions session_options;
provider_name_ = performance_test_config.machine_config.provider_type_name;
+ std::unordered_map provider_options;
if (provider_name_ == onnxruntime::kDnnlExecutionProvider) {
#ifdef USE_DNNL
// Generate provider options
@@ -72,24 +74,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif // defined(_MSC_VER)
int num_threads = 0;
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW(
- "[ERROR] [OneDNN] Use a '|' to separate the key and value for the "
- "run-time option you are trying to use.\n");
- }
-
- auto key = token.substr(0, pos);
- auto value = token.substr(pos + 1);
-
- if (key == "num_of_threads") {
- std::stringstream sstream(value);
+ ParseSessionConfigs(ov_string, provider_options, {"num_of_threads"});
+ for (const auto& provider_option : provider_options) {
+ if (provider_option.first == "num_of_threads") {
+ std::stringstream sstream(provider_option.second);
sstream >> num_threads;
if (num_threads < 0) {
ORT_THROW(
@@ -97,10 +85,6 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
" set number of threads or use '0' for default\n");
// If the user doesnt define num_threads, auto detect threads later
}
- } else {
- ORT_THROW(
- "[ERROR] [OneDNN] wrong key type entered. "
- "Choose from the following runtime key options that are available for OneDNN. ['num_of_threads']\n");
}
}
dnnl_options.threadpool_args = static_cast(&num_threads);
@@ -144,22 +128,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW(
- "[ERROR] [CUDA] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- buffer.emplace_back(token.substr(0, pos));
- option_keys.push_back(buffer.back().c_str());
- buffer.emplace_back(token.substr(pos + 1));
- option_values.push_back(buffer.back().c_str());
+ ParseSessionConfigs(ov_string, provider_options);
+ for (const auto& provider_option : provider_options) {
+ option_keys.push_back(provider_option.first.c_str());
+ option_values.push_back(provider_option.second.c_str());
}
Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options,
@@ -192,24 +164,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW(
- "[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- buffer.emplace_back(token.substr(0, pos));
- option_keys.push_back(buffer.back().c_str());
- buffer.emplace_back(token.substr(pos + 1));
- option_values.push_back(buffer.back().c_str());
+ ParseSessionConfigs(ov_string, provider_options);
+ for (const auto& provider_option : provider_options) {
+ option_keys.push_back(provider_option.first.c_str());
+ option_values.push_back(provider_option.second.c_str());
}
-
Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options,
option_keys.data(), option_values.data(), option_keys.size()));
if (!status.IsOK()) {
@@ -239,22 +198,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string option_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(option_string);
- std::string token;
- std::unordered_map qnn_options;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.");
- }
-
- std::string key(token.substr(0, pos));
- std::string value(token.substr(pos + 1));
-
+ ParseSessionConfigs(option_string, provider_options,
+ {"backend_path", "profiling_file_path", "profiling_level", "rpc_control_latency",
+ "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "qnn_saver_path",
+ "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch",
+ "enable_htp_fp16_precision", "offload_graph_io_quantization"});
+ for (const auto& provider_option : provider_options) {
+ const std::string& key = provider_option.first;
+ const std::string& value = provider_option.second;
if (key == "backend_path" || key == "profiling_file_path") {
if (value.empty()) {
ORT_THROW("Please provide the valid file path.");
@@ -311,16 +262,9 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string str = str_stream.str();
ORT_THROW("Wrong value for ", key, ". select from: ", str);
}
- } else {
- ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path',
-'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
-'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model',
-'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])");
}
-
- qnn_options[key] = value;
}
- session_options.AppendExecutionProvider("QNN", qnn_options);
+ session_options.AppendExecutionProvider("QNN", provider_options);
#else
ORT_THROW("QNN is not supported in this build\n");
#endif
@@ -331,22 +275,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#else
std::string option_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(option_string);
- std::string token;
- std::unordered_map snpe_options;
-
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- std::string key(token.substr(0, pos));
- std::string value(token.substr(pos + 1));
-
+ ParseSessionConfigs(option_string, provider_options, {"runtime", "priority", "buffer_type", "enable_init_cache"});
+ for (const auto& provider_option : provider_options) {
if (key == "runtime") {
std::set supported_runtime = {"CPU", "GPU_FP32", "GPU", "GPU_FLOAT16", "DSP", "AIP_FIXED_TF"};
if (supported_runtime.find(value) == supported_runtime.end()) {
@@ -365,14 +295,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
if (value != "1") {
ORT_THROW("Set to 1 to enable_init_cache.");
}
- } else {
- ORT_THROW("Wrong key type entered. Choose from options: ['runtime', 'priority', 'buffer_type', 'enable_init_cache'] \n");
}
-
- snpe_options[key] = value;
}
- session_options.AppendExecutionProvider("SNPE", snpe_options);
+ session_options.AppendExecutionProvider("SNPE", provider_options);
#else
ORT_THROW("SNPE is not supported in this build\n");
#endif
@@ -416,30 +342,34 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) {
#ifdef __APPLE__
#ifdef USE_COREML
- uint32_t coreml_flags = 0;
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
- std::istringstream ss(ov_string);
+ static const std::unordered_set available_keys = {kCoremlProviderOption_MLComputeUnits,
+ kCoremlProviderOption_ModelFormat,
+ kCoremlProviderOption_RequireStaticInputShapes,
+ kCoremlProviderOption_EnableOnSubgraphs};
+ ParseSessionConfigs(ov_string, provider_options, available_keys);
- std::string key;
- while (ss >> key) {
- if (key == "COREML_FLAG_CREATE_MLPROGRAM") {
- coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM;
- std::cout << "Enabling ML Program.\n";
- } else if (key == "COREML_FLAG_USE_CPU_ONLY") {
- coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
- std::cout << "CoreML enabled COREML_FLAG_USE_CPU_ONLY.\n";
- } else if (key == "COREML_FLAG_USE_CPU_AND_GPU") {
- coreml_flags |= COREML_FLAG_USE_CPU_AND_GPU;
- std::cout << "CoreML enabled COREML_FLAG_USE_CPU_AND_GPU.\n";
- } else if (key.empty()) {
+ std::unordered_map available_options = {
+ {"CPUAndNeuralEngine", "1"},
+ {"CPUAndGPU", "1"},
+ {"CPUOnly", "1"},
+ {"ALL", "1"},
+ };
+ for (const auto& provider_option : provider_options) {
+ if (provider_option.first == kCoremlProviderOption_MLComputeUnits &&
+ available_options.find(provider_option.second) != available_options.end()) {
+ } else if (provider_option.first == kCoremlProviderOption_ModelFormat &&
+ (provider_option.second == "MLProgram" || provider_option.second == "NeuralNetwork")) {
+ } else if (provider_option.first == kCoremlProviderOption_RequireStaticInputShapes &&
+ (provider_option.second == "1" || provider_option.second == "0")) {
+ } else if (provider_option.first == kCoremlProviderOption_EnableOnSubgraphs &&
+ (provider_option.second == "0" || provider_option.second == "1")) {
} else {
- ORT_THROW(
- "[ERROR] [CoreML] wrong key type entered. Choose from the following runtime key options "
- "that are available for CoreML. ['COREML_FLAG_CREATE_MLPROGRAM'] \n");
+ ORT_THROW("Invalid value for option ", provider_option.first, ": ", provider_option.second);
}
}
// COREML_FLAG_CREATE_MLPROGRAM
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags));
+ session_options.AppendExecutionProvider("CoreML", provider_options);
#else
ORT_THROW("CoreML is not supported in this build\n");
#endif
@@ -448,34 +378,20 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#endif
} else if (provider_name_ == onnxruntime::kDmlExecutionProvider) {
#ifdef USE_DML
- std::unordered_map dml_options;
- dml_options["performance_preference"] = "high_performance";
- dml_options["device_filter"] = "gpu";
- dml_options["disable_metacommands"] = "false";
- dml_options["enable_graph_capture"] = "false";
#ifdef _MSC_VER
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(ov_string);
- std::string token;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- auto key = token.substr(0, pos);
- auto value = token.substr(pos + 1);
-
+ ParseSessionConfigs(ov_string, provider_options,
+ {"device_filter", "performance_preference", "disable_metacommands",
+ "enable_graph_capture", "enable_graph_serialization"});
+ for (const auto& provider_option : provider_options) {
+ const std::string& key = provider_option.first;
+ const std::string& value = provider_option.second;
if (key == "device_filter") {
std::set ov_supported_device_types = {"gpu", "npu"};
if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. "
@@ -484,7 +400,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (key == "performance_preference") {
std::set ov_supported_values = {"default", "high_performance", "minimal_power"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. "
@@ -493,7 +408,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (key == "disable_metacommands") {
std::set ov_supported_values = {"true", "True", "false", "False"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong value for the key 'disable_metacommands'. "
@@ -502,7 +416,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
} else if (key == "enable_graph_capture") {
std::set ov_supported_values = {"true", "True", "false", "False"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
- dml_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [DML] You have selected a wrong value for the key 'enable_graph_capture'. "
@@ -519,7 +432,19 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
}
}
- session_options.AppendExecutionProvider("DML", dml_options);
+ if (provider_options.find("performance_preference") == provider_options.end()) {
+ provider_options["performance_preference"] = "high_performance";
+ }
+ if (provider_options.find("device_filter") == provider_options.end()) {
+ provider_options["device_filter"] = "gpu";
+ }
+ if (provider_options.find("disable_metacommands") == provider_options.end()) {
+ provider_options["disable_metacommands"] = "false";
+ }
+ if (provider_options.find("enable_graph_capture") == provider_options.end()) {
+ provider_options["enable_graph_capture"] = "false";
+ }
+ session_options.AppendExecutionProvider("DML", provider_options);
#else
ORT_THROW("DML is not supported in this build\n");
#endif
@@ -530,21 +455,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#else
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
#endif // defined(_MSC_VER)
- std::istringstream ss(ov_string);
- std::string token;
bool enable_fast_math = false;
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- auto key = token.substr(0, pos);
- auto value = token.substr(pos + 1);
-
+ ParseSessionConfigs(ov_string, provider_options, {"enable_fast_math"});
+ for (const auto& provider_option : provider_options) {
if (key == "enable_fast_math") {
std::set ov_supported_values = {"true", "True", "false", "False"};
if (ov_supported_values.find(value) != ov_supported_values.end()) {
@@ -554,9 +467,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
"[ERROR] [ACL] You have selcted an invalid value for the key 'enable_fast_math'. "
"Select from 'true' or 'false' \n");
}
- } else {
- ORT_THROW(
- "[ERROR] [ACL] Unrecognized option: ", key);
}
}
Ort::ThrowOnError(
@@ -612,24 +522,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#else
std::string option_string = performance_test_config.run_config.ep_runtime_config_string;
#endif
- std::istringstream ss(option_string);
- std::string token;
- std::unordered_map vitisai_session_options;
+ ParseSessionConfigs(option_string, provider_options);
- while (ss >> token) {
- if (token == "") {
- continue;
- }
- auto pos = token.find("|");
- if (pos == std::string::npos || pos == 0 || pos == token.length()) {
- ORT_THROW("[ERROR] [VitisAI] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
- }
-
- std::string key(token.substr(0, pos));
- std::string value(token.substr(pos + 1));
- vitisai_session_options[key] = value;
- }
- session_options.AppendExecutionProvider_VitisAI(vitisai_session_options);
+ session_options.AppendExecutionProvider_VitisAI(provider_options);
#else
ORT_THROW("VitisAI is not supported in this build\n");
#endif
diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc
new file mode 100644
index 0000000000..e09c8fac70
--- /dev/null
+++ b/onnxruntime/test/perftest/strings_helper.cc
@@ -0,0 +1,57 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) 2023 NVIDIA Corporation.
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+// Licensed under the MIT License.
+
+#include
+#include
+
+#include "strings_helper.h"
+#include "core/common/common.h"
+
+namespace onnxruntime {
+namespace perftest {
+
+void ParseSessionConfigs(const std::string& configs_string,
+ std::unordered_map& session_configs,
+ const std::unordered_set& available_keys) {
+ std::istringstream ss(configs_string);
+ std::string token;
+
+ while (ss >> token) {
+ if (token == "") {
+ continue;
+ }
+
+ std::string_view token_sv(token);
+
+ auto pos = token_sv.find("|");
+ if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) {
+ ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
+ }
+
+ std::string key(token_sv.substr(0, pos));
+ std::string value(token_sv.substr(pos + 1));
+
+ if (available_keys.empty() == false && available_keys.count(key) == 0) {
+ // Error: unknown option: {key}
+ std::string available_keys_str;
+ for (const auto& av_key : available_keys) {
+ available_keys_str += av_key;
+ available_keys_str += ", ";
+ }
+ ORT_THROW("[ERROR] wrong key type entered : `", key,
+ "`. The following runtime key options are avaible: [", available_keys_str, "]");
+ }
+
+ auto it = session_configs.find(key);
+ if (it != session_configs.end()) {
+ // Error: specified duplicate session configuration entry: {key}
+ ORT_THROW("Specified duplicate session configuration entry: ", key);
+ }
+
+ session_configs.insert(std::make_pair(std::move(key), std::move(value)));
+ }
+}
+} // namespace perftest
+} // namespace onnxruntime
diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h
new file mode 100644
index 0000000000..0d6c56709f
--- /dev/null
+++ b/onnxruntime/test/perftest/strings_helper.h
@@ -0,0 +1,16 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) 2023 NVIDIA Corporation.
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+// Licensed under the MIT License.
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace perftest {
+
+void ParseSessionConfigs(const std::string& configs_string,
+ std::unordered_map& session_configs,
+ const std::unordered_set& available_keys = {});
+} // namespace perftest
+} // namespace onnxruntime
diff --git a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
index 32b4b32e29..fa95c1fc52 100644
--- a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
+++ b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
@@ -35,8 +35,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
- const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
+ std::unordered_map provider_options = {
+ {kCoremlProviderOption_MLComputeUnits, "CPUOnly"}};
+ session_options.AppendExecutionProvider("CoreML", provider_options);
}
#else
(void)useCoreML;
diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
index 86001b6cb5..b53a4a2df0 100644
--- a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
+++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
@@ -35,8 +35,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU =
#if COREML_EP_AVAILABLE
if (useCoreML) {
- const uint32_t flags = COREML_FLAG_USE_CPU_ONLY;
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags));
+ std::unordered_map provider_options = {
+ {kCoremlProviderOption_MLComputeUnits, "CPUOnly"}};
+ session_options.AppendExecutionProvider("CoreML", provider_options);
}
#else
(void)useCoreML;
diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc
index de647d9e3a..a8480e7416 100644
--- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc
+++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc
@@ -4,7 +4,7 @@
#include "core/common/logging/logging.h"
#include "core/graph/graph.h"
#include "core/graph/graph_viewer.h"
-#include "core/providers/coreml/coreml_execution_provider.h"
+#include "core/providers/coreml/coreml_provider_factory_creator.h"
#include "core/providers/coreml/coreml_provider_factory.h"
#include "core/session/inference_session.h"
#include "test/common/tensor_op_test_utils.h"
@@ -30,11 +30,11 @@ using namespace ::onnxruntime::logging;
namespace onnxruntime {
namespace test {
-// We want to run UT on CPU only to get output value without losing precision to pass the verification
-static constexpr uint32_t s_coreml_flags = COREML_FLAG_USE_CPU_ONLY;
-
-static std::unique_ptr MakeCoreMLExecutionProvider(uint32_t flags = s_coreml_flags) {
- return std::make_unique(flags);
+static std::unique_ptr MakeCoreMLExecutionProvider(
+ std::string ModelFormat = "NeuralNetwork", std::string ComputeUnits = "CPUOnly") {
+ std::unordered_map provider_options = {{kCoremlProviderOption_MLComputeUnits, ComputeUnits},
+ {kCoremlProviderOption_ModelFormat, ModelFormat}};
+ return CoreMLProviderFactoryCreator::Create(provider_options)->CreateProvider();
}
#if !defined(ORT_MINIMAL_BUILD)
@@ -128,7 +128,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) {
feeds,
verification_params);
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
- MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM),
+ MakeCoreMLExecutionProvider("MLProgram"),
feeds,
verification_params);
#else
@@ -170,7 +170,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) {
verification_params);
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
- MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM),
+ MakeCoreMLExecutionProvider("MLProgram"),
feeds,
verification_params);
#else
diff --git a/onnxruntime/test/providers/coreml/dynamic_input_test.cc b/onnxruntime/test/providers/coreml/dynamic_input_test.cc
index c91ef23650..8294f65745 100644
--- a/onnxruntime/test/providers/coreml/dynamic_input_test.cc
+++ b/onnxruntime/test/providers/coreml/dynamic_input_test.cc
@@ -7,6 +7,7 @@
#include
#include "core/providers/coreml/coreml_execution_provider.h"
+#include "core/providers/coreml/coreml_provider_factory_creator.h"
#include "core/providers/coreml/coreml_provider_factory.h" // for COREMLFlags
#include "test/common/random_generator.h"
#include "test/providers/model_tester.h"
@@ -20,8 +21,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MatMul) {
auto test = [&](const size_t M) {
SCOPED_TRACE(MakeString("M=", M));
-
- auto coreml_ep = std::make_unique(0);
+ std::unordered_map options;
+ auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
@@ -54,8 +55,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MobileNetExcerpt) {
auto test = [&](const size_t batch_size) {
SCOPED_TRACE(MakeString("batch_size=", batch_size));
-
- auto coreml_ep = std::make_unique(0);
+ std::unordered_map options;
+ auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
const auto ep_verification_params = EPVerificationParams{
ExpectedEPNodeAssignment::All,
@@ -87,6 +88,7 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");
ModelTester tester(CurrentTestName(), model_path);
+ std::unordered_map options;
tester.AddInput("A", {0, 2}, {});
tester.AddOutput("Y", {0, 4}, {});
@@ -94,14 +96,15 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) {
tester
.Config(ModelTester::ExpectResult::kExpectFailure,
"the runtime shape ({0,2}) has zero elements. This is not supported by the CoreML EP.")
- .ConfigEp(std::make_unique(0))
+ .ConfigEp(CoreMLProviderFactoryCreator::Create(options)->CreateProvider())
.RunWithConfig();
}
TEST(CoreMLExecutionProviderDynamicInputShapeTest, OnlyAllowStaticInputShapes) {
constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx");
-
- auto coreml_ep = std::make_unique(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES);
+ std::unordered_map options = {{kCoremlProviderOption_RequireStaticInputShapes, "1"}};
+ auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider();
+ ;
TestModelLoad(model_path, std::move(coreml_ep),
// expect no supported nodes because we disable dynamic input shape support
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index 3519c5d72c..59926bbcd1 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -251,14 +251,14 @@ std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlpr
// The test will create a model but execution of it will obviously fail.
#if defined(USE_COREML) && defined(__APPLE__)
// We want to run UT on CPU only to get output value without losing precision
- uint32_t coreml_flags = 0;
- coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
+ auto option = ProviderOptions();
+ option[kCoremlProviderOption_MLComputeUnits] = "CPUOnly";
if (use_mlprogram) {
- coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM;
+ option[kCoremlProviderOption_ModelFormat] = "MLProgram";
}
- return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
+ return CoreMLProviderFactoryCreator::Create(option)->CreateProvider();
#else
ORT_UNUSED_PARAMETER(use_mlprogram);
return nullptr;