Refactor setup for Azure ops. Add Android support. (#507)
* Refactor setup for Azure ops to try and make common things more re-usable, and for the actual ops to simply layer in the specific input/output constraints for that type of request. Currently builds on Linux, Windows (x64 only) and Android Android requires a manual pre-build of openssl and curl. Linux requires a manual pre-install of openssl. Windows currently only works for x64. Other targets need the triplet adjusted. * Address PR comments * Fix could of android build warnings. * Update .gitignore to remove old path * Fix build break from merge
This commit is contained in:
Родитель
5881931bf2
Коммит
2bde82fce9
|
@ -1,9 +1,16 @@
|
|||
parameters:
|
||||
- name: ExtraEnv
|
||||
displayName: 'Extra env variable set to CIBW_ENVIRONMENT'
|
||||
type: string
|
||||
default: 'None=None'
|
||||
|
||||
jobs:
|
||||
- job: windows
|
||||
timeoutInMinutes: 120
|
||||
pool: {name: 'onnxruntime-Win-CPU-2022'}
|
||||
variables:
|
||||
CIBW_BUILD: "cp3{8,9,10,11}-*amd64"
|
||||
CIBW_ENVIRONMENT: "${{ parameters.ExtraEnv }}"
|
||||
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
|
|
138
CMakeLists.txt
138
CMakeLists.txt
|
@ -380,8 +380,22 @@ endif()
|
|||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
# Azure endpoint invokers
|
||||
include(triton)
|
||||
file(GLOB TARGET_SRC_AZURE "operators/azure/*.cc" "operators/azure/*.h*")
|
||||
|
||||
if (ANDROID)
|
||||
include(curl)
|
||||
|
||||
# make sure these work
|
||||
find_package(OpenSSL REQUIRED)
|
||||
find_package(CURL REQUIRED)
|
||||
|
||||
# exclude triton
|
||||
list(FILTER TARGET_SRC_AZURE EXCLUDE REGEX ".*triton.*")
|
||||
else()
|
||||
add_compile_definitions(AZURE_INVOKERS_ENABLE_TRITON)
|
||||
include(triton)
|
||||
endif()
|
||||
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_AZURE})
|
||||
endif()
|
||||
|
||||
|
@ -454,11 +468,13 @@ if(_HAS_TOKENIZER)
|
|||
endif()
|
||||
|
||||
if(OCOS_ENABLE_TF_STRING)
|
||||
target_include_directories(noexcep_operators PUBLIC
|
||||
${googlere2_SOURCE_DIR}
|
||||
${farmhash_SOURCE_DIR}/src)
|
||||
target_include_directories(noexcep_operators PUBLIC ${farmhash_SOURCE_DIR}/src)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TF_STRING NOMINMAX FARMHASH_NO_BUILTIN_EXPECT FARMHASH_DEBUG=0)
|
||||
target_link_libraries(noexcep_operators PRIVATE re2)
|
||||
|
||||
if(OCOS_ENABLE_RE2_REGEX)
|
||||
target_include_directories(noexcep_operators PUBLIC ${googlere2_SOURCE_DIR})
|
||||
target_link_libraries(noexcep_operators PRIVATE re2)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_RE2_REGEX)
|
||||
|
@ -526,6 +542,10 @@ if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
|||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_AZURE)
|
||||
endif()
|
||||
|
||||
target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR})
|
||||
list(APPEND ocos_libraries Microsoft.GSL::GSL)
|
||||
|
||||
|
@ -577,9 +597,79 @@ else()
|
|||
set(_BUILD_SHARED_LIBRARY TRUE)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
if (ANDROID)
|
||||
# the find_package calls were made immediately after `include(curl)` so we know CURL and OpenSSL are available
|
||||
target_link_libraries(ocos_operators PUBLIC CURL::libcurl OpenSSL::Crypto OpenSSL::SSL)
|
||||
elseif(IOS)
|
||||
# TODO
|
||||
else()
|
||||
# we need files from the triton client (e.g. curl header on linux) to be available for the ocos_operators build.
|
||||
# add a dependency so the fetch and build of triton happens first.
|
||||
add_dependencies(ocos_operators triton)
|
||||
target_include_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/include)
|
||||
|
||||
target_link_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/lib)
|
||||
target_link_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/lib64)
|
||||
|
||||
if (WIN32)
|
||||
if (ocos_target_platform STREQUAL "AMD64")
|
||||
set(vcpkg_target_platform "x64")
|
||||
else()
|
||||
set(vcpkg_target_platform ${ocos_target_platform})
|
||||
endif()
|
||||
|
||||
# As per https://curl.se/docs/faq.html#Link_errors_when_building_libcur we need to set CURL_STATICLIB.
|
||||
target_compile_definitions(ocos_operators PRIVATE CURL_STATICLIB)
|
||||
target_include_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/include)
|
||||
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
set(curl_LIB_NAME "libcurl-d")
|
||||
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/debug/lib)
|
||||
else()
|
||||
set(curl_LIB_NAME "libcurl")
|
||||
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/lib)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ocos_operators PUBLIC httpclient_static ${curl_LIB_NAME} ws2_32 crypt32 Wldap32)
|
||||
else()
|
||||
find_package(ZLIB REQUIRED)
|
||||
|
||||
# If finding the OpenSSL or CURL package fails for a local build you can install libcurl4-openssl-dev.
|
||||
# See also info on triton client dependencies here: https://github.com/triton-inference-server/client
|
||||
find_package(OpenSSL REQUIRED)
|
||||
find_package(CURL)
|
||||
|
||||
if (CURL_FOUND)
|
||||
message(STATUS "Found CURL package")
|
||||
set(libcurl_target CURL::libcurl)
|
||||
else()
|
||||
# curl is coming from triton but as that's an external project it isn't built yet and we have to add
|
||||
# paths and library names instead of cmake targets.
|
||||
message(STATUS "Using CURL build from triton client. Once built it should be in ${triton_THIRD_PARTY_DIR}/curl")
|
||||
target_include_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/include)
|
||||
|
||||
# Install is to 'lib' except on CentOS (which is used for the manylinux build of the python wheel).
|
||||
# Side note: we have to patch the triton client CMakeLists.txt to only use 'lib64' for 64-bit builds otherwise
|
||||
# the build of the 32-bit python wheel fails with CURL not being found due to the invalid library
|
||||
# directory name.
|
||||
target_link_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/lib)
|
||||
target_link_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/lib64)
|
||||
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
set(libcurl_target "curl-d")
|
||||
else()
|
||||
set(libcurl_target "curl")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_link_libraries(ocos_operators PUBLIC httpclient_static ${libcurl_target} OpenSSL::Crypto OpenSSL::SSL ZLIB::ZLIB)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
|
||||
target_include_directories(ortcustomops PUBLIC
|
||||
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
target_link_libraries(ortcustomops PUBLIC ocos_operators)
|
||||
|
||||
if(_BUILD_SHARED_LIBRARY)
|
||||
|
@ -592,8 +682,7 @@ if(_BUILD_SHARED_LIBRARY)
|
|||
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
|
||||
endif()
|
||||
|
||||
target_include_directories(extensions_shared PUBLIC
|
||||
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
target_include_directories(extensions_shared PUBLIC "$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
target_link_libraries(extensions_shared PRIVATE ortcustomops)
|
||||
set_target_properties(extensions_shared PROPERTIES OUTPUT_NAME "ortextensions")
|
||||
if(MSVC AND ocos_target_platform MATCHES "x86|x64")
|
||||
|
@ -672,8 +761,10 @@ if(OCOS_ENABLE_CTEST)
|
|||
list(APPEND LINUX_CC_FLAGS stdc++fs -pthread)
|
||||
endif()
|
||||
|
||||
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc")
|
||||
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc" "${TEST_SRC_DIR}/shared_test/*.hpp")
|
||||
add_executable(extensions_test ${shared_TEST_SRC})
|
||||
target_compile_definitions(extensions_test PUBLIC ${OCOS_COMPILE_DEFINITIONS})
|
||||
|
||||
standardize_output_folder(extensions_test)
|
||||
target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}
|
||||
"$<TARGET_PROPERTY:extensions_shared,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
|
@ -706,30 +797,3 @@ if(OCOS_ENABLE_CTEST)
|
|||
add_test(NAME extensions_test COMMAND $<TARGET_FILE:extensions_test>)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
|
||||
add_dependencies(ocos_operators triton)
|
||||
target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
|
||||
target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
|
||||
|
||||
if (ocos_target_platform STREQUAL "AMD64")
|
||||
set(vcpkg_target_platform "x86")
|
||||
else()
|
||||
set(vcpkg_target_platform ${ocos_target_platform})
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
|
||||
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib)
|
||||
target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32)
|
||||
|
||||
else()
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
|
||||
|
||||
endif() #if (WIN32)
|
||||
|
||||
endif()
|
|
@ -40,15 +40,28 @@ else()
|
|||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "ONNX Runtime URL suffix: ${ONNXRUNTIME_URL}")
|
||||
if (ANDROID)
|
||||
set(ort_fetch_URL "https://repo1.maven.org/maven2/com/microsoft/onnxruntime/onnxruntime-android/${ONNXRUNTIME_VER}/onnxruntime-android-${ONNXRUNTIME_VER}.aar")
|
||||
else()
|
||||
set(ort_fetch_URL "https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}")
|
||||
endif()
|
||||
|
||||
message(STATUS "ONNX Runtime URL: ${ort_fetch_URL}")
|
||||
FetchContent_Declare(
|
||||
onnxruntime
|
||||
URL https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}
|
||||
URL ${ort_fetch_URL}
|
||||
)
|
||||
|
||||
FetchContent_makeAvailable(onnxruntime)
|
||||
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
|
||||
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
|
||||
|
||||
if (ANDROID)
|
||||
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/headers)
|
||||
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/jni/${ANDROID_ABI})
|
||||
message(STATUS "Android onnxruntime inc=${ONNXRUNTIME_INCLUDE_DIR} lib=${ONNXRUNTIME_LIB_DIR}")
|
||||
else()
|
||||
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
|
||||
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${ONNXRUNTIME_INCLUDE_DIR})
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
if (ANDROID)
|
||||
set(PREBUILD_OUTPUT_PATH "${PROJECT_SOURCE_DIR}/prebuild/openssl_for_ios_and_android/output/android")
|
||||
set(OPENSSL_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/openssl-${ANDROID_ABI}")
|
||||
set(NGHTTP2_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/nghttp2-${ANDROID_ABI}")
|
||||
set(CURL_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/curl-${ANDROID_ABI}")
|
||||
|
||||
# Update CMAKE_FIND_ROOT_PATH so find_package/find_library can find these builds
|
||||
list(APPEND CMAKE_FIND_ROOT_PATH "${OPENSSL_ROOT_DIR}" )
|
||||
list(APPEND CMAKE_FIND_ROOT_PATH "${NGHTTP2_ROOT_DIR}" )
|
||||
list(APPEND CMAKE_FIND_ROOT_PATH "${CURL_ROOT_DIR}" )
|
||||
endif()
|
|
@ -1,32 +1,41 @@
|
|||
include(ExternalProject)
|
||||
|
||||
if (WIN32)
|
||||
set(triton_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton)
|
||||
set(triton_INSTALL_DIR ${triton_PREFIX}/install)
|
||||
|
||||
if (WIN32)
|
||||
if (ocos_target_platform STREQUAL "AMD64")
|
||||
set(vcpkg_target_platform "x86")
|
||||
set(vcpkg_target_platform "x64")
|
||||
else()
|
||||
set(vcpkg_target_platform ${ocos_target_platform})
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(vcpkg
|
||||
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
|
||||
GIT_TAG 2023.06.20
|
||||
PREFIX vcpkg
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-build
|
||||
CONFIGURE_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND ""
|
||||
BUILD_COMMAND "<SOURCE_DIR>/bootstrap-vcpkg.bat")
|
||||
set(vcpkg_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg)
|
||||
|
||||
set(VCPKG_SRC ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
|
||||
set(ENV{VCPKG_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
|
||||
ExternalProject_Add(vcpkg
|
||||
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
|
||||
GIT_TAG 2023.06.20
|
||||
PREFIX ${vcpkg_PREFIX}
|
||||
CONFIGURE_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND ""
|
||||
BUILD_COMMAND "<SOURCE_DIR>/bootstrap-vcpkg.bat")
|
||||
|
||||
ExternalProject_Get_Property(vcpkg SOURCE_DIR BINARY_DIR)
|
||||
set(VCPKG_SRC ${SOURCE_DIR})
|
||||
message(status "vcpkg source dir: " ${VCPKG_SRC})
|
||||
|
||||
# set the environment variable so that the vcpkg.cmake file can find the vcpkg root directory
|
||||
set(ENV{VCPKG_ROOT} ${VCPKG_SRC})
|
||||
|
||||
message(STATUS "VCPKG_SRC: " ${VCPKG_SRC})
|
||||
message(STATUS "VCPKG_ROOT: " $ENV{VCPKG_ROOT})
|
||||
message(STATUS "ENV{VCPKG_ROOT}: " $ENV{VCPKG_ROOT})
|
||||
|
||||
# NOTE: The VCPKG_ROOT environment variable isn't propagated to an add_custom_command target, so specify --vcpkg-root
|
||||
# here and in the vcpkg_install function
|
||||
add_custom_command(
|
||||
COMMAND ${VCPKG_SRC}/vcpkg integrate install
|
||||
COMMAND ${CMAKE_COMMAND} -E echo ${VCPKG_SRC}/vcpkg integrate --vcpkg-root=$ENV{VCPKG_ROOT} install
|
||||
COMMAND ${VCPKG_SRC}/vcpkg integrate --vcpkg-root=$ENV{VCPKG_ROOT} install
|
||||
COMMAND ${CMAKE_COMMAND} -E touch vcpkg_integrate.stamp
|
||||
OUTPUT vcpkg_integrate.stamp
|
||||
DEPENDS vcpkg
|
||||
|
@ -35,77 +44,97 @@ if (WIN32)
|
|||
add_custom_target(vcpkg_integrate ALL DEPENDS vcpkg_integrate.stamp)
|
||||
set(VCPKG_DEPENDENCIES "vcpkg_integrate")
|
||||
|
||||
# use static-md so it adjusts for debug/release CRT
|
||||
# https://stackoverflow.com/questions/67258905/vcpkg-difference-between-windows-windows-static-and-other
|
||||
function(vcpkg_install PACKAGE_NAME)
|
||||
add_custom_command(
|
||||
OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO
|
||||
COMMAND ${VCPKG_SRC}/vcpkg install ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static --vcpkg-root=${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
|
||||
OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static-md/BUILD_INFO
|
||||
COMMAND ${CMAKE_COMMAND} -E echo ${VCPKG_SRC}/vcpkg install --vcpkg-root=$ENV{VCPKG_ROOT}
|
||||
${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static-md
|
||||
COMMAND ${VCPKG_SRC}/vcpkg install --vcpkg-root=$ENV{VCPKG_ROOT}
|
||||
${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static-md
|
||||
WORKING_DIRECTORY ${VCPKG_SRC}
|
||||
DEPENDS vcpkg_integrate)
|
||||
|
||||
add_custom_target(get${PACKAGE_NAME}
|
||||
add_custom_target(
|
||||
get${PACKAGE_NAME}
|
||||
ALL
|
||||
DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO)
|
||||
DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static-md/BUILD_INFO)
|
||||
|
||||
list(APPEND VCPKG_DEPENDENCIES "get${PACKAGE_NAME}")
|
||||
set(VCPKG_DEPENDENCIES ${VCPKG_DEPENDENCIES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
vcpkg_install(openssl)
|
||||
vcpkg_install(openssl-windows)
|
||||
vcpkg_install(rapidjson)
|
||||
vcpkg_install(re2)
|
||||
vcpkg_install(boost-interprocess)
|
||||
vcpkg_install(boost-stacktrace)
|
||||
vcpkg_install(pthread)
|
||||
vcpkg_install(b64)
|
||||
vcpkg_install(openssl)
|
||||
vcpkg_install(curl)
|
||||
|
||||
add_dependencies(getb64 getpthread)
|
||||
add_dependencies(getpthread getboost-stacktrace)
|
||||
add_dependencies(getboost-stacktrace getboost-interprocess)
|
||||
add_dependencies(getboost-interprocess getre2)
|
||||
add_dependencies(getre2 getrapidjson)
|
||||
add_dependencies(getrapidjson getopenssl-windows)
|
||||
add_dependencies(getopenssl-windows getopenssl)
|
||||
|
||||
ExternalProject_Add(triton
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
|
||||
GIT_TAG r23.05
|
||||
PREFIX triton
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
|
||||
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND "")
|
||||
|
||||
add_dependencies(triton ${VCPKG_DEPENDENCIES})
|
||||
set(triton_extra_cmake_args -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static-md
|
||||
-DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake)
|
||||
set(triton_patch_command "")
|
||||
set(triton_dependencies ${VCPKG_DEPENDENCIES})
|
||||
|
||||
else()
|
||||
# RapidJSON 1.1.0 (released in 2016) is compatible with the triton build. Later code is not compatible without
|
||||
# patching due to the change in variable name for the include dir from RAPIDJSON_INCLUDE_DIRS to
|
||||
# RapidJSON_INCLUDE_DIRS in the generated cmake file used by find_package:
|
||||
# https://github.com/Tencent/rapidjson/commit/b91c515afea9f0ba6a81fc670889549d77c83db3
|
||||
# The triton code here https://github.com/triton-inference-server/common/blob/main/CMakeLists.txt is using
|
||||
# RAPIDJSON_INCLUDE_DIRS so the build fails if a newer RapidJSON version is used. It will find the package but the
|
||||
# include path will be wrong so the build error is delayed/misleading and non-trivial to understand/resolve.
|
||||
set(RapidJSON_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson)
|
||||
set(RapidJSON_INSTALL_DIR ${RapidJSON_PREFIX}/install)
|
||||
ExternalProject_Add(RapidJSON
|
||||
PREFIX ${RapidJSON_PREFIX}
|
||||
URL https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip
|
||||
URL_HASH SHA1=0fe7b4f7b83df4b3d517f4a202f3a383af7a0818
|
||||
CMAKE_ARGS -DRAPIDJSON_BUILD_DOC=OFF
|
||||
-DRAPIDJSON_BUILD_EXAMPLES=OFF
|
||||
-DRAPIDJSON_BUILD_TESTS=OFF
|
||||
-DRAPIDJSON_HAS_STDSTRING=ON
|
||||
-DRAPIDJSON_USE_MEMBERSMAP=ON
|
||||
-DCMAKE_INSTALL_PREFIX=${RapidJSON_INSTALL_DIR}
|
||||
)
|
||||
|
||||
ExternalProject_Add(curl7
|
||||
PREFIX curl7
|
||||
GIT_REPOSITORY "https://github.com/curl/curl.git"
|
||||
GIT_TAG "curl-7_86_0"
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-build
|
||||
CMAKE_ARGS -DBUILD_TESTING=OFF -DBUILD_CURL_EXE=OFF -DBUILD_SHARED_LIBS=OFF -DCURL_STATICLIB=ON -DHTTP_ONLY=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE})
|
||||
ExternalProject_Get_Property(RapidJSON SOURCE_DIR BINARY_DIR)
|
||||
# message(STATUS "RapidJSON src=${SOURCE_DIR} binary=${BINARY_DIR}")
|
||||
# Set RapidJSON_ROOT_DIR for find_package. The required RapidJSONConfig.cmake file is generated in the binary dir
|
||||
set(RapidJSON_ROOT_DIR ${BINARY_DIR})
|
||||
|
||||
ExternalProject_Add(triton
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
|
||||
GIT_TAG r23.05
|
||||
PREFIX triton
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND "")
|
||||
set(triton_extra_cmake_args "")
|
||||
set(triton_patch_command patch --verbose -p1 -i ${PROJECT_SOURCE_DIR}/cmake/externals/triton_cmake.patch)
|
||||
set(triton_dependencies RapidJSON)
|
||||
|
||||
add_dependencies(triton curl7)
|
||||
# Patch the triton client CMakeLists.txt to fix two issues when building the python wheels with cibuildwheel, which
|
||||
# uses CentOS 7.
|
||||
# 1) use the full path to the version script file so 'ld' doesn't fail to find it. Looks like ld is running from the
|
||||
# parent directory but not sure why the behavior differs vs. other linux builds
|
||||
# e.g. building locally on Ubuntu is fine without the patch
|
||||
# 2) only set the CURL lib path to 'lib64' on a 64-bit CentOS build as 'lib64' is invalid on a 32-bit OS. without
|
||||
# this patch the build of the third-party libraries in the triton client fail as the CURL build is not found.
|
||||
|
||||
endif() #if (WIN32)
|
||||
endif() #if (WIN32)
|
||||
|
||||
ExternalProject_Get_Property(triton SOURCE_DIR)
|
||||
set(TRITON_SRC ${SOURCE_DIR})
|
||||
# Add the triton build. We just need the library so we don't install it.
|
||||
#
|
||||
set(triton_VERSION_TAG r23.05)
|
||||
ExternalProject_Add(triton
|
||||
URL https://github.com/triton-inference-server/client/archive/refs/heads/${triton_VERSION_TAG}.tar.gz
|
||||
URL_HASH SHA1=b8fd2a4e09eae39c33cd04cfa9ec934e39d9afc1
|
||||
PREFIX ${triton_PREFIX}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${triton_INSTALL_DIR}
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
-DTRITON_COMMON_REPO_TAG=${triton_VERSION_TAG}
|
||||
-DTRITON_THIRD_PARTY_REPO_TAG=${triton_VERSION_TAG}
|
||||
-DTRITON_CORE_REPO_TAG=${triton_VERSION_TAG}
|
||||
-DTRITON_ENABLE_CC_HTTP=ON
|
||||
-DTRITON_ENABLE_ZLIB=OFF
|
||||
${triton_extra_cmake_args}
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step."
|
||||
PATCH_COMMAND ${triton_patch_command}
|
||||
)
|
||||
|
||||
ExternalProject_Get_Property(triton BINARY_DIR)
|
||||
set(TRITON_BIN ${BINARY_DIR}/binary)
|
||||
set(TRITON_THIRD_PARTY ${BINARY_DIR}/third-party)
|
||||
add_dependencies(triton ${triton_dependencies})
|
||||
|
||||
ExternalProject_Get_Property(triton SOURCE_DIR BINARY_DIR)
|
||||
set(triton_THIRD_PARTY_DIR ${BINARY_DIR}/third-party)
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 7b11178..7749fa9 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -115,10 +115,11 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_CC_GRPC OR TRITON_ENABLE_PERF_ANALYZER
|
||||
file(STRINGS /etc/os-release DISTRO REGEX "^NAME=")
|
||||
string(REGEX REPLACE "NAME=\"(.*)\"" "\\1" DISTRO "${DISTRO}")
|
||||
message(STATUS "Distro Name: ${DISTRO}")
|
||||
- if(DISTRO STREQUAL "CentOS Linux")
|
||||
+ if(DISTRO STREQUAL "CentOS Linux" AND CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
set (CURL_LIB_DIR "lib64")
|
||||
endif()
|
||||
endif()
|
||||
+ message(STATUS "Triton client CURL_LIB_DIR=${CURL_LIB_DIR}")
|
||||
|
||||
set(_cc_client_depends "")
|
||||
if(${TRITON_ENABLE_CC_HTTP})
|
||||
diff --git a/src/c++/library/CMakeLists.txt b/src/c++/library/CMakeLists.txt
|
||||
index bdaae25..c36dbc8 100644
|
||||
--- a/src/c++/library/CMakeLists.txt
|
||||
+++ b/src/c++/library/CMakeLists.txt
|
||||
@@ -320,7 +320,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
|
||||
httpclient
|
||||
PROPERTIES
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libhttpclient.ldscript
|
||||
- LINK_FLAGS "-Wl,--version-script=libhttpclient.ldscript"
|
||||
+ LINK_FLAGS "-Wl,--version-script=${CMAKE_CURRENT_BINARY_DIR}/libhttpclient.ldscript"
|
||||
)
|
||||
endif() # NOT WIN32
|
||||
|
|
@ -1,427 +1,95 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#define CURL_STATICLIB
|
||||
|
||||
#include "http_client.h"
|
||||
#include "curl/curl.h"
|
||||
#include "azure_invokers.hpp"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#define MIN_SUPPORTED_ORT_VER 14
|
||||
namespace ort_extensions {
|
||||
|
||||
constexpr const char* kUri = "model_uri";
|
||||
constexpr const char* kModelName = "model_name";
|
||||
constexpr const char* kModelVer = "model_version";
|
||||
constexpr const char* kVerbose = "verbose";
|
||||
constexpr const char* kBinaryType = "binary_type";
|
||||
////////////////////// AzureAudioToTextInvoker //////////////////////
|
||||
|
||||
struct StringBuffer {
|
||||
StringBuffer() = default;
|
||||
~StringBuffer() = default;
|
||||
std::stringstream ss_;
|
||||
};
|
||||
AzureAudioToTextInvoker::AzureAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CurlInvoker(api, info) {
|
||||
audio_format_ = TryToGetAttributeWithDefault<std::string>(kAudioFormat, "");
|
||||
}
|
||||
|
||||
// apply the callback only when response is for sure to be a '/0' terminated string
|
||||
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
|
||||
try {
|
||||
size_t realsize = size * nmemb;
|
||||
auto buffer = reinterpret_cast<struct StringBuffer*>(userp);
|
||||
buffer->ss_.write(reinterpret_cast<const char*>(contents), realsize);
|
||||
return realsize;
|
||||
} catch (...) {
|
||||
// exception caught, abort write
|
||||
return CURLcode::CURLE_WRITE_ERROR;
|
||||
void AzureAudioToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const {
|
||||
// TODO: Validate any required input names are present
|
||||
|
||||
// We don't have a way to get the output type from the custom op API.
|
||||
// If there's a mismatch it will fail in the Compute when it allocates the output tensor.
|
||||
if (OutputNames().size() != 1) {
|
||||
ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);
|
||||
void AzureAudioToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const {
|
||||
// theoretically the filename the content was buffered from
|
||||
static const std::string fake_filename = "audio." + audio_format_;
|
||||
|
||||
class CurlHandler {
|
||||
public:
|
||||
CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup),
|
||||
headers_(nullptr, curl_slist_free_all),
|
||||
from_holder_(from_, curl_formfree) {
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1");
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back);
|
||||
}
|
||||
~CurlHandler() = default;
|
||||
const auto& property_names = RequestPropertyNames();
|
||||
|
||||
void AddHeader(const char* data) {
|
||||
headers_.reset(curl_slist_append(headers_.release(), data));
|
||||
}
|
||||
template <typename... Args>
|
||||
void AddForm(Args... args) {
|
||||
curl_formadd(&from_, &last_, args...);
|
||||
}
|
||||
template <typename T>
|
||||
void SetOption(CURLoption opt, T val) {
|
||||
curl_easy_setopt(curl_.get(), opt, val);
|
||||
}
|
||||
CURLcode Perform() {
|
||||
SetOption(CURLOPT_HTTPHEADER, headers_.get());
|
||||
if (from_) {
|
||||
SetOption(CURLOPT_HTTPPOST, from_);
|
||||
}
|
||||
return curl_easy_perform(curl_.get());
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
|
||||
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
|
||||
curl_httppost* from_{};
|
||||
curl_httppost* last_{};
|
||||
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_;
|
||||
};
|
||||
|
||||
////////////////////// AzureInvoker //////////////////////
|
||||
|
||||
AzureInvoker::AzureInvoker(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
auto ver = GetActiveOrtAPIVersion();
|
||||
if (ver < MIN_SUPPORTED_ORT_VER) {
|
||||
ORTX_CXX_API_THROW("Azure ops requires ort >= 1.14", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
model_uri_ = TryToGetAttributeWithDefault<std::string>(kUri, "");
|
||||
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
|
||||
model_ver_ = TryToGetAttributeWithDefault<std::string>(kModelVer, "0");
|
||||
verbose_ = TryToGetAttributeWithDefault<std::string>(kVerbose, "0");
|
||||
OrtStatusPtr status = {};
|
||||
size_t input_count = {};
|
||||
status = api_.KernelInfo_GetInputCount(&info_, &input_count);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
|
||||
char input_name[1024] = {};
|
||||
size_t name_size = 1024;
|
||||
status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get input name", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
input_names_.push_back(input_name);
|
||||
}
|
||||
|
||||
size_t output_count = {};
|
||||
status = api_.KernelInfo_GetOutputCount(&info_, &output_count);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
for (size_t ith_output = 0; ith_output < output_count; ++ith_output) {
|
||||
char output_name[1024] = {};
|
||||
size_t name_size = 1024;
|
||||
status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get output name", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
output_names_.push_back(output_name);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////// AzureAudioInvoker //////////////////////
|
||||
|
||||
AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
|
||||
file_name_ = std::string{"non_exist."} + TryToGetAttributeWithDefault<std::string>(kBinaryType, "wav");
|
||||
}
|
||||
|
||||
void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const {
|
||||
if (inputs.Size() < 1 ||
|
||||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
if (inputs.Size() != input_names_.size()) {
|
||||
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
if (inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || "auth_token" != input_names_[0]) {
|
||||
ORTX_CXX_API_THROW("first input must be a string of auth token", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::string auth_token = reinterpret_cast<const char*>(inputs[0]->DataRaw());
|
||||
std::string full_auth = std::string{"Authorization: Bearer "} + auth_token;
|
||||
|
||||
StringBuffer string_buffer;
|
||||
CurlHandler curl_handler(WriteStringCallback);
|
||||
curl_handler.AddHeader(full_auth.c_str());
|
||||
curl_handler.AddHeader("Content-Type: multipart/form-data");
|
||||
curl_handler.AddFormString("deployment_id", ModelName().c_str());
|
||||
|
||||
// TODO: If the handling here stays the same as in OpenAIAudioToText we can create a helper function to re-use
|
||||
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
|
||||
switch (inputs[ith_input]->Type()) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
curl_handler.AddForm(CURLFORM_COPYNAME,
|
||||
input_names_[ith_input].c_str(),
|
||||
CURLFORM_COPYCONTENTS,
|
||||
inputs[ith_input]->DataRaw(),
|
||||
CURLFORM_END);
|
||||
curl_handler.AddFormString(property_names[ith_input].c_str(),
|
||||
static_cast<const char*>(inputs[ith_input]->DataRaw())); // assumes null terminated
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
curl_handler.AddForm(CURLFORM_COPYNAME,
|
||||
input_names_[ith_input].data(),
|
||||
CURLFORM_BUFFER,
|
||||
file_name_.c_str(),
|
||||
CURLFORM_BUFFERPTR,
|
||||
inputs[ith_input]->DataRaw(),
|
||||
CURLFORM_BUFFERLENGTH,
|
||||
inputs[ith_input]->SizeInBytes(),
|
||||
CURLFORM_END);
|
||||
curl_handler.AddFormBuffer(property_names[ith_input].c_str(),
|
||||
fake_filename.c_str(),
|
||||
inputs[ith_input]->DataRaw(),
|
||||
inputs[ith_input]->SizeInBytes());
|
||||
break;
|
||||
default:
|
||||
ORTX_CXX_API_THROW("input must be either text or binary", ORT_RUNTIME_EXCEPTION);
|
||||
break;
|
||||
}
|
||||
} // for
|
||||
}
|
||||
}
|
||||
|
||||
curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str());
|
||||
curl_handler.SetOption(CURLOPT_VERBOSE, verbose_);
|
||||
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
|
||||
void AzureAudioToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const {
|
||||
auto& string_tensor = outputs.AllocateStringTensor(0);
|
||||
string_tensor.SetStringOutput(std::vector<std::string>{response}, std::vector<int64_t>{1});
|
||||
}
|
||||
|
||||
auto curl_ret = curl_handler.Perform();
|
||||
if (CURLE_OK != curl_ret) {
|
||||
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
|
||||
////////////////////// AzureTextToTextInvoker //////////////////////
|
||||
|
||||
AzureTextToTextInvoker::AzureTextToTextInvoker(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CurlInvoker(api, info) {
|
||||
}
|
||||
|
||||
void AzureTextToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const {
|
||||
if (inputs.Size() != 2 || inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("Expected 2 string inputs of auth_token and text respectively", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
output.SetStringOutput(std::vector<std::string>{string_buffer.ss_.str()}, std::vector<int64_t>{1L});
|
||||
// We don't have a way to get the output type from the custom op API.
|
||||
// If there's a mismatch it will fail in the Compute when it allocates the output tensor.
|
||||
if (OutputNames().size() != 1) {
|
||||
ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////// AzureTextInvoker //////////////////////
|
||||
void AzureTextToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const {
|
||||
gsl::span<const std::string> input_names = InputNames();
|
||||
|
||||
AzureTextInvoker::AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
|
||||
}
|
||||
|
||||
void AzureTextInvoker::Compute(std::string_view auth, std::string_view input,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
CurlHandler curl_handler(WriteStringCallback);
|
||||
StringBuffer string_buffer;
|
||||
|
||||
std::string full_auth = std::string{"Authorization: Bearer "} + auth.data();
|
||||
curl_handler.AddHeader(full_auth.c_str());
|
||||
// TODO: assuming we need to create the correct json from the input text
|
||||
curl_handler.AddHeader("Content-Type: application/json");
|
||||
|
||||
curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str());
|
||||
curl_handler.SetOption(CURLOPT_POSTFIELDS, input.data());
|
||||
curl_handler.SetOption(CURLOPT_POSTFIELDSIZE_LARGE, input.size());
|
||||
curl_handler.SetOption(CURLOPT_VERBOSE, verbose_);
|
||||
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
|
||||
|
||||
auto curl_ret = curl_handler.Perform();
|
||||
if (CURLE_OK != curl_ret) {
|
||||
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
|
||||
}
|
||||
|
||||
output.SetStringOutput(std::vector<std::string>{string_buffer.ss_.str()}, std::vector<int64_t>{1L});
|
||||
const auto& text_input = inputs[1];
|
||||
curl_handler.SetOption(CURLOPT_POSTFIELDS, text_input->DataRaw());
|
||||
curl_handler.SetOption(CURLOPT_POSTFIELDSIZE_LARGE, text_input->SizeInBytes());
|
||||
}
|
||||
|
||||
////////////////////// AzureTritonInvoker //////////////////////
|
||||
|
||||
namespace tc = triton::client;
|
||||
|
||||
AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) {
|
||||
auto err = tc::InferenceServerHttpClient::Create(&triton_client_, model_uri_, verbose_ != "0");
|
||||
void AzureTextToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const {
|
||||
auto& string_tensor = outputs.AllocateStringTensor(0);
|
||||
string_tensor.SetStringOutput(std::vector<std::string>{response}, std::vector<int64_t>{1});
|
||||
}
|
||||
|
||||
std::string MapDataType(ONNXTensorElementDataType onnx_data_type) {
|
||||
std::string triton_data_type;
|
||||
switch (onnx_data_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
triton_data_type = "FP32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
triton_data_type = "UINT8";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
triton_data_type = "INT8";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
triton_data_type = "UINT16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
triton_data_type = "INT16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
triton_data_type = "INT32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
triton_data_type = "INT64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
triton_data_type = "BYTES";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
triton_data_type = "BOOL";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
triton_data_type = "FP16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
triton_data_type = "FP64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
triton_data_type = "UINT32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
triton_data_type = "UINT64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
|
||||
triton_data_type = "BF16";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return triton_data_type;
|
||||
}
|
||||
|
||||
int8_t* CreateNonStrTensor(const std::string& data_type,
|
||||
ortc::Variadic& outputs,
|
||||
size_t i,
|
||||
const std::vector<int64_t>& shape) {
|
||||
if (data_type == "FP32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<float>(i, shape));
|
||||
} else if (data_type == "UINT8") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint8_t>(i, shape));
|
||||
} else if (data_type == "INT8") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int8_t>(i, shape));
|
||||
} else if (data_type == "UINT16") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint16_t>(i, shape));
|
||||
} else if (data_type == "INT16") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int16_t>(i, shape));
|
||||
} else if (data_type == "INT32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int32_t>(i, shape));
|
||||
} else if (data_type == "UINT32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint32_t>(i, shape));
|
||||
} else if (data_type == "INT64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int64_t>(i, shape));
|
||||
} else if (data_type == "UINT64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint64_t>(i, shape));
|
||||
} else if (data_type == "BOOL") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<bool>(i, shape));
|
||||
} else if (data_type == "FP64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<double>(i, shape));
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_TRITON_ERR(ret, msg) \
|
||||
if (!ret.IsOk()) { \
|
||||
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
|
||||
}
|
||||
|
||||
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
if (inputs.Size() < 1 ||
|
||||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
if (inputs.Size() != input_names_.size()) {
|
||||
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
auto auth_token = reinterpret_cast<const char*>(inputs[0]->DataRaw());
|
||||
std::vector<std::unique_ptr<tc::InferInput>> triton_input_vec;
|
||||
std::vector<tc::InferInput*> triton_inputs;
|
||||
std::vector<std::unique_ptr<const tc::InferRequestedOutput>> triton_output_vec;
|
||||
std::vector<const tc::InferRequestedOutput*> triton_outputs;
|
||||
tc::Error err;
|
||||
|
||||
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
|
||||
tc::InferInput* triton_input = {};
|
||||
std::string triton_data_type = MapDataType(inputs[ith_input]->Type());
|
||||
if (triton_data_type.empty()) {
|
||||
ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
err = tc::InferInput::Create(&triton_input, input_names_[ith_input], inputs[ith_input]->Shape(), triton_data_type);
|
||||
CHECK_TRITON_ERR(err, "failed to create triton input");
|
||||
triton_input_vec.emplace_back(triton_input);
|
||||
|
||||
triton_inputs.push_back(triton_input);
|
||||
if ("BYTES" == triton_data_type) {
|
||||
const auto* string_tensor = reinterpret_cast<const ortc::Tensor<std::string>*>(inputs[ith_input].get());
|
||||
triton_input->AppendFromString(string_tensor->Data());
|
||||
} else {
|
||||
const float* data_raw = reinterpret_cast<const float*>(inputs[ith_input]->DataRaw());
|
||||
size_t size_in_bytes = inputs[ith_input]->SizeInBytes();
|
||||
err = triton_input->AppendRaw(reinterpret_cast<const uint8_t*>(data_raw), size_in_bytes);
|
||||
CHECK_TRITON_ERR(err, "failed to append raw data to input");
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t ith_output = 0; ith_output < output_names_.size(); ++ith_output) {
|
||||
tc::InferRequestedOutput* triton_output = {};
|
||||
err = tc::InferRequestedOutput::Create(&triton_output, output_names_[ith_output]);
|
||||
CHECK_TRITON_ERR(err, "failed to create triton output");
|
||||
triton_output_vec.emplace_back(triton_output);
|
||||
triton_outputs.push_back(triton_output);
|
||||
}
|
||||
|
||||
std::unique_ptr<tc::InferResult> results_ptr;
|
||||
tc::InferResult* results = {};
|
||||
tc::InferOptions options(model_name_);
|
||||
options.model_version_ = model_ver_;
|
||||
options.client_timeout_ = 0;
|
||||
|
||||
tc::Headers http_headers;
|
||||
http_headers["Authorization"] = std::string{"Bearer "} + auth_token;
|
||||
|
||||
err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs,
|
||||
http_headers, tc::Parameters(),
|
||||
tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config?
|
||||
tc::InferenceServerHttpClient::CompressionType::NONE);
|
||||
|
||||
results_ptr.reset(results);
|
||||
CHECK_TRITON_ERR(err, "failed to do triton inference");
|
||||
|
||||
size_t output_index = 0;
|
||||
auto iter = output_names_.begin();
|
||||
|
||||
while (iter != output_names_.end()) {
|
||||
std::vector<int64_t> shape;
|
||||
err = results_ptr->Shape(*iter, &shape);
|
||||
CHECK_TRITON_ERR(err, "failed to get output shape");
|
||||
|
||||
std::string type;
|
||||
err = results_ptr->Datatype(*iter, &type);
|
||||
CHECK_TRITON_ERR(err, "failed to get output type");
|
||||
|
||||
if ("BYTES" == type) {
|
||||
std::vector<std::string> output_strings;
|
||||
err = results_ptr->StringData(*iter, &output_strings);
|
||||
CHECK_TRITON_ERR(err, "failed to get output as string");
|
||||
auto& string_tensor = outputs.AllocateStringTensor(output_index);
|
||||
string_tensor.SetStringOutput(output_strings, shape);
|
||||
} else {
|
||||
const uint8_t* raw_data = {};
|
||||
size_t raw_size;
|
||||
err = results_ptr->RawData(*iter, &raw_data, &raw_size);
|
||||
CHECK_TRITON_ERR(err, "failed to get output raw data");
|
||||
auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape);
|
||||
memcpy(output_raw, raw_data, raw_size);
|
||||
}
|
||||
|
||||
++output_index;
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
|
||||
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker),
|
||||
CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker),
|
||||
CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker),
|
||||
CustomAzureStruct("AzureTextInvoker", AzureTextInvoker),
|
||||
CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker),
|
||||
CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker),
|
||||
CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker),
|
||||
CustomCpuStruct("AzureTextInvoker", AzureTextInvoker)
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader;
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -1,43 +1,54 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "curl_invoker.hpp"
|
||||
|
||||
struct AzureInvoker : public BaseKernel {
|
||||
AzureInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
namespace ort_extensions {
|
||||
|
||||
protected:
|
||||
~AzureInvoker() = default;
|
||||
std::string model_uri_;
|
||||
std::string model_name_;
|
||||
std::string model_ver_;
|
||||
std::string verbose_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
};
|
||||
////////////////////// AzureAudioToTextInvoker //////////////////////
|
||||
|
||||
struct AzureAudioInvoker : public AzureInvoker {
|
||||
AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Tensor<std::string>& output) const;
|
||||
/// <summary>
|
||||
/// Azure Audio to Text
|
||||
/// Input: auth_token {string}, ??? (Update when AOAI endpoint is defined)
|
||||
/// Output: text {string}
|
||||
/// </summary>
|
||||
class AzureAudioToTextInvoker : public CurlInvoker {
|
||||
public:
|
||||
AzureAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
// use impl from CurlInvoker
|
||||
ComputeImpl(inputs, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string file_name_;
|
||||
void ValidateInputs(const ortc::Variadic& inputs) const override;
|
||||
void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override;
|
||||
void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override;
|
||||
|
||||
static constexpr const char* kAudioFormat = "audio_format";
|
||||
std::string audio_format_;
|
||||
};
|
||||
|
||||
struct AzureTextInvoker : public AzureInvoker {
|
||||
AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(std::string_view auth, std::string_view input, ortc::Tensor<std::string>& output) const;
|
||||
////////////////////// AzureTextToTextInvoker //////////////////////
|
||||
|
||||
/// Azure Text to Text
|
||||
/// Input: auth_token {string}, text {string}
|
||||
/// Output: text {string}
|
||||
struct AzureTextToTextInvoker : public CurlInvoker {
|
||||
AzureTextToTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
// use impl from CurlInvoker
|
||||
ComputeImpl(inputs, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string binary_type_;
|
||||
void ValidateInputs(const ortc::Variadic& inputs) const override;
|
||||
void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override;
|
||||
void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override;
|
||||
};
|
||||
|
||||
struct AzureTritonInvoker : public AzureInvoker {
|
||||
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
|
||||
};
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "azure_triton_invoker.hpp"
|
||||
|
||||
////////////////////// AzureTritonInvoker //////////////////////
|
||||
|
||||
namespace tc = triton::client;
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CloudBaseKernel(api, info) {
|
||||
auto err = tc::InferenceServerHttpClient::Create(&triton_client_, ModelUri(), Verbose());
|
||||
}
|
||||
|
||||
std::string MapDataType(ONNXTensorElementDataType onnx_data_type) {
|
||||
std::string triton_data_type;
|
||||
switch (onnx_data_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
triton_data_type = "FP32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
triton_data_type = "UINT8";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
triton_data_type = "INT8";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
triton_data_type = "UINT16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
triton_data_type = "INT16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
triton_data_type = "INT32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
triton_data_type = "INT64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
triton_data_type = "BYTES";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
triton_data_type = "BOOL";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
triton_data_type = "FP16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
triton_data_type = "FP64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
triton_data_type = "UINT32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
triton_data_type = "UINT64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
|
||||
triton_data_type = "BF16";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return triton_data_type;
|
||||
}
|
||||
|
||||
int8_t* CreateNonStrTensor(const std::string& data_type,
|
||||
ortc::Variadic& outputs,
|
||||
size_t i,
|
||||
const std::vector<int64_t>& shape) {
|
||||
if (data_type == "FP32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<float>(i, shape));
|
||||
} else if (data_type == "UINT8") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint8_t>(i, shape));
|
||||
} else if (data_type == "INT8") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int8_t>(i, shape));
|
||||
} else if (data_type == "UINT16") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint16_t>(i, shape));
|
||||
} else if (data_type == "INT16") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int16_t>(i, shape));
|
||||
} else if (data_type == "INT32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int32_t>(i, shape));
|
||||
} else if (data_type == "UINT32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint32_t>(i, shape));
|
||||
} else if (data_type == "INT64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int64_t>(i, shape));
|
||||
} else if (data_type == "UINT64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint64_t>(i, shape));
|
||||
} else if (data_type == "BOOL") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<bool>(i, shape));
|
||||
} else if (data_type == "FP64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<double>(i, shape));
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_TRITON_ERR(ret, msg) \
|
||||
if (!ret.IsOk()) { \
|
||||
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
|
||||
}
|
||||
|
||||
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
auto auth_token = GetAuthToken(inputs);
|
||||
|
||||
gsl::span<const std::string> input_names = InputNames();
|
||||
if (inputs.Size() != input_names.size()) {
|
||||
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<tc::InferInput>> triton_input_vec;
|
||||
std::vector<tc::InferInput*> triton_inputs;
|
||||
std::vector<std::unique_ptr<const tc::InferRequestedOutput>> triton_output_vec;
|
||||
std::vector<const tc::InferRequestedOutput*> triton_outputs;
|
||||
tc::Error err;
|
||||
|
||||
const auto& property_names = RequestPropertyNames();
|
||||
|
||||
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
|
||||
tc::InferInput* triton_input = {};
|
||||
std::string triton_data_type = MapDataType(inputs[ith_input]->Type());
|
||||
if (triton_data_type.empty()) {
|
||||
ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
err = tc::InferInput::Create(&triton_input, property_names[ith_input], inputs[ith_input]->Shape(),
|
||||
triton_data_type);
|
||||
CHECK_TRITON_ERR(err, "failed to create triton input");
|
||||
triton_input_vec.emplace_back(triton_input);
|
||||
|
||||
triton_inputs.push_back(triton_input);
|
||||
if ("BYTES" == triton_data_type) {
|
||||
const auto* string_tensor = reinterpret_cast<const ortc::Tensor<std::string>*>(inputs[ith_input].get());
|
||||
triton_input->AppendFromString(string_tensor->Data());
|
||||
} else {
|
||||
const float* data_raw = reinterpret_cast<const float*>(inputs[ith_input]->DataRaw());
|
||||
size_t size_in_bytes = inputs[ith_input]->SizeInBytes();
|
||||
err = triton_input->AppendRaw(reinterpret_cast<const uint8_t*>(data_raw), size_in_bytes);
|
||||
CHECK_TRITON_ERR(err, "failed to append raw data to input");
|
||||
}
|
||||
}
|
||||
|
||||
gsl::span<const std::string> output_names = OutputNames();
|
||||
for (size_t ith_output = 0; ith_output < output_names.size(); ++ith_output) {
|
||||
tc::InferRequestedOutput* triton_output = {};
|
||||
err = tc::InferRequestedOutput::Create(&triton_output, output_names[ith_output]);
|
||||
CHECK_TRITON_ERR(err, "failed to create triton output");
|
||||
triton_output_vec.emplace_back(triton_output);
|
||||
triton_outputs.push_back(triton_output);
|
||||
}
|
||||
|
||||
std::unique_ptr<tc::InferResult> results_ptr;
|
||||
tc::InferResult* results = {};
|
||||
tc::InferOptions options(ModelName());
|
||||
options.model_version_ = ModelVersion();
|
||||
options.client_timeout_ = 0;
|
||||
|
||||
tc::Headers http_headers;
|
||||
http_headers["Authorization"] = std::string{"Bearer "} + auth_token;
|
||||
|
||||
err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs,
|
||||
http_headers, tc::Parameters(),
|
||||
tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config?
|
||||
tc::InferenceServerHttpClient::CompressionType::NONE);
|
||||
|
||||
results_ptr.reset(results);
|
||||
CHECK_TRITON_ERR(err, "failed to do triton inference");
|
||||
|
||||
size_t output_index = 0;
|
||||
|
||||
for (const auto& output_name : output_names) {
|
||||
std::vector<int64_t> shape;
|
||||
err = results_ptr->Shape(output_name, &shape);
|
||||
CHECK_TRITON_ERR(err, "failed to get output shape");
|
||||
|
||||
std::string type;
|
||||
err = results_ptr->Datatype(output_name, &type);
|
||||
CHECK_TRITON_ERR(err, "failed to get output type");
|
||||
|
||||
if ("BYTES" == type) {
|
||||
std::vector<std::string> output_strings;
|
||||
err = results_ptr->StringData(output_name, &output_strings);
|
||||
CHECK_TRITON_ERR(err, "failed to get output as string");
|
||||
auto& string_tensor = outputs.AllocateStringTensor(output_index);
|
||||
string_tensor.SetStringOutput(output_strings, shape);
|
||||
} else {
|
||||
const uint8_t* raw_data = {};
|
||||
size_t raw_size;
|
||||
err = results_ptr->RawData(output_name, &raw_data, &raw_size);
|
||||
CHECK_TRITON_ERR(err, "failed to get output raw data");
|
||||
auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape);
|
||||
memcpy(output_raw, raw_data, raw_size);
|
||||
}
|
||||
|
||||
++output_index;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cloud_base_kernel.hpp"
|
||||
#include "http_client.h" // triton
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
struct AzureTritonInvoker : public CloudBaseKernel {
|
||||
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
|
||||
};
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "cloud_base_kernel.hpp"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace ort_extensions {
|
||||
CloudBaseKernel::CloudBaseKernel(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
auto ver = GetActiveOrtAPIVersion();
|
||||
if (ver < kMinimumSupportedOrtVersion) {
|
||||
ORTX_CXX_API_THROW("Azure custom operators require onnxruntime version >= 1.14", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
// require model uri. other properties are optional
|
||||
// Custom op implementation can allow user to override attributes via inputs
|
||||
if (!TryToGetAttribute<std::string>(kUri, model_uri_)) {
|
||||
ORTX_CXX_API_THROW("Required " + model_uri_ + " attribute was not found", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
|
||||
model_ver_ = TryToGetAttributeWithDefault<std::string>(kModelVer, "0");
|
||||
verbose_ = TryToGetAttributeWithDefault<std::string>(kVerbose, "0") != "0";
|
||||
|
||||
OrtStatusPtr status{};
|
||||
size_t input_count{};
|
||||
status = api_.KernelInfo_GetInputCount(&info_, &input_count);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
input_names_.reserve(input_count);
|
||||
property_names_.reserve(input_count);
|
||||
|
||||
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
|
||||
char input_name[1024]{};
|
||||
size_t name_size = 1024;
|
||||
status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get name for input " + std::to_string(ith_input), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
input_names_.push_back(input_name);
|
||||
property_names_.push_back(GetPropertyNameFromInputName(input_name));
|
||||
}
|
||||
|
||||
if (input_names_[0] != "auth_token") {
|
||||
ORTX_CXX_API_THROW("first input name must be 'auth_token'", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
size_t output_count = {};
|
||||
status = api_.KernelInfo_GetOutputCount(&info_, &output_count);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
output_names_.reserve(output_count);
|
||||
for (size_t ith_output = 0; ith_output < output_count; ++ith_output) {
|
||||
char output_name[1024]{};
|
||||
size_t name_size = 1024;
|
||||
status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get name for output " + std::to_string(ith_output), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
output_names_.push_back(output_name);
|
||||
}
|
||||
}
|
||||
|
||||
std::string CloudBaseKernel::GetAuthToken(const ortc::Variadic& inputs) const {
|
||||
if (inputs.Size() < 1 ||
|
||||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("auth_token string is required to be the first input", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::string auth_token{static_cast<const char*>(inputs[0]->DataRaw())};
|
||||
return auth_token;
|
||||
}
|
||||
|
||||
/*static */ std::string CloudBaseKernel::GetPropertyNameFromInputName(const std::string& input_name) {
|
||||
auto idx = input_name.find_last_of('/');
|
||||
if (idx == std::string::npos) {
|
||||
return input_name;
|
||||
}
|
||||
|
||||
if (idx == input_name.length() - 1) {
|
||||
ORTX_CXX_API_THROW("Input name cannot end with '/'. Invalid input:" + input_name, ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
return input_name.substr(idx + 1); // return text after the '/'
|
||||
}
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "gsl/span"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
/// <summary>
|
||||
/// Base kernel for custom ops that call cloud endpoints.
|
||||
/// </summary>
|
||||
class CloudBaseKernel : public BaseKernel {
|
||||
protected:
|
||||
CloudBaseKernel(const OrtApi& api, const OrtKernelInfo& info);
|
||||
virtual ~CloudBaseKernel() = default;
|
||||
|
||||
// Names of attributes the custom operator provides.
|
||||
static constexpr const char* kUri = "model_uri"; // required
|
||||
static constexpr const char* kModelName = "model_name"; // optional
|
||||
static constexpr const char* kModelVer = "model_version"; // optional
|
||||
static constexpr const char* kVerbose = "verbose";
|
||||
|
||||
static constexpr int kMinimumSupportedOrtVersion = 14;
|
||||
|
||||
const std::string& ModelUri() const { return model_uri_; }
|
||||
const std::string& ModelName() const { return model_name_; }
|
||||
const std::string& ModelVersion() const { return model_ver_; }
|
||||
bool Verbose() const { return verbose_; }
|
||||
|
||||
const gsl::span<const std::string> InputNames() const { return input_names_; }
|
||||
const gsl::span<const std::string> OutputNames() const { return output_names_; }
|
||||
|
||||
// Request property names that are parsed from input names. 1:1 with InputNames() values.
|
||||
// e.g. 'node0/prompt' -> 'prompt' and that input provides the 'prompt' property in the request to the endpoint.
|
||||
// <see cref="GetPropertyNameFromInputName"/> for further details.
|
||||
const gsl::span<const std::string> RequestPropertyNames() const { return property_names_; }
|
||||
|
||||
// first input is required to be auth token. validate that and return it.
|
||||
std::string GetAuthToken(const ortc::Variadic& inputs) const;
|
||||
|
||||
/// <summary>
|
||||
/// Parse the property name to use in the request to the cloud endpoint from a node input name.
|
||||
/// Value returned is text following last '/', or the entire string if no '/'.
|
||||
/// e.g. 'node0/prompt' -> 'prompt'
|
||||
/// </summary>
|
||||
/// <param name="input_name">Node input name.</param>
|
||||
/// <returns>Request property name the input is providing data for.</returns>
|
||||
static std::string GetPropertyNameFromInputName(const std::string& input_name);
|
||||
|
||||
private:
|
||||
std::string model_uri_;
|
||||
std::string model_name_;
|
||||
std::string model_ver_;
|
||||
bool verbose_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> property_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "ocos.h"
|
||||
#include "azure_invokers.hpp"
|
||||
#include "openai_invokers.hpp"
|
||||
|
||||
#ifdef AZURE_INVOKERS_ENABLE_TRITON
|
||||
#include "azure_triton_invoker.hpp"
|
||||
#endif
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
|
||||
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioToText", AzureAudioToTextInvoker),
|
||||
CustomCpuStruct("AzureAudioToText", AzureAudioToTextInvoker),
|
||||
CustomAzureStruct("AzureTextToText", AzureTextToTextInvoker),
|
||||
CustomCpuStruct("AzureTextToText", AzureTextToTextInvoker),
|
||||
CustomAzureStruct("OpenAIAudioToText", OpenAIAudioToTextInvoker),
|
||||
CustomCpuStruct("OpenAIAudioToText", OpenAIAudioToTextInvoker)
|
||||
#ifdef AZURE_INVOKERS_ENABLE_TRITON
|
||||
,
|
||||
CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker),
|
||||
CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker)
|
||||
#endif
|
||||
);
|
||||
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader;
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "curl_invoker.hpp"
|
||||
|
||||
#include <iostream> // TEMP error output
|
||||
#include <sstream>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
// apply the callback only when response is for sure to be a '/0' terminated string
|
||||
size_t CurlHandler::WriteStringCallback(char* contents, size_t element_size, size_t num_elements, void* userdata) {
|
||||
try {
|
||||
size_t bytes = element_size * num_elements;
|
||||
std::string& buffer = *static_cast<std::string*>(userdata);
|
||||
buffer.append(contents, bytes);
|
||||
return bytes;
|
||||
} catch (const std::exception& ex) {
|
||||
// TODO: This should be captured/logger properly
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return 0;
|
||||
} catch (...) {
|
||||
// exception caught, abort write
|
||||
std::cerr << "Unknown exception caught in CurlHandler::WriteStringCallback" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
CurlHandler::CurlHandler(WriteCallBack callback) : curl_(curl_easy_init(), curl_easy_cleanup),
|
||||
headers_(nullptr, curl_slist_free_all),
|
||||
from_holder_(from_, curl_formfree) {
|
||||
CURL* curl = curl_.get(); // CURL == void* so can't dereference
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, 100 * 1024L); // how was this size chosen? should it be set on a per operator basis?
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_USERAGENT, "curl/7.83.1"); // should this value come from the curl src instead of being hardcoded?
|
||||
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 50L); // 50 seems like a lot if we're directly calling a specific endpoint
|
||||
curl_easy_setopt(curl, CURLOPT_FTP_SKIP_PASV_IP, 1L); // what does this have to do with http requests?
|
||||
curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, callback);
|
||||
|
||||
// should this be configured via a node attribute? different endpoints may have different timeouts
|
||||
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 15);
|
||||
}
|
||||
|
||||
////////////////////// CurlInvoker //////////////////////
|
||||
|
||||
CurlInvoker::CurlInvoker(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CloudBaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void CurlInvoker::ComputeImpl(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
std::string auth_token = GetAuthToken(inputs);
|
||||
|
||||
if (inputs.Size() != InputNames().size()) {
|
||||
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
// do any additional validation of the number and type of inputs/outputs
|
||||
ValidateInputs(inputs);
|
||||
|
||||
// set the options for the curl handler that apply to all usages
|
||||
CurlHandler curl_handler(CurlHandler::WriteStringCallback);
|
||||
|
||||
std::string full_auth = std::string{"Authorization: Bearer "} + auth_token;
|
||||
curl_handler.AddHeader(full_auth.c_str());
|
||||
curl_handler.SetOption(CURLOPT_URL, ModelUri().c_str());
|
||||
curl_handler.SetOption(CURLOPT_VERBOSE, Verbose());
|
||||
|
||||
std::string response;
|
||||
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&response);
|
||||
|
||||
SetupRequest(curl_handler, inputs);
|
||||
ExecuteRequest(curl_handler);
|
||||
ProcessResponse(response, outputs);
|
||||
}
|
||||
|
||||
void CurlInvoker::ExecuteRequest(CurlHandler& curl_handler) const {
|
||||
// this is where we could add any logic required to make the request async or maybe handle retries/cancellation.
|
||||
auto curl_ret = curl_handler.Perform();
|
||||
if (CURLE_OK != curl_ret) {
|
||||
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
|
||||
}
|
||||
}
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
|
||||
#include "curl/curl.h"
|
||||
|
||||
#include "ocos.h"
|
||||
#include "cloud_base_kernel.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
class CurlHandler {
|
||||
public:
|
||||
using WriteCallBack = size_t (*)(char*, size_t, size_t, void*);
|
||||
|
||||
CurlHandler(WriteCallBack callback);
|
||||
~CurlHandler() = default;
|
||||
|
||||
/// <summary>
|
||||
/// Callback to add contents to a string
|
||||
/// </summary>
|
||||
/// <seealso cref="https://curl.se/libcurl/c/CURLOPT_WRITEFUNCTION.html"/>
|
||||
/// <returns>Bytes processed. If this does not match element_size * num_elements the libcurl function
|
||||
/// used will return CURLE_WRITE_ERROR</returns>
|
||||
static size_t WriteStringCallback(char* contents, size_t element_size, size_t num_elements, void* userdata);
|
||||
|
||||
void AddHeader(const char* data) {
|
||||
headers_.reset(curl_slist_append(headers_.release(), data));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void AddForm(Args... args) {
|
||||
curl_formadd(&from_, &last_, args...);
|
||||
}
|
||||
|
||||
void AddFormString(const char* name, const char* value) {
|
||||
AddForm(CURLFORM_COPYNAME, name,
|
||||
CURLFORM_COPYCONTENTS, value,
|
||||
CURLFORM_END);
|
||||
}
|
||||
|
||||
void AddFormBuffer(const char* name, const char* buffer_name, const void* buffer_ptr, size_t buffer_len) {
|
||||
AddForm(CURLFORM_COPYNAME, name,
|
||||
CURLFORM_BUFFER, buffer_name,
|
||||
CURLFORM_BUFFERPTR, buffer_ptr,
|
||||
CURLFORM_BUFFERLENGTH, buffer_len,
|
||||
CURLFORM_END);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetOption(CURLoption opt, T val) {
|
||||
curl_easy_setopt(curl_.get(), opt, val);
|
||||
}
|
||||
|
||||
CURLcode Perform() {
|
||||
SetOption(CURLOPT_HTTPHEADER, headers_.get());
|
||||
if (from_) {
|
||||
SetOption(CURLOPT_HTTPPOST, from_);
|
||||
}
|
||||
|
||||
return curl_easy_perform(curl_.get());
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
|
||||
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
|
||||
curl_httppost* from_{};
|
||||
curl_httppost* last_{};
|
||||
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_; // TODO: Why no last_holder_?
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Base class for requests using Curl
|
||||
/// </summary>
|
||||
class CurlInvoker : public CloudBaseKernel {
|
||||
protected:
|
||||
CurlInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
virtual ~CurlInvoker() = default;
|
||||
|
||||
// Compute implementation that is used to co-ordinate all Curl based Azure requests.
|
||||
// Derived classes need their own Compute to work with the CustomOpLite infrastructure
|
||||
void ComputeImpl(const ortc::Variadic& inputs, ortc::Variadic& outputs) const;
|
||||
|
||||
private:
|
||||
void ExecuteRequest(CurlHandler& handler) const;
|
||||
|
||||
// Derived classes can add any arg validation required.
|
||||
// Prior to this being called, `inputs` are validated to match the number of input names, and
|
||||
// the auth_token has been read from input[0] so validation can skip that.
|
||||
//
|
||||
// the ortc::Variadic outputs are empty until the Compute populates it, so only output names can be validated
|
||||
// and those are available from the base class.
|
||||
virtual void ValidateInputs(const ortc::Variadic& inputs) const {}
|
||||
|
||||
// curl_handler has auth token set from input[0].
|
||||
virtual void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const = 0;
|
||||
virtual void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const = 0;
|
||||
};
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "openai_invokers.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
OpenAIAudioToTextInvoker::OpenAIAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CurlInvoker(api, info) {
|
||||
audio_format_ = TryToGetAttributeWithDefault<std::string>(kAudioFormat, "");
|
||||
|
||||
const auto& property_names = RequestPropertyNames();
|
||||
|
||||
const auto find_optional_input = [&property_names](const std::string& property_name) {
|
||||
std::optional<size_t> result;
|
||||
auto optional_input = std::find_if(property_names.begin(), property_names.end(),
|
||||
[&property_name](const auto& name) { return name == property_name; });
|
||||
|
||||
if (optional_input != property_names.end()) {
|
||||
result = optional_input - property_names.begin();
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
filename_input_ = find_optional_input("filename");
|
||||
model_name_input_ = find_optional_input("model");
|
||||
|
||||
// OpenAI audio endpoints require 'file' and 'model'.
|
||||
if (!std::any_of(property_names.begin(), property_names.end(),
|
||||
[](const auto& name) { return name == "file"; })) {
|
||||
ORTX_CXX_API_THROW("Required 'file' input was not found", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
if (ModelName().empty() && !model_name_input_) {
|
||||
ORTX_CXX_API_THROW("Required 'model' input was not found", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
void OpenAIAudioToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const {
|
||||
// We don't have a way to get the output type from the custom op API.
|
||||
// If there's a mismatch it will fail in the Compute when it allocates the output tensor.
|
||||
if (OutputNames().size() != 1) {
|
||||
ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
void OpenAIAudioToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const {
|
||||
// theoretically the filename the content was buffered from. provides the extensions indicating the audio format
|
||||
static const std::string fake_filename = "audio." + audio_format_;
|
||||
|
||||
const auto& property_names = RequestPropertyNames();
|
||||
|
||||
const auto& get_optional_input =
|
||||
[&](const std::optional<size_t>& input_idx, const std::string& default_value, size_t min_size = 1) {
|
||||
return (input_idx.has_value() && inputs[*input_idx]->SizeInBytes() > min_size)
|
||||
? static_cast<const char*>(inputs[*input_idx]->DataRaw())
|
||||
: default_value.c_str();
|
||||
};
|
||||
|
||||
// filename_input_ is optional in a model. if it's not present, use a fake filename.
|
||||
// if it's present make sure it's not a default empty value. as the filename needs to have an extension of
|
||||
// mp3, mp4, mpeg, mpga, m4a, wav, or webm it must be at least 4 characters long.
|
||||
const char* filename = get_optional_input(filename_input_, fake_filename, 4);
|
||||
|
||||
curl_handler.AddHeader("Content-Type: multipart/form-data");
|
||||
// model name could be input or attribute
|
||||
curl_handler.AddFormString("model", get_optional_input(model_name_input_, ModelName()));
|
||||
|
||||
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
|
||||
switch (inputs[ith_input]->Type()) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
curl_handler.AddFormString(property_names[ith_input].c_str(),
|
||||
// assumes null terminated.
|
||||
// might be safer to pass pointer and length and add use CURLFORM_CONTENTSLENGTH
|
||||
static_cast<const char*>(inputs[ith_input]->DataRaw()));
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
// only the 'file' input is uint8
|
||||
if (property_names[ith_input] != "file") {
|
||||
ORTX_CXX_API_THROW("Only the 'file' input should be uint8 data. Invalid input:" + InputNames()[ith_input],
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
curl_handler.AddFormBuffer(property_names[ith_input].c_str(),
|
||||
filename,
|
||||
inputs[ith_input]->DataRaw(),
|
||||
inputs[ith_input]->SizeInBytes());
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
// TODO - required to support 'temperature' input.
|
||||
default:
|
||||
ORTX_CXX_API_THROW("input must be either text or binary", ORT_INVALID_ARGUMENT);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpenAIAudioToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const {
|
||||
auto& string_tensor = outputs.AllocateStringTensor(0);
|
||||
string_tensor.SetStringOutput(std::vector<std::string>{response}, std::vector<int64_t>{1});
|
||||
}
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "curl_invoker.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
////////////////////// OpenAIAudioToTextInvoker //////////////////////
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI Audio to Text
|
||||
/// Input: auth_token {string}, Request body values {string|uint8} as per https://platform.openai.com/docs/api-reference/audio
|
||||
/// Output: text {string}
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The model URI is read from the node attributes.
|
||||
/// The model name (e.g. 'whisper-1') can be provided as a node attribute or via an input.
|
||||
///
|
||||
/// Example input would be:
|
||||
/// - string tensor named `auth_token` (required, must be first input)
|
||||
/// - a uint8 tensor named `file` with audio data in the format matching the 'audio_format' attribute (required)
|
||||
/// - see OpenAI documentation for current supported audio formats
|
||||
/// - a string tensor named `filename` (optional) with extension indicating the format of the audio data
|
||||
/// - e.g. 'audio.mp3'
|
||||
/// - a string tensor named `prompt` (optional)
|
||||
///
|
||||
/// NOTE: 'temperature' is not currently supported.
|
||||
/// </remarks>
|
||||
class OpenAIAudioToTextInvoker final : public CurlInvoker {
|
||||
public:
|
||||
OpenAIAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const {
|
||||
// use impl from CurlInvoker
|
||||
ComputeImpl(inputs, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
void ValidateInputs(const ortc::Variadic& inputs) const override;
|
||||
void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override;
|
||||
void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override;
|
||||
|
||||
// audio format to use if the optional 'filename' input is not provided
|
||||
static constexpr const char* kAudioFormat = "audio_format";
|
||||
std::string audio_format_;
|
||||
std::optional<size_t> filename_input_; // optional override for generated filename using audio_format
|
||||
std::optional<size_t> model_name_input_; // optional override for model_name attribute
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set -e
|
||||
set -u
|
||||
set -x
|
||||
|
||||
# export skip_checkout if you want to repeat a build
|
||||
if [ -z ${skip_checkout+x} ]; then
|
||||
git clone https://github.com/leenjewel/openssl_for_ios_and_android.git
|
||||
cd openssl_for_ios_and_android
|
||||
git checkout ci-release-663da9e2
|
||||
# patch with fixes to build on linux with NDK 25 or later
|
||||
git apply ../build_curl_for_android_on_linux.patch
|
||||
else
|
||||
echo "Skipping checkout and patch"
|
||||
cd openssl_for_ios_and_android
|
||||
fi
|
||||
|
||||
cd tools
|
||||
|
||||
# we target Android API level 24
|
||||
export api=24
|
||||
|
||||
# provide a specific architecture as an argument to the script to limit the build to that
|
||||
# default is to build all
|
||||
# valid architecture values: "arm" "arm64" "x86" "x86_64"
|
||||
if [ $# -eq 1 ]; then
|
||||
arch=$1
|
||||
./build-android-openssl.sh $arch
|
||||
./build-android-nghttp2.sh $arch
|
||||
./build-android-curl.sh $arch
|
||||
else
|
||||
./build-android-openssl.sh
|
||||
./build-android-nghttp2.sh
|
||||
./build-android-curl.sh
|
||||
fi
|
|
@ -0,0 +1,67 @@
|
|||
diff --git a/tools/build-android-common.sh b/tools/build-android-common.sh
|
||||
index 87df207..797d58a 100755
|
||||
--- a/tools/build-android-common.sh
|
||||
+++ b/tools/build-android-common.sh
|
||||
@@ -148,13 +148,20 @@ function set_android_toolchain() {
|
||||
local build_host=$(get_build_host_internal "$arch")
|
||||
local clang_target_host=$(get_clang_target_host "$arch" "$api")
|
||||
|
||||
- export AR=${build_host}-ar
|
||||
+ # NDK r23 removed a bunch of GNU things and replaced with llvm
|
||||
+ # https://stackoverflow.com/questions/73105626/arm-linux-androideabi-ar-command-not-found-in-ndk
|
||||
+ # export AR=${build_host}-ar
|
||||
+ export AR=llvm-ar
|
||||
export CC=${clang_target_host}-clang
|
||||
export CXX=${clang_target_host}-clang++
|
||||
- export AS=${build_host}-as
|
||||
- export LD=${build_host}-ld
|
||||
- export RANLIB=${build_host}-ranlib
|
||||
+ #export AS=${build_host}-as
|
||||
+ export AS=llvm-as
|
||||
+ #export LD=${build_host}-ld
|
||||
+ export LD=ld
|
||||
+ # export RANLIB=${build_host}-ranlib
|
||||
+ export RANLIB=llvm-ranlib
|
||||
export STRIP=${build_host}-strip
|
||||
+ export STRIP=llvm-strip
|
||||
}
|
||||
|
||||
function get_common_includes() {
|
||||
@@ -187,13 +194,13 @@ function set_android_cpu_feature() {
|
||||
export CPPFLAGS=${CFLAGS}
|
||||
;;
|
||||
x86)
|
||||
- export CFLAGS="-march=i686 -mtune=intel -mssse3 -mfpmath=sse -m32 -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
|
||||
+ export CFLAGS="-march=i686 -mtune=native -mssse3 -mfpmath=sse -m32 -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
|
||||
export CXXFLAGS="-std=c++14 -Os -ffunction-sections -fdata-sections"
|
||||
export LDFLAGS="-march=i686 -Wl,--gc-sections -Os -ffunction-sections -fdata-sections $(get_common_linked_libraries ${api} ${arch})"
|
||||
export CPPFLAGS=${CFLAGS}
|
||||
;;
|
||||
x86-64)
|
||||
- export CFLAGS="-march=x86-64 -msse4.2 -mpopcnt -m64 -mtune=intel -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
|
||||
+ export CFLAGS="-march=x86-64 -msse4.2 -mpopcnt -m64 -mtune=native -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)"
|
||||
export CXXFLAGS="-std=c++14 -Os -ffunction-sections -fdata-sections"
|
||||
export LDFLAGS="-march=x86-64 -Wl,--gc-sections -Os -ffunction-sections -fdata-sections $(get_common_linked_libraries ${api} ${arch})"
|
||||
export CPPFLAGS=${CFLAGS}
|
||||
diff --git a/tools/build-android-openssl.sh b/tools/build-android-openssl.sh
|
||||
index e13c314..5660cec 100755
|
||||
--- a/tools/build-android-openssl.sh
|
||||
+++ b/tools/build-android-openssl.sh
|
||||
@@ -17,7 +17,7 @@
|
||||
# # read -n1 -p "Press any key to continue..."
|
||||
|
||||
set -u
|
||||
-
|
||||
+set -x
|
||||
source ./build-android-common.sh
|
||||
|
||||
if [ -z ${version+x} ]; then
|
||||
@@ -115,6 +115,8 @@ function configure_make() {
|
||||
if [ $the_rc -eq 0 ] ; then
|
||||
make SHLIB_EXT='.so' install_sw >>"${OUTPUT_ROOT}/log/${ABI}.log" 2>&1
|
||||
make install_ssldirs >>"${OUTPUT_ROOT}/log/${ABI}.log" 2>&1
|
||||
+ else
|
||||
+ log_error "make returned $the_rc"
|
||||
fi
|
||||
|
||||
popd
|
|
@ -0,0 +1,66 @@
|
|||
# Mobile Azure EP pre-build
|
||||
|
||||
Manual libraries that need to be prebuilt for the Azure operators on Android and iOS.
|
||||
There is no simple cmake setup that works, so we prebuild as a one-off.
|
||||
|
||||
## Requirements:
|
||||
- pkg-config
|
||||
- Android
|
||||
- Android SDK installed with NDK 25 or later
|
||||
- You can install a package but that means you have to use `sudo` for all updates like installing an NDK
|
||||
- https://stackoverflow.com/questions/34556884/how-to-install-android-sdk-on-ubuntu
|
||||
- you still need to manually add the cmdline-tools to that package as well
|
||||
- probably easier to create a per-user install using command line tools
|
||||
- Using command line tools
|
||||
- Download the command line tools from https://developer.android.com/studio
|
||||
- Download the 'Command line tools only' and unzip
|
||||
- `mkdir ~/Android`
|
||||
- `unzip commandlinetools-linux-9477386_latest.zip`
|
||||
- `mkdir -p ~/Android/cmdline-tools/latest`
|
||||
- `mv cmdline-tools/* ~/Android/cmdline-tools/latest`
|
||||
- `export ANDROID_HOME=~/Android`
|
||||
- Add these to PATH
|
||||
- ~/Android/cmdline-tools/latest/bin
|
||||
- ~/Android/platform-tools/bin
|
||||
- `sdkmanager --list` to make sure the setup works
|
||||
- Install platform-tools and latest NDK
|
||||
- `sdkmanager --install platform-tools`
|
||||
- e.g. `sdkmanager --install ndk;25.2.9519653`
|
||||
|
||||
That should be enough to build.
|
||||
e.g. `./build_lib.sh --android --android_api=24 --android_home=/home/me/Android --android_abi=x86_64 --android_ndk_path=/home/me/Android/ndk/25.2.9519653 --enable_cxx_tests`
|
||||
|
||||
See Android documentation for installing a system image with `sdkmanager` and
|
||||
creating an emulator with `avdmanager`.
|
||||
- iOS
|
||||
- TBD
|
||||
|
||||
## Android build
|
||||
Export ANDROID_NDK_ROOT with the value set to the NDK path as this is used by the build script
|
||||
- e.g. export ANDROID_NDK_ROOT=~/Android/ndk/25.2.9519653
|
||||
From this directory run `./build_curl_for_android.sh`
|
||||
An architecture can optionally be specified as the first argument to limit the build to that architecture.
|
||||
Otherwise all 4 architectures (arm, arm64, x86, x86_64) will be built.
|
||||
e.g. if you just want to build locally for the emulator you can do `./build_curl_for_android.sh x86_64`
|
||||
|
||||
## Android testing
|
||||
Build with `--enable_cxx_tests`.
|
||||
This should result in the 'bin' directory of the build output having the two test executables.
|
||||
Create/start Android emulator
|
||||
Use `adb push` to copy bin, lib and data directories from the build output to the /data/local/tmp directory
|
||||
- `adb push build/Android/bin /data/local/tmp`
|
||||
- repeat for 'lib' and 'data'
|
||||
- copy the onnxruntime shared library to the lib dir (adjust version number as needed)
|
||||
- adjust architecture as needed (most likely x86_64 for emulator and arm)
|
||||
- `adb push build/Android/Debug/_deps/onnxruntime-src/jni/x86_64/libonnxruntime.so /data/local/tmp/lib`
|
||||
- Connect to emulator
|
||||
- `adb shell`
|
||||
- `cd /data/local/tmp`
|
||||
- Add path to .so
|
||||
- export LD_LIBRARY_PATH=/data/local/tmp/lib:$LD_LIBRARY_PATH
|
||||
- Make tests executable
|
||||
- `chmod +x bin/ocos_test`
|
||||
- `chmod +x bin/extensions_test`
|
||||
- Run tests from `tmp` dir so paths to `data` are as expected
|
||||
- ./bin/ocos_test
|
||||
- ./bin/extensions_test
|
|
@ -154,9 +154,6 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
#endif
|
||||
#if defined(ENABLE_DR_LIBS)
|
||||
LoadCustomOpClasses_Audio,
|
||||
#endif
|
||||
#if defined(ENABLE_AZURE)
|
||||
LoadCustomOpClasses_Azure,
|
||||
#endif
|
||||
LoadCustomOpClasses<>
|
||||
};
|
||||
|
@ -187,6 +184,10 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
return status;
|
||||
}
|
||||
|
||||
//
|
||||
// New custom ops should use the com.microsoft.extensions domain.
|
||||
//
|
||||
|
||||
// Create domain for ops using the new domain name.
|
||||
if (status = ortApi->CreateCustomOpDomain(c_ComMsExtOpDomain, &domain); status) {
|
||||
return status;
|
||||
|
@ -200,6 +201,9 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
#endif
|
||||
#if defined(ENABLE_TOKENIZER)
|
||||
LoadCustomOpClasses_Tokenizer,
|
||||
#endif
|
||||
#if defined(ENABLE_AZURE)
|
||||
LoadCustomOpClasses_Azure,
|
||||
#endif
|
||||
LoadCustomOpClasses<>
|
||||
};
|
||||
|
|
Двоичные данные
test/data/azure/be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3
Normal file
Двоичные данные
test/data/azure/be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1,52 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
|
||||
|
||||
def create_audio_model():
|
||||
auth_token = helper.make_tensor_value_info('auth_token', TensorProto.STRING, [1])
|
||||
model = helper.make_tensor_value_info('model', TensorProto.STRING, [1])
|
||||
response_format = helper.make_tensor_value_info('response_format', TensorProto.STRING, [-1])
|
||||
file = helper.make_tensor_value_info('file', TensorProto.UINT8, [-1])
|
||||
|
||||
transcriptions = helper.make_tensor_value_info('transcriptions', TensorProto.STRING, [-1])
|
||||
|
||||
invoker = helper.make_node('OpenAIAudioToText',
|
||||
['auth_token', 'model', 'response_format', 'file'],
|
||||
['transcriptions'],
|
||||
domain='com.microsoft.extensions',
|
||||
name='audio_invoker',
|
||||
model_uri='https://api.openai.com/v1/audio/transcriptions',
|
||||
audio_format='wav',
|
||||
verbose=False)
|
||||
|
||||
graph = helper.make_graph([invoker], 'graph', [auth_token, model, response_format, file], [transcriptions])
|
||||
model = helper.make_model(graph,
|
||||
opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)])
|
||||
|
||||
onnx.save(model, 'openai_audio.onnx')
|
||||
|
||||
|
||||
def create_chat_model():
|
||||
auth_token = helper.make_tensor_value_info('auth_token', TensorProto.STRING, [-1])
|
||||
chat = helper.make_tensor_value_info('chat', TensorProto.STRING, [-1])
|
||||
response = helper.make_tensor_value_info('response', TensorProto.STRING, [-1])
|
||||
|
||||
invoker = helper.make_node('AzureTextToText', ['auth_token', 'chat'], ['response'],
|
||||
domain='com.microsoft.extensions',
|
||||
name='chat_invoker',
|
||||
model_uri='https://api.openai.com/v1/chat/completions',
|
||||
verbose=False)
|
||||
|
||||
graph = helper.make_graph([invoker], 'graph', [auth_token, chat], [response])
|
||||
model = helper.make_model(graph,
|
||||
opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)])
|
||||
|
||||
onnx.save(model, 'openai_chat.onnx')
|
||||
|
||||
|
||||
create_audio_model()
|
||||
create_chat_model()
|
|
@ -0,0 +1,75 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from onnx import helper, numpy_helper, TensorProto
|
||||
|
||||
import onnx
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
|
||||
def order_repeated_field(repeated_proto, key_name, order):
|
||||
order = list(order)
|
||||
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
|
||||
|
||||
|
||||
def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
|
||||
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
|
||||
if doc_string == '':
|
||||
node.doc_string = ''
|
||||
order_repeated_field(node.attribute, 'name', kwargs.keys())
|
||||
return node
|
||||
|
||||
|
||||
def make_graph(*args, doc_string=None, **kwargs):
|
||||
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
|
||||
if doc_string == '':
|
||||
graph.doc_string = ''
|
||||
return graph
|
||||
|
||||
|
||||
# This creates a model that allows the prompt and filename to be optionally provided as inputs.
|
||||
# The filename can be specified to indicate a different audio type to the default value in the audio_format attribute.
|
||||
model = helper.make_model(
|
||||
opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)],
|
||||
graph=make_graph(
|
||||
name='OpenAIWhisperTranscribe',
|
||||
initializer=[
|
||||
# add default values in the initializers to make the model inputs optional
|
||||
helper.make_tensor('transcribe0/filename', TensorProto.STRING, [1], [b""]),
|
||||
helper.make_tensor('transcribe0/prompt', TensorProto.STRING, [1], [b""])
|
||||
],
|
||||
inputs=[
|
||||
helper.make_tensor_value_info('auth_token', TensorProto.STRING, shape=[1]),
|
||||
helper.make_tensor_value_info('transcribe0/file', TensorProto.UINT8, shape=["bytes"]),
|
||||
helper.make_tensor_value_info('transcribe0/filename', TensorProto.STRING, shape=["bytes"]), # optional
|
||||
helper.make_tensor_value_info('transcribe0/prompt', TensorProto.STRING, shape=["bytes"]), # optional
|
||||
],
|
||||
outputs=[helper.make_tensor_value_info('transcription', TensorProto.STRING, shape=[1])],
|
||||
nodes=[
|
||||
make_node(
|
||||
'OpenAIAudioToText',
|
||||
# additional optional request inputs that could be added:
|
||||
# response_format, temperature, language
|
||||
# Using a prefix for input names allows the model to have multiple nodes calling cloud endpoints.
|
||||
# auth_token does not need a prefix unless different auth tokens are used for different nodes.
|
||||
inputs=['auth_token', 'transcribe0/file', 'transcribe0/filename', 'transcribe0/prompt'],
|
||||
outputs=['transcription'],
|
||||
name='OpenAIAudioToText0',
|
||||
domain='com.microsoft.extensions',
|
||||
audio_format='wav', # default audio type if filename is not specified.
|
||||
model_uri='https://api.openai.com/v1/audio/transcriptions',
|
||||
model_name='whisper-1',
|
||||
verbose=0,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
out_path = "openai_whisper_transcriptions.onnx"
|
||||
if len(sys.argv) == 2:
|
||||
out_path = sys.argv[1]
|
||||
|
||||
onnx.save(model, out_path)
|
Двоичные данные
test/data/azure/openai_audio.onnx
Двоичные данные
test/data/azure/openai_audio.onnx
Двоичный файл не отображается.
Двоичные данные
test/data/azure/openai_chat.onnx
Двоичные данные
test/data/azure/openai_chat.onnx
Двоичный файл не отображается.
|
@ -1,17 +0,0 @@
|
|||
:…
|
||||
™
|
||||
|
||||
auth_token
|
||||
text embeddingembedding_invoker"AzureTextInvoker*4
|
||||
model_uri"$https://api.openai.com/v1/embeddings *
|
||||
verbose :ai.onnx.contribgraphZ!
|
||||
|
||||
auth_token
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ
|
||||
text
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb
|
||||
embedding
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿB
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,106 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_AZURE
|
||||
#include <cstdlib>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ocos.h"
|
||||
#include "narrow.h"
|
||||
#include "test_kernel.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ort_extensions;
|
||||
using namespace ort_extensions::test;
|
||||
|
||||
// Test custom op with OpenAIAudioInvoker calling Whisper
|
||||
// Default input format. No prompt.
|
||||
TEST(AzureOps, OpenAIWhisper_basic) {
|
||||
const char* auth_token = std::getenv("OPENAI_AUTH_TOKEN");
|
||||
if (auth_token == nullptr) {
|
||||
GTEST_SKIP() << "OPENAI_AUTH_TOKEN environment variable was not set.";
|
||||
}
|
||||
|
||||
auto data_dir = std::filesystem::current_path() / "data" / "azure";
|
||||
auto model_path = data_dir / "openai_whisper_transcriptions.onnx";
|
||||
auto audio_path = data_dir / "self-destruct-button.wav";
|
||||
std::vector<uint8_t> audio_data = LoadBytesFromFile(audio_path);
|
||||
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs{TestValue("auth_token", {std::string(auth_token)}, {1}),
|
||||
TestValue("transcribe0/file", audio_data, {narrow<int64_t>(audio_data.size())})};
|
||||
|
||||
// punctuation can differ between calls to OpenAI Whisper. sometimes there's a comma after 'button' and sometimes
|
||||
// a full stop. use a custom output validator that looks for substrings in the output that aren't affected by this.
|
||||
std::vector<std::string> expected_output{"Thank you for pressing the self-destruct button",
|
||||
"ship will self-destruct in three minutes"};
|
||||
|
||||
// dims are set to '{1}' as we expect one string output. the expected_output is the collection of substrings to look
|
||||
// for in the single output
|
||||
std::vector<TestValue> outputs{TestValue("transcription", expected_output, {1})};
|
||||
|
||||
OutputValidator find_strings_in_output =
|
||||
[](size_t output_idx, Ort::Value& actual, TestValue expected) {
|
||||
std::vector<std::string> output_string;
|
||||
GetTensorMutableDataString(Ort::GetApi(), actual, output_string);
|
||||
|
||||
ASSERT_EQ(output_string.size(), 1) << "Expected the Whisper response to be a single string with json";
|
||||
|
||||
for (auto& expected_substring : expected.values_string) {
|
||||
if (output_string[0].find(expected_substring) == std::string::npos) {
|
||||
FAIL() << "'" << expected_substring << "' was not found in output " << output_string[0];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output);
|
||||
}
|
||||
|
||||
// test calling Whisper with a filename to provide mp3 instead of the default wav, and the optional prompt
|
||||
TEST(AzureOps, OpenAIWhisper_Prompt_CustomFormat) {
|
||||
const char* auth_token = std::getenv("OPENAI_AUTH_TOKEN");
|
||||
if (auth_token == nullptr) {
|
||||
GTEST_SKIP() << "OPENAI_AUTH_TOKEN environment variable was not set.";
|
||||
}
|
||||
|
||||
std::string ort_version{OrtGetApiBase()->GetVersionString()};
|
||||
|
||||
auto data_dir = std::filesystem::current_path() / "data" / "azure";
|
||||
auto model_path = data_dir / "openai_whisper_transcriptions.onnx";
|
||||
auto audio_path = data_dir / "be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3";
|
||||
std::vector<uint8_t> audio_data = LoadBytesFromFile(audio_path);
|
||||
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
// provide filename with 'mp3' extension to indicate audio format. doesn't need to be the 'real' filename
|
||||
std::vector<TestValue> inputs{TestValue("auth_token", {std::string(auth_token)}, {1}),
|
||||
TestValue("transcribe0/file", audio_data, {narrow<int64_t>(audio_data.size())}),
|
||||
TestValue("transcribe0/filename", {std::string("audio.mp3")}, {1})};
|
||||
|
||||
std::vector<std::string> expected_output = {"Take some Pepto-Bismol, get dressed, and come on over here."};
|
||||
std::vector<TestValue> outputs{TestValue("transcription", expected_output, {1})};
|
||||
|
||||
OutputValidator find_strings_in_output =
|
||||
[](size_t output_idx, Ort::Value& actual, TestValue expected) {
|
||||
std::vector<std::string> output_string;
|
||||
GetTensorMutableDataString(Ort::GetApi(), actual, output_string);
|
||||
|
||||
ASSERT_EQ(output_string.size(), 1) << "Expected the Whisper response to be a single string with json";
|
||||
const auto& expected_substring = expected.values_string[0];
|
||||
if (output_string[0].find(expected_substring) == std::string::npos) {
|
||||
FAIL() << "'" << expected_substring << "' was not found in output " << output_string[0];
|
||||
}
|
||||
};
|
||||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output);
|
||||
|
||||
// use optional 'prompt' input to mis-spell Pepto-Bismol in response
|
||||
std::string prompt = "Peptoe-Bismole";
|
||||
inputs.push_back(TestValue("transcribe0/prompt", {prompt}, {1}));
|
||||
outputs[0].values_string[0] = "Take some Peptoe-Bismole, get dressed, and come on over here.";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output);
|
||||
}
|
||||
|
||||
#endif
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
#include "test_utils.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ort_extensions::test;
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "utils.hpp"
|
||||
#include <fstream>
|
||||
|
||||
namespace ort_extensions {
|
||||
namespace test {
|
||||
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename) {
|
||||
using namespace std;
|
||||
ifstream ifs(filename, ios::binary | ios::ate);
|
||||
ifstream::pos_type pos = ifs.tellg();
|
||||
|
||||
std::vector<uint8_t> input_bytes(pos);
|
||||
ifs.seekg(0, ios::beg);
|
||||
// we want uint8_t values so reinterpret_cast so we don't have to read chars and copy to uint8_t after.
|
||||
ifs.read(reinterpret_cast<char*>(input_bytes.data()), pos);
|
||||
|
||||
return input_bytes;
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
#include <vector>
|
||||
|
||||
namespace ort_extensions {
|
||||
namespace test {
|
||||
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename);
|
||||
|
||||
} // namespace test
|
||||
} // namespace ort_extensions
|
|
@ -12,6 +12,7 @@ script_dir = os.path.dirname(os.path.realpath(__file__))
|
|||
ort_ext_root = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
test_data_dir = os.path.join(ort_ext_root, "test", "data", "azure")
|
||||
|
||||
|
||||
class TestAzureOps(unittest.TestCase):
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -21,7 +22,7 @@ class TestAzureOps(unittest.TestCase):
|
|||
self.__opt = SessionOptions()
|
||||
self.__opt.register_custom_ops_library(get_library_path())
|
||||
|
||||
def test_addf(self):
|
||||
def test_add_f(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "triton_addf.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider"])
|
||||
|
@ -36,7 +37,7 @@ class TestAzureOps(unittest.TestCase):
|
|||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue(np.allclose(out, [5,5,5,5]))
|
||||
|
||||
def testAddf8(self):
|
||||
def test_add_f8(self):
|
||||
if self.__enabled:
|
||||
opt = SessionOptions()
|
||||
opt.register_custom_ops_library(get_library_path())
|
||||
|
@ -53,7 +54,7 @@ class TestAzureOps(unittest.TestCase):
|
|||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue(np.allclose(out, [5,5,5,5]))
|
||||
|
||||
def testAddi4(self):
|
||||
def test_add_i4(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "triton_addi4.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider"])
|
||||
|
@ -68,7 +69,7 @@ class TestAzureOps(unittest.TestCase):
|
|||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue(np.allclose(out, [5,5,5,5]))
|
||||
|
||||
def testAnd(self):
|
||||
def test_and(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "triton_and.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider"])
|
||||
|
@ -83,7 +84,7 @@ class TestAzureOps(unittest.TestCase):
|
|||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue(np.allclose(out, [True, False]))
|
||||
|
||||
def testStr(self):
|
||||
def test_str(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "triton_str.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider"])
|
||||
|
@ -98,7 +99,7 @@ class TestAzureOps(unittest.TestCase):
|
|||
self.assertEqual(outs[0], ['this is the input'])
|
||||
self.assertEqual(outs[1], ['this is the input'])
|
||||
|
||||
def testOpenAiAudio(self):
|
||||
def test_open_ai_audio(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "openai_audio.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider"])
|
||||
|
@ -110,14 +111,14 @@ class TestAzureOps(unittest.TestCase):
|
|||
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
|
||||
ort_inputs = {
|
||||
"auth_token": auth_token,
|
||||
"model": model,
|
||||
"model_name": model,
|
||||
"response_format": response_format,
|
||||
"file": audio_blob
|
||||
"file": audio_blob,
|
||||
}
|
||||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertEqual(out, ['This is a test recording to test the Whisper model.\n'])
|
||||
|
||||
def testOpenAiChat(self):
|
||||
def test_open_ai_chat(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "openai_chat.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider"])
|
||||
|
@ -130,23 +131,6 @@ class TestAzureOps(unittest.TestCase):
|
|||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue('assist' in out[0])
|
||||
|
||||
def testOpenAiEmb(self):
|
||||
if self.__enabled:
|
||||
opt = SessionOptions()
|
||||
opt.register_custom_ops_library(get_library_path())
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "openai_embedding.onnx"),
|
||||
opt, providers=["CPUExecutionProvider"])
|
||||
auth_token = np.array([os.getenv('EMB', '')])
|
||||
text = np.array(['{\"input\": \"The food was delicious and the waiter...\", \"model\": \"text-embedding-ada-002\"}'])
|
||||
|
||||
ort_inputs = {
|
||||
"auth_token": auth_token,
|
||||
"text": text,
|
||||
}
|
||||
|
||||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue('text-embedding-ada' in out[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -154,7 +154,7 @@ def _parse_arguments():
|
|||
# WebAssembly options
|
||||
parser.add_argument("--wasm", action="store_true", help="Build for WebAssembly")
|
||||
parser.add_argument("--emsdk_path", type=Path,
|
||||
help="Specify path to emscripten SDK. Setup manually with: "
|
||||
help="Specify path to emscripten SDK. Setup manually with: "
|
||||
" git clone https://github.com/emscripten-core/emsdk")
|
||||
parser.add_argument("--emsdk_version", default="3.1.26", help="Specify version of emsdk")
|
||||
|
||||
|
@ -404,11 +404,11 @@ def _generate_build_tree(cmake_path: Path,
|
|||
_run_subprocess(cmake_args + [f"-DCMAKE_BUILD_TYPE={config}"], cwd=config_build_dir)
|
||||
|
||||
|
||||
def clean_targets(cmake_path, build_dir: Path, configs: Set[str]):
|
||||
def clean_targets(cmake_path: Path, build_dir: Path, configs: Set[str]):
|
||||
for config in configs:
|
||||
log.info("Cleaning targets for %s configuration", config)
|
||||
build_dir2 = _get_build_config_dir(build_dir, config)
|
||||
cmd_args = [cmake_path, "--build", build_dir2, "--config", config, "--target", "clean"]
|
||||
cmd_args = [str(cmake_path), "--build", str(build_dir2), "--config", config, "--target", "clean"]
|
||||
|
||||
_run_subprocess(cmd_args)
|
||||
|
||||
|
@ -564,6 +564,10 @@ def main():
|
|||
cmake_path = _resolve_executable_path(
|
||||
args.cmake_path,
|
||||
resolution_failure_allowed=(not (args.update or args.clean or args.build)))
|
||||
|
||||
if not cmake_path:
|
||||
raise UsageError("Unable to find CMake executable. Please specify --cmake-path.")
|
||||
|
||||
build_dir = args.build_dir
|
||||
|
||||
if args.update or args.build:
|
||||
|
|
|
@ -13,4 +13,4 @@ if "%1" == "install" (
|
|||
del "%ProgramFiles%\Miniconda3\python3.exe"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
if [[ "$OCOS_ENABLE_AZURE" == "1" ]]
|
||||
then
|
||||
if [[ "$1" == "many64" ]]; then
|
||||
yum -y install openssl openssl-devel wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /usr/lib64/libhttpclient.ldscript
|
||||
yum -y install openssl-devel
|
||||
elif [[ "$1" == "many86" ]]; then
|
||||
yum -y install openssl openssl-devel wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /opt/rh/devtoolset-10/root/usr/lib/libhttpclient.ldscript
|
||||
yum -y install openssl-devel
|
||||
else # for musllinux
|
||||
apk add openssl-dev wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /usr/lib/libhttpclient.ldscript
|
||||
apk add openssl-dev
|
||||
fi
|
||||
fi
|
||||
|
|
Загрузка…
Ссылка в новой задаче