* Implement azure invokers (#486)

* draft azure ops

* migrate triton client

* AzureAudioInvoker works

* triton client builds

* triton invoker works

* limit version

* restore setup.py

* limit ort version

* upgrade version

* pip install cmake

* add ut

* promote ort header version to 1.15.1

* register as cpu op

* limit triton invoker to 1.14 and newer

* remove test

* install rapidjson

* install dep

* sudo install

* install version script

* print err msg

* fix pipeline

* disable from web assembly

* install cmake

* Fix pipelines (#479)

* 1

* 2

* 3

* 4

* 5

* 6

* 7

* 8

* 9

* 10

* 11

* 12

* 13

* 14

* 15

* 16

* 17

* 18

* 19

* 20

* 21

* 22

* 23

* 24

* 25

* 26

* 27

* 28

* 29

* 30

* 31

* 32

* 33

* 34

* 35

* 36

* 37

* 38

* 39

* 40

* 41

* 42

* 43

* 44

* 45

* 46

* 47

* 47

* 48

* 49

* 50

* 51

* 52

* 53

* 54

* 55

* 56

* 57

* 58

* 59

* 60

* 61

* 62

* 62

* 63

* 64

* 65

* 66

* 67

* 68

* 69

* 70

* 71

* 72

* 73

* 74

* 75

* 76

* 77

* 78

* 79:

* 80:

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* fix pipelines (#481)

* 1

* 2

* 3

* 4

* 5

* 6

* 7

* 8

* 9

* 10

* 11

* 12

* 13

* 14

* 15

* 16

* 17

* 18

* 19

* 20

* 21

* 22

* 23

* 24

* 25

* 26

* 27

* 28

* 29

* 30

* 31

* 32

* 33

* 34

* 35

* 36

* 37

* 38

* 39

* 40

* 41

* 42

* 43

* 44

* 45

* 46

* 47

* 47

* 48

* 49

* 50

* 51

* 52

* 53

* 54

* 55

* 56

* 57

* 58

* 59

* 60

* 61

* 62

* 62

* 63

* 64

* 65

* 66

* 67

* 68

* 69

* 70

* 71

* 72

* 73

* 74

* 75

* 76

* 77

* 78

* 79:

* 80:

* 81

* 82

* 83

* 84

* 85

* 86

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* test as cpu op

* add ut

* add ut

* move cond

* tune ut

* tune pipeline

* promote to ort 141

* reset header version

* restore cmake

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* trim changes

* revert req txt

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-07-10 10:07:33 -07:00 коммит произвёл GitHub
Родитель b49c0231ab
Коммит 27132ced71
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 849 добавлений и 6 удалений

Просмотреть файл

@ -58,7 +58,20 @@ jobs:
displayName: Unpack ONNXRuntime package.
- script: |
CPU_NUMBER=2 sh ./build.sh -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=$(Build.SourcesDirectory)/onnxruntime-linux-x64-$(ort.version)
wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz
tar zxvf v1.1.0.tar.gz
cd rapidjson-1.1.0
mkdir build
cd build
cmake ..
sudo cmake --install .
cd ../..
git clone https://github.com/triton-inference-server/client.git --branch r23.05 ~/client
sudo ln -s ~/client/src/c++/library/libhttpclient.ldscript /lib/libhttpclient.ldscript
displayName: install deps for azure invokers
- script: |
CPU_NUMBER=2 sh ./build.sh -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=$(Build.SourcesDirectory)/onnxruntime-linux-x64-$(ort.version) -DOCOS_ENABLE_AZURE=ON
displayName: build the customop library with onnxruntime
- script: |
@ -263,7 +276,7 @@ jobs:
- script: |
call $(vsdevcmd)
call .\build.bat -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=.\onnxruntime-win-x64-$(ort.version)
call .\build.bat -DOCOS_ENABLE_CTEST=ON -DONNXRUNTIME_PKG_DIR=.\onnxruntime-win-x64-$(ort.version) -DOCOS_ENABLE_AZURE=ON
displayName: build the customop library with onnxruntime
- script: |

Просмотреть файл

@ -58,6 +58,7 @@ option(OCOS_ENABLE_OPENCV_CODECS "Enable cv2 and vision operators that require o
option(OCOS_ENABLE_CV2 "Enable the operators in `operators/cv2`" ON)
option(OCOS_ENABLE_VISION "Enable the operators in `operators/vision`" ON)
option(OCOS_ENABLE_AUDIO "Enable the operators for audio processing" ON)
option(OCOS_ENABLE_AZURE "Enable the operators for azure execution provider" OFF)
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
@ -79,6 +80,7 @@ function(disable_all_operators)
set(OCOS_ENABLE_OPENCV_CODECS OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_CV2 OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_VISION OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_AZURE OFF CACHE INTERNAL "" FORCE)
endfunction()
if (CMAKE_GENERATOR_PLATFORM)
@ -376,6 +378,19 @@ if(OCOS_ENABLE_BLINGFIRE)
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
endif()
if(OCOS_ENABLE_AZURE)
if ($ENV{TEST_AZURE_INVOKERS_AS_CPU_OP} MATCHES "ON")
message(STATUS "Azure inovkers will be testable as cpu ops")
add_compile_definitions(TEST_AZURE_INVOKERS_AS_CPU_OP)
endif()
# Azure endpoint invokers
include(triton)
file(GLOB TARGET_SRC_AZURE "operators/azure/*.cc" "operators/azure/*.h*")
list(APPEND TARGET_SRC ${TARGET_SRC_AZURE})
endif()
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
message(STATUS "Fetch json")
include(json)
@ -478,6 +493,10 @@ if(OCOS_ENABLE_VISION)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION)
endif()
if(OCOS_ENABLE_AZURE)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_AZURE)
endif()
if(OCOS_ENABLE_GPT2_TOKENIZER)
# GPT2
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
@ -693,3 +712,30 @@ if(OCOS_ENABLE_CTEST)
add_test(NAME extensions_test COMMAND $<TARGET_FILE:extensions_test>)
endif()
endif()
if(OCOS_ENABLE_AZURE)
add_dependencies(ocos_operators triton)
target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
if (ocos_target_platform STREQUAL "AMD64")
set(vcpkg_target_platform "x86")
else()
set(vcpkg_target_platform ${ocos_target_platform})
endif()
if (WIN32)
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib)
target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32)
else()
find_package(ZLIB REQUIRED)
find_package(OpenSSL REQUIRED)
target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
endif() #if (WIN32)
endif()

