Implement azure invokers (#487)
* Implement azure invokers (#486) * draft azure ops * migrate triton client * AzureAudioInvoker works * triton client builds * triton invoker works * limit version * restore setup.py * limit ort version * upgrade version * pip install cmake * add ut * promote ort header version to 1.15.1 * register as cpu op * limit triton invoker to 1.14 and newer * remove test * install rapidjson * install dep * sudo install * install version script * print err msg * fix pipeline * disable from web assembly * install cmake * Fix pipelines (#479) * 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20 * 21 * 22 * 23 * 24 * 25 * 26 * 27 * 28 * 29 * 30 * 31 * 32 * 33 * 34 * 35 * 36 * 37 * 38 * 39 * 40 * 41 * 42 * 43 * 44 * 45 * 46 * 47 * 47 * 48 * 49 * 50 * 51 * 52 * 53 * 54 * 55 * 56 * 57 * 58 * 59 * 60 * 61 * 62 * 62 * 63 * 64 * 65 * 66 * 67 * 68 * 69 * 70 * 71 * 72 * 73 * 74 * 75 * 76 * 77 * 78 * 79: * 80: --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com> * fix pipelines (#481) * 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20 * 21 * 22 * 23 * 24 * 25 * 26 * 27 * 28 * 29 * 30 * 31 * 32 * 33 * 34 * 35 * 36 * 37 * 38 * 39 * 40 * 41 * 42 * 43 * 44 * 45 * 46 * 47 * 47 * 48 * 49 * 50 * 51 * 52 * 53 * 54 * 55 * 56 * 57 * 58 * 59 * 60 * 61 * 62 * 62 * 63 * 64 * 65 * 66 * 67 * 68 * 69 * 70 * 71 * 72 * 73 * 74 * 75 * 76 * 77 * 78 * 79: * 80: * 81 * 82 * 83 * 84 * 85 * 86 --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com> * test as cpu op * add ut * add ut * move cond * tune ut * tune pipeline * promote to ort 141 * reset header version * restore cmake --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com> * trim changes * revert req txt --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
Родитель
b49c0231ab
Коммит
27132ced71
|
@ -58,7 +58,20 @@ jobs:
|
|||
displayName: Unpack ONNXRuntime package.
|
||||
|
||||
- script: |
|
||||
CPU_NUMBER=2 sh ./build.sh -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=$(Build.SourcesDirectory)/onnxruntime-linux-x64-$(ort.version)
|
||||
wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz
|
||||
tar zxvf v1.1.0.tar.gz
|
||||
cd rapidjson-1.1.0
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
sudo cmake --install .
|
||||
cd ../..
|
||||
git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client
|
||||
sudo ln -s ~/client/src/c++/library/libhttpclient.ldscript /lib/libhttpclient.ldscript
|
||||
displayName: install deps for azure invokers
|
||||
|
||||
- script: |
|
||||
CPU_NUMBER=2 sh ./build.sh -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=$(Build.SourcesDirectory)/onnxruntime-linux-x64-$(ort.version) -DOCOS_ENABLE_AZURE=ON
|
||||
displayName: build the customop library with onnxruntime
|
||||
|
||||
- script: |
|
||||
|
@ -263,7 +276,7 @@ jobs:
|
|||
|
||||
- script: |
|
||||
call $(vsdevcmd)
|
||||
call .\build.bat -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=.\onnxruntime-win-x64-$(ort.version)
|
||||
call .\build.bat -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=.\onnxruntime-win-x64-$(ort.version) -DOCOS_ENABLE_AZURE=ON
|
||||
displayName: build the customop library with onnxruntime
|
||||
|
||||
- script: |
|
||||
|
|
|
@ -58,6 +58,7 @@ option(OCOS_ENABLE_OPENCV_CODECS "Enable cv2 and vision operators that require o
|
|||
option(OCOS_ENABLE_CV2 "Enable the operators in `operators/cv2`" ON)
|
||||
option(OCOS_ENABLE_VISION "Enable the operators in `operators/vision`" ON)
|
||||
option(OCOS_ENABLE_AUDIO "Enable the operators for audio processing" ON)
|
||||
option(OCOS_ENABLE_AZURE "Enable the operators for azure execution provider" OFF)
|
||||
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
|
||||
|
@ -79,6 +80,7 @@ function(disable_all_operators)
|
|||
set(OCOS_ENABLE_OPENCV_CODECS OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_CV2 OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_VISION OFF CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_AZURE OFF CACHE INTERNAL "" FORCE)
|
||||
endfunction()
|
||||
|
||||
if (CMAKE_GENERATOR_PLATFORM)
|
||||
|
@ -376,6 +378,19 @@ if(OCOS_ENABLE_BLINGFIRE)
|
|||
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
|
||||
if ($ENV{TEST_AZURE_INVOKERS_AS_CPU_OP} MATCHES "ON")
|
||||
message(STATUS "Azure inovkers will be testable as cpu ops")
|
||||
add_compile_definitions(TEST_AZURE_INVOKERS_AS_CPU_OP)
|
||||
endif()
|
||||
|
||||
# Azure endpoint invokers
|
||||
include(triton)
|
||||
file(GLOB TARGET_SRC_AZURE "operators/azure/*.cc" "operators/azure/*.h*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_AZURE})
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||
message(STATUS "Fetch json")
|
||||
include(json)
|
||||
|
@ -478,6 +493,10 @@ if(OCOS_ENABLE_VISION)
|
|||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_AZURE)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
# GPT2
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
|
||||
|
@ -693,3 +712,30 @@ if(OCOS_ENABLE_CTEST)
|
|||
add_test(NAME extensions_test COMMAND $<TARGET_FILE:extensions_test>)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_AZURE)
|
||||
|
||||
add_dependencies(ocos_operators triton)
|
||||
target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
|
||||
target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
|
||||
|
||||
if (ocos_target_platform STREQUAL "AMD64")
|
||||
set(vcpkg_target_platform "x86")
|
||||
else()
|
||||
set(vcpkg_target_platform ${ocos_target_platform})
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
|
||||
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib)
|
||||
target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32)
|
||||
|
||||
else()
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
|
||||
|
||||
endif() #if (WIN32)
|
||||
|
||||
endif()
|
|
@ -0,0 +1,116 @@
|
|||
include(ExternalProject)
|
||||
|
||||
if (WIN32)
|
||||
|
||||
if (ocos_target_platform STREQUAL "AMD64")
|
||||
set(vcpkg_target_platform "x86")
|
||||
else()
|
||||
set(vcpkg_target_platform ${ocos_target_platform})
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(vcpkg
|
||||
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
|
||||
GIT_TAG 2023.06.20
|
||||
PREFIX vcpkg
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-build
|
||||
CONFIGURE_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND ""
|
||||
BUILD_COMMAND "<SOURCE_DIR>/bootstrap-vcpkg.bat")
|
||||
|
||||
set(VCPKG_SRC ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
|
||||
set(ENV{VCPKG_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
|
||||
|
||||
message(STATUS "VCPKG_SRC: " ${VCPKG_SRC})
|
||||
message(STATUS "VCPKG_ROOT: " $ENV{VCPKG_ROOT})
|
||||
|
||||
add_custom_command(
|
||||
COMMAND ${VCPKG_SRC}/vcpkg integrate install
|
||||
COMMAND ${CMAKE_COMMAND} -E touch vcpkg_integrate.stamp
|
||||
OUTPUT vcpkg_integrate.stamp
|
||||
DEPENDS vcpkg
|
||||
)
|
||||
|
||||
add_custom_target(vcpkg_integrate ALL DEPENDS vcpkg_integrate.stamp)
|
||||
set(VCPKG_DEPENDENCIES "vcpkg_integrate")
|
||||
|
||||
function(vcpkg_install PACKAGE_NAME)
|
||||
add_custom_command(
|
||||
OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO
|
||||
COMMAND ${VCPKG_SRC}/vcpkg install ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static --vcpkg-root=${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
|
||||
WORKING_DIRECTORY ${VCPKG_SRC}
|
||||
DEPENDS vcpkg_integrate)
|
||||
|
||||
add_custom_target(get${PACKAGE_NAME}
|
||||
ALL
|
||||
DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO)
|
||||
|
||||
list(APPEND VCPKG_DEPENDENCIES "get${PACKAGE_NAME}")
|
||||
set(VCPKG_DEPENDENCIES ${VCPKG_DEPENDENCIES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
vcpkg_install(openssl)
|
||||
vcpkg_install(openssl-windows)
|
||||
vcpkg_install(rapidjson)
|
||||
vcpkg_install(re2)
|
||||
vcpkg_install(boost-interprocess)
|
||||
vcpkg_install(boost-stacktrace)
|
||||
vcpkg_install(pthread)
|
||||
vcpkg_install(b64)
|
||||
|
||||
add_dependencies(getb64 getpthread)
|
||||
add_dependencies(getpthread getboost-stacktrace)
|
||||
add_dependencies(getboost-stacktrace getboost-interprocess)
|
||||
add_dependencies(getboost-interprocess getre2)
|
||||
add_dependencies(getre2 getrapidjson)
|
||||
add_dependencies(getrapidjson getopenssl-windows)
|
||||
add_dependencies(getopenssl-windows getopenssl)
|
||||
|
||||
ExternalProject_Add(triton
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
|
||||
GIT_TAG r23.05
|
||||
PREFIX triton
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
|
||||
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND "")
|
||||
|
||||
add_dependencies(triton ${VCPKG_DEPENDENCIES})
|
||||
|
||||
else()
|
||||
|
||||
if(DEFINED ENV{IS_DOCKER_BUILD})
|
||||
message(STATUS "IS_DOCKER_BUILD set")
|
||||
ExternalProject_Add(curl7
|
||||
PREFIX curl7
|
||||
GIT_REPOSITORY "https://github.com/curl/curl.git"
|
||||
GIT_TAG "curl-7_86_0"
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-build
|
||||
CMAKE_ARGS -DBUILD_TESTING=OFF -DBUILD_CURL_EXE=OFF -DBUILD_SHARED_LIBS=OFF -DCURL_STATICLIB=ON -DHTTP_ONLY=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE})
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(triton
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
|
||||
GIT_TAG r23.05
|
||||
PREFIX triton
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND "")
|
||||
|
||||
if(DEFINED ENV{IS_DOCKER_BUILD})
|
||||
add_dependencies(triton curl7)
|
||||
endif()
|
||||
|
||||
endif() #if (WIN32)
|
||||
|
||||
ExternalProject_Get_Property(triton SOURCE_DIR)
|
||||
set(TRITON_SRC ${SOURCE_DIR})
|
||||
|
||||
ExternalProject_Get_Property(triton BINARY_DIR)
|
||||
set(TRITON_BIN ${BINARY_DIR}/binary)
|
||||
set(TRITON_THIRD_PARTY ${BINARY_DIR}/third-party)
|
|
@ -5,8 +5,6 @@
|
|||
#include "onnxruntime_customop.hpp"
|
||||
#include <optional>
|
||||
#include <numeric>
|
||||
// uplevel the version when supported ort version migrates to newer ones
|
||||
#define SUPPORT_ORT_API_VERSION_TO 13
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
@ -32,6 +30,9 @@ class TensorBase {
|
|||
ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
ONNXTensorElementDataType Type() const {
|
||||
return type_;
|
||||
}
|
||||
int64_t NumberOfElement() const {
|
||||
if (shape_.has_value()) {
|
||||
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
|
||||
|
@ -51,12 +52,16 @@ class TensorBase {
|
|||
return "empty";
|
||||
}
|
||||
}
|
||||
virtual const void* DataRaw() const = 0;
|
||||
virtual size_t SizeInBytes() const = 0;
|
||||
|
||||
protected:
|
||||
const OrtW::CustomOpApi& api_;
|
||||
OrtKernelContext& ctx_;
|
||||
size_t indice_;
|
||||
bool is_input_;
|
||||
std::optional<std::vector<int64_t>> shape_;
|
||||
ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -93,12 +98,22 @@ class Tensor : public TensorBase {
|
|||
const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value_);
|
||||
shape_ = api_.GetTensorShape(info);
|
||||
type_ = api_.GetTensorElementType(info);
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
}
|
||||
const TT* Data() const {
|
||||
return api_.GetTensorData<TT>(const_value_);
|
||||
}
|
||||
|
||||
const void* DataRaw() const override {
|
||||
return reinterpret_cast<const void*>(Data());
|
||||
}
|
||||
|
||||
size_t SizeInBytes() const override {
|
||||
return NumberOfElement() * sizeof(TT);
|
||||
}
|
||||
|
||||
TT* Allocate(const std::vector<int64_t>& shape) {
|
||||
if (!data_) {
|
||||
OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
|
||||
|
@ -148,6 +163,7 @@ class Tensor<std::string> : public TensorBase {
|
|||
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value);
|
||||
shape_ = api_.GetTensorShape(info);
|
||||
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
|
||||
size_t num_chars;
|
||||
|
@ -173,6 +189,18 @@ class Tensor<std::string> : public TensorBase {
|
|||
const strings& Data() const {
|
||||
return input_strings_;
|
||||
}
|
||||
const void* DataRaw() const override {
|
||||
if (input_strings_.size() != 1) {
|
||||
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return reinterpret_cast<const void*>(input_strings_[0].c_str());
|
||||
}
|
||||
size_t SizeInBytes() const override {
|
||||
if (input_strings_.size() != 1) {
|
||||
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return input_strings_[0].size();
|
||||
}
|
||||
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
||||
std::vector<const char*> raw;
|
||||
for (const auto& s : ss) {
|
||||
|
@ -220,6 +248,7 @@ class Tensor<std::string_view> : public TensorBase {
|
|||
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value);
|
||||
shape_ = api_.GetTensorShape(info);
|
||||
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
|
||||
size_t num_chars;
|
||||
|
@ -251,6 +280,18 @@ class Tensor<std::string_view> : public TensorBase {
|
|||
const string_views& Data() const {
|
||||
return input_string_views_;
|
||||
}
|
||||
const void* DataRaw() const override {
|
||||
if (input_string_views_.size() != 1) {
|
||||
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return reinterpret_cast<const void*>(input_string_views_[0].data());
|
||||
}
|
||||
size_t SizeInBytes() const override {
|
||||
if (input_string_views_.size() != 1) {
|
||||
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return input_string_views_[0].size();
|
||||
}
|
||||
const Span<std::string_view>& AsSpan() {
|
||||
ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
@ -267,6 +308,103 @@ class Tensor<std::string_view> : public TensorBase {
|
|||
};
|
||||
|
||||
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
|
||||
using TensorPtrs = std::vector<TensorPtr>;
|
||||
|
||||
// Represent variadic input or output
|
||||
struct Variadic : public TensorBase {
|
||||
Variadic(const OrtW::CustomOpApi& api,
|
||||
OrtKernelContext& ctx,
|
||||
size_t indice,
|
||||
bool is_input) : TensorBase(api,
|
||||
ctx,
|
||||
indice,
|
||||
is_input) {
|
||||
#if ORT_API_VERSION < 14
|
||||
ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
|
||||
#endif
|
||||
if (is_input) {
|
||||
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
|
||||
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
|
||||
auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value);
|
||||
auto type = api_.GetTensorElementType(info);
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
TensorPtr tensor;
|
||||
switch (type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
tensor = std::make_unique<Custom::Tensor<std::string_view>>(api, ctx, ith_input, true);
|
||||
break;
|
||||
default:
|
||||
ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
|
||||
break;
|
||||
}
|
||||
tensors_.emplace_back(tensor.release());
|
||||
} // for
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
|
||||
auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
|
||||
auto raw_output = tensor.get()->Allocate(shape);
|
||||
tensors_.emplace_back(tensor.release());
|
||||
return raw_output;
|
||||
}
|
||||
Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
|
||||
auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
|
||||
Tensor<std::string>& output = *tensor;
|
||||
tensors_.emplace_back(tensor.release());
|
||||
return output;
|
||||
}
|
||||
const void* DataRaw() const override {
|
||||
ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
|
||||
return nullptr;
|
||||
}
|
||||
size_t SizeInBytes() const override {
|
||||
ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
|
||||
return 0;
|
||||
}
|
||||
size_t Size() const {
|
||||
return tensors_.size();
|
||||
}
|
||||
const TensorPtr& operator[](size_t indice) const {
|
||||
return tensors_.at(indice);
|
||||
}
|
||||
private:
|
||||
TensorPtrs tensors_;
|
||||
};
|
||||
|
||||
struct OrtLiteCustomOp : public OrtCustomOp {
|
||||
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
|
||||
|
@ -287,6 +425,44 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
#if ORT_API_VERSION >= 14
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CREATE_TUPLE_INPUT(data_type) \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
|
@ -446,6 +622,48 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
|
||||
#if ORT_API_VERSION >= 14
|
||||
template <typename T, typename... Ts>
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
||||
if (!input_types.empty()) {
|
||||
ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
||||
if (!input_types.empty()) {
|
||||
ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
||||
if (!output_types.empty()) {
|
||||
ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
||||
if (!output_types.empty()) {
|
||||
ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
||||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
|
||||
|
@ -508,7 +726,7 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
OrtLiteCustomOp(const char* op_name,
|
||||
const char* execution_provider) : op_name_(op_name),
|
||||
execution_provider_(execution_provider) {
|
||||
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED;
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
|
||||
|
@ -533,13 +751,45 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
return self->output_types_[indice];
|
||||
};
|
||||
|
||||
#if ORT_API_VERSION >= 14
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
|
||||
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
||||
return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
|
||||
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
||||
return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
|
||||
return 1;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
|
||||
return 0;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
|
||||
return 1;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
|
||||
return 0;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
|
||||
return OrtMemTypeDefault;
|
||||
};
|
||||
#else
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
|
||||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
|
||||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
#endif
|
||||
}
|
||||
|
||||
const std::string op_name_;
|
||||
|
@ -617,6 +867,11 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
|
|||
void init(CustomComputeFn<Args...>) {
|
||||
ParseArgs<Args...>(input_types_, output_types_);
|
||||
|
||||
if (!input_types_.empty() && input_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ||
|
||||
!output_types_.empty() && output_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
|
||||
OrtCustomOp::version = 14;
|
||||
}
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
|
|
|
@ -87,6 +87,7 @@ class CuopContainer {
|
|||
|
||||
#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
|
||||
#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
|
||||
#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
|
||||
|
||||
template <typename F>
|
||||
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
|
||||
|
@ -171,3 +172,7 @@ extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
|
|||
#ifdef ENABLE_DR_LIBS
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
|
||||
#endif
|
||||
|
||||
#if ENABLE_AZURE
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,369 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#define CURL_STATICLIB
|
||||
|
||||
#include "http_client.h"
|
||||
#include "curl/curl.h"
|
||||
#include "azure_invokers.hpp"
|
||||
#include <sstream>
|
||||
|
||||
constexpr const char* kUri = "model_uri";
|
||||
constexpr const char* kModelName = "model_name";
|
||||
constexpr const char* kModelVer = "model_version";
|
||||
constexpr const char* kVerbose = "verbose";
|
||||
|
||||
struct StringBuffer {
|
||||
StringBuffer() = default;
|
||||
~StringBuffer() = default;
|
||||
std::stringstream ss_;
|
||||
};
|
||||
|
||||
// apply the callback only when response is for sure to be a '/0' terminated string
|
||||
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
|
||||
try {
|
||||
size_t realsize = size * nmemb;
|
||||
auto buffer = reinterpret_cast<struct StringBuffer*>(userp);
|
||||
buffer->ss_.write(reinterpret_cast<const char*>(contents), realsize);
|
||||
return realsize;
|
||||
} catch (...) {
|
||||
// exception caught, abort write
|
||||
return CURLcode::CURLE_WRITE_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);
|
||||
|
||||
class CurlHandler {
|
||||
public:
|
||||
CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup),
|
||||
headers_(nullptr, curl_slist_free_all),
|
||||
from_holder_(from_, curl_formfree) {
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1");
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L);
|
||||
curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back);
|
||||
}
|
||||
~CurlHandler() = default;
|
||||
|
||||
void AddHeader(const char* data) {
|
||||
headers_.reset(curl_slist_append(headers_.release(), data));
|
||||
}
|
||||
template <typename... Args>
|
||||
void AddForm(Args... args) {
|
||||
curl_formadd(&from_, &last_, args...);
|
||||
}
|
||||
template <typename T>
|
||||
void SetOption(CURLoption opt, T val) {
|
||||
curl_easy_setopt(curl_.get(), opt, val);
|
||||
}
|
||||
CURLcode Perform() {
|
||||
SetOption(CURLOPT_HTTPHEADER, headers_.get());
|
||||
SetOption(CURLOPT_HTTPPOST, from_);
|
||||
return curl_easy_perform(curl_.get());
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
|
||||
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
|
||||
curl_httppost* from_{};
|
||||
curl_httppost* last_{};
|
||||
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_;
|
||||
};
|
||||
|
||||
AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api,
|
||||
const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
model_uri_ = TryToGetAttributeWithDefault<std::string>(kUri, "");
|
||||
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
|
||||
verbose_ = TryToGetAttributeWithDefault<bool>(kVerbose, false);
|
||||
}
|
||||
|
||||
void AzureAudioInvoker::Compute(const ortc::Tensor<std::string>& auth_token,
|
||||
const ortc::Tensor<uint8_t>& audio,
|
||||
ortc::Tensor<std::string>& text) {
|
||||
CurlHandler curl_handler(WriteStringCallback);
|
||||
StringBuffer string_buffer;
|
||||
|
||||
std::string full_auth = std::string{"Authorization: Bearer "} + auth_token.Data()[0];
|
||||
curl_handler.AddHeader(full_auth.c_str());
|
||||
curl_handler.AddHeader("Content-Type: multipart/form-data");
|
||||
|
||||
curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END);
|
||||
curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END);
|
||||
curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, audio.Data(),
|
||||
CURLFORM_BUFFERLENGTH, audio.NumberOfElement(), CURLFORM_END);
|
||||
|
||||
curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str());
|
||||
curl_handler.SetOption(CURLOPT_VERBOSE, verbose_);
|
||||
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
|
||||
|
||||
auto curl_ret = curl_handler.Perform();
|
||||
if (CURLE_OK != curl_ret) {
|
||||
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
|
||||
}
|
||||
|
||||
text.SetStringOutput(std::vector<std::string>{string_buffer.ss_.str()}, std::vector<int64_t>{1L});
|
||||
}
|
||||
|
||||
#if ORT_API_VERSION >= 14
|
||||
|
||||
namespace tc = triton::client;
|
||||
|
||||
AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
|
||||
//todo - check ver_str against 1.14 to prevent invalid access on newer APIs such as "KernelInfo_GetInputCount"
|
||||
//const auto* ort_api_base = OrtGetApiBase();
|
||||
//if (!ort_api_base) {
|
||||
// ORTX_CXX_API_THROW("failed to get ort base api", ORT_RUNTIME_EXCEPTION);
|
||||
//}
|
||||
//std::string ver_str = ort_api_base->GetVersionString();
|
||||
|
||||
model_uri_ = TryToGetAttributeWithDefault<std::string>(kUri, "");
|
||||
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
|
||||
model_ver_ = TryToGetAttributeWithDefault<std::string>(kModelVer, "0");
|
||||
verbose_ = TryToGetAttributeWithDefault<std::string>(kVerbose, "0");
|
||||
auto err = tc::InferenceServerHttpClient::Create(&triton_client_, model_uri_, verbose_ != "0");
|
||||
|
||||
OrtStatusPtr status = {};
|
||||
size_t input_count = {};
|
||||
status = api_.KernelInfo_GetInputCount(&info_, &input_count);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
|
||||
char input_name[1024] = {};
|
||||
size_t name_size = 1024;
|
||||
status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get input name", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
input_names_.push_back(input_name);
|
||||
}
|
||||
|
||||
size_t output_count = {};
|
||||
status = api_.KernelInfo_GetOutputCount(&info_, &output_count);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
for (size_t ith_output = 0; ith_output < output_count; ++ith_output) {
|
||||
char output_name[1024] = {};
|
||||
size_t name_size = 1024;
|
||||
status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size);
|
||||
if (status) {
|
||||
ORTX_CXX_API_THROW("failed to get output name", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
output_names_.push_back(output_name);
|
||||
}
|
||||
}
|
||||
|
||||
std::string MapDataType(ONNXTensorElementDataType onnx_data_type) {
|
||||
std::string triton_data_type;
|
||||
switch (onnx_data_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
triton_data_type = "FP32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
triton_data_type = "UINT8";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
||||
triton_data_type = "INT8";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
triton_data_type = "UINT16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
||||
triton_data_type = "INT16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
triton_data_type = "INT32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
triton_data_type = "INT64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
triton_data_type = "BYTES";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
triton_data_type = "BOOL";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
triton_data_type = "FP16";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
triton_data_type = "FP64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
||||
triton_data_type = "UINT32";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
||||
triton_data_type = "UINT64";
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
|
||||
triton_data_type = "BF16";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return triton_data_type;
|
||||
}
|
||||
|
||||
int8_t* CreateNonStrTensor(const std::string& data_type,
|
||||
ortc::Variadic& outputs,
|
||||
size_t i,
|
||||
const std::vector<int64_t>& shape) {
|
||||
if (data_type == "FP32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<float>(i, shape));
|
||||
} else if (data_type == "UINT8") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint8_t>(i, shape));
|
||||
} else if (data_type == "INT8") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int8_t>(i, shape));
|
||||
} else if (data_type == "UINT16") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint16_t>(i, shape));
|
||||
} else if (data_type == "INT16") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int16_t>(i, shape));
|
||||
} else if (data_type == "INT32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int32_t>(i, shape));
|
||||
} else if (data_type == "UINT32") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint32_t>(i, shape));
|
||||
} else if (data_type == "INT64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int64_t>(i, shape));
|
||||
} else if (data_type == "UINT64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint64_t>(i, shape));
|
||||
} else if (data_type == "BOOL") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<bool>(i, shape));
|
||||
} else if (data_type == "FP64") {
|
||||
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<double>(i, shape));
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_TRITON_ERR(ret, msg) \
|
||||
if (!ret.IsOk()) { \
|
||||
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
|
||||
}
|
||||
|
||||
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs,
|
||||
ortc::Variadic& outputs) {
|
||||
if (inputs.Size() < 1 ||
|
||||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
if (inputs.Size() != input_names_.size()) {
|
||||
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
auto auth_token = reinterpret_cast<const char*>(inputs[0]->DataRaw());
|
||||
std::vector<std::unique_ptr<tc::InferInput>> triton_input_vec;
|
||||
std::vector<tc::InferInput*> triton_inputs;
|
||||
std::vector<std::unique_ptr<const tc::InferRequestedOutput>> triton_output_vec;
|
||||
std::vector<const tc::InferRequestedOutput*> triton_outputs;
|
||||
tc::Error err;
|
||||
|
||||
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
|
||||
tc::InferInput* triton_input = {};
|
||||
std::string triton_data_type = MapDataType(inputs[ith_input]->Type());
|
||||
if (triton_data_type.empty()) {
|
||||
ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
err = tc::InferInput::Create(&triton_input, input_names_[ith_input], inputs[ith_input]->Shape(), triton_data_type);
|
||||
CHECK_TRITON_ERR(err, "failed to create triton input");
|
||||
triton_input_vec.emplace_back(triton_input);
|
||||
|
||||
triton_inputs.push_back(triton_input);
|
||||
// todo - test string
|
||||
const float* data_raw = reinterpret_cast<const float*>(inputs[ith_input]->DataRaw());
|
||||
size_t size_in_bytes = inputs[ith_input]->SizeInBytes();
|
||||
err = triton_input->AppendRaw(reinterpret_cast<const uint8_t*>(data_raw), size_in_bytes);
|
||||
CHECK_TRITON_ERR(err, "failed to append raw data to input");
|
||||
}
|
||||
|
||||
for (size_t ith_output = 0; ith_output < output_names_.size(); ++ith_output) {
|
||||
tc::InferRequestedOutput* triton_output = {};
|
||||
err = tc::InferRequestedOutput::Create(&triton_output, output_names_[ith_output]);
|
||||
CHECK_TRITON_ERR(err, "failed to create triton output");
|
||||
triton_output_vec.emplace_back(triton_output);
|
||||
triton_outputs.push_back(triton_output);
|
||||
}
|
||||
|
||||
std::unique_ptr<tc::InferResult> results_ptr;
|
||||
tc::InferResult* results = {};
|
||||
tc::InferOptions options(model_name_);
|
||||
options.model_version_ = model_ver_;
|
||||
options.client_timeout_ = 0;
|
||||
|
||||
tc::Headers http_headers;
|
||||
http_headers["Authorization"] = std::string{"Bearer "} + auth_token;
|
||||
|
||||
err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs,
|
||||
http_headers, tc::Parameters(),
|
||||
tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config?
|
||||
tc::InferenceServerHttpClient::CompressionType::NONE);
|
||||
|
||||
results_ptr.reset(results);
|
||||
CHECK_TRITON_ERR(err, "failed to do triton inference");
|
||||
|
||||
size_t output_index = 0;
|
||||
auto iter = output_names_.begin();
|
||||
|
||||
while (iter != output_names_.end()) {
|
||||
std::vector<int64_t> shape;
|
||||
err = results_ptr->Shape(*iter, &shape);
|
||||
CHECK_TRITON_ERR(err, "failed to get output shape");
|
||||
|
||||
std::string type;
|
||||
err = results_ptr->Datatype(*iter, &type);
|
||||
CHECK_TRITON_ERR(err, "failed to get output type");
|
||||
|
||||
if ("BYTES" == type) {
|
||||
std::vector<std::string> output_strings;
|
||||
err = results_ptr->StringData(*iter, &output_strings);
|
||||
CHECK_TRITON_ERR(err, "failed to get output as string");
|
||||
auto& string_tensor = outputs.AllocateStringTensor(output_index);
|
||||
string_tensor.SetStringOutput(output_strings, shape);
|
||||
} else {
|
||||
const uint8_t* raw_data = {};
|
||||
size_t raw_size;
|
||||
err = results_ptr->RawData(*iter, &raw_data, &raw_size);
|
||||
CHECK_TRITON_ERR(err, "failed to get output raw data");
|
||||
auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape);
|
||||
memcpy(output_raw, raw_data, raw_size);
|
||||
}
|
||||
|
||||
++output_index;
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
|
||||
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker),
|
||||
CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker)
|
||||
#ifdef TEST_AZURE_INVOKERS_AS_CPU_OP
|
||||
,CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker)
|
||||
,CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker)
|
||||
#endif
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
|
||||
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker)
|
||||
#ifdef TEST_AZURE_INVOKERS_AS_CPU_OP
|
||||
,CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker)
|
||||
#endif
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader;
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
struct AzureAudioInvoker : public BaseKernel {
|
||||
AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<std::string>& auth_token,
|
||||
const ortc::Tensor<uint8_t>& raw_audio_data,
|
||||
ortc::Tensor<std::string>& text);
|
||||
|
||||
private:
|
||||
std::string model_uri_;
|
||||
std::string model_name_;
|
||||
bool verbose_;
|
||||
};
|
||||
|
||||
#if ORT_API_VERSION >= 14
|
||||
struct AzureTritonInvoker : public BaseKernel {
|
||||
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Variadic& inputs,
|
||||
ortc::Variadic& outputs);
|
||||
|
||||
private:
|
||||
std::string model_uri_;
|
||||
std::string model_name_;
|
||||
std::string model_ver_;
|
||||
std::string verbose_;
|
||||
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
};
|
||||
#endif
|
|
@ -121,6 +121,10 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
#if defined(ENABLE_DR_LIBS)
|
||||
,
|
||||
LoadCustomOpClasses_Audio
|
||||
#endif
|
||||
#if defined(ENABLE_AZURE)
|
||||
,
|
||||
LoadCustomOpClasses_Azure
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче