Update Objective-C API (#7567)
Update Objective-C API to be more usable from Swift. E.g., to allow conversion from Objective-C methods with trailing NSError** parameter to throwing Swift methods. Update CMake Objective-C framework setup.
This commit is contained in:
Родитель
f3a70f1aec
Коммит
830f0b45d0
|
@ -37,9 +37,10 @@ set(OBJC_ARC_COMPILE_OPTIONS "-fobjc-arc" "-fobjc-arc-exceptions")
|
|||
# explicitly list them here so it is easy to see what is included
|
||||
set(onnxruntime_objc_headers
|
||||
"${OBJC_ROOT}/include/onnxruntime.h"
|
||||
"${OBJC_ROOT}/include/onnxruntime/ort_env.h"
|
||||
"${OBJC_ROOT}/include/onnxruntime/ort_session.h"
|
||||
"${OBJC_ROOT}/include/onnxruntime/ort_value.h")
|
||||
"${OBJC_ROOT}/include/ort_enums.h"
|
||||
"${OBJC_ROOT}/include/ort_env.h"
|
||||
"${OBJC_ROOT}/include/ort_session.h"
|
||||
"${OBJC_ROOT}/include/ort_value.h")
|
||||
|
||||
file(GLOB onnxruntime_objc_srcs
|
||||
"${OBJC_ROOT}/src/*.h"
|
||||
|
@ -84,12 +85,29 @@ set_target_properties(onnxruntime_objc PROPERTIES
|
|||
FRAMEWORK_VERSION "A"
|
||||
PUBLIC_HEADER "${onnxruntime_objc_headers}"
|
||||
FOLDER "ONNXRuntime"
|
||||
CXX_STANDARD 17) # TODO remove when everything else moves to 17
|
||||
CXX_STANDARD 17 # TODO remove when everything else moves to 17
|
||||
)
|
||||
|
||||
target_link_options(onnxruntime_objc PRIVATE "-Wl,-headerpad_max_install_names")
|
||||
|
||||
add_custom_command(TARGET onnxruntime_objc POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory
|
||||
"$<TARGET_BUNDLE_CONTENT_DIR:onnxruntime_objc>/Libraries"
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
"$<TARGET_FILE:onnxruntime>"
|
||||
"$<TARGET_BUNDLE_CONTENT_DIR:onnxruntime_objc>/Libraries"
|
||||
COMMAND install_name_tool
|
||||
-change "@rpath/$<TARGET_FILE_NAME:onnxruntime>"
|
||||
"@rpath/$<TARGET_NAME:onnxruntime_objc>.framework/Libraries/$<TARGET_FILE_NAME:onnxruntime>"
|
||||
"$<TARGET_FILE:onnxruntime_objc>")
|
||||
|
||||
install(TARGETS onnxruntime_objc
|
||||
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
|
||||
if(onnxruntime_BUILD_UNIT_TESTS)
|
||||
find_package(XCTest REQUIRED)
|
||||
|
||||
# onnxruntime_test_objc target
|
||||
# onnxruntime_objc_test target
|
||||
|
||||
file(GLOB onnxruntime_objc_test_srcs
|
||||
"${OBJC_ROOT}/test/*.h"
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
// this header contains the entire ONNX Runtime Objective-C API
|
||||
// the headers below can also be imported individually
|
||||
|
||||
#import "onnxruntime/ort_env.h"
|
||||
#import "onnxruntime/ort_session.h"
|
||||
#import "onnxruntime/ort_value.h"
|
||||
#import "ort_enums.h"
|
||||
#import "ort_env.h"
|
||||
#import "ort_session.h"
|
||||
#import "ort_value.h"
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* The ORT logging verbosity levels.
|
||||
*/
|
||||
typedef NS_ENUM(int32_t, ORTLoggingLevel) {
|
||||
ORTLoggingLevelVerbose,
|
||||
ORTLoggingLevelInfo,
|
||||
ORTLoggingLevelWarning,
|
||||
ORTLoggingLevelError,
|
||||
ORTLoggingLevelFatal,
|
||||
};
|
||||
|
||||
/**
|
||||
* The ORT value types.
|
||||
* Currently, a subset of all types is supported.
|
||||
*/
|
||||
typedef NS_ENUM(int32_t, ORTValueType) {
|
||||
ORTValueTypeUnknown,
|
||||
ORTValueTypeTensor,
|
||||
};
|
||||
|
||||
/**
|
||||
* The ORT tensor element data types.
|
||||
* Currently, a subset of all types is supported.
|
||||
*/
|
||||
typedef NS_ENUM(int32_t, ORTTensorElementDataType) {
|
||||
ORTTensorElementDataTypeUndefined,
|
||||
ORTTensorElementDataTypeFloat,
|
||||
ORTTensorElementDataTypeInt32,
|
||||
};
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "ort_enums.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
|
@ -10,15 +12,17 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
*/
|
||||
@interface ORTEnv : NSObject
|
||||
|
||||
- (nullable instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Creates an ORT Environment.
|
||||
*
|
||||
* @param loggingLevel The environment logging level.
|
||||
* @param[out] error Optional error information set if an error occurs.
|
||||
* @return The instance, or nil if an error occurs.
|
||||
*/
|
||||
- (nullable instancetype)initWithError:(NSError**)error NS_DESIGNATED_INITIALIZER;
|
||||
- (nullable instancetype)initWithLoggingLevel:(ORTLoggingLevel)loggingLevel
|
||||
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
|
@ -13,7 +13,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
*/
|
||||
@interface ORTSession : NSObject
|
||||
|
||||
- (nullable instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Creates an ORT Session.
|
|
@ -3,38 +3,26 @@
|
|||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "ort_enums.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* The supported ORT value types.
|
||||
*/
|
||||
typedef NS_ENUM(int32_t, ORTValueType) {
|
||||
ORTValueTypeUnknown,
|
||||
ORTValueTypeTensor,
|
||||
};
|
||||
|
||||
/**
|
||||
* The supported ORT tensor element data types.
|
||||
*/
|
||||
typedef NS_ENUM(int32_t, ORTTensorElementDataType) {
|
||||
ORTTensorElementDataTypeUndefined,
|
||||
ORTTensorElementDataTypeFloat,
|
||||
ORTTensorElementDataTypeInt32,
|
||||
};
|
||||
@class ORTValueTypeInfo;
|
||||
@class ORTTensorTypeAndShapeInfo;
|
||||
|
||||
/**
|
||||
* An ORT value encapsulates data used as an input or output to a model at runtime.
|
||||
*/
|
||||
@interface ORTValue : NSObject
|
||||
|
||||
- (nullable instancetype)init NS_UNAVAILABLE;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Creates a value that is a tensor.
|
||||
* The tensor data is allocated by the caller.
|
||||
*
|
||||
* @param data The tensor data.
|
||||
* @param type The tensor data element type.
|
||||
* @param tensorData The tensor data.
|
||||
* @param elementType The tensor element data type.
|
||||
* @param shape The tensor shape.
|
||||
* @param[out] error Optional error information set if an error occurs.
|
||||
* @return The instance, or nil if an error occurs.
|
||||
|
@ -45,34 +33,21 @@ typedef NS_ENUM(int32_t, ORTTensorElementDataType) {
|
|||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Gets the value type.
|
||||
* Gets the type information.
|
||||
*
|
||||
* @param[out] valueType The type of the value.
|
||||
* @param[out] error Optional error information set if an error occurs.
|
||||
* @return Whether the value type was retrieved successfully.
|
||||
* @return The type information, or nil if an error occurs.
|
||||
*/
|
||||
- (BOOL)valueType:(ORTValueType*)valueType
|
||||
error:(NSError**)error;
|
||||
- (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Gets the tensor data element type.
|
||||
* This assumes that the value is a tensor.
|
||||
*
|
||||
* @param[out] elementType The type of the tensor's data elements.
|
||||
* @param[out] error Optional error information set if an error occurs.
|
||||
* @return Whether the tensor data element type was retrieved successfully.
|
||||
*/
|
||||
- (BOOL)tensorElementType:(ORTTensorElementDataType*)elementType
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Gets the tensor shape.
|
||||
* Gets the tensor type and shape information.
|
||||
* This assumes that the value is a tensor.
|
||||
*
|
||||
* @param[out] error Optional error information set if an error occurs.
|
||||
* @return The tensor shape, or nil if an error occurs.
|
||||
* @return The tensor type and shape information, or nil if an error occurs.
|
||||
*/
|
||||
- (nullable NSArray<NSNumber*>*)tensorShapeWithError:(NSError**)error;
|
||||
- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Gets the tensor data.
|
||||
|
@ -85,4 +60,30 @@ typedef NS_ENUM(int32_t, ORTTensorElementDataType) {
|
|||
|
||||
@end
|
||||
|
||||
/**
|
||||
* A value's type information.
|
||||
*/
|
||||
@interface ORTValueTypeInfo : NSObject
|
||||
|
||||
/** The value type. */
|
||||
@property(nonatomic) ORTValueType type;
|
||||
|
||||
/** The tensor type and shape information, if the value is a tensor. */
|
||||
@property(nonatomic, nullable) ORTTensorTypeAndShapeInfo* tensorTypeAndShapeInfo;
|
||||
|
||||
@end
|
||||
|
||||
/**
|
||||
* A tensor's type and shape information.
|
||||
*/
|
||||
@interface ORTTensorTypeAndShapeInfo : NSObject
|
||||
|
||||
/** The tensor element data type. */
|
||||
@property(nonatomic) ORTTensorElementDataType elementType;
|
||||
|
||||
/** The tensor shape. */
|
||||
@property(nonatomic) NSArray<NSNumber*>* shape;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "ort_enums_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingLevelInfo {
|
||||
ORTLoggingLevel logging_level;
|
||||
OrtLoggingLevel capi_logging_level;
|
||||
};
|
||||
|
||||
// supported ORT logging levels
|
||||
// define the mapping from ORTLoggingLevel to C API OrtLoggingLevel here
|
||||
constexpr LoggingLevelInfo kLoggingLevelInfos[]{
|
||||
{ORTLoggingLevelVerbose, ORT_LOGGING_LEVEL_VERBOSE},
|
||||
{ORTLoggingLevelInfo, ORT_LOGGING_LEVEL_INFO},
|
||||
{ORTLoggingLevelWarning, ORT_LOGGING_LEVEL_WARNING},
|
||||
{ORTLoggingLevelError, ORT_LOGGING_LEVEL_ERROR},
|
||||
{ORTLoggingLevelFatal, ORT_LOGGING_LEVEL_FATAL},
|
||||
};
|
||||
|
||||
struct ValueTypeInfo {
|
||||
ORTValueType type;
|
||||
ONNXType capi_type;
|
||||
};
|
||||
|
||||
// supported ORT value types
|
||||
// define the mapping from ORTValueType to C API ONNXType here
|
||||
constexpr ValueTypeInfo kValueTypeInfos[]{
|
||||
{ORTValueTypeUnknown, ONNX_TYPE_UNKNOWN},
|
||||
{ORTValueTypeTensor, ONNX_TYPE_TENSOR},
|
||||
};
|
||||
|
||||
struct TensorElementTypeInfo {
|
||||
ORTTensorElementDataType type;
|
||||
ONNXTensorElementDataType capi_type;
|
||||
size_t element_size;
|
||||
};
|
||||
|
||||
// supported ORT tensor element data types
|
||||
// define the mapping from ORTTensorElementDataType to C API
|
||||
// ONNXTensorElementDataType here
|
||||
constexpr TensorElementTypeInfo kElementTypeInfos[]{
|
||||
{ORTTensorElementDataTypeUndefined, ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, 0},
|
||||
{ORTTensorElementDataTypeFloat, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)},
|
||||
{ORTTensorElementDataTypeInt32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)},
|
||||
};
|
||||
|
||||
template <typename Container, typename SelectFn, typename TransformFn>
|
||||
auto SelectAndTransform(
|
||||
const Container& container, SelectFn select_fn, TransformFn transform_fn,
|
||||
const char* not_found_msg)
|
||||
-> decltype(transform_fn(*std::begin(container))) {
|
||||
const auto it = std::find_if(
|
||||
std::begin(container), std::end(container), select_fn);
|
||||
if (it == std::end(container)) {
|
||||
throw Ort::Exception{not_found_msg, ORT_NOT_IMPLEMENTED};
|
||||
}
|
||||
return transform_fn(*it);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OrtLoggingLevel PublicToCAPILoggingLevel(ORTLoggingLevel logging_level) {
|
||||
return SelectAndTransform(
|
||||
kLoggingLevelInfos,
|
||||
[logging_level](const auto& logging_level_info) {
|
||||
return logging_level_info.logging_level == logging_level;
|
||||
},
|
||||
[](const auto& logging_level_info) {
|
||||
return logging_level_info.capi_logging_level;
|
||||
},
|
||||
"unsupported logging level");
|
||||
}
|
||||
|
||||
ORTValueType CAPIToPublicValueType(ONNXType capi_type) {
|
||||
return SelectAndTransform(
|
||||
kValueTypeInfos,
|
||||
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; },
|
||||
[](const auto& type_info) { return type_info.type; },
|
||||
"unsupported value type");
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType PublicToCAPITensorElementType(ORTTensorElementDataType type) {
|
||||
return SelectAndTransform(
|
||||
kElementTypeInfos,
|
||||
[type](const auto& type_info) { return type_info.type == type; },
|
||||
[](const auto& type_info) { return type_info.capi_type; },
|
||||
"unsupported tensor element type");
|
||||
}
|
||||
|
||||
ORTTensorElementDataType CAPIToPublicTensorElementType(ONNXTensorElementDataType capi_type) {
|
||||
return SelectAndTransform(
|
||||
kElementTypeInfos,
|
||||
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; },
|
||||
[](const auto& type_info) { return type_info.type; },
|
||||
"unsupported tensor element type");
|
||||
}
|
||||
|
||||
size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type) {
|
||||
return SelectAndTransform(
|
||||
kElementTypeInfos,
|
||||
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; },
|
||||
[](const auto& type_info) { return type_info.element_size; },
|
||||
"unsupported tensor element type");
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "ort_enums.h"
|
||||
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
OrtLoggingLevel PublicToCAPILoggingLevel(ORTLoggingLevel logging_level);
|
||||
|
||||
ORTValueType CAPIToPublicValueType(ONNXType capi_type);
|
||||
|
||||
ONNXTensorElementDataType PublicToCAPITensorElementType(ORTTensorElementDataType type);
|
||||
ORTTensorElementDataType CAPIToPublicTensorElementType(ONNXTensorElementDataType capi_type);
|
||||
|
||||
size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type);
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
#import "src/error_utils.h"
|
||||
#import "src/ort_enums_internal.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -15,11 +16,13 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
std::optional<Ort::Env> _env;
|
||||
}
|
||||
|
||||
- (nullable instancetype)initWithError:(NSError**)error {
|
||||
- (nullable instancetype)initWithLoggingLevel:(ORTLoggingLevel)loggingLevel
|
||||
error:(NSError**)error {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
try {
|
||||
_env = Ort::Env{};
|
||||
const auto CAPILoggingLevel = PublicToCAPILoggingLevel(loggingLevel);
|
||||
_env = Ort::Env{CAPILoggingLevel};
|
||||
} catch (const Ort::Exception& e) {
|
||||
ORTSaveExceptionToError(e, error);
|
||||
self = nil;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "onnxruntime/ort_env.h"
|
||||
#import "ort_env.h"
|
||||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "onnxruntime/ort_session.h"
|
||||
#import "ort_session.h"
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#import "src/ort_value_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "safeint/SafeInt.hpp"
|
||||
|
@ -11,82 +10,49 @@
|
|||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
#import "src/error_utils.h"
|
||||
#import "src/ort_enums_internal.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
namespace {
|
||||
struct ValueTypeInfo {
|
||||
ORTValueType type;
|
||||
ONNXType capi_type;
|
||||
};
|
||||
|
||||
// supported ORT value types
|
||||
// define the mapping from ORTValueType to C API ONNXType here
|
||||
constexpr ValueTypeInfo kValueTypeInfos[]{
|
||||
{ORTValueTypeUnknown, ONNX_TYPE_UNKNOWN},
|
||||
{ORTValueTypeTensor, ONNX_TYPE_TENSOR},
|
||||
};
|
||||
ORTTensorTypeAndShapeInfo* CXXAPIToPublicTensorTypeAndShapeInfo(
|
||||
const Ort::TensorTypeAndShapeInfo& CXXAPITensorTypeAndShapeInfo) {
|
||||
auto* result = [[ORTTensorTypeAndShapeInfo alloc] init];
|
||||
const auto elementType = CXXAPITensorTypeAndShapeInfo.GetElementType();
|
||||
const std::vector<int64_t> shape = CXXAPITensorTypeAndShapeInfo.GetShape();
|
||||
|
||||
struct TensorElementTypeInfo {
|
||||
ORTTensorElementDataType type;
|
||||
ONNXTensorElementDataType capi_type;
|
||||
size_t element_size;
|
||||
};
|
||||
|
||||
// supported ORT tensor element data types
|
||||
// define the mapping from ORTTensorElementDataType to C API
|
||||
// ONNXTensorElementDataType here
|
||||
constexpr TensorElementTypeInfo kElementTypeInfos[]{
|
||||
{ORTTensorElementDataTypeUndefined, ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, 0},
|
||||
{ORTTensorElementDataTypeFloat, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)},
|
||||
{ORTTensorElementDataTypeInt32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)},
|
||||
};
|
||||
|
||||
ORTValueType CAPIToPublicValueType(ONNXType capi_type) {
|
||||
const auto it = std::find_if(
|
||||
std::begin(kValueTypeInfos), std::end(kValueTypeInfos),
|
||||
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; });
|
||||
if (it == std::end(kValueTypeInfos)) {
|
||||
throw Ort::Exception{"unsupported value type", ORT_NOT_IMPLEMENTED};
|
||||
result.elementType = CAPIToPublicTensorElementType(elementType);
|
||||
auto* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()];
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
shapeArray[i] = @(shape[i]);
|
||||
}
|
||||
return it->type;
|
||||
result.shape = shapeArray;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType PublicToCAPITensorElementType(ORTTensorElementDataType type) {
|
||||
const auto it = std::find_if(
|
||||
std::begin(kElementTypeInfos), std::end(kElementTypeInfos),
|
||||
[type](const auto& type_info) { return type_info.type == type; });
|
||||
if (it == std::end(kElementTypeInfos)) {
|
||||
throw Ort::Exception{"unsupported tensor element type", ORT_NOT_IMPLEMENTED};
|
||||
ORTValueTypeInfo* CXXAPIToPublicValueTypeInfo(
|
||||
const Ort::TypeInfo& CXXAPITypeInfo) {
|
||||
auto* result = [[ORTValueTypeInfo alloc] init];
|
||||
const auto valueType = CXXAPITypeInfo.GetONNXType();
|
||||
|
||||
result.type = CAPIToPublicValueType(valueType);
|
||||
|
||||
if (valueType == ONNX_TYPE_TENSOR) {
|
||||
const auto tensorTypeAndShapeInfo = CXXAPITypeInfo.GetTensorTypeAndShapeInfo();
|
||||
result.tensorTypeAndShapeInfo = CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
|
||||
}
|
||||
return it->capi_type;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ORTTensorElementDataType CAPIToPublicTensorElementType(ONNXTensorElementDataType capi_type) {
|
||||
const auto it = std::find_if(
|
||||
std::begin(kElementTypeInfos), std::end(kElementTypeInfos),
|
||||
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; });
|
||||
if (it == std::end(kElementTypeInfos)) {
|
||||
throw Ort::Exception{"unsupported tensor element type", ORT_NOT_IMPLEMENTED};
|
||||
}
|
||||
return it->type;
|
||||
}
|
||||
|
||||
size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type) {
|
||||
const auto it = std::find_if(
|
||||
std::begin(kElementTypeInfos), std::end(kElementTypeInfos),
|
||||
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; });
|
||||
if (it == std::end(kElementTypeInfos)) {
|
||||
throw Ort::Exception{"unsupported tensor element type", ORT_NOT_IMPLEMENTED};
|
||||
}
|
||||
return it->element_size;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@interface ORTValue ()
|
||||
|
||||
// pointer to any external tensor data to keep alive for the lifetime of the ORTValue
|
||||
@property(nullable) NSMutableData* externalTensorData;
|
||||
@property(nonatomic, nullable) NSMutableData* externalTensorData;
|
||||
|
||||
@end
|
||||
|
||||
|
@ -127,40 +93,19 @@ size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type) {
|
|||
return self;
|
||||
}
|
||||
|
||||
- (BOOL)valueType:(ORTValueType*)valueType
|
||||
error:(NSError**)error {
|
||||
- (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
|
||||
try {
|
||||
const auto ortValueType = _typeInfo->GetONNXType();
|
||||
*valueType = CAPIToPublicValueType(ortValueType);
|
||||
return YES;
|
||||
return CXXAPIToPublicValueTypeInfo(*_typeInfo);
|
||||
} catch (const Ort::Exception& e) {
|
||||
ORTSaveExceptionToError(e, error);
|
||||
return NO;
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
- (BOOL)tensorElementType:(ORTTensorElementDataType*)elementType
|
||||
error:(NSError**)error {
|
||||
- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
|
||||
try {
|
||||
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
|
||||
const auto ortElementType = tensorTypeAndShapeInfo.GetElementType();
|
||||
*elementType = CAPIToPublicTensorElementType(ortElementType);
|
||||
return YES;
|
||||
} catch (const Ort::Exception& e) {
|
||||
ORTSaveExceptionToError(e, error);
|
||||
return NO;
|
||||
}
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSNumber*>*)tensorShapeWithError:(NSError**)error {
|
||||
try {
|
||||
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
|
||||
const std::vector<int64_t> shape = tensorTypeAndShapeInfo.GetShape();
|
||||
NSMutableArray<NSNumber*>* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()];
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
shapeArray[i] = @(shape[i]);
|
||||
}
|
||||
return shapeArray;
|
||||
return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
|
||||
} catch (const Ort::Exception& e) {
|
||||
ORTSaveExceptionToError(e, error);
|
||||
return nil;
|
||||
|
@ -212,4 +157,10 @@ size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type) {
|
|||
|
||||
@end
|
||||
|
||||
@implementation ORTValueTypeInfo
|
||||
@end
|
||||
|
||||
@implementation ORTTensorTypeAndShapeInfo
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "onnxruntime/ort_value.h"
|
||||
#import "ort_value.h"
|
||||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "onnxruntime/ort_env.h"
|
||||
#import "ort_env.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -14,7 +14,8 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
|
||||
- (void)testInitOk {
|
||||
NSError* err = nil;
|
||||
ORTEnv* env = [[ORTEnv alloc] initWithError:&err];
|
||||
ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
|
||||
error:&err];
|
||||
XCTAssertNotNil(env);
|
||||
XCTAssertNil(err);
|
||||
}
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "onnxruntime/ort_env.h"
|
||||
#import "onnxruntime/ort_session.h"
|
||||
#import "onnxruntime/ort_value.h"
|
||||
#import "ort_env.h"
|
||||
#import "ort_session.h"
|
||||
#import "ort_value.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
@ -24,7 +24,8 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
|
||||
self.continueAfterFailure = NO;
|
||||
|
||||
_ortEnv = [[ORTEnv alloc] initWithError:nil];
|
||||
_ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
|
||||
error:nil];
|
||||
XCTAssertNotNil(_ortEnv);
|
||||
}
|
||||
|
||||
|
@ -55,23 +56,23 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
+ (ORTValue*)ortValueWithScalarFloatData:(NSMutableData*)data {
|
||||
NSArray<NSNumber*>* shape = @[ @1 ];
|
||||
NSError* err = nil;
|
||||
ORTValue* ort_value = [[ORTValue alloc] initTensorWithData:data
|
||||
elementType:ORTTensorElementDataTypeFloat
|
||||
shape:shape
|
||||
error:&err];
|
||||
XCTAssertNotNil(ort_value);
|
||||
ORTValue* ortValue = [[ORTValue alloc] initTensorWithData:data
|
||||
elementType:ORTTensorElementDataTypeFloat
|
||||
shape:shape
|
||||
error:&err];
|
||||
XCTAssertNotNil(ortValue);
|
||||
XCTAssertNil(err);
|
||||
return ort_value;
|
||||
return ortValue;
|
||||
}
|
||||
|
||||
- (void)testInitAndRunWithPreallocatedOutputOk {
|
||||
NSMutableData* a_data = [ORTSessionTest dataWithScalarFloat:1.0f];
|
||||
NSMutableData* b_data = [ORTSessionTest dataWithScalarFloat:2.0f];
|
||||
NSMutableData* c_data = [ORTSessionTest dataWithScalarFloat:0.0f];
|
||||
NSMutableData* aData = [ORTSessionTest dataWithScalarFloat:1.0f];
|
||||
NSMutableData* bData = [ORTSessionTest dataWithScalarFloat:2.0f];
|
||||
NSMutableData* cData = [ORTSessionTest dataWithScalarFloat:0.0f];
|
||||
|
||||
ORTValue* a = [ORTSessionTest ortValueWithScalarFloatData:a_data];
|
||||
ORTValue* b = [ORTSessionTest ortValueWithScalarFloatData:b_data];
|
||||
ORTValue* c = [ORTSessionTest ortValueWithScalarFloatData:c_data];
|
||||
ORTValue* a = [ORTSessionTest ortValueWithScalarFloatData:aData];
|
||||
ORTValue* b = [ORTSessionTest ortValueWithScalarFloatData:bData];
|
||||
ORTValue* c = [ORTSessionTest ortValueWithScalarFloatData:cData];
|
||||
|
||||
NSError* err = nil;
|
||||
ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
|
||||
|
@ -80,24 +81,24 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
XCTAssertNotNil(session);
|
||||
XCTAssertNil(err);
|
||||
|
||||
BOOL run_result = [session runWithInputs:@{@"A" : a, @"B" : b}
|
||||
outputs:@{@"C" : c}
|
||||
error:&err];
|
||||
XCTAssertTrue(run_result);
|
||||
BOOL runResult = [session runWithInputs:@{@"A" : a, @"B" : b}
|
||||
outputs:@{@"C" : c}
|
||||
error:&err];
|
||||
XCTAssertTrue(runResult);
|
||||
XCTAssertNil(err);
|
||||
|
||||
const float c_expected = 3.0f;
|
||||
float c_actual;
|
||||
memcpy(&c_actual, c_data.bytes, sizeof(float));
|
||||
XCTAssertEqual(c_actual, c_expected);
|
||||
const float cExpected = 3.0f;
|
||||
float cActual;
|
||||
memcpy(&cActual, cData.bytes, sizeof(float));
|
||||
XCTAssertEqual(cActual, cExpected);
|
||||
}
|
||||
|
||||
- (void)testInitAndRunOk {
|
||||
NSMutableData* a_data = [ORTSessionTest dataWithScalarFloat:1.0f];
|
||||
NSMutableData* b_data = [ORTSessionTest dataWithScalarFloat:2.0f];
|
||||
NSMutableData* aData = [ORTSessionTest dataWithScalarFloat:1.0f];
|
||||
NSMutableData* bData = [ORTSessionTest dataWithScalarFloat:2.0f];
|
||||
|
||||
ORTValue* a = [ORTSessionTest ortValueWithScalarFloatData:a_data];
|
||||
ORTValue* b = [ORTSessionTest ortValueWithScalarFloatData:b_data];
|
||||
ORTValue* a = [ORTSessionTest ortValueWithScalarFloatData:aData];
|
||||
ORTValue* b = [ORTSessionTest ortValueWithScalarFloatData:bData];
|
||||
|
||||
NSError* err = nil;
|
||||
ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
|
||||
|
@ -113,35 +114,35 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
XCTAssertNotNil(outputs);
|
||||
XCTAssertNil(err);
|
||||
|
||||
ORTValue* c_output = outputs[@"C"];
|
||||
XCTAssertNotNil(c_output);
|
||||
ORTValue* cOutput = outputs[@"C"];
|
||||
XCTAssertNotNil(cOutput);
|
||||
|
||||
NSData* c_data = [c_output tensorDataWithError:&err];
|
||||
XCTAssertNotNil(c_data);
|
||||
NSData* cData = [cOutput tensorDataWithError:&err];
|
||||
XCTAssertNotNil(cData);
|
||||
XCTAssertNil(err);
|
||||
|
||||
const float c_expected = 3.0f;
|
||||
float c_actual;
|
||||
memcpy(&c_actual, c_data.bytes, sizeof(float));
|
||||
XCTAssertEqual(c_actual, c_expected);
|
||||
const float cExpected = 3.0f;
|
||||
float cActual;
|
||||
memcpy(&cActual, cData.bytes, sizeof(float));
|
||||
XCTAssertEqual(cActual, cExpected);
|
||||
}
|
||||
|
||||
- (void)testInitFailsWithInvalidPath {
|
||||
NSString* invalid_model_path = [ORTSessionTest getTestDataWithRelativePath:@"/invalid/path/to/model.onnx"];
|
||||
NSString* invalidModelPath = [ORTSessionTest getTestDataWithRelativePath:@"/invalid/path/to/model.onnx"];
|
||||
NSError* err = nil;
|
||||
ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
|
||||
modelPath:invalid_model_path
|
||||
modelPath:invalidModelPath
|
||||
error:&err];
|
||||
XCTAssertNil(session);
|
||||
XCTAssertNotNil(err);
|
||||
}
|
||||
|
||||
- (void)testRunFailsWithInvalidInput {
|
||||
NSMutableData* d_data = [ORTSessionTest dataWithScalarFloat:1.0f];
|
||||
NSMutableData* c_data = [ORTSessionTest dataWithScalarFloat:0.0f];
|
||||
NSMutableData* dData = [ORTSessionTest dataWithScalarFloat:1.0f];
|
||||
NSMutableData* cData = [ORTSessionTest dataWithScalarFloat:0.0f];
|
||||
|
||||
ORTValue* d = [ORTSessionTest ortValueWithScalarFloatData:d_data];
|
||||
ORTValue* c = [ORTSessionTest ortValueWithScalarFloatData:c_data];
|
||||
ORTValue* d = [ORTSessionTest ortValueWithScalarFloatData:dData];
|
||||
ORTValue* c = [ORTSessionTest ortValueWithScalarFloatData:cData];
|
||||
|
||||
NSError* err = nil;
|
||||
ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv
|
||||
|
@ -150,10 +151,10 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
XCTAssertNotNil(session);
|
||||
XCTAssertNil(err);
|
||||
|
||||
BOOL run_result = [session runWithInputs:@{@"D" : d}
|
||||
outputs:@{@"C" : c}
|
||||
error:&err];
|
||||
XCTAssertFalse(run_result);
|
||||
BOOL runResult = [session runWithInputs:@{@"D" : d}
|
||||
outputs:@{@"C" : c}
|
||||
error:&err];
|
||||
XCTAssertFalse(runResult);
|
||||
XCTAssertNotNil(err);
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "onnxruntime/ort_value.h"
|
||||
#import "ort_value.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
@ -26,28 +26,32 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
length:sizeof(int32_t)];
|
||||
NSArray<NSNumber*>* shape = @[ @1 ];
|
||||
|
||||
const ORTTensorElementDataType elementType = ORTTensorElementDataTypeInt32;
|
||||
|
||||
NSError* err = nil;
|
||||
ORTValue* ortValue = [[ORTValue alloc] initTensorWithData:data
|
||||
elementType:ORTTensorElementDataTypeInt32
|
||||
elementType:elementType
|
||||
shape:shape
|
||||
error:&err];
|
||||
XCTAssertNotNil(ortValue);
|
||||
XCTAssertNil(err);
|
||||
|
||||
ORTValueType actualValueType;
|
||||
XCTAssertTrue([ortValue valueType:&actualValueType error:&err]);
|
||||
XCTAssertNil(err);
|
||||
XCTAssertEqual(actualValueType, ORTValueTypeTensor);
|
||||
auto checkTensorInfo = [&](ORTTensorTypeAndShapeInfo* tensorInfo) {
|
||||
XCTAssertEqual(tensorInfo.elementType, elementType);
|
||||
XCTAssertEqualObjects(tensorInfo.shape, shape);
|
||||
};
|
||||
|
||||
ORTTensorElementDataType actualElementType;
|
||||
XCTAssertTrue([ortValue tensorElementType:&actualElementType error:&err]);
|
||||
ORTValueTypeInfo* typeInfo = [ortValue typeInfoWithError:&err];
|
||||
XCTAssertNotNil(typeInfo);
|
||||
XCTAssertNil(err);
|
||||
XCTAssertEqual(actualElementType, ORTTensorElementDataTypeInt32);
|
||||
XCTAssertEqual(typeInfo.type, ORTValueTypeTensor);
|
||||
XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo);
|
||||
checkTensorInfo(typeInfo.tensorTypeAndShapeInfo);
|
||||
|
||||
NSArray<NSNumber*>* actualShape = [ortValue tensorShapeWithError:&err];
|
||||
XCTAssertNotNil(actualShape);
|
||||
ORTTensorTypeAndShapeInfo* tensorInfo = [ortValue tensorTypeAndShapeInfoWithError:&err];
|
||||
XCTAssertNotNil(tensorInfo);
|
||||
XCTAssertNil(err);
|
||||
XCTAssertEqualObjects(shape, actualShape);
|
||||
checkTensorInfo(tensorInfo);
|
||||
|
||||
NSData* actualData = [ortValue tensorDataWithError:&err];
|
||||
XCTAssertNotNil(actualData);
|
||||
|
@ -65,11 +69,11 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
NSArray<NSNumber*>* shape = @[ @2, @3 ]; // too large
|
||||
|
||||
NSError* err = nil;
|
||||
ORTValue* ort_value = [[ORTValue alloc] initTensorWithData:data
|
||||
elementType:ORTTensorElementDataTypeInt32
|
||||
shape:shape
|
||||
error:&err];
|
||||
XCTAssertNil(ort_value);
|
||||
ORTValue* ortValue = [[ORTValue alloc] initTensorWithData:data
|
||||
elementType:ORTTensorElementDataTypeInt32
|
||||
shape:shape
|
||||
error:&err];
|
||||
XCTAssertNil(ortValue);
|
||||
XCTAssertNotNil(err);
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче