Update OrtExtensionsUsage to also use the ORT Objective-C API. (#483)
This commit is contained in:
Родитель
ef19c6672a
Коммит
9abcda779f
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"lineLength": 120
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
2ECE960B293A77FD00039409 /* OrtClient.mm in Sources */ = {isa = PBXBuildFile; fileRef = 2ECE9608293A77FD00039409 /* OrtClient.mm */; };
|
||||
2ECE960C293A77FD00039409 /* decode_image.onnx in Resources */ = {isa = PBXBuildFile; fileRef = 2ECE9609293A77FD00039409 /* decode_image.onnx */; };
|
||||
2ECE960E293A784E00039409 /* r32_g64_b128_32x32.png in Resources */ = {isa = PBXBuildFile; fileRef = 2ECE960D293A784E00039409 /* r32_g64_b128_32x32.png */; };
|
||||
2EF816742A572DC300CABC99 /* OrtSwiftClient.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2EF816732A572DC300CABC99 /* OrtSwiftClient.swift */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
|
@ -50,6 +51,7 @@
|
|||
2ECE9609293A77FD00039409 /* decode_image.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; path = decode_image.onnx; sourceTree = "<group>"; };
|
||||
2ECE960A293A77FD00039409 /* OrtClient.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = OrtClient.h; sourceTree = "<group>"; };
|
||||
2ECE960D293A784E00039409 /* r32_g64_b128_32x32.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = r32_g64_b128_32x32.png; sourceTree = "<group>"; };
|
||||
2EF816732A572DC300CABC99 /* OrtSwiftClient.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OrtSwiftClient.swift; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
|
@ -111,6 +113,7 @@
|
|||
children = (
|
||||
2ECE960A293A77FD00039409 /* OrtClient.h */,
|
||||
2ECE9608293A77FD00039409 /* OrtClient.mm */,
|
||||
2EF816732A572DC300CABC99 /* OrtSwiftClient.swift */,
|
||||
2ECE95DD293A742700039409 /* OrtExtensionsUsageApp.swift */,
|
||||
2ECE95DF293A742700039409 /* ContentView.swift */,
|
||||
2ECE9607293A77FD00039409 /* OrtExtensionsUsage-Bridging-Header.h */,
|
||||
|
@ -286,6 +289,7 @@
|
|||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
2EF816742A572DC300CABC99 /* OrtSwiftClient.swift in Sources */,
|
||||
2ECE960B293A77FD00039409 /* OrtClient.mm in Sources */,
|
||||
2ECE95E0293A742700039409 /* ContentView.swift in Sources */,
|
||||
2ECE95DE293A742700039409 /* OrtExtensionsUsageApp.swift in Sources */,
|
||||
|
|
|
@ -7,6 +7,7 @@ struct ContentView: View {
|
|||
func runOrtDecodeAndCheckImage() -> String {
|
||||
do {
|
||||
try OrtClient.decodeAndCheckImage()
|
||||
try swiftDecodeAndCheckImage()
|
||||
return "Ok"
|
||||
} catch let error as NSError {
|
||||
return "Error: \(error.localizedDescription)"
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
@implementation OrtClient
|
||||
|
||||
// Runs a model and checks the result.
|
||||
// Uses the ORT C++ API.
|
||||
+ (BOOL)decodeAndCheckImageWithError:(NSError **)error {
|
||||
try {
|
||||
const auto ort_log_level = ORT_LOGGING_LEVEL_INFO;
|
||||
|
@ -65,8 +67,10 @@
|
|||
const auto output_type_and_shape_info =
|
||||
output_tensor.GetTensorTypeAndShapeInfo();
|
||||
|
||||
// We expect the model output to be BGR values (3 uint8's) for each pixel
|
||||
// in the decoded image.
|
||||
// The input image has 32x32 pixels.
|
||||
const int64_t h = 32, w = 32, c = 3;
|
||||
const uint8_t expected_pixel_bgr_data[] = {128, 64, 32};
|
||||
const std::vector<int64_t> expected_output_shape{h, w, c};
|
||||
const auto output_shape = output_type_and_shape_info.GetShape();
|
||||
if (output_shape != expected_output_shape) {
|
||||
|
@ -79,6 +83,9 @@
|
|||
throw std::runtime_error("Unexpected output element type");
|
||||
}
|
||||
|
||||
// Each pixel in the input image has an RGB value of [32, 64, 128], or
|
||||
// equivalently, a BGR value of [128, 64, 32].
|
||||
const uint8_t expected_pixel_bgr_data[] = {128, 64, 32};
|
||||
const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>();
|
||||
for (size_t i = 0; i < h * w * c; ++i) {
|
||||
if (output_data_raw[i] != expected_pixel_bgr_data[i % 3]) {
|
||||
|
|
|
@ -3,3 +3,6 @@
|
|||
//
|
||||
|
||||
#import "OrtClient.h"
|
||||
|
||||
#import <onnxruntime.h>
|
||||
#include <onnxruntime_extensions.h>
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import Foundation
|
||||
|
||||
enum OrtSwiftClientError: Error {
|
||||
case error(_ message: String)
|
||||
}
|
||||
|
||||
// Runs a model and checks the result.
|
||||
// Uses the ORT Objective-C/Swift API.
|
||||
func swiftDecodeAndCheckImage() throws {
|
||||
let ort_log_level = ORTLoggingLevel.info
|
||||
let ort_env = try ORTEnv(loggingLevel: ort_log_level)
|
||||
|
||||
let ort_session_options = try ORTSessionOptions()
|
||||
try ort_session_options.registerCustomOps(functionPointer: RegisterCustomOps)
|
||||
|
||||
guard let model_path = Bundle.main.path(forResource: "decode_image", ofType: "onnx") else {
|
||||
throw OrtSwiftClientError.error("Failed to get model path")
|
||||
}
|
||||
|
||||
let ort_session = try ORTSession(
|
||||
env: ort_env, modelPath: model_path, sessionOptions: ort_session_options)
|
||||
|
||||
// note: need to set Xcode settings to prevent it from messing with PNG files:
|
||||
// in "Build Settings":
|
||||
// - set "Compress PNG Files" to "No"
|
||||
// - set "Remove Text Metadata From PNG Files" to "No"
|
||||
guard
|
||||
let input_image_url = Bundle.main.url(forResource: "r32_g64_b128_32x32", withExtension: "png")
|
||||
else {
|
||||
throw OrtSwiftClientError.error("Failed to get image URL")
|
||||
}
|
||||
|
||||
let input_data = try Data(contentsOf: input_image_url)
|
||||
let input_data_length = input_data.count
|
||||
let input_shape = [NSNumber(integerLiteral: input_data_length)]
|
||||
let input_tensor = try ORTValue(
|
||||
tensorData: NSMutableData(data: input_data), elementType: ORTTensorElementDataType.uInt8,
|
||||
shape: input_shape)
|
||||
|
||||
let outputs = try ort_session.run(
|
||||
withInputs: ["image": input_tensor], outputNames: ["bgr_data"], runOptions: nil)
|
||||
|
||||
guard let output_tensor = outputs["bgr_data"] else {
|
||||
throw OrtSwiftClientError.error("Failed to get output")
|
||||
}
|
||||
|
||||
let output_type_and_shape = try output_tensor.tensorTypeAndShapeInfo()
|
||||
|
||||
// We expect the model output to be BGR values (3 uint8's) for each pixel
|
||||
// in the decoded image.
|
||||
// The input image has 32x32 pixels.
|
||||
let expected_output_shape: [NSNumber] = [32, 32, 3]
|
||||
guard output_type_and_shape.shape == expected_output_shape else {
|
||||
throw OrtSwiftClientError.error("Unexpected output shape")
|
||||
}
|
||||
|
||||
let expected_output_element_type = ORTTensorElementDataType.uInt8
|
||||
guard output_type_and_shape.elementType == expected_output_element_type else {
|
||||
throw OrtSwiftClientError.error("Unexpected output element type")
|
||||
}
|
||||
|
||||
// Each pixel in the input image has an RGB value of [32, 64, 128], or
|
||||
// equivalently, a BGR value of [128, 64, 32].
|
||||
let output_data: Data = try output_tensor.tensorData() as Data
|
||||
let expected_pixel_bgr_data: [UInt8] = [128, 64, 32]
|
||||
for (idx, byte) in output_data.enumerated() {
|
||||
guard byte == expected_pixel_bgr_data[idx % expected_pixel_bgr_data.count] else {
|
||||
throw OrtSwiftClientError.error("Unexpected pixel data")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import XCTest
|
||||
|
||||
@testable import OrtExtensionsUsage
|
||||
|
||||
final class OrtExtensionsUsageTests: XCTestCase {
|
||||
|
@ -19,4 +20,8 @@ final class OrtExtensionsUsageTests: XCTestCase {
|
|||
try OrtClient.decodeAndCheckImage()
|
||||
}
|
||||
|
||||
func testSwiftDecodeAndCheckImage() throws {
|
||||
// test that it doesn't throw
|
||||
try swiftDecodeAndCheckImage()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,8 @@ final class OrtExtensionsUsageUITests: XCTestCase {
|
|||
app.launch()
|
||||
|
||||
// Use XCTAssert and related functions to verify your tests produce the correct results.
|
||||
let element = app.staticTexts.element(matching: XCUIElement.ElementType.staticText, identifier: "decodeImageResult")
|
||||
let element = app.staticTexts.element(
|
||||
matching: XCUIElement.ElementType.staticText, identifier: "decodeImageResult")
|
||||
XCTAssertEqual(element.label, "Decode image result: Ok")
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ target 'OrtExtensionsUsage' do
|
|||
use_frameworks!
|
||||
|
||||
# Pods for OrtExtensionsUsage
|
||||
pod 'onnxruntime-c'
|
||||
pod 'onnxruntime-objc', '>= 1.16.0'
|
||||
|
||||
# environment variable ORT_EXTENSIONS_LOCAL_POD_PATH can be used to specify a local onnxruntime-extensions-c pod path
|
||||
ort_extensions_local_pod_path = ENV['ORT_EXTENSIONS_LOCAL_POD_PATH']
|
||||
|
|
Загрузка…
Ссылка в новой задаче