Re-organize the source code folder structure (#88)

* Reorg the code folder structure

* update the math test case

* Add an matrix inverse op.

* turn off the ctest by default.

* disbable jpeg lib in dlib for Linux build issue.

* Linux build fixing

* typo

* enable dlib library on Win32 build

* rename ocos to operators

* add the missing operator folder
This commit is contained in:
Wenbing Li 2021-05-04 17:12:28 -07:00 коммит произвёл GitHub
Родитель 2243847f22
Коммит c891e5d732
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
66 изменённых файлов: 1327 добавлений и 1144 удалений

Просмотреть файл

@ -7,6 +7,15 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE)
endif()
project(onnxruntime_extensions)
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
set(CPACK_PACKAGE_VERSION_MAJOR "0")
set(CPACK_PACKAGE_VERSION_MINOR "2")
set(CPACK_PACKAGE_VERSION_PATCH "0")
set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH})
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
@ -15,15 +24,15 @@ include(CheckLanguage)
option(CC_OPTIMIZE "Allow compiler optimizations, Set to OFF to disable" ON)
option(OCOS_ENABLE_PYTHON "Enable Python component building" OFF)
option(OCOS_ENABLE_CTEST "Enable C++ test" ON)
option(OCOS_ENABLE_CTEST "Enable C++ test" OFF)
option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
option(OCOS_ENABLE_MATH "Enable the math tensor operators building" ON)
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
if(NOT CC_OPTIMIZE)
message("!!!THE COMPILER OPTIMIZATION HAS BEEN DISABLED, DEBUG-ONLY!!!")
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO}")
@ -54,13 +63,13 @@ include(FetchContent)
if (OCOS_ENABLE_TF_STRING)
if (NOT TARGET re2::re2)
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
message(STATUS "fetch googlere2")
message(STATUS "Fetch googlere2")
include(googlere2)
FetchContent_GetProperties(googlere2)
endif()
if (NOT TARGET farmhash)
message(STATUS "fetch farmhash")
message(STATUS "Fetch farmhash")
include(farmhash)
FetchContent_GetProperties(farmhash)
endif()
@ -70,22 +79,32 @@ if (OCOS_ENABLE_TF_STRING)
endif()
endif()
file(GLOB TARGET_SRC "./ocos/*.cc" "./ocos/*.h*" "./ocos/utils/*.h*" "./ocos/utils/*.cc")
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
if (OCOS_ENABLE_TF_STRING)
file(GLOB TARGET_SRC_KERNELS "./ocos/kernels/*.cc" "./ocos/kernels/*.h*")
file(GLOB TARGET_SRC_KERNELS "operators/text/*.cc" "operators/text/*.h*")
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
endif()
if (OCOS_ENABLE_MATH)
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
set(DLIB_USE_CUDA OFF CACHE INTERNAL "")
set(DLIB_USE_LAPACK OFF CACHE INTERNAL "")
set(DLIB_USE_BLAS OFF CACHE INTERNAL "")
include(dlib)
file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*")
list(APPEND TARGET_SRC ${TARGET_SRC_MATH})
endif()
if (OCOS_ENABLE_GPT2_TOKENIZER)
# GPT2
if (NOT TARGET nlohmann_json)
set(JSON_BuildTests OFF CACHE INTERNAL "")
message(STATUS "fetch json")
message(STATUS "Fetch json")
include(json)
endif()
file(GLOB tok_TARGET_SRC "tokenizer/gpt*.cc" "tokenizer/unicode*.*")
file(GLOB tok_TARGET_SRC "operators/tokenizer/gpt*.cc" "operators/tokenizer/unicode*.*")
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
endif()
@ -93,23 +112,23 @@ if (OCOS_ENABLE_SPM_TOKENIZER)
# SentencePiece
set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "")
set(SPM_ENABLE_SHARED OFF CACHE INTERNAL "")
message(STATUS "fetch sentencepiece")
message(STATUS "Fetch sentencepiece")
include(sentencepieceproject)
file(GLOB stpiece_TARGET_SRC "sentencepiece/*.cc" "tokenizer/sentencepiece*")
file(GLOB stpiece_TARGET_SRC "operators/tokenizer/sentencepiece/*.cc" "operators/tokenizer/sentencepiece*")
list(REMOVE_ITEM stpiece_TARGET_SRC INCLUDE REGEX ".*((spm)|(train)).*")
list(APPEND TARGET_SRC ${stpiece_TARGET_SRC})
endif()
if (OCOS_ENABLE_BERT_TOKENIZER)
# Bert
file(GLOB bert_TARGET_SRC "tokenizer/wordpiece*.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/wordpiece*.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
add_library(ortcustomops_static STATIC ${TARGET_SRC})
target_include_directories(ortcustomops_static PUBLIC tokenizer)
target_include_directories(ortcustomops_static PUBLIC operators/tokenizer)
set(ocos_libraries ortcustomops_static)
if (OCOS_ENABLE_TF_STRING)
@ -119,7 +138,7 @@ endif()
target_include_directories(ortcustomops_static PUBLIC
${PROJECT_SOURCE_DIR}/includes
${PROJECT_SOURCE_DIR}/includes/onnxruntime
${PROJECT_SOURCE_DIR}/ocos)
${PROJECT_SOURCE_DIR}/operators)
set(OCOS_COMPILE_DEFINITIONS "")
@ -130,6 +149,15 @@ if (OCOS_ENABLE_TF_STRING)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TF_STRING)
endif()
if (OCOS_ENABLE_MATH)
target_include_directories(ortcustomops_static PUBLIC ${dlib_SOURCE_DIR})
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_MATH)
# The dlib matrix implementation is all in the headers, no library compiling needed.
if (WIN32)
list(APPEND ocos_libraries dlib::dlib)
endif()
endif()
if (OCOS_ENABLE_GPT2_TOKENIZER)
# GPT2
target_include_directories(ortcustomops_static PRIVATE ${json_SOURCE_DIR}/single_include)
@ -158,7 +186,7 @@ target_compile_definitions(ortcustomops_static PRIVATE ${OCOS_COMPILE_DEFINITION
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h")
if(OCOS_ENABLE_PYTHON)
file(GLOB TARGET_SRC_PYOPS "./ocos/pyfunc/*.cc" "./ocos/pyfunc/*.h*")
file(GLOB TARGET_SRC_PYOPS "pyop/*.cc" "pyop/*.h")
set(Python3_FIND_REGISTRY NEVER CACHE STRING "...")
if(NOT "${Python3_FIND_REGISTRY}" STREQUAL "NEVER")
message(FATAL_ERROR "Python3_FIND_REGISTRY is not NEVER")
@ -201,7 +229,7 @@ target_compile_definitions(ortcustomops PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_link_libraries(ortcustomops PRIVATE ${ocos_libraries})
if(OCOS_ENABLE_PYTHON)
message(STATUS "fetch pybind11")
message(STATUS "Fetch pybind11")
include(pybind11)
set(NUMPY_NOT_FOUND false)
exec_program("${Python3_EXECUTABLE}"
@ -228,10 +256,6 @@ if(OCOS_ENABLE_PYTHON)
endif()
endif()
set(CPACK_PROJECT_NAME ${PROJECT_NAME})
set(CPACK_PROJECT_VERSION ${PROJECT_VERSION})
message(STATUS "fetch CPack")
include(CPack)
# test section
if (OCOS_ENABLE_CTEST)
@ -242,29 +266,33 @@ if (OCOS_ENABLE_CTEST)
endif()
enable_testing()
message(STATUS "fetch CTest")
message(STATUS "Fetch CTest")
include(CTest)
set(TEST_SRC_DIR ${PROJECT_SOURCE_DIR}/test)
message(STATUS "fetch googletest")
message(STATUS "Fetch googletest")
include(googletest)
file(GLOB static_TEST_SRC "${TEST_SRC_DIR}/static_test/*.cc")
add_executable(ortcustomops_static_test ${static_TEST_SRC})
target_link_libraries(ortcustomops_static_test gtest_main ${ocos_libraries})
add_test(NAME ortcustomops_static_test COMMAND $<TARGET_FILE:ortcustomops_static_test>)
# needs to link with stdc++fs in Linux
if(UNIX AND NOT APPLE)
set(FS_STDLIB stdc++fs)
endif()
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc")
add_executable(ortcustomops_test ${shared_TEST_SRC})
if (ONNXRUNTIME_LIB_DIR)
target_link_directories(ortcustomops_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
target_link_libraries(ortcustomops_test ortcustomops onnxruntime gtest_main ${ocos_libraries})
if (WIN32)
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
add_custom_command(
TARGET ortcustomops_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_LIB_FILES} $<TARGET_FILE_DIR:ortcustomops_test>)
endif()
endif()
target_link_libraries(ortcustomops_test ortcustomops onnxruntime gtest_main ${ocos_libraries} ${FS_STDLIB})
if (WIN32)
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
add_custom_command(
TARGET ortcustomops_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_LIB_FILES} $<TARGET_FILE_DIR:ortcustomops_test>)
endif()
set(TEST_DATA_SRC ${TEST_SRC_DIR}/data)

Просмотреть файл

@ -42,7 +42,7 @@ jobs:
displayName: Unpack ONNXRuntime package.
- script: |
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-linux-x64-$(ort.version)/lib
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-linux-x64-$(ort.version)/lib -DOCOS_ENABLE_CTEST=ON
displayName: build the customop library with onnxruntime
- script: |
@ -133,7 +133,7 @@ jobs:
displayName: Unpack ONNXRuntime package.
- script: |
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-osx-x64-$(ort.version)/lib
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-osx-x64-$(ort.version)/lib -DOCOS_ENABLE_CTEST=ON
displayName: build the customop library with onnxruntime
- script: |
@ -249,7 +249,7 @@ jobs:
displayName: Unpack ONNXRuntime package.
- script: |
call .\build.bat -DONNXRUNTIME_LIB_DIR=.\onnxruntime-win-x64-$(ort.version)\lib
call .\build.bat -DONNXRUNTIME_LIB_DIR=.\onnxruntime-win-x64-$(ort.version)\lib -DOCOS_ENABLE_CTEST=ON
displayName: build the customop library with onnxruntime
- script: |
@ -341,6 +341,6 @@ jobs:
displayName: Setup emscripten pipeline
- script: |
sh ./build.sh -DCMAKE_TOOLCHAIN_FILE=cmake/deps/emsdk/upstream/emscripten/cmake/Modules/Platform/Emscripten.cmake -DOCOS_ENABLE_SPM_TOKENIZER=OFF -DOCOS_ENABLE_PYTHON=OFF -DOCOS_ENABLE_CTEST=OFF
sh ./build.sh -DCMAKE_TOOLCHAIN_FILE=cmake/deps/emsdk/upstream/emscripten/cmake/Modules/Platform/Emscripten.cmake -DOCOS_ENABLE_SPM_TOKENIZER=OFF -DOCOS_ENABLE_PYTHON=OFF
displayName: build the customop library with onnxruntime
# TODO add unittest for webassembly
# TODO add unittest for webassembly

@ -1 +0,0 @@
Subproject commit 2e7eaf7233144e5e25b1c1338890bbab5d011815

14
cmake/externals/dlib.cmake поставляемый Normal file
Просмотреть файл

@ -0,0 +1,14 @@
FetchContent_Declare(dlib
GIT_REPOSITORY https://github.com/davisking/dlib.git
GIT_TAG v19.22
)
if (WIN32)
FetchContent_MakeAvailable(dlib)
else()
FetchContent_GetProperties(dlib)
if(NOT dlib_POPULATED)
# Fetch the content using previously declared details
FetchContent_Populate(dlib)
endif()
endif()

Просмотреть файл

@ -3,10 +3,13 @@
#pragma once
#include <vector>
#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT
typedef const OrtCustomOp** (*FxLoadCustomOpFactory)();
#if defined(ENABLE_GPT2_TOKENIZER)
const OrtCustomOp** LoadTokenizerSchemaList();
#endif // ENABLE_GPT2_TOKENIZER
@ -16,6 +19,7 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count);
bool EnablePyCustomOps(bool enable = true);
#endif
// A helper API to support test kernels.
// Must be invoked before RegisterCustomOps.
extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op);
@ -52,3 +56,33 @@ struct OrtTensorDimensions : std::vector<int64_t> {
return s;
}
};
template <class... Args>
class CuopContainer {
public:
CuopContainer() : ocos_list_({[]() { return new Args; }()...}) {
ocos_list_.push_back(nullptr);
}
~CuopContainer() {
// skip the last null pointer.
for (auto i = 0; i < ocos_list_.size() - 1; i++) {
delete ocos_list_[i];
}
ocos_list_.clear();
}
const OrtCustomOp** GetList() {
return &const_cast<const OrtCustomOp*&>(ocos_list_.front());
}
private:
std::vector<OrtCustomOp*> ocos_list_;
};
template <typename... Args>
const OrtCustomOp** LoadCustomOpClasses() {
static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
return ctr.GetList();
}

Просмотреть файл

@ -1,29 +0,0 @@
#include "string_utils.h"
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
std::vector<std::string_view> result;
std::string ::size_type pre_pos = 0;
while (true) {
auto next_pos = str.find_first_of(seps, pre_pos);
if (next_pos == std::string::npos) {
auto sub_str = str.substr(pre_pos, next_pos);
// sub_str is empty means the last sep reach the end of string
if (!sub_str.empty()) {
result.push_back(sub_str);
}
break;
}
if (pre_pos != next_pos || !remove_empty_entries) {
auto sub_str = str.substr(pre_pos, next_pos - pre_pos);
result.push_back(sub_str);
}
pre_pos = next_pos + 1;
}
return result;
}

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <dlib/matrix.h>
#include "ocos.h"
struct KernelInverse : BaseKernel {
KernelInverse(OrtApi api) : BaseKernel(api) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 2) {
throw std::runtime_error("Only 2-d matrix supported.");
}
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out0 = ort_.GetTensorMutableData<float>(output0);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
dlib::matrix<float> dm(dimensions[0], dimensions[1]);
// Do computation
for (int64_t i = 0; i < size; i++) {
out0[i] = dm(i / dimensions[1], i % dimensions[1]);
}
}
};
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new KernelInverse(api);
}
const char* GetName() const {
return "Inverse";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
};

