Refactor the header file directory and integrate the eager tensor implementation (#689)
* refactor the header file in include folder * fix the basic-token eager unit test case * a more flexible way to handle string tensor shape. * fix the unit test path issue * remove the multi-inherits to avoid issue during pointer casting * add api cmake build support * undo some temporary changes * code refinement * fix variadic arg * only expose the context for ort version >= 17 * fix a shape bug * fix the cuda build issue * change ifdef condition of GetAllocator * finalize the ort c abi wrapper file name * fix the iOS build break * align gtest version with triton * Update ext_apple_framework.cmake for iOS header files --------- Co-authored-by: Cheng Tang <chenta@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
This commit is contained in:
Родитель
fe8cd9ee8d
Коммит
646462790b
|
@ -2,7 +2,7 @@
|
|||
# turn off readability-braces-around-statements to allow single line statement like 'if (x == y) doSomething();'
|
||||
Checks: '-*,cppcoreguidelines-*,google-*,readability-*,modernize-*,-readability-braces-around-statements,-google-runtime-references,-cppcoreguidelines-pro-type-reinterpret-cast'
|
||||
WarningsAsErrors: ''
|
||||
HeaderFilterRegex: 'includes\/.*'
|
||||
HeaderFilterRegex: 'include\/.*'
|
||||
AnalyzeTemporaryDtors: false
|
||||
FormatStyle: none
|
||||
CheckOptions:
|
||||
|
|
|
@ -67,6 +67,7 @@ option(OCOS_ENABLE_AZURE "Enable the operators for azure execution provider" OFF
|
|||
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
|
||||
option(OCOS_ENABLE_C_API "Enable building the C API" OFF)
|
||||
|
||||
option(OCOS_BUILD_PYTHON "Enable building the Python package" OFF)
|
||||
option(OCOS_BUILD_JAVA "Enable building the Java package" OFF)
|
||||
|
@ -81,7 +82,8 @@ set(OCOS_ONNXRUNTIME_VERSION "" CACHE STRING
|
|||
"The version of ONNX Runtime being used in the build. Format is <major>.<minor>.<patch>. e.g. 1.15.1")
|
||||
set(OCOS_ONNXRUNTIME_PKG_URI "" CACHE STRING
|
||||
"Specify the onnxruntime C++ shared library zip package path, like ./onnxruntime-win-x64-1.16.0.zip")
|
||||
|
||||
set(OCOS_BUILD_PRESET "" CACHE STRING
|
||||
"Specify the build preset cmake settings file path, like 'token_api_only' which includes ./cmake/presets/token_api_only.cmake")
|
||||
# TODO: Remove the following statements if AzureOp build is enabled by default.
|
||||
# If build_buildid environment varaible is set, which means this is a CI build, then always enable AzureOp.
|
||||
# or it is enabled when OCOS_ENABLE_AZURE is set, which means the user explicitly enables it.
|
||||
|
@ -188,15 +190,26 @@ if(NOT PROJECT_IS_TOP_LEVEL AND ONNXRUNTIME_ROOT)
|
|||
set(_ONNXRUNTIME_EMBEDDED TRUE)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
|
||||
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
|
||||
message(STATUS "Looking for the _selectedoplist.cmake")
|
||||
|
||||
if (OCOS_ENABLE_SELECTED_OPLIST OR OCOS_BUILD_PRESET)
|
||||
disable_all_operators()
|
||||
include(_selectedoplist)
|
||||
# Include the selected_ops case, no way to run the unit tests, so disable it,
|
||||
# even the user explicitly set it to ON. (it is rare, most of the time, it is set by default)
|
||||
set(OCOS_ENABLE_CTEST OFF CACHE BOOL "" FORCE)
|
||||
if(OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
|
||||
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
|
||||
message(STATUS "Looking for the _selectedoplist.cmake")
|
||||
include(_selectedoplist)
|
||||
# Include the selected_ops case, no way to run the unit tests, so disable it,
|
||||
# even the user explicitly set it to ON. (it is rare, most of the time, it is set by default)
|
||||
set(OCOS_ENABLE_CTEST OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
if (OCOS_BUILD_PRESET)
|
||||
set(_BUILD_PRESET "${PROJECT_SOURCE_DIR}/cmake/presets/${OCOS_BUILD_PRESET}.cmake")
|
||||
if (EXISTS ${_BUILD_PRESET})
|
||||
include(${_BUILD_PRESET})
|
||||
else()
|
||||
message(FATAL_ERROR "The specified build preset file does not exist: ${_BUILD_PRESET}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(_OCOS_EXCEPTIONS_REQUIRED OFF)
|
||||
|
@ -300,7 +313,7 @@ endif()
|
|||
|
||||
# ### scan all source files
|
||||
file(GLOB TARGET_SRC_NOEXCEPTION "base/*.h" "base/*.cc")
|
||||
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h" "includes/*.h*")
|
||||
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h" "include/*.h" "include/*.hpp")
|
||||
|
||||
if(OCOS_ENABLE_TF_STRING)
|
||||
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
|
||||
|
@ -551,13 +564,15 @@ standardize_output_folder(ocos_operators)
|
|||
|
||||
target_include_directories(noexcep_operators PUBLIC
|
||||
${ONNXRUNTIME_INCLUDE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include/custom_op
|
||||
${PROJECT_SOURCE_DIR}/base
|
||||
${PROJECT_SOURCE_DIR}/operators)
|
||||
|
||||
target_include_directories(ocos_operators PUBLIC
|
||||
${ONNXRUNTIME_INCLUDE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include/custom_op
|
||||
${PROJECT_SOURCE_DIR}/base
|
||||
${PROJECT_SOURCE_DIR}/operators)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
include *.txt
|
||||
global-include *.def
|
||||
recursive-include cmake *.*
|
||||
recursive-include includes *.*
|
||||
recursive-include include *.*
|
||||
recursive-include operators *.*
|
||||
recursive-include pyop *.*
|
||||
recursive-include shared *.*
|
||||
|
|
|
@ -190,5 +190,4 @@ uint64_t Hash64Fast(const char* data, size_t n) {
|
|||
return static_cast<int64_t>(util::Fingerprint64(data, n));
|
||||
}
|
||||
|
||||
|
||||
#endif // ENABLE_TF_STRING
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#pragma once
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include "onnxruntime_cpp_api_legacy.hpp"
|
||||
#include "ort_c_to_cpp.h"
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
|
|
|
@ -15,7 +15,9 @@ set(APPLE_FRAMEWORK_VERSION "${VERSION}")
|
|||
|
||||
# public header files
|
||||
set(APPLE_FRAMEWORK_HEADERS
|
||||
"${PROJECT_SOURCE_DIR}/includes/onnxruntime_extensions.h")
|
||||
"${PROJECT_SOURCE_DIR}/include/onnxruntime_extensions.h"
|
||||
"${PROJECT_SOURCE_DIR}/include/ortx_tokenizer.h"
|
||||
"${PROJECT_SOURCE_DIR}/include/ortx_op_registry.h")
|
||||
|
||||
# generated framework directory
|
||||
set(APPLE_FRAMEWORK_DIRECTORY
|
||||
|
|
|
@ -47,16 +47,12 @@ function(add_test_target)
|
|||
# add a test executable
|
||||
|
||||
add_executable(${ARG_TARGET})
|
||||
|
||||
standardize_output_folder(${ARG_TARGET})
|
||||
|
||||
add_test(NAME ${ARG_TARGET}
|
||||
COMMAND ${ARG_TARGET})
|
||||
|
||||
target_sources(${ARG_TARGET} PRIVATE
|
||||
${ARG_TEST_SOURCES}
|
||||
"${TEST_SRC_DIR}/unittest_main/test_main.cc")
|
||||
|
||||
target_link_libraries(${ARG_TARGET} PRIVATE
|
||||
${ARG_LIBRARIES}
|
||||
gtest gmock)
|
||||
|
@ -132,6 +128,7 @@ file(GLOB static_TEST_SRC "${TEST_SRC_DIR}/static_test/*.cc")
|
|||
add_test_target(TARGET ocos_test
|
||||
TEST_SOURCES ${static_TEST_SRC}
|
||||
LIBRARIES ortcustomops ${ocos_libraries})
|
||||
target_compile_definitions(ocos_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
|
||||
# -- shared test (needs onnxruntime) --
|
||||
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
|
@ -201,6 +198,20 @@ else()
|
|||
)
|
||||
endif()
|
||||
endblock()
|
||||
|
||||
block()
|
||||
file(GLOB tokenizer_TEST_SRC
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.cc"
|
||||
"${TEST_SRC_DIR}/tokenizer_test/*.hpp")
|
||||
|
||||
add_test_target(TARGET tokenizer_api_test
|
||||
TEST_SOURCES ${tokenizer_TEST_SRC}
|
||||
LIBRARIES onnxruntime_extensions ${ocos_libraries}
|
||||
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
|
||||
|
||||
target_compile_definitions(tokenizer_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
|
||||
endblock()
|
||||
endif()
|
||||
|
||||
endif()
|
|
@ -1,7 +1,7 @@
|
|||
FetchContent_Declare(
|
||||
googletest
|
||||
GIT_REPOSITORY https://github.com/google/googletest.git
|
||||
GIT_TAG release-1.11.0
|
||||
URL https://github.com/google/googletest/archive/9406a60c7839052e4944ea4dbc8344762a89f9bd.zip
|
||||
URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(googletest)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set(OCOS_ENABLE_GPT2_TOKENIZER ON CACHE INTERNAL "" FORCE)
|
||||
set(OCOS_ENABLE_C_API ON CACHE INTERNAL "" FORCE)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,35 @@
|
|||
#pragma once
|
||||
#include <optional>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
||||
// this is for the ORT custom op template magic
|
||||
struct Arg {
|
||||
virtual ~Arg() = default;
|
||||
};
|
||||
|
||||
class KernelContext : public Arg{
|
||||
public:
|
||||
virtual void* AllocScratchBuffer(size_t size) = 0;
|
||||
virtual void FreeScratchBuffer(void* p) = 0;
|
||||
// TODO: threadpool?
|
||||
};
|
||||
|
||||
#ifdef USE_CUDA
|
||||
class CUDAKernelContext : public KernelContext {
|
||||
public:
|
||||
virtual void* AllocCudaScratchBuffer(size_t size) = 0;
|
||||
virtual void FreeCudaScratchBuffer(void* p) = 0;
|
||||
virtual void* GetCudaStream() const = 0;
|
||||
virtual void* GetCublasHandle() const = 0;
|
||||
virtual int GetCudaDeviceId() const = 0;
|
||||
};
|
||||
#endif
|
||||
|
||||
// TODO: helper func to create context from global ORT env.
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,534 @@
|
|||
#pragma once
|
||||
#include <optional>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
#include "onnxruntime_f16.h"
|
||||
#include "kernel_context.h"
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
||||
template <typename T>
|
||||
struct Span {
|
||||
const T* data_ = {};
|
||||
size_t size_ = {};
|
||||
void Assign(const T* data, size_t size) {
|
||||
data_ = data;
|
||||
size_ = size;
|
||||
}
|
||||
size_t size() const { return size_; }
|
||||
T operator[](size_t indice) const {
|
||||
return data_[indice];
|
||||
}
|
||||
const T* data() const { return data_; }
|
||||
};
|
||||
|
||||
|
||||
#if ORT_API_VERSION >= 16
|
||||
|
||||
template <>
|
||||
struct Span<MFloat16> {
|
||||
const MFloat16* data_ = {};
|
||||
size_t size_ = {};
|
||||
void Assign(const MFloat16* data, size_t size) {
|
||||
data_ = data;
|
||||
size_ = size;
|
||||
}
|
||||
size_t size() const { return size_; }
|
||||
MFloat16 operator[](size_t indice) const {
|
||||
return data_[indice];
|
||||
}
|
||||
const MFloat16* data() const { return data_; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Span<BFloat16> {
|
||||
const BFloat16* data_ = {};
|
||||
size_t size_ = {};
|
||||
void Assign(const BFloat16* data, size_t size) {
|
||||
data_ = data;
|
||||
size_ = size;
|
||||
}
|
||||
size_t size() const { return size_; }
|
||||
BFloat16 operator[](size_t indice) const {
|
||||
return data_[indice];
|
||||
}
|
||||
const BFloat16* data() const { return data_; }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
class ITensorStorage{
|
||||
public:
|
||||
virtual const std::vector<int64_t>& Shape() const = 0;
|
||||
virtual const void* DataRaw() const = 0;
|
||||
virtual bool IsInitialized() const = 0;
|
||||
virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
|
||||
};
|
||||
|
||||
|
||||
class IAllocator {
|
||||
public:
|
||||
virtual void* Alloc(size_t size) = 0;
|
||||
virtual void Free(void* p) = 0;
|
||||
};
|
||||
|
||||
|
||||
class OrtEagerTensorStorage : public ITensorStorage {
|
||||
public:
|
||||
OrtEagerTensorStorage(const std::vector<int64_t>& shape,
|
||||
void* buffer) : buffer_(buffer), shape_(shape){
|
||||
|
||||
}
|
||||
|
||||
OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
|
||||
}
|
||||
|
||||
virtual ~OrtEagerTensorStorage(){
|
||||
if (allocator_ && buffer_)
|
||||
allocator_->Free(buffer_);
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& Shape() const override {
|
||||
if (!IsInitialized())
|
||||
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
|
||||
return *shape_;
|
||||
}
|
||||
|
||||
virtual bool IsInitialized() const override {
|
||||
return shape_.has_value();
|
||||
}
|
||||
|
||||
const void* DataRaw() const override {
|
||||
return buffer_;
|
||||
}
|
||||
|
||||
void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
|
||||
if (IsInitialized())
|
||||
return buffer_;
|
||||
assert(allocator_);
|
||||
shape_ = shape;
|
||||
int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
auto buffer_size = n_elem * element_size;
|
||||
buffer_ = allocator_->Alloc(buffer_size);
|
||||
return buffer_;
|
||||
}
|
||||
|
||||
private:
|
||||
void* buffer_ {};
|
||||
std::optional<std::vector<int64_t>> shape_;
|
||||
// caller need to make sure the allocator is alive
|
||||
IAllocator* allocator_;
|
||||
};
|
||||
|
||||
template <typename TT>
|
||||
ONNXTensorElementDataType GetOrtDType(){
|
||||
if constexpr (std::is_same<TT, bool>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
else if constexpr (std::is_same<TT, float>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
else if constexpr (std::is_same<TT, double>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
||||
else if constexpr (std::is_same<TT, uint8_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
else if constexpr (std::is_same<TT, int8_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
||||
else if constexpr (std::is_same<TT, uint16_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
||||
else if constexpr (std::is_same<TT, int16_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
||||
else if constexpr (std::is_same<TT, uint32_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
||||
else if constexpr (std::is_same<TT, int32_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
else if constexpr (std::is_same<TT, uint64_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
||||
else if constexpr (std::is_same<TT, int64_t>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
else if constexpr (std::is_same<TT, std::string>::value)
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
class TensorBase : public Arg {
|
||||
public:
|
||||
virtual ~TensorBase() {}
|
||||
|
||||
virtual ONNXTensorElementDataType Type() const = 0;
|
||||
virtual const std::vector<int64_t>& Shape() const = 0;
|
||||
virtual int64_t NumberOfElement() const = 0;
|
||||
virtual const void* DataRaw() const = 0;
|
||||
virtual size_t SizeInBytes() const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Tensor : public TensorBase {
|
||||
public:
|
||||
using TT = typename std::remove_reference<T>::type;
|
||||
Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
|
||||
}
|
||||
|
||||
Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
|
||||
|
||||
Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
|
||||
|
||||
virtual ~Tensor() = default;
|
||||
|
||||
operator bool() const {
|
||||
return storage_->IsInitialized();
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType Type() const override {
|
||||
return GetOrtDType<T>();
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& Shape() const override {
|
||||
return storage_->Shape();
|
||||
}
|
||||
|
||||
int64_t NumberOfElement() const override {
|
||||
auto& shape = storage_->Shape();
|
||||
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
std::string Shape2Str() const {
|
||||
if (storage_->IsInitialized()) {
|
||||
std::string shape_str;
|
||||
auto& shape = storage_->Shape();
|
||||
for (const auto& dim : shape) {
|
||||
shape_str.append(std::to_string(dim));
|
||||
shape_str.append(", ");
|
||||
}
|
||||
return shape_str;
|
||||
} else {
|
||||
return "empty";
|
||||
}
|
||||
}
|
||||
|
||||
const TT* Data() const {
|
||||
#if ORT_API_VERSION >= 16
|
||||
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
|
||||
return reinterpret_cast<const TT*>(storage_->DataRaw());
|
||||
else
|
||||
#endif
|
||||
return static_cast<const TT*>(storage_->DataRaw());
|
||||
}
|
||||
|
||||
const void* DataRaw() const override {
|
||||
return storage_->DataRaw();
|
||||
}
|
||||
|
||||
size_t SizeInBytes() const override {
|
||||
return NumberOfElement() * sizeof(TT);
|
||||
}
|
||||
|
||||
TT* Allocate(const std::vector<int64_t>& shape) {
|
||||
// it should be OK to allocate multiple times
|
||||
void* buffer = storage_->Initialize(shape, sizeof(TT));
|
||||
#if ORT_API_VERSION >= 16
|
||||
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
|
||||
return reinterpret_cast<TT*>(buffer);
|
||||
else
|
||||
#endif
|
||||
return static_cast<TT*>(buffer);
|
||||
}
|
||||
|
||||
const Span<T>& AsSpan() {
|
||||
#if ORT_API_VERSION >= 16
|
||||
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
|
||||
ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
else{
|
||||
#endif
|
||||
auto& shape = storage_->Shape();
|
||||
if (shape.size() != 1) {
|
||||
ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
span_.Assign(Data(), shape[0]);
|
||||
return span_;
|
||||
#if ORT_API_VERSION >= 16
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const T& AsScalar() {
|
||||
#if ORT_API_VERSION >= 16
|
||||
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
|
||||
ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
else{
|
||||
#endif
|
||||
auto& shape = storage_->Shape();
|
||||
if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
|
||||
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return *Data();
|
||||
#if ORT_API_VERSION >= 16
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ITensorStorage> storage_;
|
||||
Span<T> span_;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class IStringTensorStorage{
|
||||
public:
|
||||
using strings = std::vector<T>;
|
||||
virtual const std::vector<int64_t>& Shape() const = 0;
|
||||
virtual const void* DataRaw() const = 0;
|
||||
virtual const strings& Data() const = 0;
|
||||
virtual bool IsInitialized() const = 0;
|
||||
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
|
||||
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class EagerStringTensorStorage : public IStringTensorStorage<T>{
|
||||
public:
|
||||
using strings = std::vector<T>;
|
||||
EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
|
||||
|
||||
EagerStringTensorStorage() {}
|
||||
|
||||
const std::vector<int64_t>& Shape() const override {
|
||||
if (!IsInitialized())
|
||||
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
|
||||
return *shape_;
|
||||
}
|
||||
|
||||
virtual const void* DataRaw() const override {
|
||||
if (input_strings_.size() != 1) {
|
||||
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
if constexpr (std::is_same<std::string_view, T>::value)
|
||||
return reinterpret_cast<const void*>(input_strings_[0].data());
|
||||
else
|
||||
return reinterpret_cast<const void*>(input_strings_[0].c_str());
|
||||
}
|
||||
|
||||
virtual bool IsInitialized() const override {
|
||||
return shape_.has_value();
|
||||
}
|
||||
|
||||
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
|
||||
if constexpr (std::is_same<std::string_view, T>::value)
|
||||
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
|
||||
input_strings_.assign(ss.begin(), ss.end());
|
||||
shape_ = dims;
|
||||
}
|
||||
|
||||
const strings& Data() const override {
|
||||
return input_strings_;
|
||||
}
|
||||
|
||||
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
|
||||
if constexpr (std::is_same<std::string_view, T>::value)
|
||||
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
|
||||
|
||||
for (const char* s : ss){
|
||||
input_strings_.push_back(s);
|
||||
}
|
||||
shape_ = dims;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T> input_strings_;
|
||||
std::optional<std::vector<int64_t>> shape_;
|
||||
};
|
||||
|
||||
template <>
|
||||
class Tensor<std::string> : public TensorBase {
|
||||
public:
|
||||
using strings = std::vector<std::string>;
|
||||
|
||||
Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
|
||||
|
||||
Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
|
||||
|
||||
Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
|
||||
|
||||
ONNXTensorElementDataType Type() const override {
|
||||
return GetOrtDType<std::string>();
|
||||
}
|
||||
|
||||
const strings& Data() const {
|
||||
return storage_->Data();
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& Shape() const override {
|
||||
return storage_->Shape();
|
||||
}
|
||||
|
||||
int64_t NumberOfElement() const override {
|
||||
auto& shape = storage_->Shape();
|
||||
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
std::string Shape2Str() const {
|
||||
if (storage_->IsInitialized()) {
|
||||
std::string shape_str;
|
||||
auto& shape = storage_->Shape();
|
||||
for (const auto& dim : shape) {
|
||||
shape_str.append(std::to_string(dim));
|
||||
shape_str.append(", ");
|
||||
}
|
||||
return shape_str;
|
||||
} else {
|
||||
return "empty";
|
||||
}
|
||||
}
|
||||
|
||||
const void* DataRaw() const override {
|
||||
return storage_->DataRaw();
|
||||
}
|
||||
|
||||
size_t SizeInBytes() const override {
|
||||
auto& ss = storage_->Data();
|
||||
if (ss.size() != 1) {
|
||||
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return ss[0].size();
|
||||
}
|
||||
|
||||
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
||||
storage_->SetStringOutput(ss, dims);
|
||||
}
|
||||
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
|
||||
storage_->SetStringOutput(ss, dims);
|
||||
}
|
||||
const Span<std::string>& AsSpan() {
|
||||
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
const std::string& AsScalar() {
|
||||
auto& ss = storage_->Data();
|
||||
if (ss.size() != 1) {
|
||||
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return ss[0];
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IStringTensorStorage<std::string>> storage_;
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
class Tensor<std::string_view> : public TensorBase {
|
||||
public:
|
||||
using strings = std::vector<std::string_view>;
|
||||
|
||||
Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
|
||||
|
||||
Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
|
||||
|
||||
ONNXTensorElementDataType Type() const override {
|
||||
return GetOrtDType<std::string_view>();
|
||||
}
|
||||
|
||||
const strings& Data() const {
|
||||
return storage_->Data();
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& Shape() const override {
|
||||
return storage_->Shape();
|
||||
}
|
||||
|
||||
int64_t NumberOfElement() const override {
|
||||
auto& shape = storage_->Shape();
|
||||
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
std::string Shape2Str() const {
|
||||
if (storage_->IsInitialized()) {
|
||||
std::string shape_str;
|
||||
auto& shape = storage_->Shape();
|
||||
for (const auto& dim : shape) {
|
||||
shape_str.append(std::to_string(dim));
|
||||
shape_str.append(", ");
|
||||
}
|
||||
return shape_str;
|
||||
} else {
|
||||
return "empty";
|
||||
}
|
||||
}
|
||||
|
||||
const void* DataRaw() const override {
|
||||
return storage_->DataRaw();
|
||||
}
|
||||
|
||||
size_t SizeInBytes() const override {
|
||||
auto& ss = storage_->Data();
|
||||
if (ss.size() != 1) {
|
||||
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return ss[0].size();
|
||||
}
|
||||
|
||||
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
||||
storage_->SetStringOutput(ss, dims);
|
||||
}
|
||||
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
|
||||
storage_->SetStringOutput(ss, dims);
|
||||
}
|
||||
const Span<std::string_view>& AsSpan() {
|
||||
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
const std::string_view& AsScalar() {
|
||||
auto& ss = storage_->Data();
|
||||
if (ss.size() != 1) {
|
||||
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return ss[0];
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
|
||||
};
|
||||
|
||||
|
||||
template<typename ...Args>
|
||||
class NamedArgumentDict{
|
||||
public:
|
||||
using ValueTuple = std::tuple<Args...>;
|
||||
|
||||
NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
|
||||
for (const char* key : keys){
|
||||
names_.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
|
||||
return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
|
||||
}
|
||||
|
||||
private:
|
||||
template<size_t I, typename T>
|
||||
typename std::enable_if<I == sizeof...(Args), T>::type
|
||||
TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
|
||||
return default_value;
|
||||
}
|
||||
|
||||
template<size_t I, typename T>
|
||||
typename std::enable_if<I < sizeof...(Args), T>::type
|
||||
TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
|
||||
if (names_[I] == name){
|
||||
if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
|
||||
return std::get<I>(entries_);
|
||||
else
|
||||
throw std::runtime_error("name matched but type is not");
|
||||
}
|
||||
return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
|
||||
}
|
||||
|
||||
std::vector<std::string> names_;
|
||||
std::tuple<Args...> entries_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type_def>*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type_def>&>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
if (ith_input < num_input) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
} else {
|
||||
std::tuple<T> current = std::tuple<T>{};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type_def>*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
|
||||
ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
|
||||
}
|
||||
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type_def>&>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
|
||||
ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
|
||||
}
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
if (ith_input < num_input) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
|
||||
ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
|
||||
}
|
||||
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
} else {
|
||||
std::tuple<T> current = std::tuple<T>{};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, data_type_def>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
|
||||
ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL);
|
||||
}
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsScalar()};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, std::optional<data_type_def>>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
if (ith_input < num_input) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
|
||||
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
|
||||
ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL);
|
||||
}
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsScalar()};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
} else {
|
||||
std::tuple<T> current = std::tuple<T>{};
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type_def>*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type_def>&>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
if (ith_output < num_output) {
|
||||
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())};
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
} else {
|
||||
std::tuple<T> current = std::tuple<T>{};
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
}
|
|
@ -10,7 +10,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_customop.hpp"
|
||||
#include "op_def_struct.h"
|
||||
|
||||
// A helper API to support test kernels.
|
||||
// Must be invoked before RegisterCustomOps.
|
|
@ -0,0 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
|
||||
#include "ortx_op_registry.h"
|
|
@ -18,10 +18,8 @@
|
|||
#include <functional>
|
||||
|
||||
#include "exceptions.h"
|
||||
#include "onnxruntime_no_customop.h"
|
||||
#include "onnxruntime_cpp_api_legacy.hpp"
|
||||
#include "onnxruntime_extensions.h"
|
||||
#include "custom_op_lite.h"
|
||||
#include "custom_op/custom_op_lite.h"
|
||||
|
||||
#define MIN_ORT_VERSION_SUPPORTED 11
|
||||
|
|
@ -5,10 +5,7 @@
|
|||
#include <vector>
|
||||
#include "exceptions.h"
|
||||
|
||||
//
|
||||
// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
|
||||
// TODO: Remove this file once all custom OPs are migrated to the new API
|
||||
//
|
||||
// OrtW: ONNX Runtime C ABI Wrapper
|
||||
namespace OrtW {
|
||||
|
||||
struct CustomOpApi {
|
||||
|
@ -30,6 +27,9 @@ struct CustomOpApi {
|
|||
template <typename T>
|
||||
const T* GetTensorData(_Inout_ const OrtValue* value) const;
|
||||
|
||||
void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
|
||||
const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
|
||||
|
||||
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
|
||||
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
|
||||
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
|
||||
|
@ -48,6 +48,54 @@ struct CustomOpApi {
|
|||
const OrtApi& api_;
|
||||
};
|
||||
|
||||
class API {
|
||||
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
|
||||
public:
|
||||
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
|
||||
static API self(ort_api);
|
||||
return self;
|
||||
}
|
||||
|
||||
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
|
||||
return instance()->CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
|
||||
instance()->ReleaseStatus(ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
|
||||
|
||||
static void ThrowOnError(OrtStatusPtr ptr) {
|
||||
OrtW::ThrowOnError(instance().api_, ptr);
|
||||
}
|
||||
|
||||
// Caller is responsible for releasing OrtMemoryInfo object
|
||||
static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
|
||||
return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
|
||||
}
|
||||
#if ORT_API_VERSION >= 15
|
||||
// Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
|
||||
static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
|
||||
return instance()->KernelContext_GetAllocator(context, mem_info, out);
|
||||
}
|
||||
#endif
|
||||
private:
|
||||
const OrtApi* operator->() const {
|
||||
return &api_;
|
||||
}
|
||||
|
||||
API(const OrtApi* api) : api_(*api) {
|
||||
if (api == nullptr) {
|
||||
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
|
||||
const OrtApi& api_;
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// Custom OP API Inlines
|
||||
//
|
||||
|
@ -162,6 +210,16 @@ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const
|
|||
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
|
||||
}
|
||||
|
||||
inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
|
||||
void* data = nullptr;
|
||||
ThrowOnError(api_.GetTensorMutableData(value, &data));
|
||||
return data;
|
||||
}
|
||||
|
||||
inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
|
||||
return GetTensorMutableRawData(const_cast<OrtValue*>(value));
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
|
||||
std::vector<int64_t> output(GetDimensionsCount(info));
|
||||
GetDimensions(info, output.data(), output.size());
|
||||
|
@ -197,9 +255,72 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
|
|||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
|
||||
size_t size = 0;
|
||||
std::string out;
|
||||
// Feed nullptr for the data buffer to query the true size of the string attribute
|
||||
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
|
||||
if (status == nullptr) {
|
||||
out.resize(size);
|
||||
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
|
||||
out.resize(size - 1); // remove the terminating character '\0'
|
||||
}
|
||||
|
||||
if (status == nullptr) {
|
||||
value = std::move(out);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
|
||||
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
|
||||
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
|
||||
// Just ignore all of them.
|
||||
API::ReleaseStatus(status);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
|
||||
T ret;
|
||||
if (API::KernelInfoGetAttribute(info, name, ret)) {
|
||||
ret = default_value;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
|
||||
return API::CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
|
||||
return API::CreateStatus(code, msg.c_str());
|
||||
}
|
||||
|
||||
inline void ReleaseStatus(OrtStatusPtr& status) {
|
||||
API::ReleaseStatus(status);
|
||||
status = nullptr;
|
||||
}
|
||||
|
||||
} // namespace of OrtW
|
||||
|
||||
|
||||
// Deprecated: No needs to create a new class derived from BaseKernel.
|
||||
struct BaseKernel {
|
||||
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
|
||||
: api_(api), info_(info), ort_(api_) {
|
||||
|
@ -226,6 +347,7 @@ struct BaseKernel {
|
|||
const OrtKernelInfo& info_;
|
||||
};
|
||||
|
||||
// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions() = default;
|
||||
OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
|
|
@ -1,8 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Note: The following include path is used for building Swift Package Manager support for ORT Extensions.
|
||||
// The macro is defined in cxxSettings config in Package.swift.
|
||||
// The reason why we need a prefix is that when Xcode includes the package it copies it to an internally generated path with
|
||||
|
@ -15,7 +13,6 @@
|
|||
#include "onnxruntime_c_api.h"
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
|
@ -0,0 +1,199 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// C ABI header file for the onnxruntime-extensions tokenization module
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#if defined(__CYGWIN__) || defined(__MINGW32__)
|
||||
#define ORTX_API_CALL __stdcall
|
||||
#elif defined(_WIN32)
|
||||
#define ORTX_API_CALL _stdcall
|
||||
#define ORTX_MUST_USE_RESULT
|
||||
#elif __APPLE__
|
||||
#define ORTX_API_CALL
|
||||
// To make symbols visible on macOS/iOS
|
||||
#define ORTX_MUST_USE_RESULT __attribute__((warn_unused_result))
|
||||
#else
|
||||
#define ORTX_API_CALL
|
||||
#define ORTX_MUST_USE_RESULT
|
||||
#endif
|
||||
|
||||
typedef enum {
|
||||
kOrtxOK = 0,
|
||||
kOrtxErrorInvalidArgument = 1,
|
||||
kOrtxErrorOutOfMemory = 2,
|
||||
kOrtxErrorInvalidFile = 3,
|
||||
kOrtxErrorNotFound = 4,
|
||||
kOrtxErrorAlreadyExists = 5,
|
||||
kOrtxErrorOutOfRange = 6,
|
||||
kOrtxErrorNotImplemented = 7,
|
||||
kOrtxErrorInternal = 8,
|
||||
kOrtxErrorUnknown = 1000
|
||||
} extError_t;
|
||||
|
||||
typedef enum {
|
||||
kOrtxKindUnknown = 0,
|
||||
|
||||
kOrtxKindBegin = 0x7788, // starting from a number to help validate the object
|
||||
kOrtxKindTokenizer = kOrtxKindBegin,
|
||||
kOrtxKindStringArray = 0x7789,
|
||||
kOrtxKindTokenId2DArray = 0x778A,
|
||||
kOrtxKindDetokenizerCache = 0x778B,
|
||||
kOrtxKindEnd = 0x7999
|
||||
} extObjectKind_t;
|
||||
|
||||
// all object managed by the library should be 'derived' from this struct
|
||||
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
|
||||
typedef struct {
|
||||
int ext_kind_;
|
||||
} OrtxObject;
|
||||
|
||||
const int API_VERSION = 1;
|
||||
|
||||
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
|
||||
typedef OrtxObject OrtxTokenizer;
|
||||
typedef OrtxObject OrtxStringArray;
|
||||
typedef OrtxObject OrtxTokenId2DArray;
|
||||
typedef OrtxObject OrtxDetokenizerCache;
|
||||
|
||||
// C, instead of C++ doesn't cast automatically,
|
||||
// so we need to use a macro to cast the object to the correct type
|
||||
#define ORTX_DISPOSE(obj) OrtxDispose((OrtxObject**)&obj)
|
||||
|
||||
typedef uint32_t extTokenId_t;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/** \brief Get the current C ABI version of this library
|
||||
*
|
||||
* \snippet{doc} snippets.dox int Return Value
|
||||
*/
|
||||
int ORTX_API_CALL OrtxGetAPIVersion(void);
|
||||
|
||||
/** \brief Get the last error message generated by the library
|
||||
*
|
||||
* \param message Pointer to store the last error message
|
||||
* \return Pointer to the last error message
|
||||
*/
|
||||
const char* ORTX_API_CALL OrtxGetLastErrorMessage(void);
|
||||
|
||||
/** \brief Create a new object of the specified kind
|
||||
*
|
||||
* \param kind The kind of object to create
|
||||
* \param object Pointer to store the created object
|
||||
* \param ... Additional arguments based on the kind of object
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, ...);
|
||||
|
||||
/** \brief Dispose the specified object
|
||||
*
|
||||
* \param object Pointer to the object to dispose
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
|
||||
|
||||
/** \brief Dispose the specified object
|
||||
*
|
||||
* \param object Pointer to the object to dispose
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
|
||||
|
||||
|
||||
/** \brief Create a tokenizer object with the specified tokenizer path
|
||||
*
|
||||
* \param tokenizer Pointer to store the created tokenizer object
|
||||
* \param tokenizer_path The path to the tokenizer
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const char* tokenizer_path);
|
||||
|
||||
/** \brief Tokenize the input using the specified tokenizer
|
||||
*
|
||||
* \param tokenizer Pointer to the tokenizer object
|
||||
* \param input Array of input strings
|
||||
* \param batch_size Number of input strings in the batch
|
||||
* \param output Pointer to store the tokenized result
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxTokenize(
|
||||
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);
|
||||
|
||||
/** \brief Detokenize the input using the specified tokenizer
|
||||
*
|
||||
* \param tokenizer Pointer to the tokenizer object
|
||||
* \param input Pointer to the input token IDs
|
||||
* \param output Pointer to store the detokenized result
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxDetokenize(
|
||||
const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input, OrtxStringArray** output);
|
||||
|
||||
/** \brief Detokenize the input using the specified tokenizer (1D version)
|
||||
*
|
||||
* \param tokenizer Pointer to the tokenizer object
|
||||
* \param input Pointer to the input token IDs
|
||||
* \param len Length of the input token IDs array
|
||||
* \param output Pointer to store the detokenized result
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxDetokenize1D(
|
||||
const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len, OrtxStringArray** output);
|
||||
|
||||
/** \brief Detokenize the input using the specified tokenizer with caching
|
||||
*
|
||||
* \param tokenizer Pointer to the tokenizer object
|
||||
* \param cache Pointer to the detokenizer cache
|
||||
* \param next_id Next token ID to detokenize
|
||||
* \param text_out Pointer to store the detokenized text
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxDetokenizeCached(
|
||||
const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id, const char** text_out);
|
||||
|
||||
/** \brief Get the length of the string array
|
||||
*
|
||||
* \param string_array Pointer to the string array
|
||||
* \param length Pointer to store the length of the string array
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxStringArrayGetBatch(const OrtxStringArray* string_array, size_t* length);
|
||||
|
||||
/** \brief Get the item at the specified index from the string array
|
||||
*
|
||||
* \param string_array Pointer to the string array
|
||||
* \param index Index of the item to retrieve
|
||||
* \param item Pointer to store the retrieved item
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxStringArrayGetItem(const OrtxStringArray* string_array, size_t index, const char** item);
|
||||
|
||||
/** \brief Get the batch size of the token ID 2D array
|
||||
*
|
||||
* \param token_id_2d_array Pointer to the token ID 2D array
|
||||
* \param length Pointer to store the batch size
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* token_id_2d_array, size_t* length);
|
||||
|
||||
/** \brief Get the item at the specified index from the token ID 2D array
|
||||
*
|
||||
* \param token_id_2d_array Pointer to the token ID 2D array
|
||||
* \param index Index of the item to retrieve
|
||||
* \param item Pointer to store the retrieved item
|
||||
* \param length Pointer to store the length of the item
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(
|
||||
const OrtxTokenId2DArray* token_id_2d_array, size_t index, const extTokenId_t** item, size_t* length);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,122 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file defines API which depends on ONNXRuntime, but not including Custom Op and related facilities
|
||||
// Custom Op and related classes, functions and macros are in onnxruntime_customop.hpp
|
||||
#pragma once
|
||||
#include "exceptions.h"
|
||||
|
||||
// namespace of ORT ABI Wrapper
|
||||
namespace OrtW {
|
||||
|
||||
class API {
|
||||
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
|
||||
public:
|
||||
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
|
||||
static API self(ort_api);
|
||||
return self;
|
||||
}
|
||||
|
||||
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
|
||||
return instance()->CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
|
||||
instance()->ReleaseStatus(ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
|
||||
|
||||
static void ThrowOnError(OrtStatusPtr ptr) {
|
||||
OrtW::ThrowOnError(instance().api_, ptr);
|
||||
}
|
||||
|
||||
// Caller is responsible for releasing OrtMemoryInfo object
|
||||
static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
|
||||
return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
|
||||
}
|
||||
#if ORT_API_VERSION >= 15
|
||||
// Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
|
||||
static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
|
||||
return instance()->KernelContext_GetAllocator(context, mem_info, out);
|
||||
}
|
||||
#endif
|
||||
private:
|
||||
const OrtApi* operator->() const {
|
||||
return &api_;
|
||||
}
|
||||
|
||||
API(const OrtApi* api) : api_(*api) {
|
||||
if (api == nullptr) {
|
||||
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
|
||||
const OrtApi& api_;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
|
||||
size_t size = 0;
|
||||
std::string out;
|
||||
// Feed nullptr for the data buffer to query the true size of the string attribute
|
||||
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
|
||||
if (status == nullptr) {
|
||||
out.resize(size);
|
||||
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
|
||||
out.resize(size - 1); // remove the terminating character '\0'
|
||||
}
|
||||
|
||||
if (status == nullptr) {
|
||||
value = std::move(out);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
|
||||
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
|
||||
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
|
||||
// Just ignore all of them.
|
||||
API::ReleaseStatus(status);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
|
||||
T ret;
|
||||
if (API::KernelInfoGetAttribute(info, name, ret)) {
|
||||
ret = default_value;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
|
||||
return API::CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
|
||||
return API::CreateStatus(code, msg.c_str());
|
||||
}
|
||||
|
||||
inline void ReleaseStatus(OrtStatusPtr& status) {
|
||||
API::ReleaseStatus(status);
|
||||
status = nullptr;
|
||||
}
|
||||
|
||||
} // namespace OrtW
|
||||
|
|
@ -4,13 +4,12 @@
|
|||
#pragma once
|
||||
#include "onnxruntime_f16.h"
|
||||
#include "string_utils.h"
|
||||
#include "onnxruntime_no_customop.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using namespace Ort::Custom;
|
||||
namespace ortc = Ort::Custom;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
|
||||
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
|
||||
|
@ -97,81 +96,81 @@ __device__ __forceinline__ half2 operator/(const half2& lh, const half2& rh) {
|
|||
}
|
||||
#endif
|
||||
|
||||
/// Arithmetic for BFloat16
|
||||
/// Arithmetic for ortc::BFloat16
|
||||
|
||||
__device__ __forceinline__ BFloat16 operator+(const BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16 operator+(const ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16 operator-(const BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16 operator-(const ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16 operator*(const BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16 operator*(const ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16 operator/(const BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16 operator/(const ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16 operator-(const BFloat16& a) { return -static_cast<float>(a); }
|
||||
__device__ __forceinline__ ortc::BFloat16 operator-(const ortc::BFloat16& a) { return -static_cast<float>(a); }
|
||||
|
||||
__device__ __forceinline__ BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16& operator+=(ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16& operator-=(ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16& operator*=(ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
|
||||
__device__ __forceinline__ ortc::BFloat16& operator/=(ortc::BFloat16& a, const ortc::BFloat16& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
__device__ __forceinline__ float operator+(BFloat16 a, float b) { return a + b; }
|
||||
__device__ __forceinline__ float operator-(BFloat16 a, float b) { return a - b; }
|
||||
__device__ __forceinline__ float operator*(BFloat16 a, float b) { return a * b; }
|
||||
__device__ __forceinline__ float operator/(BFloat16 a, float b) { return a / b; }
|
||||
__device__ __forceinline__ float operator+(ortc::BFloat16 a, float b) { return a + b; }
|
||||
__device__ __forceinline__ float operator-(ortc::BFloat16 a, float b) { return a - b; }
|
||||
__device__ __forceinline__ float operator*(ortc::BFloat16 a, float b) { return a * b; }
|
||||
__device__ __forceinline__ float operator/(ortc::BFloat16 a, float b) { return a / b; }
|
||||
|
||||
__device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + b; }
|
||||
__device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - b; }
|
||||
__device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * b; }
|
||||
__device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / b; }
|
||||
__device__ __forceinline__ float operator+(float a, ortc::BFloat16 b) { return a + b; }
|
||||
__device__ __forceinline__ float operator-(float a, ortc::BFloat16 b) { return a - b; }
|
||||
__device__ __forceinline__ float operator*(float a, ortc::BFloat16 b) { return a * b; }
|
||||
__device__ __forceinline__ float operator/(float a, ortc::BFloat16 b) { return a / b; }
|
||||
|
||||
__device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += b; }
|
||||
__device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= b; }
|
||||
__device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= b; }
|
||||
__device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= b; }
|
||||
__device__ __forceinline__ float& operator+=(float& a, const ortc::BFloat16& b) { return a += b; }
|
||||
__device__ __forceinline__ float& operator-=(float& a, const ortc::BFloat16& b) { return a -= b; }
|
||||
__device__ __forceinline__ float& operator*=(float& a, const ortc::BFloat16& b) { return a *= b; }
|
||||
__device__ __forceinline__ float& operator/=(float& a, const ortc::BFloat16& b) { return a /= b; }
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
__device__ __forceinline__ double operator+(BFloat16 a, double b) { return static_cast<double>(a) + b; }
|
||||
__device__ __forceinline__ double operator-(BFloat16 a, double b) { return static_cast<double>(a) - b; }
|
||||
__device__ __forceinline__ double operator*(BFloat16 a, double b) { return static_cast<double>(a) * b; }
|
||||
__device__ __forceinline__ double operator/(BFloat16 a, double b) { return static_cast<double>(a) / b; }
|
||||
__device__ __forceinline__ double operator+(ortc::BFloat16 a, double b) { return static_cast<double>(a) + b; }
|
||||
__device__ __forceinline__ double operator-(ortc::BFloat16 a, double b) { return static_cast<double>(a) - b; }
|
||||
__device__ __forceinline__ double operator*(ortc::BFloat16 a, double b) { return static_cast<double>(a) * b; }
|
||||
__device__ __forceinline__ double operator/(ortc::BFloat16 a, double b) { return static_cast<double>(a) / b; }
|
||||
|
||||
__device__ __forceinline__ double operator+(double a, BFloat16 b) { return a + static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator-(double a, BFloat16 b) { return a - static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator*(double a, BFloat16 b) { return a * static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator/(double a, BFloat16 b) { return a / static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator+(double a, ortc::BFloat16 b) { return a + static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator-(double a, ortc::BFloat16 b) { return a - static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator*(double a, ortc::BFloat16 b) { return a * static_cast<double>(b); }
|
||||
__device__ __forceinline__ double operator/(double a, ortc::BFloat16 b) { return a / static_cast<double>(b); }
|
||||
|
||||
// Overloading < and > operators
|
||||
|
||||
__device__ __forceinline__ bool operator==(BFloat16& lhs, BFloat16& rhs) { return float(lhs) == float(rhs); }
|
||||
__device__ __forceinline__ bool operator!=(BFloat16& lhs, BFloat16& rhs) { return float(lhs) != float(rhs); }
|
||||
__device__ __forceinline__ bool operator>(BFloat16& lhs, BFloat16& rhs) { return float(lhs) > float(rhs); }
|
||||
__device__ __forceinline__ bool operator<(BFloat16& lhs, BFloat16& rhs) { return float(lhs) < float(rhs); }
|
||||
__device__ __forceinline__ bool operator==(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) == float(rhs); }
|
||||
__device__ __forceinline__ bool operator!=(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) != float(rhs); }
|
||||
__device__ __forceinline__ bool operator>(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) > float(rhs); }
|
||||
__device__ __forceinline__ bool operator<(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) < float(rhs); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline T _Tanh(T a);
|
||||
|
@ -191,4 +190,4 @@ __device__ __inline__ half2 _Tanh(half2 a) {
|
|||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast<float>(a)); }
|
||||
__device__ __inline__ ortc::BFloat16 _Tanh(ortc::BFloat16 a) { return tanhf(static_cast<float>(a)); }
|
||||
|
|
|
@ -81,20 +81,25 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
return result;
|
||||
}
|
||||
|
||||
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
|
||||
// KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
// bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
// bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
// bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
// bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
// bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
|
||||
|
||||
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
|
||||
tokenize_punctuation, remove_control_chars);
|
||||
}
|
||||
// tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
|
||||
// tokenize_punctuation, remove_control_chars);
|
||||
// }
|
||||
|
||||
void KernelBasicTokenizer::Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
// Setup inputs
|
||||
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input));
|
||||
output.SetStringOutput({result[0].operator std::string()}, {1});
|
||||
std::vector<std::string> tokens;
|
||||
for (const auto& token : result) {
|
||||
tokens.push_back((std::string)token);
|
||||
}
|
||||
|
||||
output.SetStringOutput(tokens, {static_cast<int64_t>(tokens.size())});
|
||||
}
|
||||
|
|
|
@ -21,8 +21,19 @@ class BasicTokenizer {
|
|||
bool remove_control_chars_;
|
||||
};
|
||||
|
||||
struct KernelBasicTokenizer : BaseKernel {
|
||||
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
struct KernelBasicTokenizer{
|
||||
template <typename T>
|
||||
KernelBasicTokenizer(const T& dict) {
|
||||
bool do_lower_case = dict.TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool tokenize_chinese_chars = dict.TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = dict.TryToGetAttributeWithDefault("strip_accents", false);
|
||||
bool tokenize_punctuation = dict.TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
bool remove_control_chars = dict.TryToGetAttributeWithDefault("remove_control_chars", true);
|
||||
|
||||
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
|
||||
tokenize_punctuation, remove_control_chars);
|
||||
}
|
||||
|
||||
void Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output) const;
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "ortx_common.h"
|
||||
|
||||
#include <optional>
|
||||
#include <limits>
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
|
@ -428,7 +429,7 @@ OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
tokenize_results.emplace_back(
|
||||
(this->*tok_fun)(
|
||||
ustr,
|
||||
padding_length_ < 0 ? std::numeric_limits<uint32_t>::max() : padding_length_,
|
||||
padding_length_ < 0 ? (std::numeric_limits<uint32_t>::max)() : padding_length_,
|
||||
compute_offset_mapping,
|
||||
offset_map));
|
||||
}
|
||||
|
@ -436,7 +437,7 @@ OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
size_t max_length = 0;
|
||||
if (padding_length_ == -1) {
|
||||
for (auto& res : tokenize_results) {
|
||||
max_length = std::max(max_length, res.size());
|
||||
max_length = (std::max)(max_length, res.size());
|
||||
}
|
||||
} else {
|
||||
max_length = static_cast<size_t>(padding_length_);
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
|
||||
struct BpeModelConf {
|
||||
const char* name_{"GPT2"}; // this name may be overridden by the tokenizer's attribute.
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
|
||||
|
||||
TEST(tokenizer_opertors, test_bert_tokenizer) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifdef ENABLE_DLIB
|
||||
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
using namespace dlib;
|
||||
|
@ -20,3 +23,5 @@ TEST(math, matrix_op) {
|
|||
matrix<float> x = inv(M)*y;
|
||||
EXPECT_FLOAT_EQ(x(1, 0), -13.909741);
|
||||
}
|
||||
|
||||
#endif // ENABLE_DLIB
|
||||
|
|
|
@ -7,29 +7,29 @@
|
|||
#include "bert_tokenizer.hpp"
|
||||
|
||||
#include <clocale>
|
||||
#include "tokenizer/basic_tokenizer.hpp"
|
||||
|
||||
|
||||
class LocaleBaseTest : public testing::Test{
|
||||
public:
|
||||
// Remember that SetUp() is run immediately before a test starts.
|
||||
void SetUp() override {
|
||||
class LocaleBaseTest : public testing::Test {
|
||||
public:
|
||||
// Remember that SetUp() is run immediately before a test starts.
|
||||
void SetUp() override {
|
||||
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
|
||||
default_locale_ = std::locale().name();
|
||||
std::setlocale(LC_CTYPE, "C");
|
||||
default_locale_ = std::locale().name();
|
||||
std::setlocale(LC_CTYPE, "C");
|
||||
#else
|
||||
default_locale_ = std::locale("").name();
|
||||
std::setlocale(LC_CTYPE, "en_US.UTF-8");
|
||||
default_locale_ = std::locale("").name();
|
||||
std::setlocale(LC_CTYPE, "en_US.UTF-8");
|
||||
#endif
|
||||
}
|
||||
// TearDown() is invoked immediately after a test finishes.
|
||||
void TearDown() override {
|
||||
if (!default_locale_.empty()) {
|
||||
std::setlocale(LC_CTYPE, default_locale_.c_str());
|
||||
}
|
||||
// TearDown() is invoked immediately after a test finishes.
|
||||
void TearDown() override {
|
||||
if (!default_locale_.empty()) {
|
||||
std::setlocale(LC_CTYPE, default_locale_.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string default_locale_;
|
||||
private:
|
||||
std::string default_locale_;
|
||||
};
|
||||
|
||||
TEST(tokenizer, bert_word_split) {
|
||||
|
@ -65,7 +65,7 @@ std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
|
|||
};
|
||||
std::unordered_map<std::u32string, int32_t> vocab;
|
||||
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
|
||||
vocab[*it] = vocab.size();
|
||||
vocab[*it] = static_cast<int32_t>(vocab.size());
|
||||
}
|
||||
return vocab;
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ std::unordered_map<std::u32string, int32_t> get_vocabulary_wordpiece() {
|
|||
};
|
||||
std::unordered_map<std::u32string, int32_t> vocab;
|
||||
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
|
||||
vocab[*it] = vocab.size();
|
||||
vocab[*it] = static_cast<int32_t>(vocab.size());
|
||||
}
|
||||
return vocab;
|
||||
}
|
||||
|
@ -156,9 +156,9 @@ TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
|
|||
TEST_F(LocaleBaseTest, basic_tokenizer_chinese) {
|
||||
ustring test_case = ustring("ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~");
|
||||
std::vector<ustring> expect_result = ustring_vector_convertor({"aaaaaaceeeeiiinooooouu",
|
||||
"䗓", "𨖷", "虴", "𨀐", "辘", "𧄋", "脟", "𩑢", "𡗶", "镇", "伢", "𧎼", "䪱", "轚", "榶", "𢑌", "㺽", "𤨡",
|
||||
"!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":",
|
||||
";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
|
||||
"䗓", "𨖷", "虴", "𨀐", "辘", "𧄋", "脟", "𩑢", "𡗶", "镇", "伢", "𧎼", "䪱", "轚", "榶", "𢑌", "㺽", "𤨡",
|
||||
"!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":",
|
||||
";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
|
||||
BasicTokenizer tokenizer(true, true, true, true, true);
|
||||
auto result = tokenizer.Tokenize(test_case);
|
||||
EXPECT_EQ(result, expect_result);
|
||||
|
@ -245,5 +245,19 @@ TEST(tokenizer, truncation_longest_first) {
|
|||
test_input2 = init_vector1;
|
||||
truncate.Truncate(test_input1, test_input2, 12);
|
||||
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
|
||||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6 ,7}));
|
||||
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
|
||||
}
|
||||
|
||||
TEST(tokenizer, basic_tok_eager) {
|
||||
std::string test_case = "I mean, you’ll need something to talk about next Sunday, right?";
|
||||
std::vector<std::string> expect_result = {"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"};
|
||||
|
||||
ortc::NamedArgumentDict dict({"do_lower_case", "tokenize_chinese_chars", "strip_accents", "tokenize_punctuation", "remove_control_chars"},
|
||||
std::make_tuple(false, true, true, true, true));
|
||||
|
||||
KernelBasicTokenizer tokenizer(dict);
|
||||
|
||||
ortc::Tensor<std::string> output;
|
||||
tokenizer.Compute(test_case, output);
|
||||
EXPECT_EQ(output.Data(), expect_result);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
#include <locale>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
#include "bpe_kernels.h"
|
||||
|
||||
TEST(bbpe_tokenizer, test_encoder) {
|
||||
EXPECT_EQ(0, ORT_OK);
|
||||
}
|
|
@ -10,25 +10,31 @@
|
|||
#include "exceptions.h"
|
||||
|
||||
namespace {
|
||||
void FixCurrentDir() {
|
||||
void FixCurrentDir(const std::string& init_path = "") {
|
||||
// adjust for the Google Test Adapter in Visual Studio not setting the current path to $(ProjectDir),
|
||||
// which results in us being 2 levels below where the `data` folder is copied to and where the extensions
|
||||
// library is
|
||||
auto cur = std::filesystem::current_path();
|
||||
|
||||
// if init_path is the executable path, then we need to get the directory of the executable
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
if (!init_path.empty()) {
|
||||
std::filesystem::path init_dir = init_path;
|
||||
cur_dir = init_dir.parent_path();
|
||||
}
|
||||
|
||||
do {
|
||||
auto data_dir = cur / "data";
|
||||
auto data_dir = cur_dir / "data";
|
||||
|
||||
if (std::filesystem::exists(data_dir) && std::filesystem::is_directory(data_dir)) {
|
||||
break;
|
||||
}
|
||||
|
||||
cur = cur.parent_path();
|
||||
ASSERT_NE(cur, cur.root_path()) << "Reached root directory without finding 'data' directory.";
|
||||
cur_dir = cur_dir.parent_path();
|
||||
ASSERT_NE(cur_dir, cur_dir.root_path()) << "Reached root directory without finding 'data' directory.";
|
||||
} while (true);
|
||||
|
||||
// set current path as the extensions library is also loaded from that directory by TestInference
|
||||
std::filesystem::current_path(cur);
|
||||
std::filesystem::current_path(cur_dir);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -38,7 +44,7 @@ int main(int argc, char** argv) {
|
|||
OCOS_TRY {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
FixCurrentDir();
|
||||
FixCurrentDir(argv[0]);
|
||||
status = RUN_ALL_TESTS();
|
||||
}
|
||||
OCOS_CATCH(const std::exception& ex) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче