Refactor setup for Azure ops. Add Android support. (#507)

* Refactor setup for Azure ops to try and make common things more re-usable, and for the actual ops to simply layer in the specific input/output constraints for that type of request.

Currently builds on Linux, Windows (x64 only) and Android

Android requires a manual pre-build of openssl and curl.

Linux requires a manual pre-install of openssl.

Windows currently only works for x64. Other targets need the triplet adjusted.

* Address PR comments

* Fix could of android build warnings.

* Update .gitignore to remove old path

* Fix build break from merge
This commit is contained in:
Scott McKay 2023-08-08 19:54:30 +10:00 коммит произвёл GitHub
Родитель 5881931bf2
Коммит 2bde82fce9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
37 изменённых файлов: 1573 добавлений и 579 удалений

Просмотреть файл

@ -1,9 +1,16 @@
parameters:
- name: ExtraEnv
displayName: 'Extra env variable set to CIBW_ENVIRONMENT'
type: string
default: 'None=None'
jobs:
- job: windows
timeoutInMinutes: 120
pool: {name: 'onnxruntime-Win-CPU-2022'}
variables:
CIBW_BUILD: "cp3{8,9,10,11}-*amd64"
CIBW_ENVIRONMENT: "${{ parameters.ExtraEnv }}"
steps:
- task: UsePythonVersion@0

Просмотреть файл

@ -380,8 +380,22 @@ endif()
if(OCOS_ENABLE_AZURE)
# Azure endpoint invokers
include(triton)
file(GLOB TARGET_SRC_AZURE "operators/azure/*.cc" "operators/azure/*.h*")
if (ANDROID)
include(curl)
# make sure these work
find_package(OpenSSL REQUIRED)
find_package(CURL REQUIRED)
# exclude triton
list(FILTER TARGET_SRC_AZURE EXCLUDE REGEX ".*triton.*")
else()
add_compile_definitions(AZURE_INVOKERS_ENABLE_TRITON)
include(triton)
endif()
list(APPEND TARGET_SRC ${TARGET_SRC_AZURE})
endif()
@ -454,11 +468,13 @@ if(_HAS_TOKENIZER)
endif()
if(OCOS_ENABLE_TF_STRING)
target_include_directories(noexcep_operators PUBLIC
${googlere2_SOURCE_DIR}
${farmhash_SOURCE_DIR}/src)
target_include_directories(noexcep_operators PUBLIC ${farmhash_SOURCE_DIR}/src)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TF_STRING NOMINMAX FARMHASH_NO_BUILTIN_EXPECT FARMHASH_DEBUG=0)
target_link_libraries(noexcep_operators PRIVATE re2)
if(OCOS_ENABLE_RE2_REGEX)
target_include_directories(noexcep_operators PUBLIC ${googlere2_SOURCE_DIR})
target_link_libraries(noexcep_operators PRIVATE re2)
endif()
endif()
if(OCOS_ENABLE_RE2_REGEX)
@ -526,6 +542,10 @@ if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
endif()
if(OCOS_ENABLE_AZURE)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_AZURE)
endif()
target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR})
list(APPEND ocos_libraries Microsoft.GSL::GSL)
@ -577,9 +597,79 @@ else()
set(_BUILD_SHARED_LIBRARY TRUE)
endif()
if(OCOS_ENABLE_AZURE)
if (ANDROID)
# the find_package calls were made immediately after `include(curl)` so we know CURL and OpenSSL are available
target_link_libraries(ocos_operators PUBLIC CURL::libcurl OpenSSL::Crypto OpenSSL::SSL)
elseif(IOS)
# TODO
else()
# we need files from the triton client (e.g. curl header on linux) to be available for the ocos_operators build.
# add a dependency so the fetch and build of triton happens first.
add_dependencies(ocos_operators triton)
target_include_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/include)
target_link_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/lib)
target_link_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/lib64)
if (WIN32)
if (ocos_target_platform STREQUAL "AMD64")
set(vcpkg_target_platform "x64")
else()
set(vcpkg_target_platform ${ocos_target_platform})
endif()
# As per https://curl.se/docs/faq.html#Link_errors_when_building_libcur we need to set CURL_STATICLIB.
target_compile_definitions(ocos_operators PRIVATE CURL_STATICLIB)
target_include_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/include)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(curl_LIB_NAME "libcurl-d")
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/debug/lib)
else()
set(curl_LIB_NAME "libcurl")
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/lib)
endif()
target_link_libraries(ocos_operators PUBLIC httpclient_static ${curl_LIB_NAME} ws2_32 crypt32 Wldap32)
else()
find_package(ZLIB REQUIRED)
# If finding the OpenSSL or CURL package fails for a local build you can install libcurl4-openssl-dev.
# See also info on triton client dependencies here: https://github.com/triton-inference-server/client
find_package(OpenSSL REQUIRED)
find_package(CURL)
if (CURL_FOUND)
message(STATUS "Found CURL package")
set(libcurl_target CURL::libcurl)
else()
# curl is coming from triton but as that's an external project it isn't built yet and we have to add
# paths and library names instead of cmake targets.
message(STATUS "Using CURL build from triton client. Once built it should be in ${triton_THIRD_PARTY_DIR}/curl")
target_include_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/include)
# Install is to 'lib' except on CentOS (which is used for the manylinux build of the python wheel).
# Side note: we have to patch the triton client CMakeLists.txt to only use 'lib64' for 64-bit builds otherwise
# the build of the 32-bit python wheel fails with CURL not being found due to the invalid library
# directory name.
target_link_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/lib)
target_link_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/lib64)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(libcurl_target "curl-d")
else()
set(libcurl_target "curl")
endif()
endif()
target_link_libraries(ocos_operators PUBLIC httpclient_static ${libcurl_target} OpenSSL::Crypto OpenSSL::SSL ZLIB::ZLIB)
endif()
endif()
endif()
target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_link_libraries(ortcustomops PUBLIC ocos_operators)
if(_BUILD_SHARED_LIBRARY)
@ -592,8 +682,7 @@ if(_BUILD_SHARED_LIBRARY)
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
endif()
target_include_directories(extensions_shared PUBLIC
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(extensions_shared PUBLIC "$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
target_link_libraries(extensions_shared PRIVATE ortcustomops)
set_target_properties(extensions_shared PROPERTIES OUTPUT_NAME "ortextensions")
if(MSVC AND ocos_target_platform MATCHES "x86|x64")
@ -672,8 +761,10 @@ if(OCOS_ENABLE_CTEST)
list(APPEND LINUX_CC_FLAGS stdc++fs -pthread)
endif()
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc")
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc" "${TEST_SRC_DIR}/shared_test/*.hpp")
add_executable(extensions_test ${shared_TEST_SRC})
target_compile_definitions(extensions_test PUBLIC ${OCOS_COMPILE_DEFINITIONS})
standardize_output_folder(extensions_test)
target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}
"$<TARGET_PROPERTY:extensions_shared,INTERFACE_INCLUDE_DIRECTORIES>")
@ -706,30 +797,3 @@ if(OCOS_ENABLE_CTEST)
add_test(NAME extensions_test COMMAND $<TARGET_FILE:extensions_test>)
endif()
endif()
if(OCOS_ENABLE_AZURE)
add_dependencies(ocos_operators triton)
target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
if (ocos_target_platform STREQUAL "AMD64")
set(vcpkg_target_platform "x86")
else()
set(vcpkg_target_platform ${ocos_target_platform})
endif()
if (WIN32)
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib)
target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32)
else()
find_package(ZLIB REQUIRED)
find_package(OpenSSL REQUIRED)
target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
endif() #if (WIN32)
endif()

Просмотреть файл

@ -40,15 +40,28 @@ else()
endif()
endif()
message(STATUS "ONNX Runtime URL suffix: ${ONNXRUNTIME_URL}")
if (ANDROID)
set(ort_fetch_URL "https://repo1.maven.org/maven2/com/microsoft/onnxruntime/onnxruntime-android/${ONNXRUNTIME_VER}/onnxruntime-android-${ONNXRUNTIME_VER}.aar")
else()
set(ort_fetch_URL "https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}")
endif()
message(STATUS "ONNX Runtime URL: ${ort_fetch_URL}")
FetchContent_Declare(
onnxruntime
URL https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}
URL ${ort_fetch_URL}
)
FetchContent_makeAvailable(onnxruntime)
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
if (ANDROID)
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/headers)
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/jni/${ANDROID_ABI})
message(STATUS "Android onnxruntime inc=${ONNXRUNTIME_INCLUDE_DIR} lib=${ONNXRUNTIME_LIB_DIR}")
else()
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
endif()
endif()
if(NOT EXISTS ${ONNXRUNTIME_INCLUDE_DIR})

11
cmake/externals/curl.cmake поставляемый Normal file
Просмотреть файл

@ -0,0 +1,11 @@
if (ANDROID)
set(PREBUILD_OUTPUT_PATH "${PROJECT_SOURCE_DIR}/prebuild/openssl_for_ios_and_android/output/android")
set(OPENSSL_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/openssl-${ANDROID_ABI}")
set(NGHTTP2_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/nghttp2-${ANDROID_ABI}")
set(CURL_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/curl-${ANDROID_ABI}")
# Update CMAKE_FIND_ROOT_PATH so find_package/find_library can find these builds
list(APPEND CMAKE_FIND_ROOT_PATH "${OPENSSL_ROOT_DIR}" )
list(APPEND CMAKE_FIND_ROOT_PATH "${NGHTTP2_ROOT_DIR}" )
list(APPEND CMAKE_FIND_ROOT_PATH "${CURL_ROOT_DIR}" )
endif()

167
cmake/externals/triton.cmake поставляемый
Просмотреть файл

@ -1,32 +1,41 @@
include(ExternalProject)
if (WIN32)
set(triton_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton)
set(triton_INSTALL_DIR ${triton_PREFIX}/install)
if (WIN32)
if (ocos_target_platform STREQUAL "AMD64")
set(vcpkg_target_platform "x86")
set(vcpkg_target_platform "x64")
else()
set(vcpkg_target_platform ${ocos_target_platform})
endif()
ExternalProject_Add(vcpkg
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
GIT_TAG 2023.06.20
PREFIX vcpkg
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-build
CONFIGURE_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
BUILD_COMMAND "<SOURCE_DIR>/bootstrap-vcpkg.bat")
set(vcpkg_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg)
set(VCPKG_SRC ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
set(ENV{VCPKG_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
ExternalProject_Add(vcpkg
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
GIT_TAG 2023.06.20
PREFIX ${vcpkg_PREFIX}
CONFIGURE_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
BUILD_COMMAND "<SOURCE_DIR>/bootstrap-vcpkg.bat")
ExternalProject_Get_Property(vcpkg SOURCE_DIR BINARY_DIR)
set(VCPKG_SRC ${SOURCE_DIR})
message(status "vcpkg source dir: " ${VCPKG_SRC})
# set the environment variable so that the vcpkg.cmake file can find the vcpkg root directory
set(ENV{VCPKG_ROOT} ${VCPKG_SRC})
message(STATUS "VCPKG_SRC: " ${VCPKG_SRC})
message(STATUS "VCPKG_ROOT: " $ENV{VCPKG_ROOT})
message(STATUS "ENV{VCPKG_ROOT}: " $ENV{VCPKG_ROOT})
# NOTE: The VCPKG_ROOT environment variable isn't propagated to an add_custom_command target, so specify --vcpkg-root
# here and in the vcpkg_install function
add_custom_command(
COMMAND ${VCPKG_SRC}/vcpkg integrate install
COMMAND ${CMAKE_COMMAND} -E echo ${VCPKG_SRC}/vcpkg integrate --vcpkg-root=$ENV{VCPKG_ROOT} install
COMMAND ${VCPKG_SRC}/vcpkg integrate --vcpkg-root=$ENV{VCPKG_ROOT} install
COMMAND ${CMAKE_COMMAND} -E touch vcpkg_integrate.stamp
OUTPUT vcpkg_integrate.stamp
DEPENDS vcpkg
@ -35,77 +44,97 @@ if (WIN32)
add_custom_target(vcpkg_integrate ALL DEPENDS vcpkg_integrate.stamp)
set(VCPKG_DEPENDENCIES "vcpkg_integrate")
# use static-md so it adjusts for debug/release CRT
# https://stackoverflow.com/questions/67258905/vcpkg-difference-between-windows-windows-static-and-other
function(vcpkg_install PACKAGE_NAME)
add_custom_command(
OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO
COMMAND ${VCPKG_SRC}/vcpkg install ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static --vcpkg-root=${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static-md/BUILD_INFO
COMMAND ${CMAKE_COMMAND} -E echo ${VCPKG_SRC}/vcpkg install --vcpkg-root=$ENV{VCPKG_ROOT}
${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static-md
COMMAND ${VCPKG_SRC}/vcpkg install --vcpkg-root=$ENV{VCPKG_ROOT}
${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static-md
WORKING_DIRECTORY ${VCPKG_SRC}
DEPENDS vcpkg_integrate)
add_custom_target(get${PACKAGE_NAME}
add_custom_target(
get${PACKAGE_NAME}
ALL
DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO)
DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static-md/BUILD_INFO)
list(APPEND VCPKG_DEPENDENCIES "get${PACKAGE_NAME}")
set(VCPKG_DEPENDENCIES ${VCPKG_DEPENDENCIES} PARENT_SCOPE)
endfunction()
vcpkg_install(openssl)
vcpkg_install(openssl-windows)
vcpkg_install(rapidjson)
vcpkg_install(re2)
vcpkg_install(boost-interprocess)
vcpkg_install(boost-stacktrace)
vcpkg_install(pthread)
vcpkg_install(b64)
vcpkg_install(openssl)
vcpkg_install(curl)
add_dependencies(getb64 getpthread)
add_dependencies(getpthread getboost-stacktrace)
add_dependencies(getboost-stacktrace getboost-interprocess)
add_dependencies(getboost-interprocess getre2)
add_dependencies(getre2 getrapidjson)
add_dependencies(getrapidjson getopenssl-windows)
add_dependencies(getopenssl-windows getopenssl)
ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")
add_dependencies(triton ${VCPKG_DEPENDENCIES})
set(triton_extra_cmake_args -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static-md
-DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake)
set(triton_patch_command "")
set(triton_dependencies ${VCPKG_DEPENDENCIES})
else()
# RapidJSON 1.1.0 (released in 2016) is compatible with the triton build. Later code is not compatible without
# patching due to the change in variable name for the include dir from RAPIDJSON_INCLUDE_DIRS to
# RapidJSON_INCLUDE_DIRS in the generated cmake file used by find_package:
# https://github.com/Tencent/rapidjson/commit/b91c515afea9f0ba6a81fc670889549d77c83db3
# The triton code here https://github.com/triton-inference-server/common/blob/main/CMakeLists.txt is using
# RAPIDJSON_INCLUDE_DIRS so the build fails if a newer RapidJSON version is used. It will find the package but the
# include path will be wrong so the build error is delayed/misleading and non-trivial to understand/resolve.
set(RapidJSON_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson)
set(RapidJSON_INSTALL_DIR ${RapidJSON_PREFIX}/install)
ExternalProject_Add(RapidJSON
PREFIX ${RapidJSON_PREFIX}
URL https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip
URL_HASH SHA1=0fe7b4f7b83df4b3d517f4a202f3a383af7a0818
CMAKE_ARGS -DRAPIDJSON_BUILD_DOC=OFF
-DRAPIDJSON_BUILD_EXAMPLES=OFF
-DRAPIDJSON_BUILD_TESTS=OFF
-DRAPIDJSON_HAS_STDSTRING=ON
-DRAPIDJSON_USE_MEMBERSMAP=ON
-DCMAKE_INSTALL_PREFIX=${RapidJSON_INSTALL_DIR}
)
ExternalProject_Add(curl7
PREFIX curl7
GIT_REPOSITORY "https://github.com/curl/curl.git"
GIT_TAG "curl-7_86_0"
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-build
CMAKE_ARGS -DBUILD_TESTING=OFF -DBUILD_CURL_EXE=OFF -DBUILD_SHARED_LIBS=OFF -DCURL_STATICLIB=ON -DHTTP_ONLY=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE})
ExternalProject_Get_Property(RapidJSON SOURCE_DIR BINARY_DIR)
# message(STATUS "RapidJSON src=${SOURCE_DIR} binary=${BINARY_DIR}")
# Set RapidJSON_ROOT_DIR for find_package. The required RapidJSONConfig.cmake file is generated in the binary dir
set(RapidJSON_ROOT_DIR ${BINARY_DIR})
ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")
set(triton_extra_cmake_args "")
set(triton_patch_command patch --verbose -p1 -i ${PROJECT_SOURCE_DIR}/cmake/externals/triton_cmake.patch)
set(triton_dependencies RapidJSON)
add_dependencies(triton curl7)
# Patch the triton client CMakeLists.txt to fix two issues when building the python wheels with cibuildwheel, which
# uses CentOS 7.
# 1) use the full path to the version script file so 'ld' doesn't fail to find it. Looks like ld is running from the
# parent directory but not sure why the behavior differs vs. other linux builds
# e.g. building locally on Ubuntu is fine without the patch
# 2) only set the CURL lib path to 'lib64' on a 64-bit CentOS build as 'lib64' is invalid on a 32-bit OS. without
# this patch the build of the third-party libraries in the triton client fail as the CURL build is not found.
endif() #if (WIN32)
endif() #if (WIN32)
ExternalProject_Get_Property(triton SOURCE_DIR)
set(TRITON_SRC ${SOURCE_DIR})
# Add the triton build. We just need the library so we don't install it.
#
set(triton_VERSION_TAG r23.05)
ExternalProject_Add(triton
URL https://github.com/triton-inference-server/client/archive/refs/heads/${triton_VERSION_TAG}.tar.gz
URL_HASH SHA1=b8fd2a4e09eae39c33cd04cfa9ec934e39d9afc1
PREFIX ${triton_PREFIX}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${triton_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DTRITON_COMMON_REPO_TAG=${triton_VERSION_TAG}
-DTRITON_THIRD_PARTY_REPO_TAG=${triton_VERSION_TAG}
-DTRITON_CORE_REPO_TAG=${triton_VERSION_TAG}
-DTRITON_ENABLE_CC_HTTP=ON
-DTRITON_ENABLE_ZLIB=OFF
${triton_extra_cmake_args}
INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step."
PATCH_COMMAND ${triton_patch_command}
)
ExternalProject_Get_Property(triton BINARY_DIR)
set(TRITON_BIN ${BINARY_DIR}/binary)
set(TRITON_THIRD_PARTY ${BINARY_DIR}/third-party)
add_dependencies(triton ${triton_dependencies})
ExternalProject_Get_Property(triton SOURCE_DIR BINARY_DIR)
set(triton_THIRD_PARTY_DIR ${BINARY_DIR}/third-party)

30
cmake/externals/triton_cmake.patch поставляемый Normal file
Просмотреть файл

@ -0,0 +1,30 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7b11178..7749fa9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -115,10 +115,11 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_CC_GRPC OR TRITON_ENABLE_PERF_ANALYZER
file(STRINGS /etc/os-release DISTRO REGEX "^NAME=")
string(REGEX REPLACE "NAME=\"(.*)\"" "\\1" DISTRO "${DISTRO}")
message(STATUS "Distro Name: ${DISTRO}")
- if(DISTRO STREQUAL "CentOS Linux")
+ if(DISTRO STREQUAL "CentOS Linux" AND CMAKE_SIZEOF_VOID_P EQUAL 8)
set (CURL_LIB_DIR "lib64")
endif()
endif()
+ message(STATUS "Triton client CURL_LIB_DIR=${CURL_LIB_DIR}")
set(_cc_client_depends "")
if(${TRITON_ENABLE_CC_HTTP})
diff --git a/src/c++/library/CMakeLists.txt b/src/c++/library/CMakeLists.txt
index bdaae25..c36dbc8 100644
--- a/src/c++/library/CMakeLists.txt
+++ b/src/c++/library/CMakeLists.txt
@@ -320,7 +320,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
httpclient
PROPERTIES
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libhttpclient.ldscript
- LINK_FLAGS "-Wl,--version-script=libhttpclient.ldscript"
+ LINK_FLAGS "-Wl,--version-script=${CMAKE_CURRENT_BINARY_DIR}/libhttpclient.ldscript"
)
endif() # NOT WIN32

Просмотреть файл

@ -1,427 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define CURL_STATICLIB
#include "http_client.h"
#include "curl/curl.h"
#include "azure_invokers.hpp"
#include <sstream>
#define MIN_SUPPORTED_ORT_VER 14
namespace ort_extensions {
constexpr const char* kUri = "model_uri";
constexpr const char* kModelName = "model_name";
constexpr const char* kModelVer = "model_version";
constexpr const char* kVerbose = "verbose";
constexpr const char* kBinaryType = "binary_type";
////////////////////// AzureAudioToTextInvoker //////////////////////
struct StringBuffer {
StringBuffer() = default;
~StringBuffer() = default;
std::stringstream ss_;
};
AzureAudioToTextInvoker::AzureAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info)
: CurlInvoker(api, info) {
audio_format_ = TryToGetAttributeWithDefault<std::string>(kAudioFormat, "");
}
// apply the callback only when response is for sure to be a '/0' terminated string
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
try {
size_t realsize = size * nmemb;
auto buffer = reinterpret_cast<struct StringBuffer*>(userp);
buffer->ss_.write(reinterpret_cast<const char*>(contents), realsize);
return realsize;
} catch (...) {
// exception caught, abort write
return CURLcode::CURLE_WRITE_ERROR;
void AzureAudioToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const {
// TODO: Validate any required input names are present
// We don't have a way to get the output type from the custom op API.
// If there's a mismatch it will fail in the Compute when it allocates the output tensor.
if (OutputNames().size() != 1) {
ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT);
}
}
using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);
void AzureAudioToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const {
// theoretically the filename the content was buffered from
static const std::string fake_filename = "audio." + audio_format_;
class CurlHandler {
public:
CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup),
headers_(nullptr, curl_slist_free_all),
from_holder_(from_, curl_formfree) {
curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L);
curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1");
curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L);
curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back);
}
~CurlHandler() = default;
const auto& property_names = RequestPropertyNames();
void AddHeader(const char* data) {
headers_.reset(curl_slist_append(headers_.release(), data));
}
template <typename... Args>
void AddForm(Args... args) {
curl_formadd(&from_, &last_, args...);
}
template <typename T>
void SetOption(CURLoption opt, T val) {
curl_easy_setopt(curl_.get(), opt, val);
}
CURLcode Perform() {
SetOption(CURLOPT_HTTPHEADER, headers_.get());
if (from_) {
SetOption(CURLOPT_HTTPPOST, from_);
}
return curl_easy_perform(curl_.get());
}
private:
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
curl_httppost* from_{};
curl_httppost* last_{};
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_;
};
////////////////////// AzureInvoker //////////////////////
AzureInvoker::AzureInvoker(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
auto ver = GetActiveOrtAPIVersion();
if (ver < MIN_SUPPORTED_ORT_VER) {
ORTX_CXX_API_THROW("Azure ops requires ort >= 1.14", ORT_RUNTIME_EXCEPTION);
}
model_uri_ = TryToGetAttributeWithDefault<std::string>(kUri, "");
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
model_ver_ = TryToGetAttributeWithDefault<std::string>(kModelVer, "0");
verbose_ = TryToGetAttributeWithDefault<std::string>(kVerbose, "0");
OrtStatusPtr status = {};
size_t input_count = {};
status = api_.KernelInfo_GetInputCount(&info_, &input_count);
if (status) {
ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION);
}
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
char input_name[1024] = {};
size_t name_size = 1024;
status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size);
if (status) {
ORTX_CXX_API_THROW("failed to get input name", ORT_RUNTIME_EXCEPTION);
}
input_names_.push_back(input_name);
}
size_t output_count = {};
status = api_.KernelInfo_GetOutputCount(&info_, &output_count);
if (status) {
ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION);
}
for (size_t ith_output = 0; ith_output < output_count; ++ith_output) {
char output_name[1024] = {};
size_t name_size = 1024;
status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size);
if (status) {
ORTX_CXX_API_THROW("failed to get output name", ORT_RUNTIME_EXCEPTION);
}
output_names_.push_back(output_name);
}
}
////////////////////// AzureAudioInvoker //////////////////////
AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
file_name_ = std::string{"non_exist."} + TryToGetAttributeWithDefault<std::string>(kBinaryType, "wav");
}
void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const {
if (inputs.Size() < 1 ||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
}
if (inputs.Size() != input_names_.size()) {
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
}
if (inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || "auth_token" != input_names_[0]) {
ORTX_CXX_API_THROW("first input must be a string of auth token", ORT_INVALID_ARGUMENT);
}
std::string auth_token = reinterpret_cast<const char*>(inputs[0]->DataRaw());
std::string full_auth = std::string{"Authorization: Bearer "} + auth_token;
StringBuffer string_buffer;
CurlHandler curl_handler(WriteStringCallback);
curl_handler.AddHeader(full_auth.c_str());
curl_handler.AddHeader("Content-Type: multipart/form-data");
curl_handler.AddFormString("deployment_id", ModelName().c_str());
// TODO: If the handling here stays the same as in OpenAIAudioToText we can create a helper function to re-use
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
switch (inputs[ith_input]->Type()) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
curl_handler.AddForm(CURLFORM_COPYNAME,
input_names_[ith_input].c_str(),
CURLFORM_COPYCONTENTS,
inputs[ith_input]->DataRaw(),
CURLFORM_END);
curl_handler.AddFormString(property_names[ith_input].c_str(),
static_cast<const char*>(inputs[ith_input]->DataRaw())); // assumes null terminated
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
curl_handler.AddForm(CURLFORM_COPYNAME,
input_names_[ith_input].data(),
CURLFORM_BUFFER,
file_name_.c_str(),
CURLFORM_BUFFERPTR,
inputs[ith_input]->DataRaw(),
CURLFORM_BUFFERLENGTH,
inputs[ith_input]->SizeInBytes(),
CURLFORM_END);
curl_handler.AddFormBuffer(property_names[ith_input].c_str(),
fake_filename.c_str(),
inputs[ith_input]->DataRaw(),
inputs[ith_input]->SizeInBytes());
break;
default:
ORTX_CXX_API_THROW("input must be either text or binary", ORT_RUNTIME_EXCEPTION);
break;
}
} // for
}
}
curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str());
curl_handler.SetOption(CURLOPT_VERBOSE, verbose_);
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
void AzureAudioToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const {
auto& string_tensor = outputs.AllocateStringTensor(0);
string_tensor.SetStringOutput(std::vector<std::string>{response}, std::vector<int64_t>{1});
}
auto curl_ret = curl_handler.Perform();
if (CURLE_OK != curl_ret) {
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
////////////////////// AzureTextToTextInvoker //////////////////////
AzureTextToTextInvoker::AzureTextToTextInvoker(const OrtApi& api, const OrtKernelInfo& info)
: CurlInvoker(api, info) {
}
void AzureTextToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const {
if (inputs.Size() != 2 || inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("Expected 2 string inputs of auth_token and text respectively", ORT_INVALID_ARGUMENT);
}
output.SetStringOutput(std::vector<std::string>{string_buffer.ss_.str()}, std::vector<int64_t>{1L});
// We don't have a way to get the output type from the custom op API.
// If there's a mismatch it will fail in the Compute when it allocates the output tensor.
if (OutputNames().size() != 1) {
ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT);
}
}
////////////////////// AzureTextInvoker //////////////////////
void AzureTextToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const {
gsl::span<const std::string> input_names = InputNames();
AzureTextInvoker::AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
}
void AzureTextInvoker::Compute(std::string_view auth, std::string_view input,
ortc::Tensor<std::string>& output) const {
CurlHandler curl_handler(WriteStringCallback);
StringBuffer string_buffer;
std::string full_auth = std::string{"Authorization: Bearer "} + auth.data();
curl_handler.AddHeader(full_auth.c_str());
// TODO: assuming we need to create the correct json from the input text
curl_handler.AddHeader("Content-Type: application/json");
curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str());
curl_handler.SetOption(CURLOPT_POSTFIELDS, input.data());
curl_handler.SetOption(CURLOPT_POSTFIELDSIZE_LARGE, input.size());
curl_handler.SetOption(CURLOPT_VERBOSE, verbose_);
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
auto curl_ret = curl_handler.Perform();
if (CURLE_OK != curl_ret) {
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
}
output.SetStringOutput(std::vector<std::string>{string_buffer.ss_.str()}, std::vector<int64_t>{1L});
const auto& text_input = inputs[1];
curl_handler.SetOption(CURLOPT_POSTFIELDS, text_input->DataRaw());
curl_handler.SetOption(CURLOPT_POSTFIELDSIZE_LARGE, text_input->SizeInBytes());
}
////////////////////// AzureTritonInvoker //////////////////////
namespace tc = triton::client;
AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
auto err = tc::InferenceServerHttpClient::Create(&triton_client_, model_uri_, verbose_ != "0");
void AzureTextToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const {
auto& string_tensor = outputs.AllocateStringTensor(0);
string_tensor.SetStringOutput(std::vector<std::string>{response}, std::vector<int64_t>{1});
}
std::string MapDataType(ONNXTensorElementDataType onnx_data_type) {
std::string triton_data_type;
switch (onnx_data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
triton_data_type = "FP32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
triton_data_type = "UINT8";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
triton_data_type = "INT8";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
triton_data_type = "UINT16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
triton_data_type = "INT16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
triton_data_type = "INT32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
triton_data_type = "INT64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
triton_data_type = "BYTES";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
triton_data_type = "BOOL";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
triton_data_type = "FP16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
triton_data_type = "FP64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
triton_data_type = "UINT32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
triton_data_type = "UINT64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
triton_data_type = "BF16";
break;
default:
break;
}
return triton_data_type;
}
int8_t* CreateNonStrTensor(const std::string& data_type,
ortc::Variadic& outputs,
size_t i,
const std::vector<int64_t>& shape) {
if (data_type == "FP32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<float>(i, shape));
} else if (data_type == "UINT8") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint8_t>(i, shape));
} else if (data_type == "INT8") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int8_t>(i, shape));
} else if (data_type == "UINT16") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint16_t>(i, shape));
} else if (data_type == "INT16") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int16_t>(i, shape));
} else if (data_type == "INT32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int32_t>(i, shape));
} else if (data_type == "UINT32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint32_t>(i, shape));
} else if (data_type == "INT64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int64_t>(i, shape));
} else if (data_type == "UINT64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint64_t>(i, shape));
} else if (data_type == "BOOL") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<bool>(i, shape));
} else if (data_type == "FP64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<double>(i, shape));
} else {
return {};
}
}
#define CHECK_TRITON_ERR(ret, msg) \
if (!ret.IsOk()) { \
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
}
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
if (inputs.Size() < 1 ||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
}
if (inputs.Size() != input_names_.size()) {
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
}
auto auth_token = reinterpret_cast<const char*>(inputs[0]->DataRaw());
std::vector<std::unique_ptr<tc::InferInput>> triton_input_vec;
std::vector<tc::InferInput*> triton_inputs;
std::vector<std::unique_ptr<const tc::InferRequestedOutput>> triton_output_vec;
std::vector<const tc::InferRequestedOutput*> triton_outputs;
tc::Error err;
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
tc::InferInput* triton_input = {};
std::string triton_data_type = MapDataType(inputs[ith_input]->Type());
if (triton_data_type.empty()) {
ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION);
}
err = tc::InferInput::Create(&triton_input, input_names_[ith_input], inputs[ith_input]->Shape(), triton_data_type);
CHECK_TRITON_ERR(err, "failed to create triton input");
triton_input_vec.emplace_back(triton_input);
triton_inputs.push_back(triton_input);
if ("BYTES" == triton_data_type) {
const auto* string_tensor = reinterpret_cast<const ortc::Tensor<std::string>*>(inputs[ith_input].get());
triton_input->AppendFromString(string_tensor->Data());
} else {
const float* data_raw = reinterpret_cast<const float*>(inputs[ith_input]->DataRaw());
size_t size_in_bytes = inputs[ith_input]->SizeInBytes();
err = triton_input->AppendRaw(reinterpret_cast<const uint8_t*>(data_raw), size_in_bytes);
CHECK_TRITON_ERR(err, "failed to append raw data to input");
}
}
for (size_t ith_output = 0; ith_output < output_names_.size(); ++ith_output) {
tc::InferRequestedOutput* triton_output = {};
err = tc::InferRequestedOutput::Create(&triton_output, output_names_[ith_output]);
CHECK_TRITON_ERR(err, "failed to create triton output");
triton_output_vec.emplace_back(triton_output);
triton_outputs.push_back(triton_output);
}
std::unique_ptr<tc::InferResult> results_ptr;
tc::InferResult* results = {};
tc::InferOptions options(model_name_);
options.model_version_ = model_ver_;
options.client_timeout_ = 0;
tc::Headers http_headers;
http_headers["Authorization"] = std::string{"Bearer "} + auth_token;
err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs,
http_headers, tc::Parameters(),
tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config?
tc::InferenceServerHttpClient::CompressionType::NONE);
results_ptr.reset(results);
CHECK_TRITON_ERR(err, "failed to do triton inference");
size_t output_index = 0;
auto iter = output_names_.begin();
while (iter != output_names_.end()) {
std::vector<int64_t> shape;
err = results_ptr->Shape(*iter, &shape);
CHECK_TRITON_ERR(err, "failed to get output shape");
std::string type;
err = results_ptr->Datatype(*iter, &type);
CHECK_TRITON_ERR(err, "failed to get output type");
if ("BYTES" == type) {
std::vector<std::string> output_strings;
err = results_ptr->StringData(*iter, &output_strings);
CHECK_TRITON_ERR(err, "failed to get output as string");
auto& string_tensor = outputs.AllocateStringTensor(output_index);
string_tensor.SetStringOutput(output_strings, shape);
} else {
const uint8_t* raw_data = {};
size_t raw_size;
err = results_ptr->RawData(*iter, &raw_data, &raw_size);
CHECK_TRITON_ERR(err, "failed to get output raw data");
auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape);
memcpy(output_raw, raw_data, raw_size);
}
++output_index;
++iter;
}
}
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker),
CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker),
CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker),
CustomAzureStruct("AzureTextInvoker", AzureTextInvoker),
CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker),
CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker),
CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker),
CustomCpuStruct("AzureTextInvoker", AzureTextInvoker)
);
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader;
} // namespace ort_extensions

Просмотреть файл

@ -1,43 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "curl_invoker.hpp"
struct AzureInvoker : public BaseKernel {
AzureInvoker(const OrtApi& api, const OrtKernelInfo& info);
namespace ort_extensions {
protected:
~AzureInvoker() = default;
std::string model_uri_;
std::string model_name_;
std::string model_ver_;
std::string verbose_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
};
////////////////////// AzureAudioToTextInvoker //////////////////////
struct AzureAudioInvoker : public AzureInvoker {
AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const;
/// <summary>
/// Azure Audio to Text
/// Input: auth_token {string}, ??? (Update when AOAI endpoint is defined)
/// Output: text {string}
/// </summary>
class AzureAudioToTextInvoker : public CurlInvoker {
public:
AzureAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
// use impl from CurlInvoker
ComputeImpl(inputs, outputs);
}
private:
std::string file_name_;
void ValidateInputs(const ortc::Variadic& inputs) const override;
void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override;
void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override;
static constexpr const char* kAudioFormat = "audio_format";
std::string audio_format_;
};
struct AzureTextInvoker : public AzureInvoker {
AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output) const;
////////////////////// AzureTextToTextInvoker //////////////////////
/// Azure Text to Text
/// Input: auth_token {string}, text {string}
/// Output: text {string}
struct AzureTextToTextInvoker : public CurlInvoker {
AzureTextToTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
// use impl from CurlInvoker
ComputeImpl(inputs, outputs);
}
private:
std::string binary_type_;
void ValidateInputs(const ortc::Variadic& inputs) const override;
void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override;
void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override;
};
struct AzureTritonInvoker : public AzureInvoker {
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
private:
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,200 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "azure_triton_invoker.hpp"
////////////////////// AzureTritonInvoker //////////////////////
namespace tc = triton::client;
namespace ort_extensions {
AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info)
: CloudBaseKernel(api, info) {
auto err = tc::InferenceServerHttpClient::Create(&triton_client_, ModelUri(), Verbose());
}
std::string MapDataType(ONNXTensorElementDataType onnx_data_type) {
std::string triton_data_type;
switch (onnx_data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
triton_data_type = "FP32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
triton_data_type = "UINT8";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
triton_data_type = "INT8";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
triton_data_type = "UINT16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
triton_data_type = "INT16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
triton_data_type = "INT32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
triton_data_type = "INT64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
triton_data_type = "BYTES";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
triton_data_type = "BOOL";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
triton_data_type = "FP16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
triton_data_type = "FP64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
triton_data_type = "UINT32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
triton_data_type = "UINT64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
triton_data_type = "BF16";
break;
default:
break;
}
return triton_data_type;
}
int8_t* CreateNonStrTensor(const std::string& data_type,
ortc::Variadic& outputs,
size_t i,
const std::vector<int64_t>& shape) {
if (data_type == "FP32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<float>(i, shape));
} else if (data_type == "UINT8") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint8_t>(i, shape));
} else if (data_type == "INT8") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int8_t>(i, shape));
} else if (data_type == "UINT16") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint16_t>(i, shape));
} else if (data_type == "INT16") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int16_t>(i, shape));
} else if (data_type == "INT32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int32_t>(i, shape));
} else if (data_type == "UINT32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint32_t>(i, shape));
} else if (data_type == "INT64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int64_t>(i, shape));
} else if (data_type == "UINT64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint64_t>(i, shape));
} else if (data_type == "BOOL") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<bool>(i, shape));
} else if (data_type == "FP64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<double>(i, shape));
} else {
return {};
}
}
#define CHECK_TRITON_ERR(ret, msg) \
if (!ret.IsOk()) { \
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
}
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
auto auth_token = GetAuthToken(inputs);
gsl::span<const std::string> input_names = InputNames();
if (inputs.Size() != input_names.size()) {
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
}
std::vector<std::unique_ptr<tc::InferInput>> triton_input_vec;
std::vector<tc::InferInput*> triton_inputs;
std::vector<std::unique_ptr<const tc::InferRequestedOutput>> triton_output_vec;
std::vector<const tc::InferRequestedOutput*> triton_outputs;
tc::Error err;
const auto& property_names = RequestPropertyNames();
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
tc::InferInput* triton_input = {};
std::string triton_data_type = MapDataType(inputs[ith_input]->Type());
if (triton_data_type.empty()) {
ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION);
}
err = tc::InferInput::Create(&triton_input, property_names[ith_input], inputs[ith_input]->Shape(),
triton_data_type);
CHECK_TRITON_ERR(err, "failed to create triton input");
triton_input_vec.emplace_back(triton_input);
triton_inputs.push_back(triton_input);
if ("BYTES" == triton_data_type) {
const auto* string_tensor = reinterpret_cast<const ortc::Tensor<std::string>*>(inputs[ith_input].get());
triton_input->AppendFromString(string_tensor->Data());
} else {
const float* data_raw = reinterpret_cast<const float*>(inputs[ith_input]->DataRaw());
size_t size_in_bytes = inputs[ith_input]->SizeInBytes();
err = triton_input->AppendRaw(reinterpret_cast<const uint8_t*>(data_raw), size_in_bytes);
CHECK_TRITON_ERR(err, "failed to append raw data to input");
}
}
gsl::span<const std::string> output_names = OutputNames();
for (size_t ith_output = 0; ith_output < output_names.size(); ++ith_output) {
tc::InferRequestedOutput* triton_output = {};
err = tc::InferRequestedOutput::Create(&triton_output, output_names[ith_output]);
CHECK_TRITON_ERR(err, "failed to create triton output");
triton_output_vec.emplace_back(triton_output);
triton_outputs.push_back(triton_output);
}
std::unique_ptr<tc::InferResult> results_ptr;
tc::InferResult* results = {};
tc::InferOptions options(ModelName());
options.model_version_ = ModelVersion();
options.client_timeout_ = 0;
tc::Headers http_headers;
http_headers["Authorization"] = std::string{"Bearer "} + auth_token;
err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs,
http_headers, tc::Parameters(),
tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config?
tc::InferenceServerHttpClient::CompressionType::NONE);
results_ptr.reset(results);
CHECK_TRITON_ERR(err, "failed to do triton inference");
size_t output_index = 0;
for (const auto& output_name : output_names) {
std::vector<int64_t> shape;
err = results_ptr->Shape(output_name, &shape);
CHECK_TRITON_ERR(err, "failed to get output shape");
std::string type;
err = results_ptr->Datatype(output_name, &type);
CHECK_TRITON_ERR(err, "failed to get output type");
if ("BYTES" == type) {
std::vector<std::string> output_strings;
err = results_ptr->StringData(output_name, &output_strings);
CHECK_TRITON_ERR(err, "failed to get output as string");
auto& string_tensor = outputs.AllocateStringTensor(output_index);
string_tensor.SetStringOutput(output_strings, shape);
} else {
const uint8_t* raw_data = {};
size_t raw_size;
err = results_ptr->RawData(output_name, &raw_data, &raw_size);
CHECK_TRITON_ERR(err, "failed to get output raw data");
auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape);
memcpy(output_raw, raw_data, raw_size);
}
++output_index;
}
}
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "cloud_base_kernel.hpp"
#include "http_client.h" // triton
namespace ort_extensions {
struct AzureTritonInvoker : public CloudBaseKernel {
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
private:
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,92 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "cloud_base_kernel.hpp"
#include <sstream>
namespace ort_extensions {
CloudBaseKernel::CloudBaseKernel(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
auto ver = GetActiveOrtAPIVersion();
if (ver < kMinimumSupportedOrtVersion) {
ORTX_CXX_API_THROW("Azure custom operators require onnxruntime version >= 1.14", ORT_RUNTIME_EXCEPTION);
}
// require model uri. other properties are optional
// Custom op implementation can allow user to override attributes via inputs
if (!TryToGetAttribute<std::string>(kUri, model_uri_)) {
ORTX_CXX_API_THROW("Required " + model_uri_ + " attribute was not found", ORT_RUNTIME_EXCEPTION);
}
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
model_ver_ = TryToGetAttributeWithDefault<std::string>(kModelVer, "0");
verbose_ = TryToGetAttributeWithDefault<std::string>(kVerbose, "0") != "0";
OrtStatusPtr status{};
size_t input_count{};
status = api_.KernelInfo_GetInputCount(&info_, &input_count);
if (status) {
ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION);
}
input_names_.reserve(input_count);
property_names_.reserve(input_count);
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
char input_name[1024]{};
size_t name_size = 1024;
status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size);
if (status) {
ORTX_CXX_API_THROW("failed to get name for input " + std::to_string(ith_input), ORT_RUNTIME_EXCEPTION);
}
input_names_.push_back(input_name);
property_names_.push_back(GetPropertyNameFromInputName(input_name));
}
if (input_names_[0] != "auth_token") {
ORTX_CXX_API_THROW("first input name must be 'auth_token'", ORT_INVALID_ARGUMENT);
}
size_t output_count = {};
status = api_.KernelInfo_GetOutputCount(&info_, &output_count);
if (status) {
ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION);
}
output_names_.reserve(output_count);
for (size_t ith_output = 0; ith_output < output_count; ++ith_output) {
char output_name[1024]{};
size_t name_size = 1024;
status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size);
if (status) {
ORTX_CXX_API_THROW("failed to get name for output " + std::to_string(ith_output), ORT_RUNTIME_EXCEPTION);
}
output_names_.push_back(output_name);
}
}
std::string CloudBaseKernel::GetAuthToken(const ortc::Variadic& inputs) const {
if (inputs.Size() < 1 ||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("auth_token string is required to be the first input", ORT_INVALID_ARGUMENT);
}
std::string auth_token{static_cast<const char*>(inputs[0]->DataRaw())};
return auth_token;
}
/*static */ std::string CloudBaseKernel::GetPropertyNameFromInputName(const std::string& input_name) {
auto idx = input_name.find_last_of('/');
if (idx == std::string::npos) {
return input_name;
}
if (idx == input_name.length() - 1) {
ORTX_CXX_API_THROW("Input name cannot end with '/'. Invalid input:" + input_name, ORT_INVALID_ARGUMENT);
}
return input_name.substr(idx + 1); // return text after the '/'
}
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "gsl/span"
namespace ort_extensions {
/// <summary>
/// Base kernel for custom ops that call cloud endpoints.
/// </summary>
class CloudBaseKernel : public BaseKernel {
protected:
CloudBaseKernel(const OrtApi& api, const OrtKernelInfo& info);
virtual ~CloudBaseKernel() = default;
// Names of attributes the custom operator provides.
static constexpr const char* kUri = "model_uri"; // required
static constexpr const char* kModelName = "model_name"; // optional
static constexpr const char* kModelVer = "model_version"; // optional
static constexpr const char* kVerbose = "verbose";
static constexpr int kMinimumSupportedOrtVersion = 14;
const std::string& ModelUri() const { return model_uri_; }
const std::string& ModelName() const { return model_name_; }
const std::string& ModelVersion() const { return model_ver_; }
bool Verbose() const { return verbose_; }
const gsl::span<const std::string> InputNames() const { return input_names_; }
const gsl::span<const std::string> OutputNames() const { return output_names_; }
// Request property names that are parsed from input names. 1:1 with InputNames() values.
// e.g. 'node0/prompt' -> 'prompt' and that input provides the 'prompt' property in the request to the endpoint.
// <see cref="GetPropertyNameFromInputName"/> for further details.
const gsl::span<const std::string> RequestPropertyNames() const { return property_names_; }
// first input is required to be auth token. validate that and return it.
std::string GetAuthToken(const ortc::Variadic& inputs) const;
/// <summary>
/// Parse the property name to use in the request to the cloud endpoint from a node input name.
/// Value returned is text following last '/', or the entire string if no '/'.
/// e.g. 'node0/prompt' -> 'prompt'
/// </summary>
/// <param name="input_name">Node input name.</param>
/// <returns>Request property name the input is providing data for.</returns>
static std::string GetPropertyNameFromInputName(const std::string& input_name);
private:
std::string model_uri_;
std::string model_name_;
std::string model_ver_;
bool verbose_;
std::vector<std::string> input_names_;
std::vector<std::string> property_names_;
std::vector<std::string> output_names_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ocos.h"
#include "azure_invokers.hpp"
#include "openai_invokers.hpp"
#ifdef AZURE_INVOKERS_ENABLE_TRITON
#include "azure_triton_invoker.hpp"
#endif
using namespace ort_extensions;
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioToText", AzureAudioToTextInvoker),
CustomCpuStruct("AzureAudioToText", AzureAudioToTextInvoker),
CustomAzureStruct("AzureTextToText", AzureTextToTextInvoker),
CustomCpuStruct("AzureTextToText", AzureTextToTextInvoker),
CustomAzureStruct("OpenAIAudioToText", OpenAIAudioToTextInvoker),
CustomCpuStruct("OpenAIAudioToText", OpenAIAudioToTextInvoker)
#ifdef AZURE_INVOKERS_ENABLE_TRITON
,
CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker),
CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker)
#endif
);
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader;

Просмотреть файл

@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "curl_invoker.hpp"
#include <iostream> // TEMP error output
#include <sstream>
namespace ort_extensions {
// apply the callback only when response is for sure to be a '/0' terminated string
size_t CurlHandler::WriteStringCallback(char* contents, size_t element_size, size_t num_elements, void* userdata) {
try {
size_t bytes = element_size * num_elements;
std::string& buffer = *static_cast<std::string*>(userdata);
buffer.append(contents, bytes);
return bytes;
} catch (const std::exception& ex) {
// TODO: This should be captured/logger properly
std::cerr << ex.what() << std::endl;
return 0;
} catch (...) {
// exception caught, abort write
std::cerr << "Unknown exception caught in CurlHandler::WriteStringCallback" << std::endl;
return 0;
}
}
CurlHandler::CurlHandler(WriteCallBack callback) : curl_(curl_easy_init(), curl_easy_cleanup),
headers_(nullptr, curl_slist_free_all),
from_holder_(from_, curl_formfree) {
CURL* curl = curl_.get(); // CURL == void* so can't dereference
curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, 100 * 1024L); // how was this size chosen? should it be set on a per operator basis?
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L);
curl_easy_setopt(curl, CURLOPT_USERAGENT, "curl/7.83.1"); // should this value come from the curl src instead of being hardcoded?
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 50L); // 50 seems like a lot if we're directly calling a specific endpoint
curl_easy_setopt(curl, CURLOPT_FTP_SKIP_PASV_IP, 1L); // what does this have to do with http requests?
curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, callback);
// should this be configured via a node attribute? different endpoints may have different timeouts
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 15);
}
////////////////////// CurlInvoker //////////////////////
CurlInvoker::CurlInvoker(const OrtApi& api, const OrtKernelInfo& info)
: CloudBaseKernel(api, info) {
}
void CurlInvoker::ComputeImpl(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
std::string auth_token = GetAuthToken(inputs);
if (inputs.Size() != InputNames().size()) {
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
}
// do any additional validation of the number and type of inputs/outputs
ValidateInputs(inputs);
// set the options for the curl handler that apply to all usages
CurlHandler curl_handler(CurlHandler::WriteStringCallback);
std::string full_auth = std::string{"Authorization: Bearer "} + auth_token;
curl_handler.AddHeader(full_auth.c_str());
curl_handler.SetOption(CURLOPT_URL, ModelUri().c_str());
curl_handler.SetOption(CURLOPT_VERBOSE, Verbose());
std::string response;
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&response);
SetupRequest(curl_handler, inputs);
ExecuteRequest(curl_handler);
ProcessResponse(response, outputs);
}
void CurlInvoker::ExecuteRequest(CurlHandler& curl_handler) const {
// this is where we could add any logic required to make the request async or maybe handle retries/cancellation.
auto curl_ret = curl_handler.Perform();
if (CURLE_OK != curl_ret) {
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
}
}
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include "curl/curl.h"
#include "ocos.h"
#include "cloud_base_kernel.hpp"
namespace ort_extensions {
class CurlHandler {
public:
using WriteCallBack = size_t (*)(char*, size_t, size_t, void*);
CurlHandler(WriteCallBack callback);
~CurlHandler() = default;
/// <summary>
/// Callback to add contents to a string
/// </summary>
/// <seealso cref="https://curl.se/libcurl/c/CURLOPT_WRITEFUNCTION.html"/>
/// <returns>Bytes processed. If this does not match element_size * num_elements the libcurl function
/// used will return CURLE_WRITE_ERROR</returns>
static size_t WriteStringCallback(char* contents, size_t element_size, size_t num_elements, void* userdata);
void AddHeader(const char* data) {
headers_.reset(curl_slist_append(headers_.release(), data));
}
template <typename... Args>
void AddForm(Args... args) {
curl_formadd(&from_, &last_, args...);
}
void AddFormString(const char* name, const char* value) {
AddForm(CURLFORM_COPYNAME, name,
CURLFORM_COPYCONTENTS, value,
CURLFORM_END);
}
void AddFormBuffer(const char* name, const char* buffer_name, const void* buffer_ptr, size_t buffer_len) {
AddForm(CURLFORM_COPYNAME, name,
CURLFORM_BUFFER, buffer_name,
CURLFORM_BUFFERPTR, buffer_ptr,
CURLFORM_BUFFERLENGTH, buffer_len,
CURLFORM_END);
}
template <typename T>
void SetOption(CURLoption opt, T val) {
curl_easy_setopt(curl_.get(), opt, val);
}
CURLcode Perform() {
SetOption(CURLOPT_HTTPHEADER, headers_.get());
if (from_) {
SetOption(CURLOPT_HTTPPOST, from_);
}
return curl_easy_perform(curl_.get());
}
private:
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
curl_httppost* from_{};
curl_httppost* last_{};
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_; // TODO: Why no last_holder_?
};
/// <summary>
/// Base class for requests using Curl
/// </summary>
class CurlInvoker : public CloudBaseKernel {
protected:
CurlInvoker(const OrtApi& api, const OrtKernelInfo& info);
virtual ~CurlInvoker() = default;
// Compute implementation that is used to co-ordinate all Curl based Azure requests.
// Derived classes need their own Compute to work with the CustomOpLite infrastructure
void ComputeImpl(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
private:
void ExecuteRequest(CurlHandler& handler) const;
// Derived classes can add any arg validation required.
// Prior to this being called, `inputs` are validated to match the number of input names, and
// the auth_token has been read from input[0] so validation can skip that.
//
// the ortc::Variadic outputs are empty until the Compute populates it, so only output names can be validated
// and those are available from the base class.
virtual void ValidateInputs(const ortc::Variadic& inputs) const {}
// curl_handler has auth token set from input[0].
virtual void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const = 0;
virtual void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const = 0;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,103 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "openai_invokers.hpp"
namespace ort_extensions {
OpenAIAudioToTextInvoker::OpenAIAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info)
: CurlInvoker(api, info) {
audio_format_ = TryToGetAttributeWithDefault<std::string>(kAudioFormat, "");
const auto& property_names = RequestPropertyNames();
const auto find_optional_input = [&property_names](const std::string& property_name) {
std::optional<size_t> result;
auto optional_input = std::find_if(property_names.begin(), property_names.end(),
[&property_name](const auto& name) { return name == property_name; });
if (optional_input != property_names.end()) {
result = optional_input - property_names.begin();
}
return result;
};
filename_input_ = find_optional_input("filename");
model_name_input_ = find_optional_input("model");
// OpenAI audio endpoints require 'file' and 'model'.
if (!std::any_of(property_names.begin(), property_names.end(),
[](const auto& name) { return name == "file"; })) {
ORTX_CXX_API_THROW("Required 'file' input was not found", ORT_INVALID_ARGUMENT);
}
if (ModelName().empty() && !model_name_input_) {
ORTX_CXX_API_THROW("Required 'model' input was not found", ORT_INVALID_ARGUMENT);
}
}
void OpenAIAudioToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const {
// We don't have a way to get the output type from the custom op API.
// If there's a mismatch it will fail in the Compute when it allocates the output tensor.
if (OutputNames().size() != 1) {
ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT);
}
}
void OpenAIAudioToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const {
// theoretically the filename the content was buffered from. provides the extensions indicating the audio format
static const std::string fake_filename = "audio." + audio_format_;
const auto& property_names = RequestPropertyNames();
const auto& get_optional_input =
[&](const std::optional<size_t>& input_idx, const std::string& default_value, size_t min_size = 1) {
return (input_idx.has_value() && inputs[*input_idx]->SizeInBytes() > min_size)
? static_cast<const char*>(inputs[*input_idx]->DataRaw())
: default_value.c_str();
};
// filename_input_ is optional in a model. if it's not present, use a fake filename.
// if it's present make sure it's not a default empty value. as the filename needs to have an extension of
// mp3, mp4, mpeg, mpga, m4a, wav, or webm it must be at least 4 characters long.
const char* filename = get_optional_input(filename_input_, fake_filename, 4);
curl_handler.AddHeader("Content-Type: multipart/form-data");
// model name could be input or attribute
curl_handler.AddFormString("model", get_optional_input(model_name_input_, ModelName()));
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
switch (inputs[ith_input]->Type()) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
curl_handler.AddFormString(property_names[ith_input].c_str(),
// assumes null terminated.
// might be safer to pass pointer and length and add use CURLFORM_CONTENTSLENGTH
static_cast<const char*>(inputs[ith_input]->DataRaw()));
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
// only the 'file' input is uint8
if (property_names[ith_input] != "file") {
ORTX_CXX_API_THROW("Only the 'file' input should be uint8 data. Invalid input:" + InputNames()[ith_input],
ORT_INVALID_ARGUMENT);
}
curl_handler.AddFormBuffer(property_names[ith_input].c_str(),
filename,
inputs[ith_input]->DataRaw(),
inputs[ith_input]->SizeInBytes());
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
// TODO - required to support 'temperature' input.
default:
ORTX_CXX_API_THROW("input must be either text or binary", ORT_INVALID_ARGUMENT);
break;
}
}
}
void OpenAIAudioToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const {
auto& string_tensor = outputs.AllocateStringTensor(0);
string_tensor.SetStringOutput(std::vector<std::string>{response}, std::vector<int64_t>{1});
}
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <optional>
#include "curl_invoker.hpp"
namespace ort_extensions {
////////////////////// OpenAIAudioToTextInvoker //////////////////////
/// <summary>
/// OpenAI Audio to Text
/// Input: auth_token {string}, Request body values {string|uint8} as per https://platform.openai.com/docs/api-reference/audio
/// Output: text {string}
/// </summary>
/// <remarks>
/// The model URI is read from the node attributes.
/// The model name (e.g. 'whisper-1') can be provided as a node attribute or via an input.
///
/// Example input would be:
/// - string tensor named `auth_token` (required, must be first input)
/// - a uint8 tensor named `file` with audio data in the format matching the 'audio_format' attribute (required)
/// - see OpenAI documentation for current supported audio formats
/// - a string tensor named `filename` (optional) with extension indicating the format of the audio data
/// - e.g. 'audio.mp3'
/// - a string tensor named `prompt` (optional)
///
/// NOTE: 'temperature' is not currently supported.
/// </remarks>
class OpenAIAudioToTextInvoker final : public CurlInvoker {
public:
OpenAIAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
// use impl from CurlInvoker
ComputeImpl(inputs, outputs);
}
private:
void ValidateInputs(const ortc::Variadic& inputs) const override;
void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override;
void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override;
// audio format to use if the optional 'filename' input is not provided
static constexpr const char* kAudioFormat = "audio_format";
std::string audio_format_;
std::optional<size_t> filename_input_; // optional override for generated filename using audio_format
std::optional<size_t> model_name_input_; // optional override for model_name attribute
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set -e
set -u
set -x
# export skip_checkout if you want to repeat a build
if [ -z ${skip_checkout+x} ]; then
git clone https://github.com/leenjewel/openssl_for_ios_and_android.git
cd openssl_for_ios_and_android
git checkout ci-release-663da9e2
# patch with fixes to build on linux with NDK 25 or later
git apply ../build_curl_for_android_on_linux.patch
else
echo "Skipping checkout and patch"
cd openssl_for_ios_and_android
fi
cd tools
# we target Android API level 24
export api=24
# provide a specific architecture as an argument to the script to limit the build to that
# default is to build all
# valid architecture values: "arm" "arm64" "x86" "x86_64"
if [ $# -eq 1 ]; then
arch=$1
./build-android-openssl.sh $arch
./build-android-nghttp2.sh $arch
./build-android-curl.sh $arch
else
./build-android-openssl.sh
./build-android-nghttp2.sh
./build-android-curl.sh
fi

Просмотреть файл

@ -0,0 +1,67 @@
diff --git a/tools/build-android-common.sh b/tools/build-android-common.sh
index 87df207..797d58a 100755
--- a/tools/build-android-common.sh
+++ b/tools/build-android-common.sh
@@ -148,13 +148,20 @@ function set_android_toolchain() {
local build_host=$(get_build_host_internal "$arch")
local clang_target_host=$(get_clang_target_host "$arch" "$api")
- export AR=${build_host}-ar
+ # NDK r23 removed a bunch of GNU things and replaced with llvm
+ # https://stackoverflow.com/questions/73105626/arm-linux-androideabi-ar-command-not-found-in-ndk
+ # export AR=${build_host}-ar
+ export AR=llvm-ar
export CC=${clang_target_host}-clang
export CXX=${clang_target_host}-clang++
- export AS=${build_host}-as
- export LD=${build_host}-ld
- export RANLIB=${build_host}-ranlib
+ #export AS=${build_host}-as
+ export AS=llvm-as
+ #export LD=${build_host}-ld
+ export LD=ld
+ # export RANLIB=${build_host}-ranlib
+ export RANLIB=llvm-ranlib
export STRIP=${build_host}-strip
+ export STRIP=llvm-strip
}
function get_common_includes() {
@@ -187,13 +194,13 @@ function set_android_cpu_feature() {
export CPPFLAGS=${CFLAGS}
;;
x86)
- export CFLAGS="-march=i686 -mtune=intel -mssse3 -mfpmath=sse -m32 -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
+ export CFLAGS="-march=i686 -mtune=native -mssse3 -mfpmath=sse -m32 -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
export CXXFLAGS="-std=c++14 -Os -ffunction-sections -fdata-sections"
export LDFLAGS="-march=i686 -Wl,--gc-sections -Os -ffunction-sections -fdata-sections $(get_common_linked_libraries ${api} ${arch})"
export CPPFLAGS=${CFLAGS}
;;
x86-64)
- export CFLAGS="-march=x86-64 -msse4.2 -mpopcnt -m64 -mtune=intel -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
+ export CFLAGS="-march=x86-64 -msse4.2 -mpopcnt -m64 -mtune=native -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
export CXXFLAGS="-std=c++14 -Os -ffunction-sections -fdata-sections"
export LDFLAGS="-march=x86-64 -Wl,--gc-sections -Os -ffunction-sections -fdata-sections $(get_common_linked_libraries ${api} ${arch})"
export CPPFLAGS=${CFLAGS}
diff --git a/tools/build-android-openssl.sh b/tools/build-android-openssl.sh
index e13c314..5660cec 100755
--- a/tools/build-android-openssl.sh
+++ b/tools/build-android-openssl.sh
@@ -17,7 +17,7 @@
# # read -n1 -p "Press any key to continue..."
set -u
-
+set -x
source ./build-android-common.sh
if [ -z ${version+x} ]; then
@@ -115,6 +115,8 @@ function configure_make() {
if [ $the_rc -eq 0 ] ; then
make SHLIB_EXT='.so' install_sw >>"${OUTPUT_ROOT}/log/${ABI}.log" 2>&1
make install_ssldirs >>"${OUTPUT_ROOT}/log/${ABI}.log" 2>&1
+ else
+ log_error "make returned $the_rc"
fi
popd

66
prebuild/readme.md Normal file
Просмотреть файл

@ -0,0 +1,66 @@
# Mobile Azure EP pre-build
Manual libraries that need to be prebuilt for the Azure operators on Android and iOS.
There is no simple cmake setup that works, so we prebuild as a one-off.
## Requirements:
- pkg-config
- Android
- Android SDK installed with NDK 25 or later
- You can install a package but that means you have to use `sudo` for all updates like installing an NDK
- https://stackoverflow.com/questions/34556884/how-to-install-android-sdk-on-ubuntu
- you still need to manually add the cmdline-tools to that package as well
- probably easier to create a per-user install using command line tools
- Using command line tools
- Download the command line tools from https://developer.android.com/studio
- Download the 'Command line tools only' and unzip
- `mkdir ~/Android`
- `unzip commandlinetools-linux-9477386_latest.zip`
- `mkdir -p ~/Android/cmdline-tools/latest`
- `mv cmdline-tools/* ~/Android/cmdline-tools/latest`
- `export ANDROID_HOME=~/Android`
- Add these to PATH
- ~/Android/cmdline-tools/latest/bin
- ~/Android/platform-tools/bin
- `sdkmanager --list` to make sure the setup works
- Install platform-tools and latest NDK
- `sdkmanager --install platform-tools`
- e.g. `sdkmanager --install ndk;25.2.9519653`
That should be enough to build.
e.g. `./build_lib.sh --android --android_api=24 --android_home=/home/me/Android --android_abi=x86_64 --android_ndk_path=/home/me/Android/ndk/25.2.9519653 --enable_cxx_tests`
See Android documentation for installing a system image with `sdkmanager` and
creating an emulator with `avdmanager`.
- iOS
- TBD
## Android build
Export ANDROID_NDK_ROOT with the value set to the NDK path as this is used by the build script
- e.g. export ANDROID_NDK_ROOT=~/Android/ndk/25.2.9519653
From this directory run `./build_curl_for_android.sh`
An architecture can optionally be specified as the first argument to limit the build to that architecture.
Otherwise all 4 architectures (arm, arm64, x86, x86_64) will be built.
e.g. if you just want to build locally for the emulator you can do `./build_curl_for_android.sh x86_64`
## Android testing
Build with `--enable_cxx_tests`.
This should result in the 'bin' directory of the build output having the two test executables.
Create/start Android emulator
Use `adb push` to copy bin, lib and data directories from the build output to the /data/local/tmp directory
- `adb push build/Android/bin /data/local/tmp`
- repeat for 'lib' and 'data'
- copy the onnxruntime shared library to the lib dir (adjust version number as needed)
- adjust architecture as needed (most likely x86_64 for emulator and arm)
- `adb push build/Android/Debug/_deps/onnxruntime-src/jni/x86_64/libonnxruntime.so /data/local/tmp/lib`
- Connect to emulator
- `adb shell`
- `cd /data/local/tmp`
- Add path to .so
- export LD_LIBRARY_PATH=/data/local/tmp/lib:$LD_LIBRARY_PATH
- Make tests executable
- `chmod +x bin/ocos_test`
- `chmod +x bin/extensions_test`
- Run tests from `tmp` dir so paths to `data` are as expected
- ./bin/ocos_test
- ./bin/extensions_test

Просмотреть файл

@ -154,9 +154,6 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
#endif
#if defined(ENABLE_DR_LIBS)
LoadCustomOpClasses_Audio,
#endif
#if defined(ENABLE_AZURE)
LoadCustomOpClasses_Azure,
#endif
LoadCustomOpClasses<>
};
@ -187,6 +184,10 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
return status;
}
//
// New custom ops should use the com.microsoft.extensions domain.
//
// Create domain for ops using the new domain name.
if (status = ortApi->CreateCustomOpDomain(c_ComMsExtOpDomain, &domain); status) {
return status;
@ -200,6 +201,9 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
#endif
#if defined(ENABLE_TOKENIZER)
LoadCustomOpClasses_Tokenizer,
#endif
#if defined(ENABLE_AZURE)
LoadCustomOpClasses_Azure,
#endif
LoadCustomOpClasses<>
};

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,52 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import helper, TensorProto
def create_audio_model():
auth_token = helper.make_tensor_value_info('auth_token', TensorProto.STRING, [1])
model = helper.make_tensor_value_info('model', TensorProto.STRING, [1])
response_format = helper.make_tensor_value_info('response_format', TensorProto.STRING, [-1])
file = helper.make_tensor_value_info('file', TensorProto.UINT8, [-1])
transcriptions = helper.make_tensor_value_info('transcriptions', TensorProto.STRING, [-1])
invoker = helper.make_node('OpenAIAudioToText',
['auth_token', 'model', 'response_format', 'file'],
['transcriptions'],
domain='com.microsoft.extensions',
name='audio_invoker',
model_uri='https://api.openai.com/v1/audio/transcriptions',
audio_format='wav',
verbose=False)
graph = helper.make_graph([invoker], 'graph', [auth_token, model, response_format, file], [transcriptions])
model = helper.make_model(graph,
opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)])
onnx.save(model, 'openai_audio.onnx')
def create_chat_model():
auth_token = helper.make_tensor_value_info('auth_token', TensorProto.STRING, [-1])
chat = helper.make_tensor_value_info('chat', TensorProto.STRING, [-1])
response = helper.make_tensor_value_info('response', TensorProto.STRING, [-1])
invoker = helper.make_node('AzureTextToText', ['auth_token', 'chat'], ['response'],
domain='com.microsoft.extensions',
name='chat_invoker',
model_uri='https://api.openai.com/v1/chat/completions',
verbose=False)
graph = helper.make_graph([invoker], 'graph', [auth_token, chat], [response])
model = helper.make_model(graph,
opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)])
onnx.save(model, 'openai_chat.onnx')
create_audio_model()
create_chat_model()