7
operators/math/math.cc Normal file
Просмотреть файл

@ -0,0 +1,7 @@
#include "ocos.h"
#include "negpos.hpp"
#include "inverse.hpp"
template const OrtCustomOp** LoadCustomOpClasses<CustomOpNegPos, CustomOpInverse>();
FxLoadCustomOpFactory LoadCustomOpClasses_Math = &LoadCustomOpClasses<CustomOpNegPos, CustomOpInverse>;

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "kernels.h"
#include "ocos.h"
struct KernelNegPos : BaseKernel {
KernelNegPos(OrtApi api) : BaseKernel(api) {

Просмотреть файл

@ -1,79 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ocos.h"
#include "utils/string_utils.h"
bool BaseKernel::HasAttribute(const char* name) const {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
size_t size;
std::string out;
// Crashes here.
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
auto r = api_.GetErrorCode(status);
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
if (has) {
api_.ReleaseStatus(status);
return has;
}
const char* error = api_.GetErrorMessage(status);
if (strstr(error, "No attribute") == error) {
api_.ReleaseStatus(status);
return false;
}
api_.ReleaseStatus(status);
return true;
}
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
if (status == nullptr) {
return ORT_OK;
}
auto error_code = api_.GetErrorCode(status);
api_.ReleaseStatus(status);
return error_code;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
size_t size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
return false;
}
value.resize(size);
status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
if (GetErrorCodeAndRelease(status) != ORT_OK) {
return false;
}
value.resize(size - 1);
return true;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ocos.h"
bool BaseKernel::HasAttribute(const char* name) const {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
size_t size;
std::string out;
// Crashes here.
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
auto r = api_.GetErrorCode(status);
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
if (has) {
api_.ReleaseStatus(status);
return has;
}
const char* error = api_.GetErrorMessage(status);
if (strstr(error, "No attribute") == error) {
api_.ReleaseStatus(status);
return false;
}
api_.ReleaseStatus(status);
return true;
}
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
if (status == nullptr) {
return ORT_OK;
}
auto error_code = api_.GetErrorCode(status);
api_.ReleaseStatus(status);
return error_code;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
size_t size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
return false;
}
value.resize(size);
status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
if (GetErrorCodeAndRelease(status) != ORT_OK) {
return false;
}
value.resize(size - 1);
return true;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
if (info_ == nullptr) {
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
}

