Update OrtExtensionsUsage to also use the ORT Objective-C API. (#483)

This commit is contained in:
Edward Chen 2023-09-25 09:35:37 -05:00 коммит произвёл GitHub
Родитель ef19c6672a
Коммит 9abcda779f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 150 добавлений и 52 удалений

3
.swift-format Normal file
Просмотреть файл

@ -0,0 +1,3 @@
{
"lineLength": 120
}

Просмотреть файл

@ -16,6 +16,7 @@
2ECE960B293A77FD00039409 /* OrtClient.mm in Sources */ = {isa = PBXBuildFile; fileRef = 2ECE9608293A77FD00039409 /* OrtClient.mm */; };
2ECE960C293A77FD00039409 /* decode_image.onnx in Resources */ = {isa = PBXBuildFile; fileRef = 2ECE9609293A77FD00039409 /* decode_image.onnx */; };
2ECE960E293A784E00039409 /* r32_g64_b128_32x32.png in Resources */ = {isa = PBXBuildFile; fileRef = 2ECE960D293A784E00039409 /* r32_g64_b128_32x32.png */; };
2EF816742A572DC300CABC99 /* OrtSwiftClient.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2EF816732A572DC300CABC99 /* OrtSwiftClient.swift */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
@ -50,6 +51,7 @@
2ECE9609293A77FD00039409 /* decode_image.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; path = decode_image.onnx; sourceTree = "<group>"; };
2ECE960A293A77FD00039409 /* OrtClient.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = OrtClient.h; sourceTree = "<group>"; };
2ECE960D293A784E00039409 /* r32_g64_b128_32x32.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = r32_g64_b128_32x32.png; sourceTree = "<group>"; };
2EF816732A572DC300CABC99 /* OrtSwiftClient.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OrtSwiftClient.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@ -111,6 +113,7 @@
children = (
2ECE960A293A77FD00039409 /* OrtClient.h */,
2ECE9608293A77FD00039409 /* OrtClient.mm */,
2EF816732A572DC300CABC99 /* OrtSwiftClient.swift */,
2ECE95DD293A742700039409 /* OrtExtensionsUsageApp.swift */,
2ECE95DF293A742700039409 /* ContentView.swift */,
2ECE9607293A77FD00039409 /* OrtExtensionsUsage-Bridging-Header.h */,
@ -286,6 +289,7 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
2EF816742A572DC300CABC99 /* OrtSwiftClient.swift in Sources */,
2ECE960B293A77FD00039409 /* OrtClient.mm in Sources */,
2ECE95E0293A742700039409 /* ContentView.swift in Sources */,
2ECE95DE293A742700039409 /* OrtExtensionsUsageApp.swift in Sources */,

Просмотреть файл

@ -7,6 +7,7 @@ struct ContentView: View {
func runOrtDecodeAndCheckImage() -> String {
do {
try OrtClient.decodeAndCheckImage()
try swiftDecodeAndCheckImage()
return "Ok"
} catch let error as NSError {
return "Error: \(error.localizedDescription)"

Просмотреть файл

@ -15,6 +15,8 @@
@implementation OrtClient
// Runs a model and checks the result.
// Uses the ORT C++ API.
+ (BOOL)decodeAndCheckImageWithError:(NSError **)error {
try {
const auto ort_log_level = ORT_LOGGING_LEVEL_INFO;
@ -65,8 +67,10 @@
const auto output_type_and_shape_info =
output_tensor.GetTensorTypeAndShapeInfo();
// We expect the model output to be BGR values (3 uint8's) for each pixel
// in the decoded image.
// The input image has 32x32 pixels.
const int64_t h = 32, w = 32, c = 3;
const uint8_t expected_pixel_bgr_data[] = {128, 64, 32};
const std::vector<int64_t> expected_output_shape{h, w, c};
const auto output_shape = output_type_and_shape_info.GetShape();
if (output_shape != expected_output_shape) {
@ -79,6 +83,9 @@
throw std::runtime_error("Unexpected output element type");
}
// Each pixel in the input image has an RGB value of [32, 64, 128], or
// equivalently, a BGR value of [128, 64, 32].
const uint8_t expected_pixel_bgr_data[] = {128, 64, 32};
const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>();
for (size_t i = 0; i < h * w * c; ++i) {
if (output_data_raw[i] != expected_pixel_bgr_data[i % 3]) {

Просмотреть файл

@ -3,3 +3,6 @@
//
#import "OrtClient.h"
#import <onnxruntime.h>
#include <onnxruntime_extensions.h>

Просмотреть файл

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import Foundation
enum OrtSwiftClientError: Error {
case error(_ message: String)
}
// Runs a model and checks the result.
// Uses the ORT Objective-C/Swift API.
func swiftDecodeAndCheckImage() throws {
let ort_log_level = ORTLoggingLevel.info
let ort_env = try ORTEnv(loggingLevel: ort_log_level)
let ort_session_options = try ORTSessionOptions()
try ort_session_options.registerCustomOps(functionPointer: RegisterCustomOps)
guard let model_path = Bundle.main.path(forResource: "decode_image", ofType: "onnx") else {
throw OrtSwiftClientError.error("Failed to get model path")
}
let ort_session = try ORTSession(
env: ort_env, modelPath: model_path, sessionOptions: ort_session_options)
// note: need to set Xcode settings to prevent it from messing with PNG files:
// in "Build Settings":
// - set "Compress PNG Files" to "No"
// - set "Remove Text Metadata From PNG Files" to "No"
guard
let input_image_url = Bundle.main.url(forResource: "r32_g64_b128_32x32", withExtension: "png")
else {
throw OrtSwiftClientError.error("Failed to get image URL")
}
let input_data = try Data(contentsOf: input_image_url)
let input_data_length = input_data.count
let input_shape = [NSNumber(integerLiteral: input_data_length)]
let input_tensor = try ORTValue(
tensorData: NSMutableData(data: input_data), elementType: ORTTensorElementDataType.uInt8,
shape: input_shape)
let outputs = try ort_session.run(
withInputs: ["image": input_tensor], outputNames: ["bgr_data"], runOptions: nil)
guard let output_tensor = outputs["bgr_data"] else {
throw OrtSwiftClientError.error("Failed to get output")
}
let output_type_and_shape = try output_tensor.tensorTypeAndShapeInfo()
// We expect the model output to be BGR values (3 uint8's) for each pixel
// in the decoded image.
// The input image has 32x32 pixels.
let expected_output_shape: [NSNumber] = [32, 32, 3]
guard output_type_and_shape.shape == expected_output_shape else {
throw OrtSwiftClientError.error("Unexpected output shape")
}
let expected_output_element_type = ORTTensorElementDataType.uInt8
guard output_type_and_shape.elementType == expected_output_element_type else {
throw OrtSwiftClientError.error("Unexpected output element type")
}
// Each pixel in the input image has an RGB value of [32, 64, 128], or
// equivalently, a BGR value of [128, 64, 32].
let output_data: Data = try output_tensor.tensorData() as Data
let expected_pixel_bgr_data: [UInt8] = [128, 64, 32]
for (idx, byte) in output_data.enumerated() {
guard byte == expected_pixel_bgr_data[idx % expected_pixel_bgr_data.count] else {
throw OrtSwiftClientError.error("Unexpected pixel data")
}
}
}

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the MIT License.
import XCTest
@testable import OrtExtensionsUsage
final class OrtExtensionsUsageTests: XCTestCase {
@ -19,4 +20,8 @@ final class OrtExtensionsUsageTests: XCTestCase {
try OrtClient.decodeAndCheckImage()
}
func testSwiftDecodeAndCheckImage() throws {
// test that it doesn't throw
try swiftDecodeAndCheckImage()
}
}

Просмотреть файл

@ -24,7 +24,8 @@ final class OrtExtensionsUsageUITests: XCTestCase {
app.launch()
// Use XCTAssert and related functions to verify your tests produce the correct results.
let element = app.staticTexts.element(matching: XCUIElement.ElementType.staticText, identifier: "decodeImageResult")
let element = app.staticTexts.element(
matching: XCUIElement.ElementType.staticText, identifier: "decodeImageResult")
XCTAssertEqual(element.label, "Decode image result: Ok")
}

Просмотреть файл

@ -5,7 +5,7 @@ target 'OrtExtensionsUsage' do
use_frameworks!
# Pods for OrtExtensionsUsage
pod 'onnxruntime-c'
pod 'onnxruntime-objc', '>= 1.16.0'
# environment variable ORT_EXTENSIONS_LOCAL_POD_PATH can be used to specify a local onnxruntime-extensions-c pod path
ort_extensions_local_pod_path = ENV['ORT_EXTENSIONS_LOCAL_POD_PATH']