Просмотреть файл

@ -0,0 +1,75 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from onnx import helper, numpy_helper, TensorProto
import onnx
import numpy as np
import sys
def order_repeated_field(repeated_proto, key_name, order):
order = list(order)
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
if doc_string == '':
node.doc_string = ''
order_repeated_field(node.attribute, 'name', kwargs.keys())
return node
def make_graph(*args, doc_string=None, **kwargs):
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
if doc_string == '':
graph.doc_string = ''
return graph
# This creates a model that allows the prompt and filename to be optionally provided as inputs.
# The filename can be specified to indicate a different audio type to the default value in the audio_format attribute.
model = helper.make_model(
opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)],
graph=make_graph(
name='OpenAIWhisperTranscribe',
initializer=[
# add default values in the initializers to make the model inputs optional
helper.make_tensor('transcribe0/filename', TensorProto.STRING, [1], [b""]),
helper.make_tensor('transcribe0/prompt', TensorProto.STRING, [1], [b""])
],
inputs=[
helper.make_tensor_value_info('auth_token', TensorProto.STRING, shape=[1]),
helper.make_tensor_value_info('transcribe0/file', TensorProto.UINT8, shape=["bytes"]),
helper.make_tensor_value_info('transcribe0/filename', TensorProto.STRING, shape=["bytes"]), # optional
helper.make_tensor_value_info('transcribe0/prompt', TensorProto.STRING, shape=["bytes"]), # optional
],
outputs=[helper.make_tensor_value_info('transcription', TensorProto.STRING, shape=[1])],
nodes=[
make_node(
'OpenAIAudioToText',
# additional optional request inputs that could be added:
# response_format, temperature, language
# Using a prefix for input names allows the model to have multiple nodes calling cloud endpoints.
# auth_token does not need a prefix unless different auth tokens are used for different nodes.
inputs=['auth_token', 'transcribe0/file', 'transcribe0/filename', 'transcribe0/prompt'],
outputs=['transcription'],
name='OpenAIAudioToText0',
domain='com.microsoft.extensions',
audio_format='wav', # default audio type if filename is not specified.
model_uri='https://api.openai.com/v1/audio/transcriptions',
model_name='whisper-1',
verbose=0,
),
],
),
)
if __name__ == '__main__':
out_path = "openai_whisper_transcriptions.onnx"
if len(sys.argv) == 2:
out_path = sys.argv[1]
onnx.save(model, out_path)

Двоичные данные
test/data/azure/openai_audio.onnx

Двоичный файл не отображается.

Двоичные данные
test/data/azure/openai_chat.onnx

Двоичный файл не отображается.

Просмотреть файл

@ -1,17 +0,0 @@
:…

auth_token
text embeddingembedding_invoker"AzureTextInvoker*4
model_uri"$https://api.openai.com/v1/embeddings *
verbose :ai.onnx.contribgraphZ!
auth_token

ÿÿÿÿÿÿÿÿÿZ
text

ÿÿÿÿÿÿÿÿÿb
embedding

ÿÿÿÿÿÿÿÿÿB

Двоичные данные
test/data/azure/openai_whisper_transcriptions.onnx Normal file

Двоичный файл не отображается.

Двоичные данные
test/data/azure/self-destruct-button.wav Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_AZURE
#include <cstdlib>
#include "gtest/gtest.h"
#include "ocos.h"
#include "narrow.h"
#include "test_kernel.hpp"
#include "utils.hpp"
using namespace ort_extensions;
using namespace ort_extensions::test;
// Test custom op with OpenAIAudioInvoker calling Whisper
// Default input format. No prompt.
TEST(AzureOps, OpenAIWhisper_basic) {
const char* auth_token = std::getenv("OPENAI_AUTH_TOKEN");
if (auth_token == nullptr) {
GTEST_SKIP() << "OPENAI_AUTH_TOKEN environment variable was not set.";
}
auto data_dir = std::filesystem::current_path() / "data" / "azure";
auto model_path = data_dir / "openai_whisper_transcriptions.onnx";
auto audio_path = data_dir / "self-destruct-button.wav";
std::vector<uint8_t> audio_data = LoadBytesFromFile(audio_path);
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs{TestValue("auth_token", {std::string(auth_token)}, {1}),
TestValue("transcribe0/file", audio_data, {narrow<int64_t>(audio_data.size())})};
// punctuation can differ between calls to OpenAI Whisper. sometimes there's a comma after 'button' and sometimes
// a full stop. use a custom output validator that looks for substrings in the output that aren't affected by this.
std::vector<std::string> expected_output{"Thank you for pressing the self-destruct button",
"ship will self-destruct in three minutes"};
// dims are set to '{1}' as we expect one string output. the expected_output is the collection of substrings to look
// for in the single output
std::vector<TestValue> outputs{TestValue("transcription", expected_output, {1})};
OutputValidator find_strings_in_output =
[](size_t output_idx, Ort::Value& actual, TestValue expected) {
std::vector<std::string> output_string;
GetTensorMutableDataString(Ort::GetApi(), actual, output_string);
ASSERT_EQ(output_string.size(), 1) << "Expected the Whisper response to be a single string with json";
for (auto& expected_substring : expected.values_string) {
if (output_string[0].find(expected_substring) == std::string::npos) {
FAIL() << "'" << expected_substring << "' was not found in output " << output_string[0];
}
}
};
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output);
}
// test calling Whisper with a filename to provide mp3 instead of the default wav, and the optional prompt
TEST(AzureOps, OpenAIWhisper_Prompt_CustomFormat) {
const char* auth_token = std::getenv("OPENAI_AUTH_TOKEN");
if (auth_token == nullptr) {
GTEST_SKIP() << "OPENAI_AUTH_TOKEN environment variable was not set.";
}
std::string ort_version{OrtGetApiBase()->GetVersionString()};
auto data_dir = std::filesystem::current_path() / "data" / "azure";
auto model_path = data_dir / "openai_whisper_transcriptions.onnx";
auto audio_path = data_dir / "be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3";
std::vector<uint8_t> audio_data = LoadBytesFromFile(audio_path);
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
// provide filename with 'mp3' extension to indicate audio format. doesn't need to be the 'real' filename
std::vector<TestValue> inputs{TestValue("auth_token", {std::string(auth_token)}, {1}),
TestValue("transcribe0/file", audio_data, {narrow<int64_t>(audio_data.size())}),
TestValue("transcribe0/filename", {std::string("audio.mp3")}, {1})};
std::vector<std::string> expected_output = {"Take some Pepto-Bismol, get dressed, and come on over here."};
std::vector<TestValue> outputs{TestValue("transcription", expected_output, {1})};
OutputValidator find_strings_in_output =
[](size_t output_idx, Ort::Value& actual, TestValue expected) {
std::vector<std::string> output_string;
GetTensorMutableDataString(Ort::GetApi(), actual, output_string);
ASSERT_EQ(output_string.size(), 1) << "Expected the Whisper response to be a single string with json";
const auto& expected_substring = expected.values_string[0];
if (output_string[0].find(expected_substring) == std::string::npos) {
FAIL() << "'" << expected_substring << "' was not found in output " << output_string[0];
}
};
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output);
// use optional 'prompt' input to mis-spell Pepto-Bismol in response
std::string prompt = "Peptoe-Bismole";
inputs.push_back(TestValue("transcribe0/prompt", {prompt}, {1}));
outputs[0].values_string[0] = "Take some Peptoe-Bismole, get dressed, and come on over here.";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output);
}
#endif