Просмотреть файл

@ -1,8 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_common.h"
#include "utils/string_utils.h"
#include "string_utils.h"
#include "string_tensor.h"
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output) {

Просмотреть файл

@ -4,8 +4,8 @@
#pragma once
#include <string>
#include "ustring.hpp"
#include "kernels.h"
#include "ustring.h"
#include "ocos.h"
// Retrieves a vector of strings if the input type is std::string.

97
operators/string_utils.cc Normal file
Просмотреть файл

@ -0,0 +1,97 @@
#include "farmhash.h"
#include "string_utils.h"
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
std::vector<std::string_view> result;
std::string ::size_type pre_pos = 0;
while (true) {
auto next_pos = str.find_first_of(seps, pre_pos);
if (next_pos == std::string::npos) {
auto sub_str = str.substr(pre_pos, next_pos);
// sub_str is empty means the last sep reach the end of string
if (!sub_str.empty()) {
result.push_back(sub_str);
}
break;
}
if (pre_pos != next_pos || !remove_empty_entries) {
auto sub_str = str.substr(pre_pos, next_pos - pre_pos);
result.push_back(sub_str);
}
pre_pos = next_pos + 1;
}
return result;
}
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L41
uint64_t DecodeFixed32(const char* ptr) {
return ((static_cast<uint64_t>(static_cast<unsigned char>(ptr[0]))) |
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[1])) << 8) |
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[2])) << 16) |
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[3])) << 24));
}
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L55
static uint64_t DecodeFixed64(const char* ptr) {
uint64_t lo = DecodeFixed32(ptr);
uint64_t hi = DecodeFixed32(ptr + 4);
return (hi << 32) | lo;
}
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L79
uint64_t Hash64(const char* data, size_t n, uint64_t seed) {
const uint64_t m = 0xc6a4a7935bd1e995;
const int r = 47;
uint64_t h = seed ^ (n * m);
while (n >= 8) {
uint64_t k = DecodeFixed64(data);
data += 8;
n -= 8;
k *= m;
k ^= k >> r;
k *= m;
h ^= k;
h *= m;
}
switch (n) {
case 7:
h ^= ByteAs64(data[6]) << 48;
case 6:
h ^= ByteAs64(data[5]) << 40;
case 5:
h ^= ByteAs64(data[4]) << 32;
case 4:
h ^= ByteAs64(data[3]) << 24;
case 3:
h ^= ByteAs64(data[2]) << 16;
case 2:
h ^= ByteAs64(data[1]) << 8;
case 1:
h ^= ByteAs64(data[0]);
h *= m;
}
h ^= h >> r;
h *= m;
h ^= h >> r;
return h;
}
uint64_t Hash64Fast(const char* data, size_t n) {
return static_cast<int64_t>(util::Fingerprint64(data, n));
}