116
cmake/externals/triton.cmake поставляемый Normal file
Просмотреть файл

@ -0,0 +1,116 @@
include(ExternalProject)
if (WIN32)
if (ocos_target_platform STREQUAL "AMD64")
set(vcpkg_target_platform "x86")
else()
set(vcpkg_target_platform ${ocos_target_platform})
endif()
ExternalProject_Add(vcpkg
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
GIT_TAG 2023.06.20
PREFIX vcpkg
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-build
CONFIGURE_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
BUILD_COMMAND "<SOURCE_DIR>/bootstrap-vcpkg.bat")
set(VCPKG_SRC ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
set(ENV{VCPKG_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src)
message(STATUS "VCPKG_SRC: " ${VCPKG_SRC})
message(STATUS "VCPKG_ROOT: " $ENV{VCPKG_ROOT})
add_custom_command(
COMMAND ${VCPKG_SRC}/vcpkg integrate install
COMMAND ${CMAKE_COMMAND} -E touch vcpkg_integrate.stamp
OUTPUT vcpkg_integrate.stamp
DEPENDS vcpkg
)
add_custom_target(vcpkg_integrate ALL DEPENDS vcpkg_integrate.stamp)
set(VCPKG_DEPENDENCIES "vcpkg_integrate")
function(vcpkg_install PACKAGE_NAME)
add_custom_command(
OUTPUT ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO
COMMAND ${VCPKG_SRC}/vcpkg install ${PACKAGE_NAME}:${vcpkg_target_platform}-windows-static --vcpkg-root=${CMAKE_CURRENT_BINARY_DIR}/_deps/vcpkg-src
WORKING_DIRECTORY ${VCPKG_SRC}
DEPENDS vcpkg_integrate)
add_custom_target(get${PACKAGE_NAME}
ALL
DEPENDS ${VCPKG_SRC}/packages/${PACKAGE_NAME}_${vcpkg_target_platform}-windows-static/BUILD_INFO)
list(APPEND VCPKG_DEPENDENCIES "get${PACKAGE_NAME}")
set(VCPKG_DEPENDENCIES ${VCPKG_DEPENDENCIES} PARENT_SCOPE)
endfunction()
vcpkg_install(openssl)
vcpkg_install(openssl-windows)
vcpkg_install(rapidjson)
vcpkg_install(re2)
vcpkg_install(boost-interprocess)
vcpkg_install(boost-stacktrace)
vcpkg_install(pthread)
vcpkg_install(b64)
add_dependencies(getb64 getpthread)
add_dependencies(getpthread getboost-stacktrace)
add_dependencies(getboost-stacktrace getboost-interprocess)
add_dependencies(getboost-interprocess getre2)
add_dependencies(getre2 getrapidjson)
add_dependencies(getrapidjson getopenssl-windows)
add_dependencies(getopenssl-windows getopenssl)
ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${vcpkg_target_platform}-windows-static -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")
add_dependencies(triton ${VCPKG_DEPENDENCIES})
else()
if(DEFINED ENV{IS_DOCKER_BUILD})
message(STATUS "IS_DOCKER_BUILD set")
ExternalProject_Add(curl7
PREFIX curl7
GIT_REPOSITORY "https://github.com/curl/curl.git"
GIT_TAG "curl-7_86_0"
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/curl7-build
CMAKE_ARGS -DBUILD_TESTING=OFF -DBUILD_CURL_EXE=OFF -DBUILD_SHARED_LIBS=OFF -DCURL_STATICLIB=ON -DHTTP_ONLY=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE})
endif()
ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")
if(DEFINED ENV{IS_DOCKER_BUILD})
add_dependencies(triton curl7)
endif()
endif() #if (WIN32)
ExternalProject_Get_Property(triton SOURCE_DIR)
set(TRITON_SRC ${SOURCE_DIR})
ExternalProject_Get_Property(triton BINARY_DIR)
set(TRITON_BIN ${BINARY_DIR}/binary)
set(TRITON_THIRD_PARTY ${BINARY_DIR}/third-party)

Просмотреть файл

@ -5,8 +5,6 @@
#include "onnxruntime_customop.hpp"
#include <optional>
#include <numeric>
// uplevel the version when supported ort version migrates to newer ones
#define SUPPORT_ORT_API_VERSION_TO 13
namespace Ort {
namespace Custom {
@ -32,6 +30,9 @@ class TensorBase {
ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
}
}
ONNXTensorElementDataType Type() const {
return type_;
}
int64_t NumberOfElement() const {
if (shape_.has_value()) {
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
@ -51,12 +52,16 @@ class TensorBase {
return "empty";
}
}
virtual const void* DataRaw() const = 0;
virtual size_t SizeInBytes() const = 0;
protected:
const OrtW::CustomOpApi& api_;
OrtKernelContext& ctx_;
size_t indice_;
bool is_input_;
std::optional<std::vector<int64_t>> shape_;
ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
};
template <typename T>
@ -93,12 +98,22 @@ class Tensor : public TensorBase {
const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
type_ = api_.GetTensorElementType(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
}
}
const TT* Data() const {
return api_.GetTensorData<TT>(const_value_);
}
const void* DataRaw() const override {
return reinterpret_cast<const void*>(Data());
}
size_t SizeInBytes() const override {
return NumberOfElement() * sizeof(TT);
}
TT* Allocate(const std::vector<int64_t>& shape) {
if (!data_) {
OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
@ -148,6 +163,7 @@ class Tensor<std::string> : public TensorBase {
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value);
shape_ = api_.GetTensorShape(info);
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
api_.ReleaseTensorTypeAndShapeInfo(info);
size_t num_chars;
@ -173,6 +189,18 @@ class Tensor<std::string> : public TensorBase {
const strings& Data() const {
return input_strings_;
}
const void* DataRaw() const override {
if (input_strings_.size() != 1) {
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return reinterpret_cast<const void*>(input_strings_[0].c_str());
}
size_t SizeInBytes() const override {
if (input_strings_.size() != 1) {
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return input_strings_[0].size();
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
std::vector<const char*> raw;
for (const auto& s : ss) {
@ -220,6 +248,7 @@ class Tensor<std::string_view> : public TensorBase {
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value);
shape_ = api_.GetTensorShape(info);
type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
api_.ReleaseTensorTypeAndShapeInfo(info);
size_t num_chars;
@ -251,6 +280,18 @@ class Tensor<std::string_view> : public TensorBase {
const string_views& Data() const {
return input_string_views_;
}
const void* DataRaw() const override {
if (input_string_views_.size() != 1) {
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return reinterpret_cast<const void*>(input_string_views_[0].data());
}
size_t SizeInBytes() const override {
if (input_string_views_.size() != 1) {
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return input_string_views_[0].size();
}
const Span<std::string_view>& AsSpan() {
ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
}
@ -267,6 +308,103 @@ class Tensor<std::string_view> : public TensorBase {
};
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
using TensorPtrs = std::vector<TensorPtr>;
// Represent variadic input or output
struct Variadic : public TensorBase {
Variadic(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : TensorBase(api,
ctx,
indice,
is_input) {
#if ORT_API_VERSION < 14
ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
#endif
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
auto* info = api_.GetTensorTypeAndShape(const_value);
auto type = api_.GetTensorElementType(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
TensorPtr tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
tensor = std::make_unique<Custom::Tensor<std::string_view>>(api, ctx, ith_input, true);
break;
default:
ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
break;
}
tensors_.emplace_back(tensor.release());
} // for
}
}
template<typename T>
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
auto raw_output = tensor.get()->Allocate(shape);
tensors_.emplace_back(tensor.release());
return raw_output;
}
Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
Tensor<std::string>& output = *tensor;
tensors_.emplace_back(tensor.release());
return output;
}
const void* DataRaw() const override {
ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
return nullptr;
}
size_t SizeInBytes() const override {
ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
return 0;
}
size_t Size() const {
return tensors_.size();
}
const TensorPtr& operator[](size_t indice) const {
return tensors_.at(indice);
}
private:
TensorPtrs tensors_;
};
struct OrtLiteCustomOp : public OrtCustomOp {
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
@ -287,6 +425,44 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return std::tuple_cat(current, next);
}
#if ORT_API_VERSION >= 14
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif
#define CREATE_TUPLE_INPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
@ -446,6 +622,48 @@ struct OrtLiteCustomOp : public OrtCustomOp {
ParseArgs<Ts...>(input_types, output_types);
}
#if ORT_API_VERSION >= 14
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!input_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
}
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!input_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
}
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!output_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
}
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!output_types.empty()) {
ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
}
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}
#endif
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
@ -508,7 +726,7 @@ struct OrtLiteCustomOp : public OrtCustomOp {
OrtLiteCustomOp(const char* op_name,
const char* execution_provider) : op_name_(op_name),
execution_provider_(execution_provider) {
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED;
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
@ -533,13 +751,45 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return self->output_types_[indice];
};
#if ORT_API_VERSION >= 14
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
};
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
return 1;
};
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
return 0;
};
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
return 1;
};
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
return 0;
};
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
return OrtMemTypeDefault;
};
#else
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
return INPUT_OUTPUT_OPTIONAL;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
return INPUT_OUTPUT_OPTIONAL;
};
#endif
}
const std::string op_name_;
@ -617,6 +867,11 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
void init(CustomComputeFn<Args...>) {
ParseArgs<Args...>(input_types_, output_types_);
if (!input_types_.empty() && input_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ||
!output_types_.empty() && output_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
OrtCustomOp::version = 14;
}
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
std::vector<TensorPtr> tensors;

Просмотреть файл

@ -87,6 +87,7 @@ class CuopContainer {
#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
template <typename F>
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
@ -171,3 +172,7 @@ extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
#ifdef ENABLE_DR_LIBS
extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
#endif
#if ENABLE_AZURE
extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
#endif

Просмотреть файл

@ -0,0 +1,369 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define CURL_STATICLIB
#include "http_client.h"
#include "curl/curl.h"
#include "azure_invokers.hpp"
#include <sstream>
constexpr const char* kUri = "model_uri";
constexpr const char* kModelName = "model_name";
constexpr const char* kModelVer = "model_version";
constexpr const char* kVerbose = "verbose";
struct StringBuffer {
StringBuffer() = default;
~StringBuffer() = default;
std::stringstream ss_;
};
// apply the callback only when response is for sure to be a '/0' terminated string
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
try {
size_t realsize = size * nmemb;
auto buffer = reinterpret_cast<struct StringBuffer*>(userp);
buffer->ss_.write(reinterpret_cast<const char*>(contents), realsize);
return realsize;
} catch (...) {
// exception caught, abort write
return CURLcode::CURLE_WRITE_ERROR;
}
}
using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);
class CurlHandler {
public:
CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup),
headers_(nullptr, curl_slist_free_all),
from_holder_(from_, curl_formfree) {
curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L);
curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1");
curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L);
curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back);
}
~CurlHandler() = default;
void AddHeader(const char* data) {
headers_.reset(curl_slist_append(headers_.release(), data));
}
template <typename... Args>
void AddForm(Args... args) {
curl_formadd(&from_, &last_, args...);
}
template <typename T>
void SetOption(CURLoption opt, T val) {
curl_easy_setopt(curl_.get(), opt, val);
}
CURLcode Perform() {
SetOption(CURLOPT_HTTPHEADER, headers_.get());
SetOption(CURLOPT_HTTPPOST, from_);
return curl_easy_perform(curl_.get());
}
private:
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
curl_httppost* from_{};
curl_httppost* last_{};
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_;
};
AzureAudioInvoker::AzureAudioInvoker(const OrtApi& api,
const OrtKernelInfo& info) : BaseKernel(api, info) {
model_uri_ = TryToGetAttributeWithDefault<std::string>(kUri, "");
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
verbose_ = TryToGetAttributeWithDefault<bool>(kVerbose, false);
}
void AzureAudioInvoker::Compute(const ortc::Tensor<std::string>& auth_token,
const ortc::Tensor<uint8_t>& audio,
ortc::Tensor<std::string>& text) {
CurlHandler curl_handler(WriteStringCallback);
StringBuffer string_buffer;
std::string full_auth = std::string{"Authorization: Bearer "} + auth_token.Data()[0];
curl_handler.AddHeader(full_auth.c_str());
curl_handler.AddHeader("Content-Type: multipart/form-data");
curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END);
curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END);
curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, audio.Data(),
CURLFORM_BUFFERLENGTH, audio.NumberOfElement(), CURLFORM_END);
curl_handler.SetOption(CURLOPT_URL, model_uri_.c_str());
curl_handler.SetOption(CURLOPT_VERBOSE, verbose_);
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
auto curl_ret = curl_handler.Perform();
if (CURLE_OK != curl_ret) {
ORTX_CXX_API_THROW(curl_easy_strerror(curl_ret), ORT_FAIL);
}
text.SetStringOutput(std::vector<std::string>{string_buffer.ss_.str()}, std::vector<int64_t>{1L});
}
#if ORT_API_VERSION >= 14
namespace tc = triton::client;
AzureTritonInvoker::AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
//todo - check ver_str against 1.14 to prevent invalid access on newer APIs such as "KernelInfo_GetInputCount"
//const auto* ort_api_base = OrtGetApiBase();
//if (!ort_api_base) {
// ORTX_CXX_API_THROW("failed to get ort base api", ORT_RUNTIME_EXCEPTION);
//}
//std::string ver_str = ort_api_base->GetVersionString();
model_uri_ = TryToGetAttributeWithDefault<std::string>(kUri, "");
model_name_ = TryToGetAttributeWithDefault<std::string>(kModelName, "");
model_ver_ = TryToGetAttributeWithDefault<std::string>(kModelVer, "0");
verbose_ = TryToGetAttributeWithDefault<std::string>(kVerbose, "0");
auto err = tc::InferenceServerHttpClient::Create(&triton_client_, model_uri_, verbose_ != "0");
OrtStatusPtr status = {};
size_t input_count = {};
status = api_.KernelInfo_GetInputCount(&info_, &input_count);
if (status) {
ORTX_CXX_API_THROW("failed to get input count", ORT_RUNTIME_EXCEPTION);
}
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
char input_name[1024] = {};
size_t name_size = 1024;
status = api_.KernelInfo_GetInputName(&info_, ith_input, input_name, &name_size);
if (status) {
ORTX_CXX_API_THROW("failed to get input name", ORT_RUNTIME_EXCEPTION);
}
input_names_.push_back(input_name);
}
size_t output_count = {};
status = api_.KernelInfo_GetOutputCount(&info_, &output_count);
if (status) {
ORTX_CXX_API_THROW("failed to get output count", ORT_RUNTIME_EXCEPTION);
}
for (size_t ith_output = 0; ith_output < output_count; ++ith_output) {
char output_name[1024] = {};
size_t name_size = 1024;
status = api_.KernelInfo_GetOutputName(&info_, ith_output, output_name, &name_size);
if (status) {
ORTX_CXX_API_THROW("failed to get output name", ORT_RUNTIME_EXCEPTION);
}
output_names_.push_back(output_name);
}
}
std::string MapDataType(ONNXTensorElementDataType onnx_data_type) {
std::string triton_data_type;
switch (onnx_data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
triton_data_type = "FP32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
triton_data_type = "UINT8";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
triton_data_type = "INT8";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
triton_data_type = "UINT16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
triton_data_type = "INT16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
triton_data_type = "INT32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
triton_data_type = "INT64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
triton_data_type = "BYTES";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
triton_data_type = "BOOL";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
triton_data_type = "FP16";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
triton_data_type = "FP64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
triton_data_type = "UINT32";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
triton_data_type = "UINT64";
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
triton_data_type = "BF16";
break;
default:
break;
}
return triton_data_type;
}
int8_t* CreateNonStrTensor(const std::string& data_type,
ortc::Variadic& outputs,
size_t i,
const std::vector<int64_t>& shape) {
if (data_type == "FP32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<float>(i, shape));
} else if (data_type == "UINT8") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint8_t>(i, shape));
} else if (data_type == "INT8") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int8_t>(i, shape));
} else if (data_type == "UINT16") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint16_t>(i, shape));
} else if (data_type == "INT16") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int16_t>(i, shape));
} else if (data_type == "INT32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int32_t>(i, shape));
} else if (data_type == "UINT32") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint32_t>(i, shape));
} else if (data_type == "INT64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<int64_t>(i, shape));
} else if (data_type == "UINT64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<uint64_t>(i, shape));
} else if (data_type == "BOOL") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<bool>(i, shape));
} else if (data_type == "FP64") {
return reinterpret_cast<int8_t*>(outputs.AllocateOutput<double>(i, shape));
} else {
return {};
}
}
#define CHECK_TRITON_ERR(ret, msg) \
if (!ret.IsOk()) { \
return ORTX_CXX_API_THROW("Triton err: " + ret.Message(), ORT_RUNTIME_EXCEPTION); \
}
void AzureTritonInvoker::Compute(const ortc::Variadic& inputs,
ortc::Variadic& outputs) {
if (inputs.Size() < 1 ||
inputs[0]->Type() != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORTX_CXX_API_THROW("invalid inputs, auto token missing", ORT_RUNTIME_EXCEPTION);
}
if (inputs.Size() != input_names_.size()) {
ORTX_CXX_API_THROW("input count mismatch", ORT_RUNTIME_EXCEPTION);
}
auto auth_token = reinterpret_cast<const char*>(inputs[0]->DataRaw());
std::vector<std::unique_ptr<tc::InferInput>> triton_input_vec;
std::vector<tc::InferInput*> triton_inputs;
std::vector<std::unique_ptr<const tc::InferRequestedOutput>> triton_output_vec;
std::vector<const tc::InferRequestedOutput*> triton_outputs;
tc::Error err;
for (size_t ith_input = 1; ith_input < inputs.Size(); ++ith_input) {
tc::InferInput* triton_input = {};
std::string triton_data_type = MapDataType(inputs[ith_input]->Type());
if (triton_data_type.empty()) {
ORTX_CXX_API_THROW("unknow onnx data type", ORT_RUNTIME_EXCEPTION);
}
err = tc::InferInput::Create(&triton_input, input_names_[ith_input], inputs[ith_input]->Shape(), triton_data_type);
CHECK_TRITON_ERR(err, "failed to create triton input");
triton_input_vec.emplace_back(triton_input);
triton_inputs.push_back(triton_input);
// todo - test string
const float* data_raw = reinterpret_cast<const float*>(inputs[ith_input]->DataRaw());
size_t size_in_bytes = inputs[ith_input]->SizeInBytes();
err = triton_input->AppendRaw(reinterpret_cast<const uint8_t*>(data_raw), size_in_bytes);
CHECK_TRITON_ERR(err, "failed to append raw data to input");
}
for (size_t ith_output = 0; ith_output < output_names_.size(); ++ith_output) {
tc::InferRequestedOutput* triton_output = {};
err = tc::InferRequestedOutput::Create(&triton_output, output_names_[ith_output]);
CHECK_TRITON_ERR(err, "failed to create triton output");
triton_output_vec.emplace_back(triton_output);
triton_outputs.push_back(triton_output);
}
std::unique_ptr<tc::InferResult> results_ptr;
tc::InferResult* results = {};
tc::InferOptions options(model_name_);
options.model_version_ = model_ver_;
options.client_timeout_ = 0;
tc::Headers http_headers;
http_headers["Authorization"] = std::string{"Bearer "} + auth_token;
err = triton_client_->Infer(&results, options, triton_inputs, triton_outputs,
http_headers, tc::Parameters(),
tc::InferenceServerHttpClient::CompressionType::NONE, // support compression in config?
tc::InferenceServerHttpClient::CompressionType::NONE);
results_ptr.reset(results);
CHECK_TRITON_ERR(err, "failed to do triton inference");
size_t output_index = 0;
auto iter = output_names_.begin();
while (iter != output_names_.end()) {
std::vector<int64_t> shape;
err = results_ptr->Shape(*iter, &shape);
CHECK_TRITON_ERR(err, "failed to get output shape");
std::string type;
err = results_ptr->Datatype(*iter, &type);
CHECK_TRITON_ERR(err, "failed to get output type");
if ("BYTES" == type) {
std::vector<std::string> output_strings;
err = results_ptr->StringData(*iter, &output_strings);
CHECK_TRITON_ERR(err, "failed to get output as string");
auto& string_tensor = outputs.AllocateStringTensor(output_index);
string_tensor.SetStringOutput(output_strings, shape);
} else {
const uint8_t* raw_data = {};
size_t raw_size;
err = results_ptr->RawData(*iter, &raw_data, &raw_size);
CHECK_TRITON_ERR(err, "failed to get output raw data");
auto* output_raw = CreateNonStrTensor(type, outputs, output_index, shape);
memcpy(output_raw, raw_data, raw_size);
}
++output_index;
++iter;
}
}
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker),
CustomAzureStruct("AzureTritonInvoker", AzureTritonInvoker)
#ifdef TEST_AZURE_INVOKERS_AS_CPU_OP
,CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker)
,CustomCpuStruct("AzureTritonInvoker", AzureTritonInvoker)
#endif
);
return op_loader.GetCustomOps();
}
#else
const std::vector<const OrtCustomOp*>& AzureInvokerLoader() {
static OrtOpLoader op_loader(CustomAzureStruct("AzureAudioInvoker", AzureAudioInvoker)
#ifdef TEST_AZURE_INVOKERS_AS_CPU_OP
,CustomCpuStruct("AzureAudioInvoker", AzureAudioInvoker)
#endif
);
return op_loader.GetCustomOps();
}
#endif
FxLoadCustomOpFactory LoadCustomOpClasses_Azure = AzureInvokerLoader;

Просмотреть файл

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
struct AzureAudioInvoker : public BaseKernel {
AzureAudioInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<std::string>& auth_token,
const ortc::Tensor<uint8_t>& raw_audio_data,
ortc::Tensor<std::string>& text);
private:
std::string model_uri_;
std::string model_name_;
bool verbose_;
};
#if ORT_API_VERSION >= 14
struct AzureTritonInvoker : public BaseKernel {
AzureTritonInvoker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Variadic& inputs,
ortc::Variadic& outputs);
private:
std::string model_uri_;
std::string model_name_;
std::string model_ver_;
std::string verbose_;
std::unique_ptr<triton::client::InferenceServerHttpClient> triton_client_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
};
#endif

Просмотреть файл

@ -121,6 +121,10 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
#if defined(ENABLE_DR_LIBS)
,
LoadCustomOpClasses_Audio
#endif
#if defined(ENABLE_AZURE)
,
LoadCustomOpClasses_Azure
#endif
};