Refactor the header file directory and integrate the eager tensor implementation (#689)

* refactor the header file in include folder

* fix the basic-token eager unit test case

* a more flexible way to handle string tensor shape.

* fix the unit test path issue

* remove the multi-inherits to avoid issue during pointer casting

* add api cmake build support

* undo some temporary changes

* code refinement

* fix variadic arg

* only expose the context for ort version >= 17

* fix a shape bug

* fix the cuda build issue

* change ifdef condition of GetAllocator

* finalize the ort c abi wrapper file name

* fix the iOS build break

* align gtest version with triton

* Update ext_apple_framework.cmake for iOS header files

---------

Co-authored-by: Cheng Tang <chenta@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
This commit is contained in:
Wenbing Li 2024-04-17 12:58:19 -07:00 коммит произвёл GitHub
Родитель fe8cd9ee8d
Коммит 646462790b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
33 изменённых файлов: 2295 добавлений и 1393 удалений

Просмотреть файл

@ -2,7 +2,7 @@
# turn off readability-braces-around-statements to allow single line statement like 'if (x == y) doSomething();'
Checks: '-*,cppcoreguidelines-*,google-*,readability-*,modernize-*,-readability-braces-around-statements,-google-runtime-references,-cppcoreguidelines-pro-type-reinterpret-cast'
WarningsAsErrors: ''
HeaderFilterRegex: 'includes\/.*'
HeaderFilterRegex: 'include\/.*'
AnalyzeTemporaryDtors: false
FormatStyle: none
CheckOptions:

Просмотреть файл

@ -67,6 +67,7 @@ option(OCOS_ENABLE_AZURE "Enable the operators for azure execution provider" OFF
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
option(OCOS_ENABLE_C_API "Enable building the C API" OFF)
option(OCOS_BUILD_PYTHON "Enable building the Python package" OFF)
option(OCOS_BUILD_JAVA "Enable building the Java package" OFF)
@ -81,7 +82,8 @@ set(OCOS_ONNXRUNTIME_VERSION "" CACHE STRING
"The version of ONNX Runtime being used in the build. Format is <major>.<minor>.<patch>. e.g. 1.15.1")
set(OCOS_ONNXRUNTIME_PKG_URI "" CACHE STRING
"Specify the onnxruntime C++ shared library zip package path, like ./onnxruntime-win-x64-1.16.0.zip")
set(OCOS_BUILD_PRESET "" CACHE STRING
"Specify the build preset cmake settings file path, like 'token_api_only' which includes ./cmake/presets/token_api_only.cmake")
# TODO: Remove the following statements if AzureOp build is enabled by default.
# If build_buildid environment varaible is set, which means this is a CI build, then always enable AzureOp.
# or it is enabled when OCOS_ENABLE_AZURE is set, which means the user explicitly enables it.
@ -188,16 +190,27 @@ if(NOT PROJECT_IS_TOP_LEVEL AND ONNXRUNTIME_ROOT)
set(_ONNXRUNTIME_EMBEDDED TRUE)
endif()
if (OCOS_ENABLE_SELECTED_OPLIST OR OCOS_BUILD_PRESET)
disable_all_operators()
if(OCOS_ENABLE_SELECTED_OPLIST)
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
message(STATUS "Looking for the _selectedoplist.cmake")
disable_all_operators()
include(_selectedoplist)
# Include the selected_ops case, no way to run the unit tests, so disable it,
# even the user explicitly set it to ON. (it is rare, most of the time, it is set by default)
set(OCOS_ENABLE_CTEST OFF CACHE BOOL "" FORCE)
endif()
if (OCOS_BUILD_PRESET)
set(_BUILD_PRESET "${PROJECT_SOURCE_DIR}/cmake/presets/${OCOS_BUILD_PRESET}.cmake")
if (EXISTS ${_BUILD_PRESET})
include(${_BUILD_PRESET})
else()
message(FATAL_ERROR "The specified build preset file does not exist: ${_BUILD_PRESET}")
endif()
endif()
endif()
set(_OCOS_EXCEPTIONS_REQUIRED OFF)
if (OCOS_ENABLE_GPT2_TOKENIZER OR
@ -300,7 +313,7 @@ endif()
# ### scan all source files
file(GLOB TARGET_SRC_NOEXCEPTION "base/*.h" "base/*.cc")
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h" "includes/*.h*")
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h" "include/*.h" "include/*.hpp")
if(OCOS_ENABLE_TF_STRING)
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
@ -551,13 +564,15 @@ standardize_output_folder(ocos_operators)
target_include_directories(noexcep_operators PUBLIC
${ONNXRUNTIME_INCLUDE_DIR}
${PROJECT_SOURCE_DIR}/includes
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/include/custom_op
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)
target_include_directories(ocos_operators PUBLIC
${ONNXRUNTIME_INCLUDE_DIR}
${PROJECT_SOURCE_DIR}/includes
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/include/custom_op
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)

Просмотреть файл

@ -1,7 +1,7 @@
include *.txt
global-include *.def
recursive-include cmake *.*
recursive-include includes *.*
recursive-include include *.*
recursive-include operators *.*
recursive-include pyop *.*
recursive-include shared *.*

Просмотреть файл

@ -190,5 +190,4 @@ uint64_t Hash64Fast(const char* data, size_t n) {
return static_cast<int64_t>(util::Fingerprint64(data, n));
}
#endif // ENABLE_TF_STRING

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include <sstream>
#include <vector>
#include "onnxruntime_cpp_api_legacy.hpp"
#include "ort_c_to_cpp.h"
template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {

Просмотреть файл

@ -15,7 +15,9 @@ set(APPLE_FRAMEWORK_VERSION "${VERSION}")
# public header files
set(APPLE_FRAMEWORK_HEADERS
"${PROJECT_SOURCE_DIR}/includes/onnxruntime_extensions.h")
"${PROJECT_SOURCE_DIR}/include/onnxruntime_extensions.h"
"${PROJECT_SOURCE_DIR}/include/ortx_tokenizer.h"
"${PROJECT_SOURCE_DIR}/include/ortx_op_registry.h")
# generated framework directory
set(APPLE_FRAMEWORK_DIRECTORY

Просмотреть файл

@ -47,16 +47,12 @@ function(add_test_target)
# add a test executable
add_executable(${ARG_TARGET})
standardize_output_folder(${ARG_TARGET})
add_test(NAME ${ARG_TARGET}
COMMAND ${ARG_TARGET})
target_sources(${ARG_TARGET} PRIVATE
${ARG_TEST_SOURCES}
"${TEST_SRC_DIR}/unittest_main/test_main.cc")
target_link_libraries(${ARG_TARGET} PRIVATE
${ARG_LIBRARIES}
gtest gmock)
@ -132,6 +128,7 @@ file(GLOB static_TEST_SRC "${TEST_SRC_DIR}/static_test/*.cc")
add_test_target(TARGET ocos_test
TEST_SOURCES ${static_TEST_SRC}
LIBRARIES ortcustomops ${ocos_libraries})
target_compile_definitions(ocos_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
# -- shared test (needs onnxruntime) --
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
@ -201,6 +198,20 @@ else()
)
endif()
endblock()
block()
file(GLOB tokenizer_TEST_SRC
"${TEST_SRC_DIR}/tokenizer_test/*.cc"
"${TEST_SRC_DIR}/tokenizer_test/*.hpp")
add_test_target(TARGET tokenizer_api_test
TEST_SOURCES ${tokenizer_TEST_SRC}
LIBRARIES onnxruntime_extensions ${ocos_libraries}
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
target_compile_definitions(tokenizer_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
endblock()
endif()
endif()

4
cmake/externals/googletest.cmake поставляемый
Просмотреть файл

@ -1,7 +1,7 @@
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.11.0
URL https://github.com/google/googletest/archive/9406a60c7839052e4944ea4dbc8344762a89f9bd.zip
URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106
)
FetchContent_MakeAvailable(googletest)

Просмотреть файл

@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set(OCOS_ENABLE_GPT2_TOKENIZER ON CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_C_API ON CACHE INTERNAL "" FORCE)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,35 @@
#pragma once
#include <optional>
#include <numeric>
#include <type_traits>
namespace Ort {
namespace Custom {
// this is for the ORT custom op template magic
struct Arg {
virtual ~Arg() = default;
};
class KernelContext : public Arg{
public:
virtual void* AllocScratchBuffer(size_t size) = 0;
virtual void FreeScratchBuffer(void* p) = 0;
// TODO: threadpool?
};
#ifdef USE_CUDA
class CUDAKernelContext : public KernelContext {
public:
virtual void* AllocCudaScratchBuffer(size_t size) = 0;
virtual void FreeCudaScratchBuffer(void* p) = 0;
virtual void* GetCudaStream() const = 0;
virtual void* GetCublasHandle() const = 0;
virtual int GetCudaDeviceId() const = 0;
};
#endif
// TODO: helper func to create context from global ORT env.
}
}

Просмотреть файл

Просмотреть файл

@ -0,0 +1,534 @@
#pragma once
#include <optional>
#include <numeric>
#include <type_traits>
#include "onnxruntime_f16.h"
#include "kernel_context.h"
namespace Ort {
namespace Custom {
template <typename T>
struct Span {
const T* data_ = {};
size_t size_ = {};
void Assign(const T* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
T operator[](size_t indice) const {
return data_[indice];
}
const T* data() const { return data_; }
};
#if ORT_API_VERSION >= 16
template <>
struct Span<MFloat16> {
const MFloat16* data_ = {};
size_t size_ = {};
void Assign(const MFloat16* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
MFloat16 operator[](size_t indice) const {
return data_[indice];
}
const MFloat16* data() const { return data_; }
};
template <>
struct Span<BFloat16> {
const BFloat16* data_ = {};
size_t size_ = {};
void Assign(const BFloat16* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
BFloat16 operator[](size_t indice) const {
return data_[indice];
}
const BFloat16* data() const { return data_; }
};
#endif
class ITensorStorage{
public:
virtual const std::vector<int64_t>& Shape() const = 0;
virtual const void* DataRaw() const = 0;
virtual bool IsInitialized() const = 0;
virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
};
class IAllocator {
public:
virtual void* Alloc(size_t size) = 0;
virtual void Free(void* p) = 0;
};
class OrtEagerTensorStorage : public ITensorStorage {
public:
OrtEagerTensorStorage(const std::vector<int64_t>& shape,
void* buffer) : buffer_(buffer), shape_(shape){
}
OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
}
virtual ~OrtEagerTensorStorage(){
if (allocator_ && buffer_)
allocator_->Free(buffer_);
}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
virtual bool IsInitialized() const override {
return shape_.has_value();
}
const void* DataRaw() const override {
return buffer_;
}
void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
if (IsInitialized())
return buffer_;
assert(allocator_);
shape_ = shape;
int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
auto buffer_size = n_elem * element_size;
buffer_ = allocator_->Alloc(buffer_size);
return buffer_;
}
private:
void* buffer_ {};
std::optional<std::vector<int64_t>> shape_;
// caller need to make sure the allocator is alive
IAllocator* allocator_;
};
template <typename TT>
ONNXTensorElementDataType GetOrtDType(){
if constexpr (std::is_same<TT, bool>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
else if constexpr (std::is_same<TT, float>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
else if constexpr (std::is_same<TT, double>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
else if constexpr (std::is_same<TT, uint8_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
else if constexpr (std::is_same<TT, int8_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
else if constexpr (std::is_same<TT, uint16_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
else if constexpr (std::is_same<TT, int16_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
else if constexpr (std::is_same<TT, uint32_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
else if constexpr (std::is_same<TT, int32_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
else if constexpr (std::is_same<TT, uint64_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
else if constexpr (std::is_same<TT, int64_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
else if constexpr (std::is_same<TT, std::string>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
class TensorBase : public Arg {
public:
virtual ~TensorBase() {}
virtual ONNXTensorElementDataType Type() const = 0;
virtual const std::vector<int64_t>& Shape() const = 0;
virtual int64_t NumberOfElement() const = 0;
virtual const void* DataRaw() const = 0;
virtual size_t SizeInBytes() const = 0;
};
template <typename T>
class Tensor : public TensorBase {
public:
using TT = typename std::remove_reference<T>::type;
Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
}
Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
virtual ~Tensor() = default;
operator bool() const {
return storage_->IsInitialized();
}
ONNXTensorElementDataType Type() const override {
return GetOrtDType<T>();
}
const std::vector<int64_t>& Shape() const override {
return storage_->Shape();
}
int64_t NumberOfElement() const override {
auto& shape = storage_->Shape();
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
}
std::string Shape2Str() const {
if (storage_->IsInitialized()) {
std::string shape_str;
auto& shape = storage_->Shape();
for (const auto& dim : shape) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
const TT* Data() const {
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
return reinterpret_cast<const TT*>(storage_->DataRaw());
else
#endif
return static_cast<const TT*>(storage_->DataRaw());
}
const void* DataRaw() const override {
return storage_->DataRaw();
}
size_t SizeInBytes() const override {
return NumberOfElement() * sizeof(TT);
}
TT* Allocate(const std::vector<int64_t>& shape) {
// it should be OK to allocate multiple times
void* buffer = storage_->Initialize(shape, sizeof(TT));
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
return reinterpret_cast<TT*>(buffer);
else
#endif
return static_cast<TT*>(buffer);
}
const Span<T>& AsSpan() {
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}
else{
#endif
auto& shape = storage_->Shape();
if (shape.size() != 1) {
ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
span_.Assign(Data(), shape[0]);
return span_;
#if ORT_API_VERSION >= 16
}
#endif
}
const T& AsScalar() {
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}
else{
#endif
auto& shape = storage_->Shape();
if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return *Data();
#if ORT_API_VERSION >= 16
}
#endif
}
private:
std::unique_ptr<ITensorStorage> storage_;
Span<T> span_;
};
template<typename T>
class IStringTensorStorage{
public:
using strings = std::vector<T>;
virtual const std::vector<int64_t>& Shape() const = 0;
virtual const void* DataRaw() const = 0;
virtual const strings& Data() const = 0;
virtual bool IsInitialized() const = 0;
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
};
template<typename T>
class EagerStringTensorStorage : public IStringTensorStorage<T>{
public:
using strings = std::vector<T>;
EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
EagerStringTensorStorage() {}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
virtual const void* DataRaw() const override {
if (input_strings_.size() != 1) {
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
if constexpr (std::is_same<std::string_view, T>::value)
return reinterpret_cast<const void*>(input_strings_[0].data());
else
return reinterpret_cast<const void*>(input_strings_[0].c_str());
}
virtual bool IsInitialized() const override {
return shape_.has_value();
}
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
if constexpr (std::is_same<std::string_view, T>::value)
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
input_strings_.assign(ss.begin(), ss.end());
shape_ = dims;
}
const strings& Data() const override {
return input_strings_;
}
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
if constexpr (std::is_same<std::string_view, T>::value)
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
for (const char* s : ss){
input_strings_.push_back(s);
}
shape_ = dims;
}
private:
std::vector<T> input_strings_;
std::optional<std::vector<int64_t>> shape_;
};
template <>
class Tensor<std::string> : public TensorBase {
public:
using strings = std::vector<std::string>;
Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
ONNXTensorElementDataType Type() const override {
return GetOrtDType<std::string>();
}
const strings& Data() const {
return storage_->Data();
}
const std::vector<int64_t>& Shape() const override {
return storage_->Shape();
}
int64_t NumberOfElement() const override {
auto& shape = storage_->Shape();
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
}
std::string Shape2Str() const {
if (storage_->IsInitialized()) {
std::string shape_str;
auto& shape = storage_->Shape();
for (const auto& dim : shape) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
const void* DataRaw() const override {
return storage_->DataRaw();
}
size_t SizeInBytes() const override {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return ss[0].size();
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
const Span<std::string>& AsSpan() {
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
}
const std::string& AsScalar() {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return ss[0];
}
private:
std::unique_ptr<IStringTensorStorage<std::string>> storage_;
};
template <>
class Tensor<std::string_view> : public TensorBase {
public:
using strings = std::vector<std::string_view>;
Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
ONNXTensorElementDataType Type() const override {
return GetOrtDType<std::string_view>();
}
const strings& Data() const {
return storage_->Data();
}
const std::vector<int64_t>& Shape() const override {
return storage_->Shape();
}
int64_t NumberOfElement() const override {
auto& shape = storage_->Shape();
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
}
std::string Shape2Str() const {
if (storage_->IsInitialized()) {
std::string shape_str;
auto& shape = storage_->Shape();
for (const auto& dim : shape) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
const void* DataRaw() const override {
return storage_->DataRaw();
}
size_t SizeInBytes() const override {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return ss[0].size();
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
const Span<std::string_view>& AsSpan() {
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
}
const std::string_view& AsScalar() {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return ss[0];
}
private:
std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
};
template<typename ...Args>
class NamedArgumentDict{
public:
using ValueTuple = std::tuple<Args...>;
NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
for (const char* key : keys){
names_.push_back(key);
}
}
template<typename T>
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
}
private:
template<size_t I, typename T>
typename std::enable_if<I == sizeof...(Args), T>::type
TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
return default_value;
}
template<size_t I, typename T>
typename std::enable_if<I < sizeof...(Args), T>::type
TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
if (names_[I] == name){
if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
return std::get<I>(entries_);
else
throw std::runtime_error("name matched but type is not");
}
return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
}
std::vector<std::string> names_;
std::tuple<Args...> entries_;
};
}
}

Просмотреть файл

@ -0,0 +1,137 @@
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type_def>*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type_def>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
if (ith_input < num_input) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
} else {
std::tuple<T> current = std::tuple<T>{};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type_def>*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
}
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type_def>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
}
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
if (ith_input < num_input) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
}
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
} else {
std::tuple<T> current = std::tuple<T>{};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, data_type_def>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL);
}
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsScalar()};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, std::optional<data_type_def>>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
if (ith_input < num_input) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL);
}
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsScalar()};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
} else {
std::tuple<T> current = std::tuple<T>{};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type_def>*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type_def>&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
if (ith_output < num_output) {
tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
} else {
std::tuple<T> current = std::tuple<T>{};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
}

Просмотреть файл

Просмотреть файл

@ -10,7 +10,7 @@
#include <string>
#include <vector>
#include "onnxruntime_customop.hpp"
#include "op_def_struct.h"
// A helper API to support test kernels.
// Must be invoked before RegisterCustomOps.

Просмотреть файл

@ -0,0 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ortx_tokenizer.h"
#include "ortx_op_registry.h"

Просмотреть файл

@ -18,10 +18,8 @@
#include <functional>
#include "exceptions.h"
#include "onnxruntime_no_customop.h"
#include "onnxruntime_cpp_api_legacy.hpp"
#include "onnxruntime_extensions.h"
#include "custom_op_lite.h"
#include "custom_op/custom_op_lite.h"
#define MIN_ORT_VERSION_SUPPORTED 11

Просмотреть файл

@ -5,10 +5,7 @@
#include <vector>
#include "exceptions.h"
//
// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
// TODO: Remove this file once all custom OPs are migrated to the new API
//
// OrtW: ONNX Runtime C ABI Wrapper
namespace OrtW {
struct CustomOpApi {
@ -30,6 +27,9 @@ struct CustomOpApi {
template <typename T>
const T* GetTensorData(_Inout_ const OrtValue* value) const;
void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
@ -48,6 +48,54 @@ struct CustomOpApi {
const OrtApi& api_;
};
class API {
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
public:
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
static API self(ort_api);
return self;
}
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
return instance()->CreateStatus(code, msg);
}
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
instance()->ReleaseStatus(ptr);
}
template <typename T>
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
static void ThrowOnError(OrtStatusPtr ptr) {
OrtW::ThrowOnError(instance().api_, ptr);
}
// Caller is responsible for releasing OrtMemoryInfo object
static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
}
#if ORT_API_VERSION >= 15
// Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
return instance()->KernelContext_GetAllocator(context, mem_info, out);
}
#endif
private:
const OrtApi* operator->() const {
return &api_;
}
API(const OrtApi* api) : api_(*api) {
if (api == nullptr) {
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
}
}
const OrtApi& api_;
};
//
// Custom OP API Inlines
//
@ -162,6 +210,16 @@ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
}
inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
void* data = nullptr;
ThrowOnError(api_.GetTensorMutableData(value, &data));
return data;
}
inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
return GetTensorMutableRawData(const_cast<OrtValue*>(value));
}
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
std::vector<int64_t> output(GetDimensionsCount(info));
GetDimensions(info, output.data(), output.size());
@ -197,9 +255,72 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
return out;
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
out.resize(size - 1); // remove the terminating character '\0'
}
if (status == nullptr) {
value = std::move(out);
}
return status;
}
template <class T>
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
// Just ignore all of them.
API::ReleaseStatus(status);
}
return nullptr;
}
template <class T>
inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
T ret;
if (API::KernelInfoGetAttribute(info, name, ret)) {
ret = default_value;
}
return ret;
}
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
return API::CreateStatus(code, msg);
}
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
return API::CreateStatus(code, msg.c_str());
}
inline void ReleaseStatus(OrtStatusPtr& status) {
API::ReleaseStatus(status);
status = nullptr;
}
} // namespace of OrtW
// Deprecated: No needs to create a new class derived from BaseKernel.
struct BaseKernel {
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
: api_(api), info_(info), ort_(api_) {
@ -226,6 +347,7 @@ struct BaseKernel {
const OrtKernelInfo& info_;
};
// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions() = default;
OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {

Просмотреть файл

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
// Note: The following include path is used for building Swift Package Manager support for ORT Extensions.
// The macro is defined in cxxSettings config in Package.swift.
// The reason why we need a prefix is that when Xcode includes the package it copies it to an internally generated path with
@ -15,7 +13,6 @@
#include "onnxruntime_c_api.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif

199
include/ortx_tokenizer.h Normal file
Просмотреть файл

@ -0,0 +1,199 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// C ABI header file for the onnxruntime-extensions tokenization module
#pragma once
#include <stdint.h>
#include <stddef.h>
#if defined(__CYGWIN__) || defined(__MINGW32__)
#define ORTX_API_CALL __stdcall
#elif defined(_WIN32)
#define ORTX_API_CALL _stdcall
#define ORTX_MUST_USE_RESULT
#elif __APPLE__
#define ORTX_API_CALL
// To make symbols visible on macOS/iOS
#define ORTX_MUST_USE_RESULT __attribute__((warn_unused_result))
#else
#define ORTX_API_CALL
#define ORTX_MUST_USE_RESULT
#endif
typedef enum {
kOrtxOK = 0,
kOrtxErrorInvalidArgument = 1,
kOrtxErrorOutOfMemory = 2,
kOrtxErrorInvalidFile = 3,
kOrtxErrorNotFound = 4,
kOrtxErrorAlreadyExists = 5,
kOrtxErrorOutOfRange = 6,
kOrtxErrorNotImplemented = 7,
kOrtxErrorInternal = 8,
kOrtxErrorUnknown = 1000
} extError_t;
typedef enum {
kOrtxKindUnknown = 0,
kOrtxKindBegin = 0x7788, // starting from a number to help validate the object
kOrtxKindTokenizer = kOrtxKindBegin,
kOrtxKindStringArray = 0x7789,
kOrtxKindTokenId2DArray = 0x778A,
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindEnd = 0x7999
} extObjectKind_t;
// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int ext_kind_;
} OrtxObject;
const int API_VERSION = 1;
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxTokenizer;
typedef OrtxObject OrtxStringArray;
typedef OrtxObject OrtxTokenId2DArray;
typedef OrtxObject OrtxDetokenizerCache;
// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
#define ORTX_DISPOSE(obj) OrtxDispose((OrtxObject**)&obj)
typedef uint32_t extTokenId_t;
#ifdef __cplusplus
extern "C" {
#endif
/** \brief Get the current C ABI version of this library
*
* \snippet{doc} snippets.dox int Return Value
*/
int ORTX_API_CALL OrtxGetAPIVersion(void);
/** \brief Get the last error message generated by the library
*
* \param message Pointer to store the last error message
* \return Pointer to the last error message
*/
const char* ORTX_API_CALL OrtxGetLastErrorMessage(void);
/** \brief Create a new object of the specified kind
*
* \param kind The kind of object to create
* \param object Pointer to store the created object
* \param ... Additional arguments based on the kind of object
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, ...);
/** \brief Dispose the specified object
*
* \param object Pointer to the object to dispose
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
/** \brief Dispose the specified object
*
* \param object Pointer to the object to dispose
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
/** \brief Create a tokenizer object with the specified tokenizer path
*
* \param tokenizer Pointer to store the created tokenizer object
* \param tokenizer_path The path to the tokenizer
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const char* tokenizer_path);
/** \brief Tokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object
* \param input Array of input strings
* \param batch_size Number of input strings in the batch
* \param output Pointer to store the tokenized result
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxTokenize(
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);
/** \brief Detokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object
* \param input Pointer to the input token IDs
* \param output Pointer to store the detokenized result
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxDetokenize(
const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input, OrtxStringArray** output);
/** \brief Detokenize the input using the specified tokenizer (1D version)
*
* \param tokenizer Pointer to the tokenizer object
* \param input Pointer to the input token IDs
* \param len Length of the input token IDs array
* \param output Pointer to store the detokenized result
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxDetokenize1D(
const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len, OrtxStringArray** output);
/** \brief Detokenize the input using the specified tokenizer with caching
*
* \param tokenizer Pointer to the tokenizer object
* \param cache Pointer to the detokenizer cache
* \param next_id Next token ID to detokenize
* \param text_out Pointer to store the detokenized text
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxDetokenizeCached(
const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id, const char** text_out);
/** \brief Get the length of the string array
*
* \param string_array Pointer to the string array
* \param length Pointer to store the length of the string array
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxStringArrayGetBatch(const OrtxStringArray* string_array, size_t* length);
/** \brief Get the item at the specified index from the string array
*
* \param string_array Pointer to the string array
* \param index Index of the item to retrieve
* \param item Pointer to store the retrieved item
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxStringArrayGetItem(const OrtxStringArray* string_array, size_t index, const char** item);
/** \brief Get the batch size of the token ID 2D array
*
* \param token_id_2d_array Pointer to the token ID 2D array
* \param length Pointer to store the batch size
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* token_id_2d_array, size_t* length);
/** \brief Get the item at the specified index from the token ID 2D array
*
* \param token_id_2d_array Pointer to the token ID 2D array
* \param index Index of the item to retrieve
* \param item Pointer to store the retrieved item
* \param length Pointer to store the length of the item
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(
const OrtxTokenId2DArray* token_id_2d_array, size_t index, const extTokenId_t** item, size_t* length);
#ifdef __cplusplus
}
#endif

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,122 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file defines API which depends on ONNXRuntime, but not including Custom Op and related facilities
// Custom Op and related classes, functions and macros are in onnxruntime_customop.hpp
#pragma once
#include "exceptions.h"
// namespace of ORT ABI Wrapper
namespace OrtW {
class API {
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
public:
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
static API self(ort_api);
return self;
}
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
return instance()->CreateStatus(code, msg);
}
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
instance()->ReleaseStatus(ptr);
}
template <typename T>
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
static void ThrowOnError(OrtStatusPtr ptr) {
OrtW::ThrowOnError(instance().api_, ptr);
}
// Caller is responsible for releasing OrtMemoryInfo object
static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
}
#if ORT_API_VERSION >= 15
// Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
return instance()->KernelContext_GetAllocator(context, mem_info, out);
}
#endif
private:
const OrtApi* operator->() const {
return &api_;
}
API(const OrtApi* api) : api_(*api) {
if (api == nullptr) {
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
}
}
const OrtApi& api_;
};
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
out.resize(size - 1); // remove the terminating character '\0'
}
if (status == nullptr) {
value = std::move(out);
}
return status;
}
template <class T>
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
// Just ignore all of them.
API::ReleaseStatus(status);
}
return nullptr;
}
template <class T>
inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
T ret;
if (API::KernelInfoGetAttribute(info, name, ret)) {
ret = default_value;
}
return ret;
}
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
return API::CreateStatus(code, msg);
}
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
return API::CreateStatus(code, msg.c_str());
}
inline void ReleaseStatus(OrtStatusPtr& status) {
API::ReleaseStatus(status);
status = nullptr;
}
} // namespace OrtW

Просмотреть файл

@ -4,13 +4,12 @@
#pragma once
#include "onnxruntime_f16.h"
#include "string_utils.h"
#include "onnxruntime_no_customop.h"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <stdexcept>
#include <string>
using namespace Ort::Custom;
namespace ortc = Ort::Custom;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
@ -97,81 +96,81 @@ __device__ __forceinline__ half2 operator/(const half2& lh, const half2& rh) {
}
#endif
/// Arithmetic for BFloat16
/// Arithmetic for ortc::BFloat16
__device__ __forceinline__ BFloat16 operator+(const BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16 operator+(const ortc::BFloat16& a, const ortc::BFloat16& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
__device__ __forceinline__ BFloat16 operator-(const BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16 operator-(const ortc::BFloat16& a, const ortc::BFloat16& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
__device__ __forceinline__ BFloat16 operator*(const BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16 operator*(const ortc::BFloat16& a, const ortc::BFloat16& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
__device__ __forceinline__ BFloat16 operator/(const BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16 operator/(const ortc::BFloat16& a, const ortc::BFloat16& b) {
return static_cast<float>(a) / static_cast<float>(b);
}
__device__ __forceinline__ BFloat16 operator-(const BFloat16& a) { return -static_cast<float>(a); }
__device__ __forceinline__ ortc::BFloat16 operator-(const ortc::BFloat16& a) { return -static_cast<float>(a); }
__device__ __forceinline__ BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16& operator+=(ortc::BFloat16& a, const ortc::BFloat16& b) {
a = a + b;
return a;
}
__device__ __forceinline__ BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16& operator-=(ortc::BFloat16& a, const ortc::BFloat16& b) {
a = a - b;
return a;
}
__device__ __forceinline__ BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16& operator*=(ortc::BFloat16& a, const ortc::BFloat16& b) {
a = a * b;
return a;
}
__device__ __forceinline__ BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
__device__ __forceinline__ ortc::BFloat16& operator/=(ortc::BFloat16& a, const ortc::BFloat16& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
__device__ __forceinline__ float operator+(BFloat16 a, float b) { return a + b; }
__device__ __forceinline__ float operator-(BFloat16 a, float b) { return a - b; }
__device__ __forceinline__ float operator*(BFloat16 a, float b) { return a * b; }
__device__ __forceinline__ float operator/(BFloat16 a, float b) { return a / b; }
__device__ __forceinline__ float operator+(ortc::BFloat16 a, float b) { return a + b; }
__device__ __forceinline__ float operator-(ortc::BFloat16 a, float b) { return a - b; }
__device__ __forceinline__ float operator*(ortc::BFloat16 a, float b) { return a * b; }
__device__ __forceinline__ float operator/(ortc::BFloat16 a, float b) { return a / b; }
__device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + b; }
__device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - b; }
__device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * b; }
__device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / b; }
__device__ __forceinline__ float operator+(float a, ortc::BFloat16 b) { return a + b; }
__device__ __forceinline__ float operator-(float a, ortc::BFloat16 b) { return a - b; }
__device__ __forceinline__ float operator*(float a, ortc::BFloat16 b) { return a * b; }
__device__ __forceinline__ float operator/(float a, ortc::BFloat16 b) { return a / b; }
__device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += b; }
__device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= b; }
__device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= b; }
__device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= b; }
__device__ __forceinline__ float& operator+=(float& a, const ortc::BFloat16& b) { return a += b; }
__device__ __forceinline__ float& operator-=(float& a, const ortc::BFloat16& b) { return a -= b; }
__device__ __forceinline__ float& operator*=(float& a, const ortc::BFloat16& b) { return a *= b; }
__device__ __forceinline__ float& operator/=(float& a, const ortc::BFloat16& b) { return a /= b; }
/// Arithmetic with doubles
__device__ __forceinline__ double operator+(BFloat16 a, double b) { return static_cast<double>(a) + b; }
__device__ __forceinline__ double operator-(BFloat16 a, double b) { return static_cast<double>(a) - b; }
__device__ __forceinline__ double operator*(BFloat16 a, double b) { return static_cast<double>(a) * b; }
__device__ __forceinline__ double operator/(BFloat16 a, double b) { return static_cast<double>(a) / b; }
__device__ __forceinline__ double operator+(ortc::BFloat16 a, double b) { return static_cast<double>(a) + b; }
__device__ __forceinline__ double operator-(ortc::BFloat16 a, double b) { return static_cast<double>(a) - b; }
__device__ __forceinline__ double operator*(ortc::BFloat16 a, double b) { return static_cast<double>(a) * b; }
__device__ __forceinline__ double operator/(ortc::BFloat16 a, double b) { return static_cast<double>(a) / b; }
__device__ __forceinline__ double operator+(double a, BFloat16 b) { return a + static_cast<double>(b); }
__device__ __forceinline__ double operator-(double a, BFloat16 b) { return a - static_cast<double>(b); }
__device__ __forceinline__ double operator*(double a, BFloat16 b) { return a * static_cast<double>(b); }
__device__ __forceinline__ double operator/(double a, BFloat16 b) { return a / static_cast<double>(b); }
__device__ __forceinline__ double operator+(double a, ortc::BFloat16 b) { return a + static_cast<double>(b); }
__device__ __forceinline__ double operator-(double a, ortc::BFloat16 b) { return a - static_cast<double>(b); }
__device__ __forceinline__ double operator*(double a, ortc::BFloat16 b) { return a * static_cast<double>(b); }
__device__ __forceinline__ double operator/(double a, ortc::BFloat16 b) { return a / static_cast<double>(b); }
// Overloading < and > operators
__device__ __forceinline__ bool operator==(BFloat16& lhs, BFloat16& rhs) { return float(lhs) == float(rhs); }
__device__ __forceinline__ bool operator!=(BFloat16& lhs, BFloat16& rhs) { return float(lhs) != float(rhs); }
__device__ __forceinline__ bool operator>(BFloat16& lhs, BFloat16& rhs) { return float(lhs) > float(rhs); }
__device__ __forceinline__ bool operator<(BFloat16& lhs, BFloat16& rhs) { return float(lhs) < float(rhs); }
__device__ __forceinline__ bool operator==(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) == float(rhs); }
__device__ __forceinline__ bool operator!=(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) != float(rhs); }
__device__ __forceinline__ bool operator>(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) > float(rhs); }
__device__ __forceinline__ bool operator<(ortc::BFloat16& lhs, ortc::BFloat16& rhs) { return float(lhs) < float(rhs); }
template <typename T>
__device__ __inline T _Tanh(T a);
@ -191,4 +190,4 @@ __device__ __inline__ half2 _Tanh(half2 a) {
}
template <>
__device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast<float>(a)); }
__device__ __inline__ ortc::BFloat16 _Tanh(ortc::BFloat16 a) { return tanhf(static_cast<float>(a)); }

Просмотреть файл

@ -81,20 +81,25 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
return result;
}
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
// KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
// bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
// bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
// bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
// bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
// bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
tokenize_punctuation, remove_control_chars);
}
// tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
// tokenize_punctuation, remove_control_chars);
// }
void KernelBasicTokenizer::Compute(std::string_view input,
ortc::Tensor<std::string>& output) const {
// Setup inputs
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input));
output.SetStringOutput({result[0].operator std::string()}, {1});
std::vector<std::string> tokens;
for (const auto& token : result) {
tokens.push_back((std::string)token);
}
output.SetStringOutput(tokens, {static_cast<int64_t>(tokens.size())});
}

Просмотреть файл

@ -21,8 +21,19 @@ class BasicTokenizer {
bool remove_control_chars_;
};
struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
struct KernelBasicTokenizer{
template <typename T>
KernelBasicTokenizer(const T& dict) {
bool do_lower_case = dict.TryToGetAttributeWithDefault("do_lower_case", true);
bool tokenize_chinese_chars = dict.TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = dict.TryToGetAttributeWithDefault("strip_accents", false);
bool tokenize_punctuation = dict.TryToGetAttributeWithDefault("tokenize_punctuation", false);
bool remove_control_chars = dict.TryToGetAttributeWithDefault("remove_control_chars", true);
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents,
tokenize_punctuation, remove_control_chars);
}
void Compute(std::string_view input,
ortc::Tensor<std::string>& output) const;

Просмотреть файл

@ -6,6 +6,7 @@
#include "ortx_common.h"
#include <optional>
#include <limits>
using namespace ort_extensions;
@ -428,7 +429,7 @@ OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
tokenize_results.emplace_back(
(this->*tok_fun)(
ustr,
padding_length_ < 0 ? std::numeric_limits<uint32_t>::max() : padding_length_,
padding_length_ < 0 ? (std::numeric_limits<uint32_t>::max)() : padding_length_,
compute_offset_mapping,
offset_map));
}
@ -436,7 +437,7 @@ OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
size_t max_length = 0;
if (padding_length_ == -1) {
for (auto& res : tokenize_results) {
max_length = std::max(max_length, res.size());
max_length = (std::max)(max_length, res.size());
}
} else {
max_length = static_cast<size_t>(padding_length_);

Просмотреть файл

@ -8,6 +8,7 @@
#include <string>
#include <vector>
#include <list>
struct BpeModelConf {
const char* name_{"GPT2"}; // this name may be overridden by the tokenizer's attribute.

Просмотреть файл

@ -7,6 +7,7 @@
#include "ocos.h"
#include "test_kernel.hpp"
TEST(tokenizer_opertors, test_bert_tokenizer) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");

Просмотреть файл

@ -2,6 +2,9 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
#ifdef ENABLE_DLIB
#include <dlib/matrix.h>
using namespace dlib;
@ -20,3 +23,5 @@ TEST(math, matrix_op) {
matrix<float> x = inv(M)*y;
EXPECT_FLOAT_EQ(x(1, 0), -13.909741);
}
#endif // ENABLE_DLIB

Просмотреть файл

@ -7,7 +7,7 @@
#include "bert_tokenizer.hpp"
#include <clocale>
#include "tokenizer/basic_tokenizer.hpp"
class LocaleBaseTest : public testing::Test {
public:
@ -65,7 +65,7 @@ std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
};
std::unordered_map<std::u32string, int32_t> vocab;
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
vocab[*it] = vocab.size();
vocab[*it] = static_cast<int32_t>(vocab.size());
}
return vocab;
}
@ -104,7 +104,7 @@ std::unordered_map<std::u32string, int32_t> get_vocabulary_wordpiece() {
};
std::unordered_map<std::u32string, int32_t> vocab;
for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
vocab[*it] = vocab.size();
vocab[*it] = static_cast<int32_t>(vocab.size());
}
return vocab;
}
@ -247,3 +247,17 @@ TEST(tokenizer, truncation_longest_first) {
EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
}
TEST(tokenizer, basic_tok_eager) {
std::string test_case = "I mean, youll need something to talk about next Sunday, right?";
std::vector<std::string> expect_result = {"I", "mean", ",", "you", "", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"};
ortc::NamedArgumentDict dict({"do_lower_case", "tokenize_chinese_chars", "strip_accents", "tokenize_punctuation", "remove_control_chars"},
std::make_tuple(false, true, true, true, true));
KernelBasicTokenizer tokenizer(dict);
ortc::Tensor<std::string> output;
tokenizer.Compute(test_case, output);
EXPECT_EQ(output.Data(), expect_result);
}

Просмотреть файл

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
#include <locale>
#include "gtest/gtest.h"
#include "ocos.h"
#include "ortx_tokenizer.h"
#include "bpe_kernels.h"
TEST(bbpe_tokenizer, test_encoder) {
EXPECT_EQ(0, ORT_OK);
}

Просмотреть файл

@ -10,25 +10,31 @@
#include "exceptions.h"
namespace {
void FixCurrentDir() {
void FixCurrentDir(const std::string& init_path = "") {
// adjust for the Google Test Adapter in Visual Studio not setting the current path to $(ProjectDir),
// which results in us being 2 levels below where the `data` folder is copied to and where the extensions
// library is
auto cur = std::filesystem::current_path();
// if init_path is the executable path, then we need to get the directory of the executable
auto cur_dir = std::filesystem::current_path();
if (!init_path.empty()) {
std::filesystem::path init_dir = init_path;
cur_dir = init_dir.parent_path();
}
do {
auto data_dir = cur / "data";
auto data_dir = cur_dir / "data";
if (std::filesystem::exists(data_dir) && std::filesystem::is_directory(data_dir)) {
break;
}
cur = cur.parent_path();
ASSERT_NE(cur, cur.root_path()) << "Reached root directory without finding 'data' directory.";
cur_dir = cur_dir.parent_path();
ASSERT_NE(cur_dir, cur_dir.root_path()) << "Reached root directory without finding 'data' directory.";
} while (true);
// set current path as the extensions library is also loaded from that directory by TestInference
std::filesystem::current_path(cur);
std::filesystem::current_path(cur_dir);
}
} // namespace
@ -38,7 +44,7 @@ int main(int argc, char** argv) {
OCOS_TRY {
::testing::InitGoogleTest(&argc, argv);
FixCurrentDir();
FixCurrentDir(argv[0]);
status = RUN_ALL_TESTS();
}
OCOS_CATCH(const std::exception& ex) {