Просмотреть файл

@ -1,53 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
ss << t;
}
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t>& t) noexcept {
ss << "[";
for (int i = 0; i < t.size(); i++) {
if (i != 0) {
ss << ", ";
}
ss << t[i];
}
ss << "]";
}
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
ss << "[";
for (int i = 0; i < t.size(); i++) {
if (i != 0) {
ss << ", ";
}
ss << t[i];
}
ss << "]";
}
template <typename T, typename... Args>
void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
MakeStringInternal(ss, t);
MakeStringInternal(ss, args...);
}
template <typename... Args>
std::string MakeString(const Args&... args) {
std::ostringstream ss;
MakeStringInternal(ss, args...);
return std::string(ss.str());
}
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries = false);
void char2unicode(const std::string& src, std::vector<uint32_t>& result);
void unicode2char(const std::vector<uint32_t>& src, std::string& result);
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
ss << t;
}
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t>& t) noexcept {
ss << "[";
for (int i = 0; i < t.size(); i++) {
if (i != 0) {
ss << ", ";
}
ss << t[i];
}
ss << "]";
}
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
ss << "[";
for (int i = 0; i < t.size(); i++) {
if (i != 0) {
ss << ", ";
}
ss << t[i];
}
ss << "]";
}
template <typename T, typename... Args>
void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
MakeStringInternal(ss, t);
MakeStringInternal(ss, args...);
}
template <typename... Args>
std::string MakeString(const Args&... args) {
std::ostringstream ss;
MakeStringInternal(ss, args...);
return std::string(ss.str());
}
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries = false);
void char2unicode(const std::string& src, std::vector<uint32_t>& result);
void unicode2char(const std::vector<uint32_t>& src, std::string& result);
uint64_t Hash64(const char* data, size_t n, uint64_t seed);
inline uint64_t Hash64(const char* data, size_t n) {
return Hash64(data, n, 0xDECAFCAFFE);
}
uint64_t Hash64Fast(const char* data, size_t n);

Просмотреть файл

@ -4,3 +4,6 @@
#pragma once
#include "ocos.h"
// TO BE DELETED.

Просмотреть файл

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringEqual : BaseKernel {
KernelStringEqual(OrtApi api);

Просмотреть файл

@ -3,8 +3,8 @@
#pragma once
#include <vector>
#include <string>
#include "utils/string_utils.h"
#include "string_common.h"
#include "string_utils.h"
#include "string_tensor.h"
template <typename T1, typename T2, typename T3>
class BroadcastIteratorRight {

Просмотреть файл

@ -1,8 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_utils.h"
#include "string_tensor.h"
#include "op_ragged_tensor.hpp"
#include "string_common.h"
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) {
}

Просмотреть файл

@ -4,7 +4,6 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
struct KernelRaggedTensorToSparse : BaseKernel {
KernelRaggedTensorToSparse(OrtApi api);

Просмотреть файл

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelSegmentSum : BaseKernel {
KernelSegmentSum(OrtApi api);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "string_concat.hpp"
#include "string_common.h"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringConcat : BaseKernel {
KernelStringConcat(OrtApi api);

Просмотреть файл

@ -1,80 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_hash.hpp"
#include <vector>
#include <cmath>
#include <algorithm>
#include "re2/re2.h"
#include "farmhash.h"
#include "string_common.h"
#include "string_tensor.h"
#include "string_hash.hpp"
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L41
uint64_t DecodeFixed32(const char* ptr) {
return ((static_cast<uint64_t>(static_cast<unsigned char>(ptr[0]))) |
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[1])) << 8) |
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[2])) << 16) |
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[3])) << 24));
}
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L55
static uint64_t DecodeFixed64(const char* ptr) {
uint64_t lo = DecodeFixed32(ptr);
uint64_t hi = DecodeFixed32(ptr + 4);
return (hi << 32) | lo;
}
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L79
uint64_t Hash64(const char* data, size_t n, uint64_t seed) {
const uint64_t m = 0xc6a4a7935bd1e995;
const int r = 47;
uint64_t h = seed ^ (n * m);
while (n >= 8) {
uint64_t k = DecodeFixed64(data);
data += 8;
n -= 8;
k *= m;
k ^= k >> r;
k *= m;
h ^= k;
h *= m;
}
switch (n) {
case 7:
h ^= ByteAs64(data[6]) << 48;
case 6:
h ^= ByteAs64(data[5]) << 40;
case 5:
h ^= ByteAs64(data[4]) << 32;
case 4:
h ^= ByteAs64(data[3]) << 24;
case 3:
h ^= ByteAs64(data[2]) << 16;
case 2:
h ^= ByteAs64(data[1]) << 8;
case 1:
h ^= ByteAs64(data[0]);
h *= m;
}
h ^= h >> r;
h *= m;
h ^= h >> r;
return h;
}
uint64_t Hash64Fast(const char* data, size_t n) {
return static_cast<int64_t>(util::Fingerprint64(data, n));
}
KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
}

