diff --git a/.pipelines/wheels_win32.yml b/.pipelines/wheels_win32.yml index 11d55c61..1e6732b1 100644 --- a/.pipelines/wheels_win32.yml +++ b/.pipelines/wheels_win32.yml @@ -1,9 +1,16 @@ +parameters: +- name: ExtraEnv + displayName: 'Extra env variable set to CIBW_ENVIRONMENT' + type: string + default: 'None=None' + jobs: - job: windows timeoutInMinutes: 120 pool: {name: 'onnxruntime-Win-CPU-2022'} variables: CIBW_BUILD: "cp3{8,9,10,11}-*amd64" + CIBW_ENVIRONMENT: "${{ parameters.ExtraEnv }}" steps: - task: UsePythonVersion@0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 2663cb71..1259a33e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -380,8 +380,22 @@ endif() if(OCOS_ENABLE_AZURE) # Azure endpoint invokers - include(triton) file(GLOB TARGET_SRC_AZURE "operators/azure/*.cc" "operators/azure/*.h*") + + if (ANDROID) + include(curl) + + # make sure these work + find_package(OpenSSL REQUIRED) + find_package(CURL REQUIRED) + + # exclude triton + list(FILTER TARGET_SRC_AZURE EXCLUDE REGEX ".*triton.*") + else() + add_compile_definitions(AZURE_INVOKERS_ENABLE_TRITON) + include(triton) + endif() + list(APPEND TARGET_SRC ${TARGET_SRC_AZURE}) endif() @@ -454,11 +468,13 @@ if(_HAS_TOKENIZER) endif() if(OCOS_ENABLE_TF_STRING) - target_include_directories(noexcep_operators PUBLIC - ${googlere2_SOURCE_DIR} - ${farmhash_SOURCE_DIR}/src) + target_include_directories(noexcep_operators PUBLIC ${farmhash_SOURCE_DIR}/src) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TF_STRING NOMINMAX FARMHASH_NO_BUILTIN_EXPECT FARMHASH_DEBUG=0) - target_link_libraries(noexcep_operators PRIVATE re2) + + if(OCOS_ENABLE_RE2_REGEX) + target_include_directories(noexcep_operators PUBLIC ${googlere2_SOURCE_DIR}) + target_link_libraries(noexcep_operators PRIVATE re2) + endif() endif() if(OCOS_ENABLE_RE2_REGEX) @@ -526,6 +542,10 @@ if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER) list(APPEND ocos_libraries nlohmann_json::nlohmann_json) endif() +if(OCOS_ENABLE_AZURE) + list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_AZURE) +endif() + target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR}) list(APPEND ocos_libraries Microsoft.GSL::GSL) @@ -577,9 +597,79 @@ else() set(_BUILD_SHARED_LIBRARY TRUE) endif() +if(OCOS_ENABLE_AZURE) + if (ANDROID) + # the find_package calls were made immediately after `include(curl)` so we know CURL and OpenSSL are available + target_link_libraries(ocos_operators PUBLIC CURL::libcurl OpenSSL::Crypto OpenSSL::SSL) + elseif(IOS) + # TODO + else() + # we need files from the triton client (e.g. curl header on linux) to be available for the ocos_operators build. + # add a dependency so the fetch and build of triton happens first. + add_dependencies(ocos_operators triton) + target_include_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/include) + + target_link_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/lib) + target_link_directories(ocos_operators PUBLIC ${triton_INSTALL_DIR}/lib64) + + if (WIN32) + if (ocos_target_platform STREQUAL "AMD64") + set(vcpkg_target_platform "x64") + else() + set(vcpkg_target_platform ${ocos_target_platform}) + endif() + + # As per https://curl.se/docs/faq.html#Link_errors_when_building_libcur we need to set CURL_STATICLIB. + target_compile_definitions(ocos_operators PRIVATE CURL_STATICLIB) + target_include_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/include) + + if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(curl_LIB_NAME "libcurl-d") + target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/debug/lib) + else() + set(curl_LIB_NAME "libcurl") + target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static-md/lib) + endif() + + target_link_libraries(ocos_operators PUBLIC httpclient_static ${curl_LIB_NAME} ws2_32 crypt32 Wldap32) + else() + find_package(ZLIB REQUIRED) + + # If finding the OpenSSL or CURL package fails for a local build you can install libcurl4-openssl-dev. + # See also info on triton client dependencies here: https://github.com/triton-inference-server/client + find_package(OpenSSL REQUIRED) + find_package(CURL) + + if (CURL_FOUND) + message(STATUS "Found CURL package") + set(libcurl_target CURL::libcurl) + else() + # curl is coming from triton but as that's an external project it isn't built yet and we have to add + # paths and library names instead of cmake targets. + message(STATUS "Using CURL build from triton client. Once built it should be in ${triton_THIRD_PARTY_DIR}/curl") + target_include_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/include) + + # Install is to 'lib' except on CentOS (which is used for the manylinux build of the python wheel). + # Side note: we have to patch the triton client CMakeLists.txt to only use 'lib64' for 64-bit builds otherwise + # the build of the 32-bit python wheel fails with CURL not being found due to the invalid library + # directory name. + target_link_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/lib) + target_link_directories(ocos_operators PUBLIC ${triton_THIRD_PARTY_DIR}/curl/lib64) + + if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(libcurl_target "curl-d") + else() + set(libcurl_target "curl") + endif() + endif() + + target_link_libraries(ocos_operators PUBLIC httpclient_static ${libcurl_target} OpenSSL::Crypto OpenSSL::SSL ZLIB::ZLIB) + endif() + endif() +endif() + target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS}) -target_include_directories(ortcustomops PUBLIC - "$") +target_include_directories(ortcustomops PUBLIC "$") target_link_libraries(ortcustomops PUBLIC ocos_operators) if(_BUILD_SHARED_LIBRARY) @@ -592,8 +682,7 @@ if(_BUILD_SHARED_LIBRARY) set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") endif() - target_include_directories(extensions_shared PUBLIC - "$") + target_include_directories(extensions_shared PUBLIC "$") target_link_libraries(extensions_shared PRIVATE ortcustomops) set_target_properties(extensions_shared PROPERTIES OUTPUT_NAME "ortextensions") if(MSVC AND ocos_target_platform MATCHES "x86|x64") @@ -672,8 +761,10 @@ if(OCOS_ENABLE_CTEST) list(APPEND LINUX_CC_FLAGS stdc++fs -pthread) endif() - file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc") + file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc" "${TEST_SRC_DIR}/shared_test/*.hpp") add_executable(extensions_test ${shared_TEST_SRC}) + target_compile_definitions(extensions_test PUBLIC ${OCOS_COMPILE_DEFINITIONS}) + standardize_output_folder(extensions_test) target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS} "$") @@ -706,30 +797,3 @@ if(OCOS_ENABLE_CTEST) add_test(NAME extensions_test COMMAND $) endif() endif() - -if(OCOS_ENABLE_AZURE) - - add_dependencies(ocos_operators triton) - target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include) - target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64) - - if (ocos_target_platform STREQUAL "AMD64") - set(vcpkg_target_platform "x86") - else() - set(vcpkg_target_platform ${ocos_target_platform}) - endif() - - if (WIN32) - - target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib) - target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32) - - else() - - find_package(ZLIB REQUIRED) - find_package(OpenSSL REQUIRED) - target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL) - - endif() #if (WIN32) - -endif() \ No newline at end of file diff --git a/cmake/ext_ortlib.cmake b/cmake/ext_ortlib.cmake index 31837760..5405bb7a 100644 --- a/cmake/ext_ortlib.cmake +++ b/cmake/ext_ortlib.cmake @@ -40,15 +40,28 @@ else() endif() endif() - message(STATUS "ONNX Runtime URL suffix: ${ONNXRUNTIME_URL}") + if (ANDROID) + set(ort_fetch_URL "https://repo1.maven.org/maven2/com/microsoft/onnxruntime/onnxruntime-android/${ONNXRUNTIME_VER}/onnxruntime-android-${ONNXRUNTIME_VER}.aar") + else() + set(ort_fetch_URL "https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}") + endif() + + message(STATUS "ONNX Runtime URL: ${ort_fetch_URL}") FetchContent_Declare( onnxruntime - URL https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL} + URL ${ort_fetch_URL} ) FetchContent_makeAvailable(onnxruntime) - set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include) - set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib) + + if (ANDROID) + set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/headers) + set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/jni/${ANDROID_ABI}) + message(STATUS "Android onnxruntime inc=${ONNXRUNTIME_INCLUDE_DIR} lib=${ONNXRUNTIME_LIB_DIR}") + else() + set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include) + set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib) + endif() endif() if(NOT EXISTS ${ONNXRUNTIME_INCLUDE_DIR}) diff --git a/cmake/externals/curl.cmake b/cmake/externals/curl.cmake new file mode 100644 index 00000000..d57af5b1 --- /dev/null +++ b/cmake/externals/curl.cmake @@ -0,0 +1,11 @@ +if (ANDROID) + set(PREBUILD_OUTPUT_PATH "${PROJECT_SOURCE_DIR}/prebuild/openssl_for_ios_and_android/output/android") + set(OPENSSL_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/openssl-${ANDROID_ABI}") + set(NGHTTP2_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/nghttp2-${ANDROID_ABI}") + set(CURL_ROOT_DIR "${PREBUILD_OUTPUT_PATH}/curl-${ANDROID_ABI}") + + # Update CMAKE_FIND_ROOT_PATH so find_package/find_library can find these builds + list(APPEND CMAKE_FIND_ROOT_PATH "${OPENSSL_ROOT_DIR}" ) + list(APPEND CMAKE_FIND_ROOT_PATH "${NGHTTP2_ROOT_DIR}" ) + list(APPEND CMAKE_FIND_ROOT_PATH "${CURL_ROOT_DIR}" ) +endif() \ No newline at end of file diff --git a/cmake/externals/triton.cmake b/cmake/externals/triton.cmake index 31826da5..06c34cc8 100644 --- a/cmake/externals/triton.cmake +++ b/cmake/externals/triton.cmake @@ -1,32 +1,41 @@ include(ExternalProject) -if (WIN32) +set(triton_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton) +set(triton_INSTALL_DIR ${triton_PREFIX}/install) +if (WIN32) if (ocos_target_platform STREQUAL "AMD64") - set(vcpkg_target_platform "x86") + set(vcpkg_target_platform "x64") else() set(vcpkg_target_platform ${ocos_target_platform}) endif() - ExternalProject_Add(vcpkg - GIT_REPOSITORY https://github.com/microsoft/vcpkg.git - GIT_TAG 2023.06.20 - PREFIX vcpkg - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src - BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-build - CONFIGURE_COMMAND "" - INSTALL_COMMAND "" - UPDATE_COMMAND "" - BUILD_COMMAND "/bootstrap-vcpkg.bat") + set(vcpkg_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg) - set(VCPKG_SRC ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src) - set(ENV{VCPKG_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src) + ExternalProject_Add(vcpkg + GIT_REPOSITORY https://github.com/microsoft/vcpkg.git + GIT_TAG 2023.06.20 + PREFIX ${vcpkg_PREFIX} + CONFIGURE_COMMAND "" + INSTALL_COMMAND "" + UPDATE_COMMAND "" + BUILD_COMMAND "/bootstrap-vcpkg.bat") + + ExternalProject_Get_Property(vcpkg SOURCE_DIR BINARY_DIR) + set(VCPKG_SRC ${SOURCE_DIR}) + message(status "vcpkg source dir: " ${VCPKG_SRC}) + + # set the environment variable so that the vcpkg.cmake file can find the vcpkg root directory + set(ENV{VCPKG_ROOT} ${VCPKG_SRC}) message(STATUS "VCPKG_SRC: " ${VCPKG_SRC}) - message(STATUS "VCPKG_ROOT: " $ENV{VCPKG_ROOT}) + message(STATUS "ENV{VCPKG_ROOT}: " $ENV{VCPKG_ROOT}) + # NOTE: The VCPKG_ROOT environment variable isn't propagated to an add_custom_command target, so specify --vcpkg-root + # here and in the vcpkg_install function add_custom_command( - COMMAND ${VCPKG_SRC}/vcpkg integrate install + COMMAND ${CMAKE_COMMAND} -E echo ${VCPKG_SRC}/vcpkg integrate --vcpkg-root=$ENV{VCPKG_ROOT} install + COMMAND ${VCPKG_SRC}/vcpkg integrate --vcpkg-root=$ENV{VCPKG_ROOT} install COMMAND ${CMAKE_COMMAND} -E touch vcpkg_integrate.stamp OUTPUT vcpkg_integrate.stamp DEPENDS vcpkg @@ -35,77 +44,97 @@ if (WIN32) add_custom_target(vcpkg_integrate ALL DEPENDS vcpkg_integrate.stamp) set(VCPKG_DEPENDENCIES "vcpkg_integrate") + # use static-md so it adjusts for debug/release CRT + # https://stackoverflow.com/questions/67258905/vcpkg-difference-between-windows-windows-static-and-other function(vcpkg_install PACKAGE_NAME) add_custom_command( - OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO - COMMAND ${VCPKG_SRC}/vcpkg install ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static --vcpkg-root=${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src + OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static-md/BUILD_INFO + COMMAND ${CMAKE_COMMAND} -E echo ${VCPKG_SRC}/vcpkg install --vcpkg-root=$ENV{VCPKG_ROOT} + ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static-md + COMMAND ${VCPKG_SRC}/vcpkg install --vcpkg-root=$ENV{VCPKG_ROOT} + ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static-md WORKING_DIRECTORY ${VCPKG_SRC} DEPENDS vcpkg_integrate) - add_custom_target(get${PACKAGE_NAME} + add_custom_target( + get${PACKAGE_NAME} ALL - DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO) + DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static-md/BUILD_INFO) list(APPEND VCPKG_DEPENDENCIES "get${PACKAGE_NAME}") set(VCPKG_DEPENDENCIES ${VCPKG_DEPENDENCIES} PARENT_SCOPE) endfunction() - vcpkg_install(openssl) - vcpkg_install(openssl-windows) vcpkg_install(rapidjson) - vcpkg_install(re2) - vcpkg_install(boost-interprocess) - vcpkg_install(boost-stacktrace) - vcpkg_install(pthread) - vcpkg_install(b64) + vcpkg_install(openssl) + vcpkg_install(curl) - add_dependencies(getb64 getpthread) - add_dependencies(getpthread getboost-stacktrace) - add_dependencies(getboost-stacktrace getboost-interprocess) - add_dependencies(getboost-interprocess getre2) - add_dependencies(getre2 getrapidjson) - add_dependencies(getrapidjson getopenssl-windows) - add_dependencies(getopenssl-windows getopenssl) - - ExternalProject_Add(triton - GIT_REPOSITORY https://github.com/triton-inference-server/client.git - GIT_TAG r23.05 - PREFIX triton - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src - BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build - CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF - INSTALL_COMMAND "" - UPDATE_COMMAND "") - - add_dependencies(triton ${VCPKG_DEPENDENCIES}) + set(triton_extra_cmake_args -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static-md + -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake) + set(triton_patch_command "") + set(triton_dependencies ${VCPKG_DEPENDENCIES}) else() + # RapidJSON 1.1.0 (released in 2016) is compatible with the triton build. Later code is not compatible without + # patching due to the change in variable name for the include dir from RAPIDJSON_INCLUDE_DIRS to + # RapidJSON_INCLUDE_DIRS in the generated cmake file used by find_package: + # https://github.com/Tencent/rapidjson/commit/b91c515afea9f0ba6a81fc670889549d77c83db3 + # The triton code here https://github.com/triton-inference-server/common/blob/main/CMakeLists.txt is using + # RAPIDJSON_INCLUDE_DIRS so the build fails if a newer RapidJSON version is used. It will find the package but the + # include path will be wrong so the build error is delayed/misleading and non-trivial to understand/resolve. + set(RapidJSON_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson) + set(RapidJSON_INSTALL_DIR ${RapidJSON_PREFIX}/install) + ExternalProject_Add(RapidJSON + PREFIX ${RapidJSON_PREFIX} + URL https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip + URL_HASH SHA1=0fe7b4f7b83df4b3d517f4a202f3a383af7a0818 + CMAKE_ARGS -DRAPIDJSON_BUILD_DOC=OFF + -DRAPIDJSON_BUILD_EXAMPLES=OFF + -DRAPIDJSON_BUILD_TESTS=OFF + -DRAPIDJSON_HAS_STDSTRING=ON + -DRAPIDJSON_USE_MEMBERSMAP=ON + -DCMAKE_INSTALL_PREFIX=${RapidJSON_INSTALL_DIR} + ) - ExternalProject_Add(curl7 - PREFIX curl7 - GIT_REPOSITORY "https://github.com/curl/curl.git" - GIT_TAG "curl-7_86_0" - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-src - BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-build - CMAKE_ARGS -DBUILD_TESTING=OFF -DBUILD_CURL_EXE=OFF -DBUILD_SHARED_LIBS=OFF -DCURL_STATICLIB=ON -DHTTP_ONLY=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}) + ExternalProject_Get_Property(RapidJSON SOURCE_DIR BINARY_DIR) + # message(STATUS "RapidJSON src=${SOURCE_DIR} binary=${BINARY_DIR}") + # Set RapidJSON_ROOT_DIR for find_package. The required RapidJSONConfig.cmake file is generated in the binary dir + set(RapidJSON_ROOT_DIR ${BINARY_DIR}) - ExternalProject_Add(triton - GIT_REPOSITORY https://github.com/triton-inference-server/client.git - GIT_TAG r23.05 - PREFIX triton - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src - BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF - INSTALL_COMMAND "" - UPDATE_COMMAND "") + set(triton_extra_cmake_args "") + set(triton_patch_command patch --verbose -p1 -i ${PROJECT_SOURCE_DIR}/cmake/externals/triton_cmake.patch) + set(triton_dependencies RapidJSON) - add_dependencies(triton curl7) + # Patch the triton client CMakeLists.txt to fix two issues when building the python wheels with cibuildwheel, which + # uses CentOS 7. + # 1) use the full path to the version script file so 'ld' doesn't fail to find it. Looks like ld is running from the + # parent directory but not sure why the behavior differs vs. other linux builds + # e.g. building locally on Ubuntu is fine without the patch + # 2) only set the CURL lib path to 'lib64' on a 64-bit CentOS build as 'lib64' is invalid on a 32-bit OS. without + # this patch the build of the third-party libraries in the triton client fail as the CURL build is not found. -endif() #if (WIN32) + endif() #if (WIN32) -ExternalProject_Get_Property(triton SOURCE_DIR) -set(TRITON_SRC ${SOURCE_DIR}) +# Add the triton build. We just need the library so we don't install it. +# +set(triton_VERSION_TAG r23.05) +ExternalProject_Add(triton + URL https://github.com/triton-inference-server/client/archive/refs/heads/${triton_VERSION_TAG}.tar.gz + URL_HASH SHA1=b8fd2a4e09eae39c33cd04cfa9ec934e39d9afc1 + PREFIX ${triton_PREFIX} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${triton_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DTRITON_COMMON_REPO_TAG=${triton_VERSION_TAG} + -DTRITON_THIRD_PARTY_REPO_TAG=${triton_VERSION_TAG} + -DTRITON_CORE_REPO_TAG=${triton_VERSION_TAG} + -DTRITON_ENABLE_CC_HTTP=ON + -DTRITON_ENABLE_ZLIB=OFF + ${triton_extra_cmake_args} + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." + PATCH_COMMAND ${triton_patch_command} + ) -ExternalProject_Get_Property(triton BINARY_DIR) -set(TRITON_BIN ${BINARY_DIR}/binary) -set(TRITON_THIRD_PARTY ${BINARY_DIR}/third-party) +add_dependencies(triton ${triton_dependencies}) + +ExternalProject_Get_Property(triton SOURCE_DIR BINARY_DIR) +set(triton_THIRD_PARTY_DIR ${BINARY_DIR}/third-party) diff --git a/cmake/externals/triton_cmake.patch b/cmake/externals/triton_cmake.patch new file mode 100644 index 00000000..47ce1403 --- /dev/null +++ b/cmake/externals/triton_cmake.patch @@ -0,0 +1,30 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 7b11178..7749fa9 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -115,10 +115,11 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_CC_GRPC OR TRITON_ENABLE_PERF_ANALYZER + file(STRINGS /etc/os-release DISTRO REGEX "^NAME=") + string(REGEX REPLACE "NAME=\"(.*)\"" "\\1" DISTRO "${DISTRO}") + message(STATUS "Distro Name: ${DISTRO}") +- if(DISTRO STREQUAL "CentOS Linux") ++ if(DISTRO STREQUAL "CentOS Linux" AND CMAKE_SIZEOF_VOID_P EQUAL 8) + set (CURL_LIB_DIR "lib64") + endif() + endif() ++ message(STATUS "Triton client CURL_LIB_DIR=${CURL_LIB_DIR}") + + set(_cc_client_depends "") + if(${TRITON_ENABLE_CC_HTTP}) +diff --git a/src/c++/library/CMakeLists.txt b/src/c++/library/CMakeLists.txt +index bdaae25..c36dbc8 100644 +--- a/src/c++/library/CMakeLists.txt ++++ b/src/c++/library/CMakeLists.txt +@@ -320,7 +320,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER) + httpclient + PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libhttpclient.ldscript +- LINK_FLAGS "-Wl,--version-script=libhttpclient.ldscript" ++ LINK_FLAGS "-Wl,--version-script=${CMAKE_CURRENT_BINARY_DIR}/libhttpclient.ldscript" + ) + endif() # NOT WIN32 + diff --git a/operators/azure/azure_invokers.cc b/operators/azure/azure_invokers.cc index fdd6ea3c..bdb66b02 100644 --- a/operators/azure/azure_invokers.cc +++ b/operators/azure/azure_invokers.cc @@ -1,427 +1,95 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#define CURL_STATICLIB - -#include "http_client.h" -#include "curl/curl.h" #include "azure_invokers.hpp" + #include -#define MIN_SUPPORTED_ORT_VER 14 +namespace ort_extensions { -constexpr const char* kUri = "model_uri"; -constexpr const char* kModelName = "model_name"; -constexpr const char* kModelVer = "model_version"; -constexpr const char* kVerbose = "verbose"; -constexpr const char* kBinaryType = "binary_type"; +////////////////////// AzureAudioToTextInvoker ////////////////////// -struct StringBuffer { - StringBuffer() = default; - ~StringBuffer() = default; - std::stringstream ss_; -}; +AzureAudioToTextInvoker::AzureAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info) + : CurlInvoker(api, info) { + audio_format_ = TryToGetAttributeWithDefault(kAudioFormat, ""); +} -// apply the callback only when response is for sure to be a '/0' terminated string -static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) { - try { - size_t realsize = size * nmemb; - auto buffer = reinterpret_cast(userp); - buffer->ss_.write(reinterpret_cast(contents), realsize); - return realsize; - } catch (...) { - // exception caught, abort write - return CURLcode::CURLE_WRITE_ERROR; +void AzureAudioToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const { + // TODO: Validate any required input names are present + + // We don't have a way to get the output type from the custom op API. + // If there's a mismatch it will fail in the Compute when it allocates the output tensor. + if (OutputNames().size() != 1) { + ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT); } } -using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*); +void AzureAudioToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const { + // theoretically the filename the content was buffered from + static const std::string fake_filename = "audio." + audio_format_; -class CurlHandler { - public: - CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup), - headers_(nullptr, curl_slist_free_all), - from_holder_(from_, curl_formfree) { - curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L); - curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L); - curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1"); - curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L); - curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L); - curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L); - curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back); - } - ~CurlHandler() = default; + const auto& property_names = RequestPropertyNames(); - void AddHeader(const char* data) { - headers_.reset(curl_slist_append(headers_.release(), data)); - } - template - void AddForm(Args... args) { - curl_formadd(&from_, &last_, args...); - } - template - void SetOption(CURLoption opt, T val) { - curl_easy_setopt(curl_.get(), opt, val); - } - CURLcode Perform() { - SetOption(CURLOPT_HTTPHEADER, headers_.get()); - if (from_) { - SetOption(CURLOPT_HTTPPOST, from_); - } - return curl_easy_perform(curl_.get()); - } - - private: - std::unique_ptr curl_; - std::unique_ptr headers_; - curl_httppost* from_{}; - curl_httppost* last_{}; - std::unique_ptr from_holder_; -}; - -////////////////////// AzureInvoker ////////////////////// - -AzureInvoker::AzureInvoker(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - auto ver = GetActiveOrtAPIVersion(); - if (ver < MIN_SUPPORTED_ORT_VER) { - ORTX_CXX_API_THROW("Azure ops requires ort >= 1.14", ORT_RUNTIME_EXCEPTION); - } - - model_uri_ = TryToGetAttributeWithDefault(kUri, ""); - model_name_ = TryToGetAttributeWithDefault(kModelName, ""); - model_ver_ = TryToGetAttributeWithDefault(kModelVer, "0"); - verbose_ = TryToGetAttributeWithDefault(kVerbose, "0"); - OrtStatusPtr status = {}; - size_t input_count = {}; - status = api_.KernelInfo_GetInputCount(&info_, &input_count); - if (status) { - ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION); - } - - for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { - char input_name[1024] = {}; - size_t name_size = 1024; - status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size); - if (status) { - ORTX_CXX_API_THROW("failed to get input name", ORT_RUNTIME_EXCEPTION); - } - input_names_.push_back(input_name); - } - - size_t output_count = {}; - status = api_.KernelInfo_GetOutputCount(&info_, &output_count); - if (status) { - ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION); - } - - for (size_t ith_output = 0; ith_output < output_count; ++ith_output) { - char output_name[1024] = {}; - size_t name_size = 1024; - status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size); - if (status) { - ORTX_CXX_API_THROW("failed to get output name", ORT_RUNTIME_EXCEPTION); - } - output_names_.push_back(output_name); - } -} - -////////////////////// AzureAudioInvoker ////////////////////// - -AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) { - file_name_ = std::string{"non_exist."} + TryToGetAttributeWithDefault(kBinaryType, "wav"); -} - -void AzureAudioInvoker::Compute(const ortc::Variadic& inputs, ortc::Tensor& output) const { - if (inputs.Size() < 1 || - inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION); - } - - if (inputs.Size() != input_names_.size()) { - ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION); - } - - if (inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || "auth_token" != input_names_[0]) { - ORTX_CXX_API_THROW("first input must be a string of auth token", ORT_INVALID_ARGUMENT); - } - - std::string auth_token = reinterpret_cast(inputs[0]->DataRaw()); - std::string full_auth = std::string{"Authorization: Bearer "} + auth_token; - - StringBuffer string_buffer; - CurlHandler curl_handler(WriteStringCallback); - curl_handler.AddHeader(full_auth.c_str()); curl_handler.AddHeader("Content-Type: multipart/form-data"); + curl_handler.AddFormString("deployment_id", ModelName().c_str()); + // TODO: If the handling here stays the same as in OpenAIAudioToText we can create a helper function to re-use for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) { switch (inputs[ith_input]->Type()) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - curl_handler.AddForm(CURLFORM_COPYNAME, - input_names_[ith_input].c_str(), - CURLFORM_COPYCONTENTS, - inputs[ith_input]->DataRaw(), - CURLFORM_END); + curl_handler.AddFormString(property_names[ith_input].c_str(), + static_cast(inputs[ith_input]->DataRaw())); // assumes null terminated break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - curl_handler.AddForm(CURLFORM_COPYNAME, - input_names_[ith_input].data(), - CURLFORM_BUFFER, - file_name_.c_str(), - CURLFORM_BUFFERPTR, - inputs[ith_input]->DataRaw(), - CURLFORM_BUFFERLENGTH, - inputs[ith_input]->SizeInBytes(), - CURLFORM_END); + curl_handler.AddFormBuffer(property_names[ith_input].c_str(), + fake_filename.c_str(), + inputs[ith_input]->DataRaw(), + inputs[ith_input]->SizeInBytes()); break; default: ORTX_CXX_API_THROW("input must be either text or binary", ORT_RUNTIME_EXCEPTION); break; } - } // for + } +} - curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str()); - curl_handler.SetOption(CURLOPT_VERBOSE, verbose_); - curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer); +void AzureAudioToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const { + auto& string_tensor = outputs.AllocateStringTensor(0); + string_tensor.SetStringOutput(std::vector{response}, std::vector{1}); +} - auto curl_ret = curl_handler.Perform(); - if (CURLE_OK != curl_ret) { - ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL); +////////////////////// AzureTextToTextInvoker ////////////////////// + +AzureTextToTextInvoker::AzureTextToTextInvoker(const OrtApi& api, const OrtKernelInfo& info) + : CurlInvoker(api, info) { +} + +void AzureTextToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const { + if (inputs.Size() != 2 || inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + ORTX_CXX_API_THROW("Expected 2 string inputs of auth_token and text respectively", ORT_INVALID_ARGUMENT); } - output.SetStringOutput(std::vector{string_buffer.ss_.str()}, std::vector{1L}); + // We don't have a way to get the output type from the custom op API. + // If there's a mismatch it will fail in the Compute when it allocates the output tensor. + if (OutputNames().size() != 1) { + ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT); + } } -////////////////////// AzureTextInvoker ////////////////////// +void AzureTextToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const { + gsl::span input_names = InputNames(); -AzureTextInvoker::AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) { -} - -void AzureTextInvoker::Compute(std::string_view auth, std::string_view input, - ortc::Tensor& output) const { - CurlHandler curl_handler(WriteStringCallback); - StringBuffer string_buffer; - - std::string full_auth = std::string{"Authorization: Bearer "} + auth.data(); - curl_handler.AddHeader(full_auth.c_str()); + // TODO: assuming we need to create the correct json from the input text curl_handler.AddHeader("Content-Type: application/json"); - curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str()); - curl_handler.SetOption(CURLOPT_POSTFIELDS, input.data()); - curl_handler.SetOption(CURLOPT_POSTFIELDSIZE_LARGE, input.size()); - curl_handler.SetOption(CURLOPT_VERBOSE, verbose_); - curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer); - - auto curl_ret = curl_handler.Perform(); - if (CURLE_OK != curl_ret) { - ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL); - } - - output.SetStringOutput(std::vector{string_buffer.ss_.str()}, std::vector{1L}); + const auto& text_input = inputs[1]; + curl_handler.SetOption(CURLOPT_POSTFIELDS, text_input->DataRaw()); + curl_handler.SetOption(CURLOPT_POSTFIELDSIZE_LARGE, text_input->SizeInBytes()); } -////////////////////// AzureTritonInvoker ////////////////////// - -namespace tc = triton::client; - -AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info) : AzureInvoker(api, info) { - auto err = tc::InferenceServerHttpClient::Create(&triton_client_, model_uri_, verbose_ != "0"); +void AzureTextToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const { + auto& string_tensor = outputs.AllocateStringTensor(0); + string_tensor.SetStringOutput(std::vector{response}, std::vector{1}); } -std::string MapDataType(ONNXTensorElementDataType onnx_data_type) { - std::string triton_data_type; - switch (onnx_data_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - triton_data_type = "FP32"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - triton_data_type = "UINT8"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - triton_data_type = "INT8"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - triton_data_type = "UINT16"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - triton_data_type = "INT16"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - triton_data_type = "INT32"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - triton_data_type = "INT64"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - triton_data_type = "BYTES"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - triton_data_type = "BOOL"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - triton_data_type = "FP16"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - triton_data_type = "FP64"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - triton_data_type = "UINT32"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - triton_data_type = "UINT64"; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - triton_data_type = "BF16"; - break; - default: - break; - } - return triton_data_type; -} - -int8_t* CreateNonStrTensor(const std::string& data_type, - ortc::Variadic& outputs, - size_t i, - const std::vector& shape) { - if (data_type == "FP32") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "UINT8") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "INT8") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "UINT16") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "INT16") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "INT32") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "UINT32") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "INT64") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "UINT64") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "BOOL") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else if (data_type == "FP64") { - return reinterpret_cast(outputs.AllocateOutput(i, shape)); - } else { - return {}; - } -} - -#define CHECK_TRITON_ERR(ret, msg) \ - if (!ret.IsOk()) { \ - return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \ - } - -void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const { - if (inputs.Size() < 1 || - inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION); - } - - if (inputs.Size() != input_names_.size()) { - ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION); - } - - auto auth_token = reinterpret_cast(inputs[0]->DataRaw()); - std::vector> triton_input_vec; - std::vector triton_inputs; - std::vector> triton_output_vec; - std::vector triton_outputs; - tc::Error err; - - for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) { - tc::InferInput* triton_input = {}; - std::string triton_data_type = MapDataType(inputs[ith_input]->Type()); - if (triton_data_type.empty()) { - ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION); - } - - err = tc::InferInput::Create(&triton_input, input_names_[ith_input], inputs[ith_input]->Shape(), triton_data_type); - CHECK_TRITON_ERR(err, "failed to create triton input"); - triton_input_vec.emplace_back(triton_input); - - triton_inputs.push_back(triton_input); - if ("BYTES" == triton_data_type) { - const auto* string_tensor = reinterpret_cast*>(inputs[ith_input].get()); - triton_input->AppendFromString(string_tensor->Data()); - } else { - const float* data_raw = reinterpret_cast(inputs[ith_input]->DataRaw()); - size_t size_in_bytes = inputs[ith_input]->SizeInBytes(); - err = triton_input->AppendRaw(reinterpret_cast(data_raw), size_in_bytes); - CHECK_TRITON_ERR(err, "failed to append raw data to input"); - } - } - - for (size_t ith_output = 0; ith_output < output_names_.size(); ++ith_output) { - tc::InferRequestedOutput* triton_output = {}; - err = tc::InferRequestedOutput::Create(&triton_output, output_names_[ith_output]); - CHECK_TRITON_ERR(err, "failed to create triton output"); - triton_output_vec.emplace_back(triton_output); - triton_outputs.push_back(triton_output); - } - - std::unique_ptr results_ptr; - tc::InferResult* results = {}; - tc::InferOptions options(model_name_); - options.model_version_ = model_ver_; - options.client_timeout_ = 0; - - tc::Headers http_headers; - http_headers["Authorization"] = std::string{"Bearer "} + auth_token; - - err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs, - http_headers, tc::Parameters(), - tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config? - tc::InferenceServerHttpClient::CompressionType::NONE); - - results_ptr.reset(results); - CHECK_TRITON_ERR(err, "failed to do triton inference"); - - size_t output_index = 0; - auto iter = output_names_.begin(); - - while (iter != output_names_.end()) { - std::vector shape; - err = results_ptr->Shape(*iter, &shape); - CHECK_TRITON_ERR(err, "failed to get output shape"); - - std::string type; - err = results_ptr->Datatype(*iter, &type); - CHECK_TRITON_ERR(err, "failed to get output type"); - - if ("BYTES" == type) { - std::vector output_strings; - err = results_ptr->StringData(*iter, &output_strings); - CHECK_TRITON_ERR(err, "failed to get output as string"); - auto& string_tensor = outputs.AllocateStringTensor(output_index); - string_tensor.SetStringOutput(output_strings, shape); - } else { - const uint8_t* raw_data = {}; - size_t raw_size; - err = results_ptr->RawData(*iter, &raw_data, &raw_size); - CHECK_TRITON_ERR(err, "failed to get output raw data"); - auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape); - memcpy(output_raw, raw_data, raw_size); - } - - ++output_index; - ++iter; - } -} - -const std::vector& AzureInvokerLoader() { - static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker), - CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker), - CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker), - CustomAzureStruct("AzureTextInvoker", AzureTextInvoker), - CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker), - CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker), - CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker), - CustomCpuStruct("AzureTextInvoker", AzureTextInvoker) - ); - return op_loader.GetCustomOps(); -} - -FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader; +} // namespace ort_extensions diff --git a/operators/azure/azure_invokers.hpp b/operators/azure/azure_invokers.hpp index 1165aa41..db84ac2b 100644 --- a/operators/azure/azure_invokers.hpp +++ b/operators/azure/azure_invokers.hpp @@ -1,43 +1,54 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #pragma once #include "ocos.h" +#include "curl_invoker.hpp" -struct AzureInvoker : public BaseKernel { - AzureInvoker(const OrtApi& api, const OrtKernelInfo& info); +namespace ort_extensions { - protected: - ~AzureInvoker() = default; - std::string model_uri_; - std::string model_name_; - std::string model_ver_; - std::string verbose_; - std::vector input_names_; - std::vector output_names_; -}; +////////////////////// AzureAudioToTextInvoker ////////////////////// -struct AzureAudioInvoker : public AzureInvoker { - AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info); - void Compute(const ortc::Variadic& inputs, ortc::Tensor& output) const; +/// +/// Azure Audio to Text +/// Input: auth_token {string}, ??? (Update when AOAI endpoint is defined) +/// Output: text {string} +/// +class AzureAudioToTextInvoker : public CurlInvoker { + public: + AzureAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info); + + void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const { + // use impl from CurlInvoker + ComputeImpl(inputs, outputs); + } private: - std::string file_name_; + void ValidateInputs(const ortc::Variadic& inputs) const override; + void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override; + void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override; + + static constexpr const char* kAudioFormat = "audio_format"; + std::string audio_format_; }; -struct AzureTextInvoker : public AzureInvoker { - AzureTextInvoker(const OrtApi& api, const OrtKernelInfo& info); - void Compute(std::string_view auth, std::string_view input, ortc::Tensor& output) const; +////////////////////// AzureTextToTextInvoker ////////////////////// + +/// Azure Text to Text +/// Input: auth_token {string}, text {string} +/// Output: text {string} +struct AzureTextToTextInvoker : public CurlInvoker { + AzureTextToTextInvoker(const OrtApi& api, const OrtKernelInfo& info); + + void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const { + // use impl from CurlInvoker + ComputeImpl(inputs, outputs); + } private: - std::string binary_type_; + void ValidateInputs(const ortc::Variadic& inputs) const override; + void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override; + void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override; }; -struct AzureTritonInvoker : public AzureInvoker { - AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info); - void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const; - - private: - std::unique_ptr triton_client_; -}; +} // namespace ort_extensions diff --git a/operators/azure/azure_triton_invoker.cc b/operators/azure/azure_triton_invoker.cc new file mode 100644 index 00000000..8bb4fadf --- /dev/null +++ b/operators/azure/azure_triton_invoker.cc @@ -0,0 +1,200 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "azure_triton_invoker.hpp" + +////////////////////// AzureTritonInvoker ////////////////////// + +namespace tc = triton::client; + +namespace ort_extensions { + +AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info) + : CloudBaseKernel(api, info) { + auto err = tc::InferenceServerHttpClient::Create(&triton_client_, ModelUri(), Verbose()); +} + +std::string MapDataType(ONNXTensorElementDataType onnx_data_type) { + std::string triton_data_type; + switch (onnx_data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + triton_data_type = "FP32"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + triton_data_type = "UINT8"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + triton_data_type = "INT8"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + triton_data_type = "UINT16"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + triton_data_type = "INT16"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + triton_data_type = "INT32"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + triton_data_type = "INT64"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + triton_data_type = "BYTES"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + triton_data_type = "BOOL"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + triton_data_type = "FP16"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + triton_data_type = "FP64"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + triton_data_type = "UINT32"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + triton_data_type = "UINT64"; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + triton_data_type = "BF16"; + break; + default: + break; + } + return triton_data_type; +} + +int8_t* CreateNonStrTensor(const std::string& data_type, + ortc::Variadic& outputs, + size_t i, + const std::vector& shape) { + if (data_type == "FP32") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "UINT8") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "INT8") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "UINT16") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "INT16") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "INT32") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "UINT32") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "INT64") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "UINT64") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "BOOL") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else if (data_type == "FP64") { + return reinterpret_cast(outputs.AllocateOutput(i, shape)); + } else { + return {}; + } +} + +#define CHECK_TRITON_ERR(ret, msg) \ + if (!ret.IsOk()) { \ + return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \ + } + +void AzureTritonInvoker::Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const { + auto auth_token = GetAuthToken(inputs); + + gsl::span input_names = InputNames(); + if (inputs.Size() != input_names.size()) { + ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION); + } + + std::vector> triton_input_vec; + std::vector triton_inputs; + std::vector> triton_output_vec; + std::vector triton_outputs; + tc::Error err; + + const auto& property_names = RequestPropertyNames(); + + for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) { + tc::InferInput* triton_input = {}; + std::string triton_data_type = MapDataType(inputs[ith_input]->Type()); + if (triton_data_type.empty()) { + ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION); + } + + err = tc::InferInput::Create(&triton_input, property_names[ith_input], inputs[ith_input]->Shape(), + triton_data_type); + CHECK_TRITON_ERR(err, "failed to create triton input"); + triton_input_vec.emplace_back(triton_input); + + triton_inputs.push_back(triton_input); + if ("BYTES" == triton_data_type) { + const auto* string_tensor = reinterpret_cast*>(inputs[ith_input].get()); + triton_input->AppendFromString(string_tensor->Data()); + } else { + const float* data_raw = reinterpret_cast(inputs[ith_input]->DataRaw()); + size_t size_in_bytes = inputs[ith_input]->SizeInBytes(); + err = triton_input->AppendRaw(reinterpret_cast(data_raw), size_in_bytes); + CHECK_TRITON_ERR(err, "failed to append raw data to input"); + } + } + + gsl::span output_names = OutputNames(); + for (size_t ith_output = 0; ith_output < output_names.size(); ++ith_output) { + tc::InferRequestedOutput* triton_output = {}; + err = tc::InferRequestedOutput::Create(&triton_output, output_names[ith_output]); + CHECK_TRITON_ERR(err, "failed to create triton output"); + triton_output_vec.emplace_back(triton_output); + triton_outputs.push_back(triton_output); + } + + std::unique_ptr results_ptr; + tc::InferResult* results = {}; + tc::InferOptions options(ModelName()); + options.model_version_ = ModelVersion(); + options.client_timeout_ = 0; + + tc::Headers http_headers; + http_headers["Authorization"] = std::string{"Bearer "} + auth_token; + + err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs, + http_headers, tc::Parameters(), + tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config? + tc::InferenceServerHttpClient::CompressionType::NONE); + + results_ptr.reset(results); + CHECK_TRITON_ERR(err, "failed to do triton inference"); + + size_t output_index = 0; + + for (const auto& output_name : output_names) { + std::vector shape; + err = results_ptr->Shape(output_name, &shape); + CHECK_TRITON_ERR(err, "failed to get output shape"); + + std::string type; + err = results_ptr->Datatype(output_name, &type); + CHECK_TRITON_ERR(err, "failed to get output type"); + + if ("BYTES" == type) { + std::vector output_strings; + err = results_ptr->StringData(output_name, &output_strings); + CHECK_TRITON_ERR(err, "failed to get output as string"); + auto& string_tensor = outputs.AllocateStringTensor(output_index); + string_tensor.SetStringOutput(output_strings, shape); + } else { + const uint8_t* raw_data = {}; + size_t raw_size; + err = results_ptr->RawData(output_name, &raw_data, &raw_size); + CHECK_TRITON_ERR(err, "failed to get output raw data"); + auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape); + memcpy(output_raw, raw_data, raw_size); + } + + ++output_index; + } +} + +} // namespace ort_extensions diff --git a/operators/azure/azure_triton_invoker.hpp b/operators/azure/azure_triton_invoker.hpp new file mode 100644 index 00000000..1d197e42 --- /dev/null +++ b/operators/azure/azure_triton_invoker.hpp @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cloud_base_kernel.hpp" +#include "http_client.h" // triton + +namespace ort_extensions { + +struct AzureTritonInvoker : public CloudBaseKernel { + AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info); + void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const; + + private: + std::unique_ptr triton_client_; +}; +} // namespace ort_extensions diff --git a/operators/azure/cloud_base_kernel.cc b/operators/azure/cloud_base_kernel.cc new file mode 100644 index 00000000..5fda1f43 --- /dev/null +++ b/operators/azure/cloud_base_kernel.cc @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cloud_base_kernel.hpp" + +#include + +namespace ort_extensions { +CloudBaseKernel::CloudBaseKernel(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { + auto ver = GetActiveOrtAPIVersion(); + if (ver < kMinimumSupportedOrtVersion) { + ORTX_CXX_API_THROW("Azure custom operators require onnxruntime version >= 1.14", ORT_RUNTIME_EXCEPTION); + } + + // require model uri. other properties are optional + // Custom op implementation can allow user to override attributes via inputs + if (!TryToGetAttribute(kUri, model_uri_)) { + ORTX_CXX_API_THROW("Required " + model_uri_ + " attribute was not found", ORT_RUNTIME_EXCEPTION); + } + + model_name_ = TryToGetAttributeWithDefault(kModelName, ""); + model_ver_ = TryToGetAttributeWithDefault(kModelVer, "0"); + verbose_ = TryToGetAttributeWithDefault(kVerbose, "0") != "0"; + + OrtStatusPtr status{}; + size_t input_count{}; + status = api_.KernelInfo_GetInputCount(&info_, &input_count); + if (status) { + ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION); + } + + input_names_.reserve(input_count); + property_names_.reserve(input_count); + + for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { + char input_name[1024]{}; + size_t name_size = 1024; + status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size); + if (status) { + ORTX_CXX_API_THROW("failed to get name for input " + std::to_string(ith_input), ORT_RUNTIME_EXCEPTION); + } + + input_names_.push_back(input_name); + property_names_.push_back(GetPropertyNameFromInputName(input_name)); + } + + if (input_names_[0] != "auth_token") { + ORTX_CXX_API_THROW("first input name must be 'auth_token'", ORT_INVALID_ARGUMENT); + } + + size_t output_count = {}; + status = api_.KernelInfo_GetOutputCount(&info_, &output_count); + if (status) { + ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION); + } + + output_names_.reserve(output_count); + for (size_t ith_output = 0; ith_output < output_count; ++ith_output) { + char output_name[1024]{}; + size_t name_size = 1024; + status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size); + if (status) { + ORTX_CXX_API_THROW("failed to get name for output " + std::to_string(ith_output), ORT_RUNTIME_EXCEPTION); + } + output_names_.push_back(output_name); + } +} + +std::string CloudBaseKernel::GetAuthToken(const ortc::Variadic& inputs) const { + if (inputs.Size() < 1 || + inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + ORTX_CXX_API_THROW("auth_token string is required to be the first input", ORT_INVALID_ARGUMENT); + } + + std::string auth_token{static_cast(inputs[0]->DataRaw())}; + return auth_token; +} + +/*static */ std::string CloudBaseKernel::GetPropertyNameFromInputName(const std::string& input_name) { + auto idx = input_name.find_last_of('/'); + if (idx == std::string::npos) { + return input_name; + } + + if (idx == input_name.length() - 1) { + ORTX_CXX_API_THROW("Input name cannot end with '/'. Invalid input:" + input_name, ORT_INVALID_ARGUMENT); + } + + return input_name.substr(idx + 1); // return text after the '/' +} + +} // namespace ort_extensions diff --git a/operators/azure/cloud_base_kernel.hpp b/operators/azure/cloud_base_kernel.hpp new file mode 100644 index 00000000..2b51c0cf --- /dev/null +++ b/operators/azure/cloud_base_kernel.hpp @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ocos.h" +#include "gsl/span" + +namespace ort_extensions { + +/// +/// Base kernel for custom ops that call cloud endpoints. +/// +class CloudBaseKernel : public BaseKernel { + protected: + CloudBaseKernel(const OrtApi& api, const OrtKernelInfo& info); + virtual ~CloudBaseKernel() = default; + + // Names of attributes the custom operator provides. + static constexpr const char* kUri = "model_uri"; // required + static constexpr const char* kModelName = "model_name"; // optional + static constexpr const char* kModelVer = "model_version"; // optional + static constexpr const char* kVerbose = "verbose"; + + static constexpr int kMinimumSupportedOrtVersion = 14; + + const std::string& ModelUri() const { return model_uri_; } + const std::string& ModelName() const { return model_name_; } + const std::string& ModelVersion() const { return model_ver_; } + bool Verbose() const { return verbose_; } + + const gsl::span InputNames() const { return input_names_; } + const gsl::span OutputNames() const { return output_names_; } + + // Request property names that are parsed from input names. 1:1 with InputNames() values. + // e.g. 'node0/prompt' -> 'prompt' and that input provides the 'prompt' property in the request to the endpoint. + // for further details. + const gsl::span RequestPropertyNames() const { return property_names_; } + + // first input is required to be auth token. validate that and return it. + std::string GetAuthToken(const ortc::Variadic& inputs) const; + + /// + /// Parse the property name to use in the request to the cloud endpoint from a node input name. + /// Value returned is text following last '/', or the entire string if no '/'. + /// e.g. 'node0/prompt' -> 'prompt' + /// + /// Node input name. + /// Request property name the input is providing data for. + static std::string GetPropertyNameFromInputName(const std::string& input_name); + + private: + std::string model_uri_; + std::string model_name_; + std::string model_ver_; + bool verbose_; + + std::vector input_names_; + std::vector property_names_; + std::vector output_names_; +}; + +} // namespace ort_extensions diff --git a/operators/azure/cloud_ops_registration.cc b/operators/azure/cloud_ops_registration.cc new file mode 100644 index 00000000..4e6773aa --- /dev/null +++ b/operators/azure/cloud_ops_registration.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ocos.h" +#include "azure_invokers.hpp" +#include "openai_invokers.hpp" + +#ifdef AZURE_INVOKERS_ENABLE_TRITON +#include "azure_triton_invoker.hpp" +#endif + +using namespace ort_extensions; + +const std::vector& AzureInvokerLoader() { + static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioToText", AzureAudioToTextInvoker), + CustomCpuStruct("AzureAudioToText", AzureAudioToTextInvoker), + CustomAzureStruct("AzureTextToText", AzureTextToTextInvoker), + CustomCpuStruct("AzureTextToText", AzureTextToTextInvoker), + CustomAzureStruct("OpenAIAudioToText", OpenAIAudioToTextInvoker), + CustomCpuStruct("OpenAIAudioToText", OpenAIAudioToTextInvoker) +#ifdef AZURE_INVOKERS_ENABLE_TRITON + , + CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker), + CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker) +#endif + ); + + return op_loader.GetCustomOps(); +} + +FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader; diff --git a/operators/azure/curl_invoker.cc b/operators/azure/curl_invoker.cc new file mode 100644 index 00000000..8cb86be5 --- /dev/null +++ b/operators/azure/curl_invoker.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "curl_invoker.hpp" + +#include // TEMP error output +#include + +namespace ort_extensions { + +// apply the callback only when response is for sure to be a '/0' terminated string +size_t CurlHandler::WriteStringCallback(char* contents, size_t element_size, size_t num_elements, void* userdata) { + try { + size_t bytes = element_size * num_elements; + std::string& buffer = *static_cast(userdata); + buffer.append(contents, bytes); + return bytes; + } catch (const std::exception& ex) { + // TODO: This should be captured/logger properly + std::cerr << ex.what() << std::endl; + return 0; + } catch (...) { + // exception caught, abort write + std::cerr << "Unknown exception caught in CurlHandler::WriteStringCallback" << std::endl; + return 0; + } +} + +CurlHandler::CurlHandler(WriteCallBack callback) : curl_(curl_easy_init(), curl_easy_cleanup), + headers_(nullptr, curl_slist_free_all), + from_holder_(from_, curl_formfree) { + CURL* curl = curl_.get(); // CURL == void* so can't dereference + + curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, 100 * 1024L); // how was this size chosen? should it be set on a per operator basis? + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "curl/7.83.1"); // should this value come from the curl src instead of being hardcoded? + curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 50L); // 50 seems like a lot if we're directly calling a specific endpoint + curl_easy_setopt(curl, CURLOPT_FTP_SKIP_PASV_IP, 1L); // what does this have to do with http requests? + curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, callback); + + // should this be configured via a node attribute? different endpoints may have different timeouts + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 15); +} + +////////////////////// CurlInvoker ////////////////////// + +CurlInvoker::CurlInvoker(const OrtApi& api, const OrtKernelInfo& info) + : CloudBaseKernel(api, info) { +} + +void CurlInvoker::ComputeImpl(const ortc::Variadic& inputs, ortc::Variadic& outputs) const { + std::string auth_token = GetAuthToken(inputs); + + if (inputs.Size() != InputNames().size()) { + ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION); + } + + // do any additional validation of the number and type of inputs/outputs + ValidateInputs(inputs); + + // set the options for the curl handler that apply to all usages + CurlHandler curl_handler(CurlHandler::WriteStringCallback); + + std::string full_auth = std::string{"Authorization: Bearer "} + auth_token; + curl_handler.AddHeader(full_auth.c_str()); + curl_handler.SetOption(CURLOPT_URL, ModelUri().c_str()); + curl_handler.SetOption(CURLOPT_VERBOSE, Verbose()); + + std::string response; + curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&response); + + SetupRequest(curl_handler, inputs); + ExecuteRequest(curl_handler); + ProcessResponse(response, outputs); +} + +void CurlInvoker::ExecuteRequest(CurlHandler& curl_handler) const { + // this is where we could add any logic required to make the request async or maybe handle retries/cancellation. + auto curl_ret = curl_handler.Perform(); + if (CURLE_OK != curl_ret) { + ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL); + } +} +} // namespace ort_extensions diff --git a/operators/azure/curl_invoker.hpp b/operators/azure/curl_invoker.hpp new file mode 100644 index 00000000..07406560 --- /dev/null +++ b/operators/azure/curl_invoker.hpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "curl/curl.h" + +#include "ocos.h" +#include "cloud_base_kernel.hpp" + +namespace ort_extensions { + +class CurlHandler { + public: + using WriteCallBack = size_t (*)(char*, size_t, size_t, void*); + + CurlHandler(WriteCallBack callback); + ~CurlHandler() = default; + + /// + /// Callback to add contents to a string + /// + /// + /// Bytes processed. If this does not match element_size * num_elements the libcurl function + /// used will return CURLE_WRITE_ERROR + static size_t WriteStringCallback(char* contents, size_t element_size, size_t num_elements, void* userdata); + + void AddHeader(const char* data) { + headers_.reset(curl_slist_append(headers_.release(), data)); + } + + template + void AddForm(Args... args) { + curl_formadd(&from_, &last_, args...); + } + + void AddFormString(const char* name, const char* value) { + AddForm(CURLFORM_COPYNAME, name, + CURLFORM_COPYCONTENTS, value, + CURLFORM_END); + } + + void AddFormBuffer(const char* name, const char* buffer_name, const void* buffer_ptr, size_t buffer_len) { + AddForm(CURLFORM_COPYNAME, name, + CURLFORM_BUFFER, buffer_name, + CURLFORM_BUFFERPTR, buffer_ptr, + CURLFORM_BUFFERLENGTH, buffer_len, + CURLFORM_END); + } + + template + void SetOption(CURLoption opt, T val) { + curl_easy_setopt(curl_.get(), opt, val); + } + + CURLcode Perform() { + SetOption(CURLOPT_HTTPHEADER, headers_.get()); + if (from_) { + SetOption(CURLOPT_HTTPPOST, from_); + } + + return curl_easy_perform(curl_.get()); + } + + private: + std::unique_ptr curl_; + std::unique_ptr headers_; + curl_httppost* from_{}; + curl_httppost* last_{}; + std::unique_ptr from_holder_; // TODO: Why no last_holder_? +}; + +/// +/// Base class for requests using Curl +/// +class CurlInvoker : public CloudBaseKernel { + protected: + CurlInvoker(const OrtApi& api, const OrtKernelInfo& info); + virtual ~CurlInvoker() = default; + + // Compute implementation that is used to co-ordinate all Curl based Azure requests. + // Derived classes need their own Compute to work with the CustomOpLite infrastructure + void ComputeImpl(const ortc::Variadic& inputs, ortc::Variadic& outputs) const; + + private: + void ExecuteRequest(CurlHandler& handler) const; + + // Derived classes can add any arg validation required. + // Prior to this being called, `inputs` are validated to match the number of input names, and + // the auth_token has been read from input[0] so validation can skip that. + // + // the ortc::Variadic outputs are empty until the Compute populates it, so only output names can be validated + // and those are available from the base class. + virtual void ValidateInputs(const ortc::Variadic& inputs) const {} + + // curl_handler has auth token set from input[0]. + virtual void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const = 0; + virtual void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const = 0; +}; +} // namespace ort_extensions diff --git a/operators/azure/openai_invokers.cc b/operators/azure/openai_invokers.cc new file mode 100644 index 00000000..e7b72e2b --- /dev/null +++ b/operators/azure/openai_invokers.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "openai_invokers.hpp" + +namespace ort_extensions { + +OpenAIAudioToTextInvoker::OpenAIAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info) + : CurlInvoker(api, info) { + audio_format_ = TryToGetAttributeWithDefault(kAudioFormat, ""); + + const auto& property_names = RequestPropertyNames(); + + const auto find_optional_input = [&property_names](const std::string& property_name) { + std::optional result; + auto optional_input = std::find_if(property_names.begin(), property_names.end(), + [&property_name](const auto& name) { return name == property_name; }); + + if (optional_input != property_names.end()) { + result = optional_input - property_names.begin(); + } + + return result; + }; + + filename_input_ = find_optional_input("filename"); + model_name_input_ = find_optional_input("model"); + + // OpenAI audio endpoints require 'file' and 'model'. + if (!std::any_of(property_names.begin(), property_names.end(), + [](const auto& name) { return name == "file"; })) { + ORTX_CXX_API_THROW("Required 'file' input was not found", ORT_INVALID_ARGUMENT); + } + + if (ModelName().empty() && !model_name_input_) { + ORTX_CXX_API_THROW("Required 'model' input was not found", ORT_INVALID_ARGUMENT); + } +} + +void OpenAIAudioToTextInvoker::ValidateInputs(const ortc::Variadic& inputs) const { + // We don't have a way to get the output type from the custom op API. + // If there's a mismatch it will fail in the Compute when it allocates the output tensor. + if (OutputNames().size() != 1) { + ORTX_CXX_API_THROW("Expected single output", ORT_INVALID_ARGUMENT); + } +} + +void OpenAIAudioToTextInvoker::SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const { + // theoretically the filename the content was buffered from. provides the extensions indicating the audio format + static const std::string fake_filename = "audio." + audio_format_; + + const auto& property_names = RequestPropertyNames(); + + const auto& get_optional_input = + [&](const std::optional& input_idx, const std::string& default_value, size_t min_size = 1) { + return (input_idx.has_value() && inputs[*input_idx]->SizeInBytes() > min_size) + ? static_cast(inputs[*input_idx]->DataRaw()) + : default_value.c_str(); + }; + + // filename_input_ is optional in a model. if it's not present, use a fake filename. + // if it's present make sure it's not a default empty value. as the filename needs to have an extension of + // mp3, mp4, mpeg, mpga, m4a, wav, or webm it must be at least 4 characters long. + const char* filename = get_optional_input(filename_input_, fake_filename, 4); + + curl_handler.AddHeader("Content-Type: multipart/form-data"); + // model name could be input or attribute + curl_handler.AddFormString("model", get_optional_input(model_name_input_, ModelName())); + + for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) { + switch (inputs[ith_input]->Type()) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + curl_handler.AddFormString(property_names[ith_input].c_str(), + // assumes null terminated. + // might be safer to pass pointer and length and add use CURLFORM_CONTENTSLENGTH + static_cast(inputs[ith_input]->DataRaw())); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + // only the 'file' input is uint8 + if (property_names[ith_input] != "file") { + ORTX_CXX_API_THROW("Only the 'file' input should be uint8 data. Invalid input:" + InputNames()[ith_input], + ORT_INVALID_ARGUMENT); + } + + curl_handler.AddFormBuffer(property_names[ith_input].c_str(), + filename, + inputs[ith_input]->DataRaw(), + inputs[ith_input]->SizeInBytes()); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + // TODO - required to support 'temperature' input. + default: + ORTX_CXX_API_THROW("input must be either text or binary", ORT_INVALID_ARGUMENT); + break; + } + } +} + +void OpenAIAudioToTextInvoker::ProcessResponse(const std::string& response, ortc::Variadic& outputs) const { + auto& string_tensor = outputs.AllocateStringTensor(0); + string_tensor.SetStringOutput(std::vector{response}, std::vector{1}); +} +} // namespace ort_extensions diff --git a/operators/azure/openai_invokers.hpp b/operators/azure/openai_invokers.hpp new file mode 100644 index 00000000..18323179 --- /dev/null +++ b/operators/azure/openai_invokers.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "ocos.h" + +#include + +#include "curl_invoker.hpp" + +namespace ort_extensions { + +////////////////////// OpenAIAudioToTextInvoker ////////////////////// + +/// +/// OpenAI Audio to Text +/// Input: auth_token {string}, Request body values {string|uint8} as per https://platform.openai.com/docs/api-reference/audio +/// Output: text {string} +/// +/// +/// The model URI is read from the node attributes. +/// The model name (e.g. 'whisper-1') can be provided as a node attribute or via an input. +/// +/// Example input would be: +/// - string tensor named `auth_token` (required, must be first input) +/// - a uint8 tensor named `file` with audio data in the format matching the 'audio_format' attribute (required) +/// - see OpenAI documentation for current supported audio formats +/// - a string tensor named `filename` (optional) with extension indicating the format of the audio data +/// - e.g. 'audio.mp3' +/// - a string tensor named `prompt` (optional) +/// +/// NOTE: 'temperature' is not currently supported. +/// +class OpenAIAudioToTextInvoker final : public CurlInvoker { + public: + OpenAIAudioToTextInvoker(const OrtApi& api, const OrtKernelInfo& info); + + void Compute(const ortc::Variadic& inputs, ortc::Variadic& outputs) const { + // use impl from CurlInvoker + ComputeImpl(inputs, outputs); + } + + private: + void ValidateInputs(const ortc::Variadic& inputs) const override; + void SetupRequest(CurlHandler& curl_handler, const ortc::Variadic& inputs) const override; + void ProcessResponse(const std::string& response, ortc::Variadic& outputs) const override; + + // audio format to use if the optional 'filename' input is not provided + static constexpr const char* kAudioFormat = "audio_format"; + std::string audio_format_; + std::optional filename_input_; // optional override for generated filename using audio_format + std::optional model_name_input_; // optional override for model_name attribute +}; + +} // namespace ort_extensions diff --git a/prebuild/build_curl_for_android.sh b/prebuild/build_curl_for_android.sh new file mode 100755 index 00000000..e1e46b5e --- /dev/null +++ b/prebuild/build_curl_for_android.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set -e +set -u +set -x + +# export skip_checkout if you want to repeat a build +if [ -z ${skip_checkout+x} ]; then + git clone https://github.com/leenjewel/openssl_for_ios_and_android.git + cd openssl_for_ios_and_android + git checkout ci-release-663da9e2 + # patch with fixes to build on linux with NDK 25 or later + git apply ../build_curl_for_android_on_linux.patch +else + echo "Skipping checkout and patch" + cd openssl_for_ios_and_android +fi + +cd tools + +# we target Android API level 24 +export api=24 + +# provide a specific architecture as an argument to the script to limit the build to that +# default is to build all +# valid architecture values: "arm" "arm64" "x86" "x86_64" +if [ $# -eq 1 ]; then + arch=$1 + ./build-android-openssl.sh $arch + ./build-android-nghttp2.sh $arch + ./build-android-curl.sh $arch +else + ./build-android-openssl.sh + ./build-android-nghttp2.sh + ./build-android-curl.sh +fi diff --git a/prebuild/build_curl_for_android_on_linux.patch b/prebuild/build_curl_for_android_on_linux.patch new file mode 100644 index 00000000..69030fc0 --- /dev/null +++ b/prebuild/build_curl_for_android_on_linux.patch @@ -0,0 +1,67 @@ +diff --git a/tools/build-android-common.sh b/tools/build-android-common.sh +index 87df207..797d58a 100755 +--- a/tools/build-android-common.sh ++++ b/tools/build-android-common.sh +@@ -148,13 +148,20 @@ function set_android_toolchain() { + local build_host=$(get_build_host_internal "$arch") + local clang_target_host=$(get_clang_target_host "$arch" "$api") + +- export AR=${build_host}-ar ++ # NDK r23 removed a bunch of GNU things and replaced with llvm ++ # https://stackoverflow.com/questions/73105626/arm-linux-androideabi-ar-command-not-found-in-ndk ++ # export AR=${build_host}-ar ++ export AR=llvm-ar + export CC=${clang_target_host}-clang + export CXX=${clang_target_host}-clang++ +- export AS=${build_host}-as +- export LD=${build_host}-ld +- export RANLIB=${build_host}-ranlib ++ #export AS=${build_host}-as ++ export AS=llvm-as ++ #export LD=${build_host}-ld ++ export LD=ld ++ # export RANLIB=${build_host}-ranlib ++ export RANLIB=llvm-ranlib + export STRIP=${build_host}-strip ++ export STRIP=llvm-strip + } + + function get_common_includes() { +@@ -187,13 +194,13 @@ function set_android_cpu_feature() { + export CPPFLAGS=${CFLAGS} + ;; + x86) +- export CFLAGS="-march=i686 -mtune=intel -mssse3 -mfpmath=sse -m32 -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)" ++ export CFLAGS="-march=i686 -mtune=native -mssse3 -mfpmath=sse -m32 -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)" + export CXXFLAGS="-std=c++14 -Os -ffunction-sections -fdata-sections" + export LDFLAGS="-march=i686 -Wl,--gc-sections -Os -ffunction-sections -fdata-sections $(get_common_linked_libraries ${api} ${arch})" + export CPPFLAGS=${CFLAGS} + ;; + x86-64) +- export CFLAGS="-march=x86-64 -msse4.2 -mpopcnt -m64 -mtune=intel -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)" ++ export CFLAGS="-march=x86-64 -msse4.2 -mpopcnt -m64 -mtune=native -Wno-unused-function -fno-integrated-as -fstrict-aliasing -fPIC -DANDROID -D__ANDROID_API__=${api} -Os -ffunction-sections -fdata-sections $(get_common_includes)" + export CXXFLAGS="-std=c++14 -Os -ffunction-sections -fdata-sections" + export LDFLAGS="-march=x86-64 -Wl,--gc-sections -Os -ffunction-sections -fdata-sections $(get_common_linked_libraries ${api} ${arch})" + export CPPFLAGS=${CFLAGS} +diff --git a/tools/build-android-openssl.sh b/tools/build-android-openssl.sh +index e13c314..5660cec 100755 +--- a/tools/build-android-openssl.sh ++++ b/tools/build-android-openssl.sh +@@ -17,7 +17,7 @@ + # # read -n1 -p "Press any key to continue..." + + set -u +- ++set -x + source ./build-android-common.sh + + if [ -z ${version+x} ]; then +@@ -115,6 +115,8 @@ function configure_make() { + if [ $the_rc -eq 0 ] ; then + make SHLIB_EXT='.so' install_sw >>"${OUTPUT_ROOT}/log/${ABI}.log" 2>&1 + make install_ssldirs >>"${OUTPUT_ROOT}/log/${ABI}.log" 2>&1 ++ else ++ log_error "make returned $the_rc" + fi + + popd diff --git a/prebuild/readme.md b/prebuild/readme.md new file mode 100644 index 00000000..5841230b --- /dev/null +++ b/prebuild/readme.md @@ -0,0 +1,66 @@ +# Mobile Azure EP pre-build + +Manual libraries that need to be prebuilt for the Azure operators on Android and iOS. +There is no simple cmake setup that works, so we prebuild as a one-off. + +## Requirements: +- pkg-config +- Android + - Android SDK installed with NDK 25 or later + - You can install a package but that means you have to use `sudo` for all updates like installing an NDK + - https://stackoverflow.com/questions/34556884/how-to-install-android-sdk-on-ubuntu + - you still need to manually add the cmdline-tools to that package as well + - probably easier to create a per-user install using command line tools + - Using command line tools + - Download the command line tools from https://developer.android.com/studio + - Download the 'Command line tools only' and unzip + - `mkdir ~/Android` + - `unzip commandlinetools-linux-9477386_latest.zip` + - `mkdir -p ~/Android/cmdline-tools/latest` + - `mv cmdline-tools/* ~/Android/cmdline-tools/latest` + - `export ANDROID_HOME=~/Android` + - Add these to PATH + - ~/Android/cmdline-tools/latest/bin + - ~/Android/platform-tools/bin + - `sdkmanager --list` to make sure the setup works + - Install platform-tools and latest NDK + - `sdkmanager --install platform-tools` + - e.g. `sdkmanager --install ndk;25.2.9519653` + + That should be enough to build. + e.g. `./build_lib.sh --android --android_api=24 --android_home=/home/me/Android --android_abi=x86_64 --android_ndk_path=/home/me/Android/ndk/25.2.9519653 --enable_cxx_tests` + + See Android documentation for installing a system image with `sdkmanager` and + creating an emulator with `avdmanager`. +- iOS + - TBD + +## Android build + Export ANDROID_NDK_ROOT with the value set to the NDK path as this is used by the build script + - e.g. export ANDROID_NDK_ROOT=~/Android/ndk/25.2.9519653 + From this directory run `./build_curl_for_android.sh` + An architecture can optionally be specified as the first argument to limit the build to that architecture. + Otherwise all 4 architectures (arm, arm64, x86, x86_64) will be built. + e.g. if you just want to build locally for the emulator you can do `./build_curl_for_android.sh x86_64` + +## Android testing + Build with `--enable_cxx_tests`. + This should result in the 'bin' directory of the build output having the two test executables. + Create/start Android emulator + Use `adb push` to copy bin, lib and data directories from the build output to the /data/local/tmp directory + - `adb push build/Android/bin /data/local/tmp` + - repeat for 'lib' and 'data' + - copy the onnxruntime shared library to the lib dir (adjust version number as needed) + - adjust architecture as needed (most likely x86_64 for emulator and arm) + - `adb push build/Android/Debug/_deps/onnxruntime-src/jni/x86_64/libonnxruntime.so /data/local/tmp/lib` + - Connect to emulator + - `adb shell` + - `cd /data/local/tmp` + - Add path to .so + - export LD_LIBRARY_PATH=/data/local/tmp/lib:$LD_LIBRARY_PATH + - Make tests executable + - `chmod +x bin/ocos_test` + - `chmod +x bin/extensions_test` + - Run tests from `tmp` dir so paths to `data` are as expected + - ./bin/ocos_test + - ./bin/extensions_test diff --git a/shared/lib/ortcustomops.cc b/shared/lib/ortcustomops.cc index e7f4d159..e18750cb 100644 --- a/shared/lib/ortcustomops.cc +++ b/shared/lib/ortcustomops.cc @@ -154,9 +154,6 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio #endif #if defined(ENABLE_DR_LIBS) LoadCustomOpClasses_Audio, -#endif -#if defined(ENABLE_AZURE) - LoadCustomOpClasses_Azure, #endif LoadCustomOpClasses<> }; @@ -187,6 +184,10 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio return status; } + // + // New custom ops should use the com.microsoft.extensions domain. + // + // Create domain for ops using the new domain name. if (status = ortApi->CreateCustomOpDomain(c_ComMsExtOpDomain, &domain); status) { return status; @@ -200,6 +201,9 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio #endif #if defined(ENABLE_TOKENIZER) LoadCustomOpClasses_Tokenizer, +#endif +#if defined(ENABLE_AZURE) + LoadCustomOpClasses_Azure, #endif LoadCustomOpClasses<> }; diff --git a/test/data/azure/be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3 b/test/data/azure/be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3 new file mode 100644 index 00000000..d0c8a33a Binary files /dev/null and b/test/data/azure/be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3 differ diff --git a/test/data/azure/create_basic_models.py b/test/data/azure/create_basic_models.py new file mode 100644 index 00000000..820be84e --- /dev/null +++ b/test/data/azure/create_basic_models.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import helper, TensorProto + + +def create_audio_model(): + auth_token = helper.make_tensor_value_info('auth_token', TensorProto.STRING, [1]) + model = helper.make_tensor_value_info('model', TensorProto.STRING, [1]) + response_format = helper.make_tensor_value_info('response_format', TensorProto.STRING, [-1]) + file = helper.make_tensor_value_info('file', TensorProto.UINT8, [-1]) + + transcriptions = helper.make_tensor_value_info('transcriptions', TensorProto.STRING, [-1]) + + invoker = helper.make_node('OpenAIAudioToText', + ['auth_token', 'model', 'response_format', 'file'], + ['transcriptions'], + domain='com.microsoft.extensions', + name='audio_invoker', + model_uri='https://api.openai.com/v1/audio/transcriptions', + audio_format='wav', + verbose=False) + + graph = helper.make_graph([invoker], 'graph', [auth_token, model, response_format, file], [transcriptions]) + model = helper.make_model(graph, + opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)]) + + onnx.save(model, 'openai_audio.onnx') + + +def create_chat_model(): + auth_token = helper.make_tensor_value_info('auth_token', TensorProto.STRING, [-1]) + chat = helper.make_tensor_value_info('chat', TensorProto.STRING, [-1]) + response = helper.make_tensor_value_info('response', TensorProto.STRING, [-1]) + + invoker = helper.make_node('AzureTextToText', ['auth_token', 'chat'], ['response'], + domain='com.microsoft.extensions', + name='chat_invoker', + model_uri='https://api.openai.com/v1/chat/completions', + verbose=False) + + graph = helper.make_graph([invoker], 'graph', [auth_token, chat], [response]) + model = helper.make_model(graph, + opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)]) + + onnx.save(model, 'openai_chat.onnx') + + +create_audio_model() +create_chat_model() diff --git a/test/data/azure/create_openai_whisper_transcriptions.py b/test/data/azure/create_openai_whisper_transcriptions.py new file mode 100644 index 00000000..7dae8fd3 --- /dev/null +++ b/test/data/azure/create_openai_whisper_transcriptions.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from onnx import helper, numpy_helper, TensorProto + +import onnx +import numpy as np +import sys + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == '': + node.doc_string = '' + order_repeated_field(node.attribute, 'name', kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == '': + graph.doc_string = '' + return graph + + +# This creates a model that allows the prompt and filename to be optionally provided as inputs. +# The filename can be specified to indicate a different audio type to the default value in the audio_format attribute. +model = helper.make_model( + opset_imports=[helper.make_operatorsetid('com.microsoft.extensions', 1)], + graph=make_graph( + name='OpenAIWhisperTranscribe', + initializer=[ + # add default values in the initializers to make the model inputs optional + helper.make_tensor('transcribe0/filename', TensorProto.STRING, [1], [b""]), + helper.make_tensor('transcribe0/prompt', TensorProto.STRING, [1], [b""]) + ], + inputs=[ + helper.make_tensor_value_info('auth_token', TensorProto.STRING, shape=[1]), + helper.make_tensor_value_info('transcribe0/file', TensorProto.UINT8, shape=["bytes"]), + helper.make_tensor_value_info('transcribe0/filename', TensorProto.STRING, shape=["bytes"]), # optional + helper.make_tensor_value_info('transcribe0/prompt', TensorProto.STRING, shape=["bytes"]), # optional + ], + outputs=[helper.make_tensor_value_info('transcription', TensorProto.STRING, shape=[1])], + nodes=[ + make_node( + 'OpenAIAudioToText', + # additional optional request inputs that could be added: + # response_format, temperature, language + # Using a prefix for input names allows the model to have multiple nodes calling cloud endpoints. + # auth_token does not need a prefix unless different auth tokens are used for different nodes. + inputs=['auth_token', 'transcribe0/file', 'transcribe0/filename', 'transcribe0/prompt'], + outputs=['transcription'], + name='OpenAIAudioToText0', + domain='com.microsoft.extensions', + audio_format='wav', # default audio type if filename is not specified. + model_uri='https://api.openai.com/v1/audio/transcriptions', + model_name='whisper-1', + verbose=0, + ), + ], + ), +) + +if __name__ == '__main__': + out_path = "openai_whisper_transcriptions.onnx" + if len(sys.argv) == 2: + out_path = sys.argv[1] + + onnx.save(model, out_path) diff --git a/test/data/azure/openai_audio.onnx b/test/data/azure/openai_audio.onnx index fa5bf70d..863a6ca4 100644 Binary files a/test/data/azure/openai_audio.onnx and b/test/data/azure/openai_audio.onnx differ diff --git a/test/data/azure/openai_chat.onnx b/test/data/azure/openai_chat.onnx index f5e1ef97..d0a9954d 100644 Binary files a/test/data/azure/openai_chat.onnx and b/test/data/azure/openai_chat.onnx differ diff --git a/test/data/azure/openai_embedding.onnx b/test/data/azure/openai_embedding.onnx deleted file mode 100644 index ff0ef01d..00000000 --- a/test/data/azure/openai_embedding.onnx +++ /dev/null @@ -1,17 +0,0 @@ -:… -™ - -auth_token -text embeddingembedding_invoker"AzureTextInvoker*4 - model_uri"$https://api.openai.com/v1/embeddings * -verbose :ai.onnx.contribgraphZ! - -auth_token - - ˙˙˙˙˙˙˙˙˙Z -text - - ˙˙˙˙˙˙˙˙˙b - embedding - - ˙˙˙˙˙˙˙˙˙B \ No newline at end of file diff --git a/test/data/azure/openai_whisper_transcriptions.onnx b/test/data/azure/openai_whisper_transcriptions.onnx new file mode 100644 index 00000000..bf276530 Binary files /dev/null and b/test/data/azure/openai_whisper_transcriptions.onnx differ diff --git a/test/data/azure/self-destruct-button.wav b/test/data/azure/self-destruct-button.wav new file mode 100644 index 00000000..f85a654f Binary files /dev/null and b/test/data/azure/self-destruct-button.wav differ diff --git a/test/shared_test/test_ortops_azure.cc b/test/shared_test/test_ortops_azure.cc new file mode 100644 index 00000000..8fa5eecf --- /dev/null +++ b/test/shared_test/test_ortops_azure.cc @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_AZURE +#include + +#include "gtest/gtest.h" + +#include "ocos.h" +#include "narrow.h" +#include "test_kernel.hpp" +#include "utils.hpp" + +using namespace ort_extensions; +using namespace ort_extensions::test; + +// Test custom op with OpenAIAudioInvoker calling Whisper +// Default input format. No prompt. +TEST(AzureOps, OpenAIWhisper_basic) { + const char* auth_token = std::getenv("OPENAI_AUTH_TOKEN"); + if (auth_token == nullptr) { + GTEST_SKIP() << "OPENAI_AUTH_TOKEN environment variable was not set."; + } + + auto data_dir = std::filesystem::current_path() / "data" / "azure"; + auto model_path = data_dir / "openai_whisper_transcriptions.onnx"; + auto audio_path = data_dir / "self-destruct-button.wav"; + std::vector audio_data = LoadBytesFromFile(audio_path); + + auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); + + std::vector inputs{TestValue("auth_token", {std::string(auth_token)}, {1}), + TestValue("transcribe0/file", audio_data, {narrow(audio_data.size())})}; + + // punctuation can differ between calls to OpenAI Whisper. sometimes there's a comma after 'button' and sometimes + // a full stop. use a custom output validator that looks for substrings in the output that aren't affected by this. + std::vector expected_output{"Thank you for pressing the self-destruct button", + "ship will self-destruct in three minutes"}; + + // dims are set to '{1}' as we expect one string output. the expected_output is the collection of substrings to look + // for in the single output + std::vector outputs{TestValue("transcription", expected_output, {1})}; + + OutputValidator find_strings_in_output = + [](size_t output_idx, Ort::Value& actual, TestValue expected) { + std::vector output_string; + GetTensorMutableDataString(Ort::GetApi(), actual, output_string); + + ASSERT_EQ(output_string.size(), 1) << "Expected the Whisper response to be a single string with json"; + + for (auto& expected_substring : expected.values_string) { + if (output_string[0].find(expected_substring) == std::string::npos) { + FAIL() << "'" << expected_substring << "' was not found in output " << output_string[0]; + } + } + }; + + TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output); +} + +// test calling Whisper with a filename to provide mp3 instead of the default wav, and the optional prompt +TEST(AzureOps, OpenAIWhisper_Prompt_CustomFormat) { + const char* auth_token = std::getenv("OPENAI_AUTH_TOKEN"); + if (auth_token == nullptr) { + GTEST_SKIP() << "OPENAI_AUTH_TOKEN environment variable was not set."; + } + + std::string ort_version{OrtGetApiBase()->GetVersionString()}; + + auto data_dir = std::filesystem::current_path() / "data" / "azure"; + auto model_path = data_dir / "openai_whisper_transcriptions.onnx"; + auto audio_path = data_dir / "be-a-man-take-some-pepto-bismol-get-dressed-and-come-on-over-here.mp3"; + std::vector audio_data = LoadBytesFromFile(audio_path); + + auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); + + // provide filename with 'mp3' extension to indicate audio format. doesn't need to be the 'real' filename + std::vector inputs{TestValue("auth_token", {std::string(auth_token)}, {1}), + TestValue("transcribe0/file", audio_data, {narrow(audio_data.size())}), + TestValue("transcribe0/filename", {std::string("audio.mp3")}, {1})}; + + std::vector expected_output = {"Take some Pepto-Bismol, get dressed, and come on over here."}; + std::vector outputs{TestValue("transcription", expected_output, {1})}; + + OutputValidator find_strings_in_output = + [](size_t output_idx, Ort::Value& actual, TestValue expected) { + std::vector output_string; + GetTensorMutableDataString(Ort::GetApi(), actual, output_string); + + ASSERT_EQ(output_string.size(), 1) << "Expected the Whisper response to be a single string with json"; + const auto& expected_substring = expected.values_string[0]; + if (output_string[0].find(expected_substring) == std::string::npos) { + FAIL() << "'" << expected_substring << "' was not found in output " << output_string[0]; + } + }; + + TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output); + + // use optional 'prompt' input to mis-spell Pepto-Bismol in response + std::string prompt = "Peptoe-Bismole"; + inputs.push_back(TestValue("transcribe0/prompt", {prompt}, {1})); + outputs[0].values_string[0] = "Take some Peptoe-Bismole, get dressed, and come on over here."; + TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath(), find_strings_in_output); +} + +#endif diff --git a/test/shared_test/test_ortops_vision.cc b/test/shared_test/test_ortops_vision.cc index aa353c38..72339fef 100644 --- a/test/shared_test/test_ortops_vision.cc +++ b/test/shared_test/test_ortops_vision.cc @@ -12,7 +12,7 @@ #include "ocos.h" #include "test_kernel.hpp" -#include "test_utils.hpp" +#include "utils.hpp" using namespace ort_extensions::test; diff --git a/test/shared_test/utils.cc b/test/shared_test/utils.cc new file mode 100644 index 00000000..a1147ba8 --- /dev/null +++ b/test/shared_test/utils.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "utils.hpp" +#include + +namespace ort_extensions { +namespace test { +std::vector LoadBytesFromFile(const std::filesystem::path& filename) { + using namespace std; + ifstream ifs(filename, ios::binary | ios::ate); + ifstream::pos_type pos = ifs.tellg(); + + std::vector input_bytes(pos); + ifs.seekg(0, ios::beg); + // we want uint8_t values so reinterpret_cast so we don't have to read chars and copy to uint8_t after. + ifs.read(reinterpret_cast(input_bytes.data()), pos); + + return input_bytes; +} +} // namespace test +} // namespace ort_extensions diff --git a/test/shared_test/utils.hpp b/test/shared_test/utils.hpp new file mode 100644 index 00000000..74df0462 --- /dev/null +++ b/test/shared_test/utils.hpp @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +namespace ort_extensions { +namespace test { +std::vector LoadBytesFromFile(const std::filesystem::path& filename); + +} // namespace test +} // namespace ort_extensions diff --git a/test/test_azure_ops.py b/test/test_azure_ops.py index 992b0ad4..05295201 100644 --- a/test/test_azure_ops.py +++ b/test/test_azure_ops.py @@ -12,6 +12,7 @@ script_dir = os.path.dirname(os.path.realpath(__file__)) ort_ext_root = os.path.abspath(os.path.join(script_dir, "..")) test_data_dir = os.path.join(ort_ext_root, "test", "data", "azure") + class TestAzureOps(unittest.TestCase): def __init__(self, config): @@ -21,7 +22,7 @@ class TestAzureOps(unittest.TestCase): self.__opt = SessionOptions() self.__opt.register_custom_ops_library(get_library_path()) - def test_addf(self): + def test_add_f(self): if self.__enabled: sess = InferenceSession(os.path.join(test_data_dir, "triton_addf.onnx"), self.__opt, providers=["CPUExecutionProvider"]) @@ -36,7 +37,7 @@ class TestAzureOps(unittest.TestCase): out = sess.run(None, ort_inputs)[0] self.assertTrue(np.allclose(out, [5,5,5,5])) - def testAddf8(self): + def test_add_f8(self): if self.__enabled: opt = SessionOptions() opt.register_custom_ops_library(get_library_path()) @@ -53,7 +54,7 @@ class TestAzureOps(unittest.TestCase): out = sess.run(None, ort_inputs)[0] self.assertTrue(np.allclose(out, [5,5,5,5])) - def testAddi4(self): + def test_add_i4(self): if self.__enabled: sess = InferenceSession(os.path.join(test_data_dir, "triton_addi4.onnx"), self.__opt, providers=["CPUExecutionProvider"]) @@ -68,7 +69,7 @@ class TestAzureOps(unittest.TestCase): out = sess.run(None, ort_inputs)[0] self.assertTrue(np.allclose(out, [5,5,5,5])) - def testAnd(self): + def test_and(self): if self.__enabled: sess = InferenceSession(os.path.join(test_data_dir, "triton_and.onnx"), self.__opt, providers=["CPUExecutionProvider"]) @@ -83,7 +84,7 @@ class TestAzureOps(unittest.TestCase): out = sess.run(None, ort_inputs)[0] self.assertTrue(np.allclose(out, [True, False])) - def testStr(self): + def test_str(self): if self.__enabled: sess = InferenceSession(os.path.join(test_data_dir, "triton_str.onnx"), self.__opt, providers=["CPUExecutionProvider"]) @@ -98,7 +99,7 @@ class TestAzureOps(unittest.TestCase): self.assertEqual(outs[0], ['this is the input']) self.assertEqual(outs[1], ['this is the input']) - def testOpenAiAudio(self): + def test_open_ai_audio(self): if self.__enabled: sess = InferenceSession(os.path.join(test_data_dir, "openai_audio.onnx"), self.__opt, providers=["CPUExecutionProvider"]) @@ -110,14 +111,14 @@ class TestAzureOps(unittest.TestCase): audio_blob = np.asarray(list(_f.read()), dtype=np.uint8) ort_inputs = { "auth_token": auth_token, - "model": model, + "model_name": model, "response_format": response_format, - "file": audio_blob + "file": audio_blob, } out = sess.run(None, ort_inputs)[0] self.assertEqual(out, ['This is a test recording to test the Whisper model.\n']) - def testOpenAiChat(self): + def test_open_ai_chat(self): if self.__enabled: sess = InferenceSession(os.path.join(test_data_dir, "openai_chat.onnx"), self.__opt, providers=["CPUExecutionProvider"]) @@ -130,23 +131,6 @@ class TestAzureOps(unittest.TestCase): out = sess.run(None, ort_inputs)[0] self.assertTrue('assist' in out[0]) - def testOpenAiEmb(self): - if self.__enabled: - opt = SessionOptions() - opt.register_custom_ops_library(get_library_path()) - sess = InferenceSession(os.path.join(test_data_dir, "openai_embedding.onnx"), - opt, providers=["CPUExecutionProvider"]) - auth_token = np.array([os.getenv('EMB', '')]) - text = np.array(['{\"input\": \"The food was delicious and the waiter...\", \"model\": \"text-embedding-ada-002\"}']) - - ort_inputs = { - "auth_token": auth_token, - "text": text, - } - - out = sess.run(None, ort_inputs)[0] - self.assertTrue('text-embedding-ada' in out[0]) - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tools/build.py b/tools/build.py index c6ab0f44..8b955b81 100755 --- a/tools/build.py +++ b/tools/build.py @@ -154,7 +154,7 @@ def _parse_arguments(): # WebAssembly options parser.add_argument("--wasm", action="store_true", help="Build for WebAssembly") parser.add_argument("--emsdk_path", type=Path, - help="Specify path to emscripten SDK. Setup manually with: " + help="Specify path to emscripten SDK. Setup manually with: " " git clone https://github.com/emscripten-core/emsdk") parser.add_argument("--emsdk_version", default="3.1.26", help="Specify version of emsdk") @@ -404,11 +404,11 @@ def _generate_build_tree(cmake_path: Path, _run_subprocess(cmake_args + [f"-DCMAKE_BUILD_TYPE={config}"], cwd=config_build_dir) -def clean_targets(cmake_path, build_dir: Path, configs: Set[str]): +def clean_targets(cmake_path: Path, build_dir: Path, configs: Set[str]): for config in configs: log.info("Cleaning targets for %s configuration", config) build_dir2 = _get_build_config_dir(build_dir, config) - cmd_args = [cmake_path, "--build", build_dir2, "--config", config, "--target", "clean"] + cmd_args = [str(cmake_path), "--build", str(build_dir2), "--config", config, "--target", "clean"] _run_subprocess(cmd_args) @@ -564,6 +564,10 @@ def main(): cmake_path = _resolve_executable_path( args.cmake_path, resolution_failure_allowed=(not (args.update or args.clean or args.build))) + + if not cmake_path: + raise UsageError("Unable to find CMake executable. Please specify --cmake-path.") + build_dir = args.build_dir if args.update or args.build: diff --git a/tools/install_deps.bat b/tools/install_deps.bat index 82730e3a..0cf06401 100644 --- a/tools/install_deps.bat +++ b/tools/install_deps.bat @@ -13,4 +13,4 @@ if "%1" == "install" ( del "%ProgramFiles%\Miniconda3\python3.exe" ) ) -) \ No newline at end of file +) diff --git a/tools/install_deps.sh b/tools/install_deps.sh index fa9d65fa..57ea1f1d 100755 --- a/tools/install_deps.sh +++ b/tools/install_deps.sh @@ -3,10 +3,10 @@ if [[ "$OCOS_ENABLE_AZURE" == "1" ]] then if [[ "$1" == "many64" ]]; then - yum -y install openssl openssl-devel wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /usr/lib64/libhttpclient.ldscript + yum -y install openssl-devel elif [[ "$1" == "many86" ]]; then - yum -y install openssl openssl-devel wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /opt/rh/devtoolset-10/root/usr/lib/libhttpclient.ldscript + yum -y install openssl-devel else # for musllinux - apk add openssl-dev wget && wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz && tar zxvf v1.1.0.tar.gz && cd rapidjson-1.1.0 && mkdir build && cd build && cmake .. && cmake --install . && cd ../.. && git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client && ln -s ~/client/src/c++/library/libhttpclient.ldscript /usr/lib/libhttpclient.ldscript + apk add openssl-dev fi fi