Просмотреть файл

@ -12,7 +12,7 @@
#include "ocos.h"
#include "test_kernel.hpp"
#include "test_utils.hpp"
#include "utils.hpp"
using namespace ort_extensions::test;

22
test/shared_test/utils.cc Normal file
Просмотреть файл

@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "utils.hpp"
#include <fstream>
namespace ort_extensions {
namespace test {
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename) {
using namespace std;
ifstream ifs(filename, ios::binary | ios::ate);
ifstream::pos_type pos = ifs.tellg();
std::vector<uint8_t> input_bytes(pos);
ifs.seekg(0, ios::beg);
// we want uint8_t values so reinterpret_cast so we don't have to read chars and copy to uint8_t after.
ifs.read(reinterpret_cast<char*>(input_bytes.data()), pos);
return input_bytes;
}
} // namespace test
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include <vector>
namespace ort_extensions {
namespace test {
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename);
} // namespace test
} // namespace ort_extensions

Просмотреть файл

@ -12,6 +12,7 @@ script_dir = os.path.dirname(os.path.realpath(__file__))
ort_ext_root = os.path.abspath(os.path.join(script_dir, ".."))
test_data_dir = os.path.join(ort_ext_root, "test", "data", "azure")
class TestAzureOps(unittest.TestCase):
def __init__(self, config):
@ -21,7 +22,7 @@ class TestAzureOps(unittest.TestCase):
self.__opt = SessionOptions()
self.__opt.register_custom_ops_library(get_library_path())
def test_addf(self):
def test_add_f(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_addf.onnx"),
self.__opt, providers=["CPUExecutionProvider"])
@ -36,7 +37,7 @@ class TestAzureOps(unittest.TestCase):
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def testAddf8(self):
def test_add_f8(self):
if self.__enabled:
opt = SessionOptions()
opt.register_custom_ops_library(get_library_path())
@ -53,7 +54,7 @@ class TestAzureOps(unittest.TestCase):
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def testAddi4(self):
def test_add_i4(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_addi4.onnx"),
self.__opt, providers=["CPUExecutionProvider"])
@ -68,7 +69,7 @@ class TestAzureOps(unittest.TestCase):
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def testAnd(self):
def test_and(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_and.onnx"),
self.__opt, providers=["CPUExecutionProvider"])
@ -83,7 +84,7 @@ class TestAzureOps(unittest.TestCase):
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [True, False]))
def testStr(self):
def test_str(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_str.onnx"),
self.__opt, providers=["CPUExecutionProvider"])
@ -98,7 +99,7 @@ class TestAzureOps(unittest.TestCase):
self.assertEqual(outs[0], ['this is the input'])
self.assertEqual(outs[1], ['this is the input'])
def testOpenAiAudio(self):
def test_open_ai_audio(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "openai_audio.onnx"),
self.__opt, providers=["CPUExecutionProvider"])
@ -110,14 +111,14 @@ class TestAzureOps(unittest.TestCase):
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
ort_inputs = {
"auth_token": auth_token,
"model": model,
"model_name": model,
"response_format": response_format,
"file": audio_blob
"file": audio_blob,
}
out = sess.run(None, ort_inputs)[0]
self.assertEqual(out, ['This is a test recording to test the Whisper model.\n'])
def testOpenAiChat(self):
def test_open_ai_chat(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "openai_chat.onnx"),
self.__opt, providers=["CPUExecutionProvider"])
@ -130,23 +131,6 @@ class TestAzureOps(unittest.TestCase):
out = sess.run(None, ort_inputs)[0]
self.assertTrue('assist' in out[0])
def testOpenAiEmb(self):
if self.__enabled:
opt = SessionOptions()
opt.register_custom_ops_library(get_library_path())
sess = InferenceSession(os.path.join(test_data_dir, "openai_embedding.onnx"),
opt, providers=["CPUExecutionProvider"])
auth_token = np.array([os.getenv('EMB', '')])
text = np.array(['{\"input\": \"The food was delicious and the waiter...\", \"model\": \"text-embedding-ada-002\"}'])
ort_inputs = {
"auth_token": auth_token,
"text": text,
}
out = sess.run(None, ort_inputs)[0]
self.assertTrue('text-embedding-ada' in out[0])
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -154,7 +154,7 @@ def _parse_arguments():
# WebAssembly options
parser.add_argument("--wasm", action="store_true", help="Build for WebAssembly")
parser.add_argument("--emsdk_path", type=Path,
help="Specify path to emscripten SDK. Setup manually with: "
help="Specify path to emscripten SDK. Setup manually with: "
" git clone https://github.com/emscripten-core/emsdk")
parser.add_argument("--emsdk_version", default="3.1.26", help="Specify version of emsdk")
@ -404,11 +404,11 @@ def _generate_build_tree(cmake_path: Path,
_run_subprocess(cmake_args + [f"-DCMAKE_BUILD_TYPE={config}"], cwd=config_build_dir)
def clean_targets(cmake_path, build_dir: Path, configs: Set[str]):
def clean_targets(cmake_path: Path, build_dir: Path, configs: Set[str]):
for config in configs:
log.info("Cleaning targets for %s configuration", config)
build_dir2 = _get_build_config_dir(build_dir, config)
cmd_args = [cmake_path, "--build", build_dir2, "--config", config, "--target", "clean"]
cmd_args = [str(cmake_path), "--build", str(build_dir2), "--config", config, "--target", "clean"]
_run_subprocess(cmd_args)
@ -564,6 +564,10 @@ def main():
cmake_path = _resolve_executable_path(
args.cmake_path,
resolution_failure_allowed=(not (args.update or args.clean or args.build)))
if not cmake_path:
raise UsageError("Unable to find CMake executable. Please specify --cmake-path.")
build_dir = args.build_dir
if args.update or args.build:

Просмотреть файл

@ -13,4 +13,4 @@ if "%1" == "install" (
del "%ProgramFiles%\Miniconda3\python3.exe"
)
)
)
)

Просмотреть файл

@ -3,10 +3,10 @@
if [[ "$OCOS_ENABLE_AZURE" == "1" ]]
then
if [[ "$1" == "many64" ]]; then
yum -y install openssl openssl-devel wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /usr/lib64/libhttpclient.ldscript
yum -y install openssl-devel
elif [[ "$1" == "many86" ]]; then
yum -y install openssl openssl-devel wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /opt/rh/devtoolset-10/root/usr/lib/libhttpclient.ldscript
yum -y install openssl-devel
else # for musllinux
apk add openssl-dev wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /usr/lib/libhttpclient.ldscript
apk add openssl-dev
fi
fi