Просмотреть файл

@ -4,15 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
uint64_t Hash64(const char* data, size_t n, uint64_t seed);
inline uint64_t Hash64(const char* data, size_t n) {
return Hash64(data, n, 0xDECAFCAFFE);
}
uint64_t Hash64Fast(const char* data, size_t n);
#include "string_utils.h"
struct KernelStringHash : BaseKernel {
KernelStringHash(OrtApi api);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "string_join.hpp"
#include "string_common.h"
#include "string_tensor.h"
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
}

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringJoin : BaseKernel {
KernelStringJoin(OrtApi api);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "string_length.hpp"
#include "string_common.h"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringLength : BaseKernel {
KernelStringLength(OrtApi api);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "string_lower.hpp"
#include "string_common.h"
#include "string_tensor.h"
#include <vector>
#include <cmath>
#include <algorithm>

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringLower : BaseKernel {
KernelStringLower(OrtApi api);

Просмотреть файл

@ -6,7 +6,7 @@
#include <cmath>
#include <algorithm>
#include "re2/re2.h"
#include "string_common.h"
#include "string_tensor.h"
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringRegexReplace : BaseKernel {
KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info);

Просмотреть файл

@ -5,7 +5,7 @@
#include "string_regex_split_re.hpp"
#include <vector>
#include <cmath>
#include "string_common.h"
#include "string_tensor.h"
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
}

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringRegexSplitWithOffsets : BaseKernel {

Просмотреть файл

@ -1,137 +1,137 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_split.hpp"
#include "string_common.h"
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
}
void KernelStringSplit::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
std::vector<std::string> X, sep;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 1)
throw std::runtime_error("Only 1D tensor are supported as input.");
std::vector<std::string> words;
std::vector<int64_t> indices;
int64_t maxc = 0;
int64_t col;
std::string delimiter = sep[0];
if (delimiter.size() == 0) {
char word[2] = "a";
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
if (str.empty())
continue;
maxc = str.size() > maxc ? str.size() : maxc;
for (auto it = str.begin(); it != str.end(); ++it) {
word[0] = *it;
words.push_back(word);
indices.push_back(row);
indices.push_back(std::distance(str.begin(), it));
}
}
} else {
bool keep = !(*skip_empty);
std::size_t current, previous = 0;
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
if (str.empty())
continue;
previous = 0;
col = 0;
current = str.find_first_of(delimiter);
while (current != std::string::npos) {
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
indices.push_back(row);
indices.push_back(col);
++col;
}
previous = current + 1;
current = str.find_first_of(delimiter, previous);
}
current = str.size();
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
indices.push_back(row);
indices.push_back(col);
++col;
}
maxc = col > maxc ? col : maxc;
}
}
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
std::vector<int64_t> shape_text(1, words.size());
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
std::vector<int64_t> shape_shape(1, 2);
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
p_shape[0] = dimensions[0];
p_shape[1] = maxc;
FillTensorDataString(api_, ort_, context, words, out_text);
}
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
return new KernelStringSplit(api);
};
const char* CustomOpStringSplit::GetName() const {
return "StringSplit";
};
size_t CustomOpStringSplit::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
}
};
size_t CustomOpStringSplit::GetOutputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
switch (index) {
case 0:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
default:
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
}
};
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_split.hpp"
#include "string_tensor.h"
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
}
void KernelStringSplit::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
std::vector<std::string> X, sep;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 1)
throw std::runtime_error("Only 1D tensor are supported as input.");
std::vector<std::string> words;
std::vector<int64_t> indices;
int64_t maxc = 0;
int64_t col;
std::string delimiter = sep[0];
if (delimiter.size() == 0) {
char word[2] = "a";
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
if (str.empty())
continue;
maxc = str.size() > maxc ? str.size() : maxc;
for (auto it = str.begin(); it != str.end(); ++it) {
word[0] = *it;
words.push_back(word);
indices.push_back(row);
indices.push_back(std::distance(str.begin(), it));
}
}
} else {
bool keep = !(*skip_empty);
std::size_t current, previous = 0;
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
if (str.empty())
continue;
previous = 0;
col = 0;
current = str.find_first_of(delimiter);
while (current != std::string::npos) {
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
indices.push_back(row);
indices.push_back(col);
++col;
}
previous = current + 1;
current = str.find_first_of(delimiter, previous);
}
current = str.size();
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
indices.push_back(row);
indices.push_back(col);
++col;
}
maxc = col > maxc ? col : maxc;
}
}
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
std::vector<int64_t> shape_text(1, words.size());
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
std::vector<int64_t> shape_shape(1, 2);
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
p_shape[0] = dimensions[0];
p_shape[1] = maxc;
FillTensorDataString(api_, ort_, context, words, out_text);
}
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
return new KernelStringSplit(api);
};
const char* CustomOpStringSplit::GetName() const {
return "StringSplit";
};
size_t CustomOpStringSplit::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
}
};
size_t CustomOpStringSplit::GetOutputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
switch (index) {
case 0:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
default:
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
}
};

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringSplit : BaseKernel {
KernelStringSplit(OrtApi api);

Просмотреть файл

@ -1,8 +1,8 @@
#include <charconv>
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
#include "string_to_vector.hpp"
#include "string_common.h"
#include "string_tensor.h"
StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
ParseMappingTable(map);

Просмотреть файл

@ -8,7 +8,7 @@
#include <vector>
#include "kernels.h"
#include "farmhash.h"
#include "utils/string_utils.h"
#include "string_utils.h"
class StringToVectorImpl {

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "string_upper.hpp"
#include "string_common.h"
#include "string_tensor.h"
#include <vector>
#include <cmath>
#include <algorithm>

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
struct KernelStringUpper : BaseKernel {
KernelStringUpper(OrtApi api);

Просмотреть файл

@ -1,8 +1,8 @@
#include <charconv>
#include "kernels.h"
#include "utils/string_utils.h"
#include "string_utils.h"
#include "vector_to_string.hpp"
#include "string_common.h"
#include "string_tensor.h"
VectorToStringImpl::VectorToStringImpl(std::string& map, std::string& unk) : unk_value_(unk) {
ParseMappingTable(map);

Просмотреть файл

@ -8,7 +8,7 @@
#include <vector>
#include "kernels.h"
#include "farmhash.h"
#include "utils/string_utils.h"
#include "string_utils.h"
namespace std {

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -4,7 +4,7 @@
#include "sentencepiece_processor.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_tokenizer.hpp"
#include "kernels/string_common.h"
#include "string_tensor.h"
#include "base64.h"
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {

Просмотреть файл

@ -3,8 +3,8 @@
#pragma once
#include "kernels/kernels.h"
#include "utils/string_utils.h"
#include "ocos.h"
#include "string_utils.h"
#include "sentencepiece_processor.h"
struct KernelSentencepieceTokenizer : BaseKernel {

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -5,9 +5,10 @@
#include <unordered_map>
#include <vector>
#include "kernels/kernels.h"
#include "kernels/string_common.h"
#include "utils/string_utils.h"
#include "ocos.h"
#include "ustring.h"
#include "string_utils.h"
#include "string_tensor.h"
struct KernelWordpieceTokenizer : BaseKernel {
KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info);

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include "ustring.hpp"
#include "ustring.h"
ustring::ustring(): std::u32string() {
}

Просмотреть файл

Просмотреть файл

@ -17,10 +17,10 @@
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <thread>
#include "utils/string_utils.h"
#include "string_utils.h"
#include "pykernel.h"
#include "kernels/string_hash.hpp"
#include "kernels/string_common.h"
#include "text/string_hash.hpp"
#include "string_tensor.h"
namespace py = pybind11;

Просмотреть файл

Просмотреть файл

@ -3,22 +3,23 @@
#include <set>
#include "kernels/op_equal.hpp"
#include "kernels/op_segment_sum.hpp"
#include "kernels/op_ragged_tensor.hpp"
#include "kernels/string_hash.hpp"
#include "kernels/string_join.hpp"
#include "kernels/string_lower.hpp"
#include "kernels/string_regex_replace.hpp"
#include "kernels/string_regex_split.hpp"
#include "kernels/string_split.hpp"
#include "kernels/string_to_vector.hpp"
#include "kernels/string_upper.hpp"
#include "kernels/negpos.hpp"
#include "kernels/vector_to_string.hpp"
#include "kernels/string_length.hpp"
#include "kernels/string_concat.hpp"
#include "utils/string_utils.h"
#include "string_utils.h"
#include "text/op_equal.hpp"
#include "text/op_segment_sum.hpp"
#include "text/op_ragged_tensor.hpp"
#include "text/string_hash.hpp"
#include "text/string_join.hpp"
#include "text/string_lower.hpp"
#include "text/string_regex_replace.hpp"
#include "text/string_regex_split.hpp"
#include "text/string_split.hpp"
#include "text/string_to_vector.hpp"
#include "text/string_upper.hpp"
#include "text/vector_to_string.hpp"
#include "text/string_length.hpp"
#include "text/string_concat.hpp"
#ifdef ENABLE_SPM_TOKENIZER
#include "sentencepiece_tokenizer.hpp"
@ -37,7 +38,6 @@ CustomOpWordpieceTokenizer c_CustomOpWordpieceTokenizer;
#endif
#ifdef ENABLE_TF_STRING
CustomOpNegPos c_CustomOpNegPos;
CustomOpSegmentSum c_CustomOpSegmentSum;
CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense;
CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
@ -67,7 +67,6 @@ OrtCustomOp* operator_lists[] = {
#endif
#ifdef ENABLE_TF_STRING
&c_CustomOpNegPos,
&c_CustomOpRaggedTensorToDense,
&c_CustomOpRaggedTensorToSparse,
&c_CustomOpSegmentSum,
@ -88,6 +87,10 @@ OrtCustomOp* operator_lists[] = {
#endif
nullptr};
#if ENABLE_MATH
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
#endif //ENABLE_MATH
class ExternalCustomOps {
public:
ExternalCustomOps() {
@ -117,11 +120,13 @@ class ExternalCustomOps {
std::vector<const OrtCustomOp*> op_array_;
};
extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op) {
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
ExternalCustomOps::instance().Add(c_op);
return true;
}
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtCustomOpDomain* domain = nullptr;
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
@ -145,27 +150,29 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
}
#endif
OrtCustomOp** ops = operator_lists;
while (*ops != nullptr) {
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
return status;
}
}
++ops;
}
#if defined(ENABLE_GPT2_TOKENIZER)
const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
while (*t_ops != nullptr) {
if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)) {
return status;
}
}
t_ops++;
}
static std::vector<FxLoadCustomOpFactory> c_factories = {
[]() { return const_cast<const OrtCustomOp**>(operator_lists); }
#if defined(ENABLE_MATH)
,
LoadCustomOpClasses_Math
#endif
#if defined(ENABLE_GPT2_TOKENIZER)
,
LoadTokenizerSchemaList
#endif
};
for (auto fx : c_factories) {
auto ops = fx();
while (*ops != nullptr) {
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
return status;
}
}
++ops;
}
}
size_t idx = 0;
const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);

Просмотреть файл

@ -2,12 +2,14 @@
// Licensed under the MIT License.
#include "onnxruntime_cxx_api.h"
#include <filesystem>
#include "gtest/gtest.h"
#include "ocos.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "test_kernel.hpp"
#include "utils/string_utils.h"
#include "kernels/string_common.h"
#include <filesystem>
const char* GetLibraryPath() {
#if defined(_WIN32)

Просмотреть файл

@ -1,12 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_cxx_api.h"
#include <filesystem>
#include "gtest/gtest.h"
#include "ocos.h"
#include "test_kernel.hpp"
#include "kernels/string_lower.hpp"
#include <filesystem>
#include "text/string_lower.hpp"
TEST(utils, test_string_lower) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");

Просмотреть файл

@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include <dlib/matrix.h>
using namespace dlib;
TEST(math, matrix_op) {
matrix<float> M(3,3);
M = 54.2, 7.4, 12.1,
1, 2, 3,
5.9, 0.05, 1;
matrix<float,3,1> y;
y = 3.5,
1.2,
7.8;
matrix<float> x = inv(M)*y;
EXPECT_FLOAT_EQ(x(1, 0), -13.909741);
}

Просмотреть файл

@ -2,8 +2,8 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "utils/string_utils.h"
#include "kernels/string_regex_split_re.hpp"
#include "string_utils.h"
#include "text/string_regex_split_re.hpp"
TEST(strings, regex_split) {
std::string input = "hello world";

Просмотреть файл

@ -2,9 +2,8 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "utils/string_utils.h"
#include "kernels/string_common.h"
#include "../tokenizer/wordpiece_tokenizer.hpp"
#include "string_utils.h"
#include "wordpiece_tokenizer.hpp"
TEST(tokenizer, bert_word_split) {
ustring ind("##");

Просмотреть файл

@ -1,6 +1,9 @@
#include "gtest/gtest.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cstring>
#include "kernels/ustring.hpp"
#include "gtest/gtest.h"
#include "ustring.h"
void convert_test(const char* const_str) {
std::string string(const_str);

Просмотреть файл

@ -2,9 +2,10 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "utils/string_utils.h"
#include "re2/re2.h"
#include "nlohmann/json.hpp"
#include "string_utils.h"
TEST(utils, make_string) {
std::string res = MakeString("a", "b", 0);