Re-organize the source code folder structure (#88)
* Reorg the code folder structure * update the math test case * Add an matrix inverse op. * turn off the ctest by default. * disbable jpeg lib in dlib for Linux build issue. * Linux build fixing * typo * enable dlib library on Win32 build * rename ocos to operators * add the missing operator folder
This commit is contained in:
Родитель
2243847f22
Коммит
c891e5d732
|
@ -7,6 +7,15 @@ if(NOT CMAKE_BUILD_TYPE)
|
|||
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE)
|
||||
endif()
|
||||
|
||||
|
||||
project(onnxruntime_extensions)
|
||||
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
|
||||
set(CPACK_PACKAGE_VERSION_MAJOR "0")
|
||||
set(CPACK_PACKAGE_VERSION_MINOR "2")
|
||||
set(CPACK_PACKAGE_VERSION_PATCH "0")
|
||||
set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH})
|
||||
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
@ -15,15 +24,15 @@ include(CheckLanguage)
|
|||
|
||||
option(CC_OPTIMIZE "Allow compiler optimizations, Set to OFF to disable" ON)
|
||||
option(OCOS_ENABLE_PYTHON "Enable Python component building" OFF)
|
||||
option(OCOS_ENABLE_CTEST "Enable C++ test" ON)
|
||||
option(OCOS_ENABLE_CTEST "Enable C++ test" OFF)
|
||||
option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
|
||||
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
|
||||
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
|
||||
option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
|
||||
option(OCOS_ENABLE_MATH "Enable the math tensor operators building" ON)
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
|
||||
|
||||
|
||||
if(NOT CC_OPTIMIZE)
|
||||
message("!!!THE COMPILER OPTIMIZATION HAS BEEN DISABLED, DEBUG-ONLY!!!")
|
||||
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO}")
|
||||
|
@ -54,13 +63,13 @@ include(FetchContent)
|
|||
if (OCOS_ENABLE_TF_STRING)
|
||||
if (NOT TARGET re2::re2)
|
||||
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
|
||||
message(STATUS "fetch googlere2")
|
||||
message(STATUS "Fetch googlere2")
|
||||
include(googlere2)
|
||||
FetchContent_GetProperties(googlere2)
|
||||
endif()
|
||||
|
||||
if (NOT TARGET farmhash)
|
||||
message(STATUS "fetch farmhash")
|
||||
message(STATUS "Fetch farmhash")
|
||||
include(farmhash)
|
||||
FetchContent_GetProperties(farmhash)
|
||||
endif()
|
||||
|
@ -70,22 +79,32 @@ if (OCOS_ENABLE_TF_STRING)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB TARGET_SRC "./ocos/*.cc" "./ocos/*.h*" "./ocos/utils/*.h*" "./ocos/utils/*.cc")
|
||||
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
file(GLOB TARGET_SRC_KERNELS "./ocos/kernels/*.cc" "./ocos/kernels/*.h*")
|
||||
file(GLOB TARGET_SRC_KERNELS "operators/text/*.cc" "operators/text/*.h*")
|
||||
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_MATH)
|
||||
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
|
||||
set(DLIB_USE_CUDA OFF CACHE INTERNAL "")
|
||||
set(DLIB_USE_LAPACK OFF CACHE INTERNAL "")
|
||||
set(DLIB_USE_BLAS OFF CACHE INTERNAL "")
|
||||
include(dlib)
|
||||
file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_MATH})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
# GPT2
|
||||
if (NOT TARGET nlohmann_json)
|
||||
set(JSON_BuildTests OFF CACHE INTERNAL "")
|
||||
message(STATUS "fetch json")
|
||||
message(STATUS "Fetch json")
|
||||
include(json)
|
||||
endif()
|
||||
|
||||
file(GLOB tok_TARGET_SRC "tokenizer/gpt*.cc" "tokenizer/unicode*.*")
|
||||
file(GLOB tok_TARGET_SRC "operators/tokenizer/gpt*.cc" "operators/tokenizer/unicode*.*")
|
||||
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
|
@ -93,23 +112,23 @@ if (OCOS_ENABLE_SPM_TOKENIZER)
|
|||
# SentencePiece
|
||||
set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "")
|
||||
set(SPM_ENABLE_SHARED OFF CACHE INTERNAL "")
|
||||
message(STATUS "fetch sentencepiece")
|
||||
message(STATUS "Fetch sentencepiece")
|
||||
include(sentencepieceproject)
|
||||
file(GLOB stpiece_TARGET_SRC "sentencepiece/*.cc" "tokenizer/sentencepiece*")
|
||||
file(GLOB stpiece_TARGET_SRC "operators/tokenizer/sentencepiece/*.cc" "operators/tokenizer/sentencepiece*")
|
||||
list(REMOVE_ITEM stpiece_TARGET_SRC INCLUDE REGEX ".*((spm)|(train)).*")
|
||||
list(APPEND TARGET_SRC ${stpiece_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_BERT_TOKENIZER)
|
||||
# Bert
|
||||
file(GLOB bert_TARGET_SRC "tokenizer/wordpiece*.*")
|
||||
file(GLOB bert_TARGET_SRC "operators/tokenizer/wordpiece*.*")
|
||||
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
|
||||
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
|
||||
add_library(ortcustomops_static STATIC ${TARGET_SRC})
|
||||
target_include_directories(ortcustomops_static PUBLIC tokenizer)
|
||||
target_include_directories(ortcustomops_static PUBLIC operators/tokenizer)
|
||||
|
||||
set(ocos_libraries ortcustomops_static)
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
|
@ -119,7 +138,7 @@ endif()
|
|||
target_include_directories(ortcustomops_static PUBLIC
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/includes/onnxruntime
|
||||
${PROJECT_SOURCE_DIR}/ocos)
|
||||
${PROJECT_SOURCE_DIR}/operators)
|
||||
|
||||
set(OCOS_COMPILE_DEFINITIONS "")
|
||||
|
||||
|
@ -130,6 +149,15 @@ if (OCOS_ENABLE_TF_STRING)
|
|||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TF_STRING)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_MATH)
|
||||
target_include_directories(ortcustomops_static PUBLIC ${dlib_SOURCE_DIR})
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_MATH)
|
||||
# The dlib matrix implementation is all in the headers, no library compiling needed.
|
||||
if (WIN32)
|
||||
list(APPEND ocos_libraries dlib::dlib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
# GPT2
|
||||
target_include_directories(ortcustomops_static PRIVATE ${json_SOURCE_DIR}/single_include)
|
||||
|
@ -158,7 +186,7 @@ target_compile_definitions(ortcustomops_static PRIVATE ${OCOS_COMPILE_DEFINITION
|
|||
|
||||
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h")
|
||||
if(OCOS_ENABLE_PYTHON)
|
||||
file(GLOB TARGET_SRC_PYOPS "./ocos/pyfunc/*.cc" "./ocos/pyfunc/*.h*")
|
||||
file(GLOB TARGET_SRC_PYOPS "pyop/*.cc" "pyop/*.h")
|
||||
set(Python3_FIND_REGISTRY NEVER CACHE STRING "...")
|
||||
if(NOT "${Python3_FIND_REGISTRY}" STREQUAL "NEVER")
|
||||
message(FATAL_ERROR "Python3_FIND_REGISTRY is not NEVER")
|
||||
|
@ -201,7 +229,7 @@ target_compile_definitions(ortcustomops PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
|||
target_link_libraries(ortcustomops PRIVATE ${ocos_libraries})
|
||||
|
||||
if(OCOS_ENABLE_PYTHON)
|
||||
message(STATUS "fetch pybind11")
|
||||
message(STATUS "Fetch pybind11")
|
||||
include(pybind11)
|
||||
set(NUMPY_NOT_FOUND false)
|
||||
exec_program("${Python3_EXECUTABLE}"
|
||||
|
@ -228,10 +256,6 @@ if(OCOS_ENABLE_PYTHON)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
set(CPACK_PROJECT_NAME ${PROJECT_NAME})
|
||||
set(CPACK_PROJECT_VERSION ${PROJECT_VERSION})
|
||||
message(STATUS "fetch CPack")
|
||||
include(CPack)
|
||||
|
||||
# test section
|
||||
if (OCOS_ENABLE_CTEST)
|
||||
|
@ -242,29 +266,33 @@ if (OCOS_ENABLE_CTEST)
|
|||
endif()
|
||||
|
||||
enable_testing()
|
||||
message(STATUS "fetch CTest")
|
||||
message(STATUS "Fetch CTest")
|
||||
include(CTest)
|
||||
|
||||
set(TEST_SRC_DIR ${PROJECT_SOURCE_DIR}/test)
|
||||
message(STATUS "fetch googletest")
|
||||
message(STATUS "Fetch googletest")
|
||||
include(googletest)
|
||||
file(GLOB static_TEST_SRC "${TEST_SRC_DIR}/static_test/*.cc")
|
||||
add_executable(ortcustomops_static_test ${static_TEST_SRC})
|
||||
target_link_libraries(ortcustomops_static_test gtest_main ${ocos_libraries})
|
||||
add_test(NAME ortcustomops_static_test COMMAND $<TARGET_FILE:ortcustomops_static_test>)
|
||||
|
||||
# needs to link with stdc++fs in Linux
|
||||
if(UNIX AND NOT APPLE)
|
||||
set(FS_STDLIB stdc++fs)
|
||||
endif()
|
||||
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc")
|
||||
add_executable(ortcustomops_test ${shared_TEST_SRC})
|
||||
if (ONNXRUNTIME_LIB_DIR)
|
||||
target_link_directories(ortcustomops_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
|
||||
target_link_libraries(ortcustomops_test ortcustomops onnxruntime gtest_main ${ocos_libraries})
|
||||
if (WIN32)
|
||||
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
|
||||
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
|
||||
add_custom_command(
|
||||
TARGET ortcustomops_test POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_LIB_FILES} $<TARGET_FILE_DIR:ortcustomops_test>)
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(ortcustomops_test ortcustomops onnxruntime gtest_main ${ocos_libraries} ${FS_STDLIB})
|
||||
if (WIN32)
|
||||
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
|
||||
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
|
||||
add_custom_command(
|
||||
TARGET ortcustomops_test POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_LIB_FILES} $<TARGET_FILE_DIR:ortcustomops_test>)
|
||||
endif()
|
||||
|
||||
set(TEST_DATA_SRC ${TEST_SRC_DIR}/data)
|
||||
|
|
|
@ -42,7 +42,7 @@ jobs:
|
|||
displayName: Unpack ONNXRuntime package.
|
||||
|
||||
- script: |
|
||||
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-linux-x64-$(ort.version)/lib
|
||||
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-linux-x64-$(ort.version)/lib -DOCOS_ENABLE_CTEST=ON
|
||||
displayName: build the customop library with onnxruntime
|
||||
|
||||
- script: |
|
||||
|
@ -133,7 +133,7 @@ jobs:
|
|||
displayName: Unpack ONNXRuntime package.
|
||||
|
||||
- script: |
|
||||
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-osx-x64-$(ort.version)/lib
|
||||
sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-osx-x64-$(ort.version)/lib -DOCOS_ENABLE_CTEST=ON
|
||||
displayName: build the customop library with onnxruntime
|
||||
|
||||
- script: |
|
||||
|
@ -249,7 +249,7 @@ jobs:
|
|||
displayName: Unpack ONNXRuntime package.
|
||||
|
||||
- script: |
|
||||
call .\build.bat -DONNXRUNTIME_LIB_DIR=.\onnxruntime-win-x64-$(ort.version)\lib
|
||||
call .\build.bat -DONNXRUNTIME_LIB_DIR=.\onnxruntime-win-x64-$(ort.version)\lib -DOCOS_ENABLE_CTEST=ON
|
||||
displayName: build the customop library with onnxruntime
|
||||
|
||||
- script: |
|
||||
|
@ -341,6 +341,6 @@ jobs:
|
|||
displayName: Setup emscripten pipeline
|
||||
|
||||
- script: |
|
||||
sh ./build.sh -DCMAKE_TOOLCHAIN_FILE=cmake/deps/emsdk/upstream/emscripten/cmake/Modules/Platform/Emscripten.cmake -DOCOS_ENABLE_SPM_TOKENIZER=OFF -DOCOS_ENABLE_PYTHON=OFF -DOCOS_ENABLE_CTEST=OFF
|
||||
sh ./build.sh -DCMAKE_TOOLCHAIN_FILE=cmake/deps/emsdk/upstream/emscripten/cmake/Modules/Platform/Emscripten.cmake -DOCOS_ENABLE_SPM_TOKENIZER=OFF -DOCOS_ENABLE_PYTHON=OFF
|
||||
displayName: build the customop library with onnxruntime
|
||||
# TODO add unittest for webassembly
|
||||
# TODO add unittest for webassembly
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 2e7eaf7233144e5e25b1c1338890bbab5d011815
|
|
@ -0,0 +1,14 @@
|
|||
FetchContent_Declare(dlib
|
||||
GIT_REPOSITORY https://github.com/davisking/dlib.git
|
||||
GIT_TAG v19.22
|
||||
)
|
||||
|
||||
if (WIN32)
|
||||
FetchContent_MakeAvailable(dlib)
|
||||
else()
|
||||
FetchContent_GetProperties(dlib)
|
||||
if(NOT dlib_POPULATED)
|
||||
# Fetch the content using previously declared details
|
||||
FetchContent_Populate(dlib)
|
||||
endif()
|
||||
endif()
|
|
@ -3,10 +3,13 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#undef ORT_API_MANUAL_INIT
|
||||
|
||||
typedef const OrtCustomOp** (*FxLoadCustomOpFactory)();
|
||||
|
||||
#if defined(ENABLE_GPT2_TOKENIZER)
|
||||
const OrtCustomOp** LoadTokenizerSchemaList();
|
||||
#endif // ENABLE_GPT2_TOKENIZER
|
||||
|
@ -16,6 +19,7 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count);
|
|||
bool EnablePyCustomOps(bool enable = true);
|
||||
#endif
|
||||
|
||||
|
||||
// A helper API to support test kernels.
|
||||
// Must be invoked before RegisterCustomOps.
|
||||
extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op);
|
||||
|
@ -52,3 +56,33 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
class CuopContainer {
|
||||
public:
|
||||
CuopContainer() : ocos_list_({[]() { return new Args; }()...}) {
|
||||
ocos_list_.push_back(nullptr);
|
||||
}
|
||||
|
||||
~CuopContainer() {
|
||||
// skip the last null pointer.
|
||||
for (auto i = 0; i < ocos_list_.size() - 1; i++) {
|
||||
delete ocos_list_[i];
|
||||
}
|
||||
|
||||
ocos_list_.clear();
|
||||
}
|
||||
|
||||
const OrtCustomOp** GetList() {
|
||||
return &const_cast<const OrtCustomOp*&>(ocos_list_.front());
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<OrtCustomOp*> ocos_list_;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
const OrtCustomOp** LoadCustomOpClasses() {
|
||||
static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
|
||||
return ctr.GetList();
|
||||
}
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
|
||||
std::vector<std::string_view> result;
|
||||
std::string ::size_type pre_pos = 0;
|
||||
|
||||
while (true) {
|
||||
auto next_pos = str.find_first_of(seps, pre_pos);
|
||||
|
||||
if (next_pos == std::string::npos) {
|
||||
auto sub_str = str.substr(pre_pos, next_pos);
|
||||
// sub_str is empty means the last sep reach the end of string
|
||||
if (!sub_str.empty()) {
|
||||
result.push_back(sub_str);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if (pre_pos != next_pos || !remove_empty_entries) {
|
||||
auto sub_str = str.substr(pre_pos, next_pos - pre_pos);
|
||||
result.push_back(sub_str);
|
||||
}
|
||||
|
||||
pre_pos = next_pos + 1;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <dlib/matrix.h>
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
struct KernelInverse : BaseKernel {
|
||||
KernelInverse(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 2) {
|
||||
throw std::runtime_error("Only 2-d matrix supported.");
|
||||
}
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
dlib::matrix<float> dm(dimensions[0], dimensions[1]);
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out0[i] = dm(i / dimensions[1], i % dimensions[1]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelInverse(api);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "Inverse";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
};
|
|
@ -0,0 +1,7 @@
|
|||
#include "ocos.h"
|
||||
#include "negpos.hpp"
|
||||
#include "inverse.hpp"
|
||||
|
||||
template const OrtCustomOp** LoadCustomOpClasses<CustomOpNegPos, CustomOpInverse>();
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math = &LoadCustomOpClasses<CustomOpNegPos, CustomOpInverse>;
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "ocos.h"
|
||||
|
||||
struct KernelNegPos : BaseKernel {
|
||||
KernelNegPos(OrtApi api) : BaseKernel(api) {
|
|
@ -1,79 +1,78 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "ocos.h"
|
||||
#include "utils/string_utils.h"
|
||||
|
||||
bool BaseKernel::HasAttribute(const char* name) const {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
size_t size;
|
||||
std::string out;
|
||||
// Crashes here.
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
auto r = api_.GetErrorCode(status);
|
||||
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
|
||||
if (has) {
|
||||
api_.ReleaseStatus(status);
|
||||
return has;
|
||||
}
|
||||
const char* error = api_.GetErrorMessage(status);
|
||||
if (strstr(error, "No attribute") == error) {
|
||||
api_.ReleaseStatus(status);
|
||||
return false;
|
||||
}
|
||||
api_.ReleaseStatus(status);
|
||||
return true;
|
||||
}
|
||||
|
||||
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
|
||||
if (status == nullptr) {
|
||||
return ORT_OK;
|
||||
}
|
||||
auto error_code = api_.GetErrorCode(status);
|
||||
api_.ReleaseStatus(status);
|
||||
return error_code;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
|
||||
size_t size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
|
||||
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
|
||||
if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
value.resize(size);
|
||||
status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
|
||||
if (GetErrorCodeAndRelease(status) != ORT_OK) {
|
||||
return false;
|
||||
}
|
||||
value.resize(size - 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
bool BaseKernel::HasAttribute(const char* name) const {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
size_t size;
|
||||
std::string out;
|
||||
// Crashes here.
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
auto r = api_.GetErrorCode(status);
|
||||
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
|
||||
if (has) {
|
||||
api_.ReleaseStatus(status);
|
||||
return has;
|
||||
}
|
||||
const char* error = api_.GetErrorMessage(status);
|
||||
if (strstr(error, "No attribute") == error) {
|
||||
api_.ReleaseStatus(status);
|
||||
return false;
|
||||
}
|
||||
api_.ReleaseStatus(status);
|
||||
return true;
|
||||
}
|
||||
|
||||
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
|
||||
if (status == nullptr) {
|
||||
return ORT_OK;
|
||||
}
|
||||
auto error_code = api_.GetErrorCode(status);
|
||||
api_.ReleaseStatus(status);
|
||||
return error_code;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
|
||||
size_t size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
|
||||
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
|
||||
if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
value.resize(size);
|
||||
status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
|
||||
if (GetErrorCodeAndRelease(status) != ORT_OK) {
|
||||
return false;
|
||||
}
|
||||
value.resize(size - 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
||||
if (info_ == nullptr) {
|
||||
throw std::runtime_error("Kernel was incorrectly initialized, pointer info_ cannot be null.");
|
||||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
|
||||
}
|
|
@ -1,8 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_common.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<std::string>& output) {
|
|
@ -4,8 +4,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ustring.hpp"
|
||||
#include "kernels.h"
|
||||
#include "ustring.h"
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
// Retrieves a vector of strings if the input type is std::string.
|
|
@ -0,0 +1,97 @@
|
|||
#include "farmhash.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries) {
|
||||
std::vector<std::string_view> result;
|
||||
std::string ::size_type pre_pos = 0;
|
||||
|
||||
while (true) {
|
||||
auto next_pos = str.find_first_of(seps, pre_pos);
|
||||
|
||||
if (next_pos == std::string::npos) {
|
||||
auto sub_str = str.substr(pre_pos, next_pos);
|
||||
// sub_str is empty means the last sep reach the end of string
|
||||
if (!sub_str.empty()) {
|
||||
result.push_back(sub_str);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if (pre_pos != next_pos || !remove_empty_entries) {
|
||||
auto sub_str = str.substr(pre_pos, next_pos - pre_pos);
|
||||
result.push_back(sub_str);
|
||||
}
|
||||
|
||||
pre_pos = next_pos + 1;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
|
||||
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L41
|
||||
uint64_t DecodeFixed32(const char* ptr) {
|
||||
return ((static_cast<uint64_t>(static_cast<unsigned char>(ptr[0]))) |
|
||||
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[1])) << 8) |
|
||||
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[2])) << 16) |
|
||||
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[3])) << 24));
|
||||
}
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L55
|
||||
static uint64_t DecodeFixed64(const char* ptr) {
|
||||
uint64_t lo = DecodeFixed32(ptr);
|
||||
uint64_t hi = DecodeFixed32(ptr + 4);
|
||||
return (hi << 32) | lo;
|
||||
}
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L79
|
||||
uint64_t Hash64(const char* data, size_t n, uint64_t seed) {
|
||||
const uint64_t m = 0xc6a4a7935bd1e995;
|
||||
const int r = 47;
|
||||
|
||||
uint64_t h = seed ^ (n * m);
|
||||
|
||||
while (n >= 8) {
|
||||
uint64_t k = DecodeFixed64(data);
|
||||
data += 8;
|
||||
n -= 8;
|
||||
|
||||
k *= m;
|
||||
k ^= k >> r;
|
||||
k *= m;
|
||||
|
||||
h ^= k;
|
||||
h *= m;
|
||||
}
|
||||
|
||||
switch (n) {
|
||||
case 7:
|
||||
h ^= ByteAs64(data[6]) << 48;
|
||||
case 6:
|
||||
h ^= ByteAs64(data[5]) << 40;
|
||||
case 5:
|
||||
h ^= ByteAs64(data[4]) << 32;
|
||||
case 4:
|
||||
h ^= ByteAs64(data[3]) << 24;
|
||||
case 3:
|
||||
h ^= ByteAs64(data[2]) << 16;
|
||||
case 2:
|
||||
h ^= ByteAs64(data[1]) << 8;
|
||||
case 1:
|
||||
h ^= ByteAs64(data[0]);
|
||||
h *= m;
|
||||
}
|
||||
|
||||
h ^= h >> r;
|
||||
h *= m;
|
||||
h ^= h >> r;
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
uint64_t Hash64Fast(const char* data, size_t n) {
|
||||
return static_cast<int64_t>(util::Fingerprint64(data, n));
|
||||
}
|
|
@ -1,53 +1,61 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
ss << t;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t>& t) noexcept {
|
||||
ss << "[";
|
||||
for (int i = 0; i < t.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << t[i];
|
||||
}
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
||||
ss << "[";
|
||||
for (int i = 0; i < t.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << t[i];
|
||||
}
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
|
||||
MakeStringInternal(ss, t);
|
||||
MakeStringInternal(ss, args...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
std::string MakeString(const Args&... args) {
|
||||
std::ostringstream ss;
|
||||
MakeStringInternal(ss, args...);
|
||||
return std::string(ss.str());
|
||||
}
|
||||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries = false);
|
||||
|
||||
void char2unicode(const std::string& src, std::vector<uint32_t>& result);
|
||||
void unicode2char(const std::vector<uint32_t>& src, std::string& result);
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
ss << t;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t>& t) noexcept {
|
||||
ss << "[";
|
||||
for (int i = 0; i < t.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << t[i];
|
||||
}
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
||||
ss << "[";
|
||||
for (int i = 0; i < t.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << t[i];
|
||||
}
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
|
||||
MakeStringInternal(ss, t);
|
||||
MakeStringInternal(ss, args...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
std::string MakeString(const Args&... args) {
|
||||
std::ostringstream ss;
|
||||
MakeStringInternal(ss, args...);
|
||||
return std::string(ss.str());
|
||||
}
|
||||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries = false);
|
||||
|
||||
void char2unicode(const std::string& src, std::vector<uint32_t>& result);
|
||||
void unicode2char(const std::vector<uint32_t>& src, std::string& result);
|
||||
|
||||
uint64_t Hash64(const char* data, size_t n, uint64_t seed);
|
||||
|
||||
inline uint64_t Hash64(const char* data, size_t n) {
|
||||
return Hash64(data, n, 0xDECAFCAFFE);
|
||||
}
|
||||
|
||||
uint64_t Hash64Fast(const char* data, size_t n);
|
|
@ -4,3 +4,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
// TO BE DELETED.
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringEqual : BaseKernel {
|
||||
KernelStringEqual(OrtApi api);
|
|
@ -3,8 +3,8 @@
|
|||
#pragma once
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_common.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
class BroadcastIteratorRight {
|
|
@ -1,8 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
#include "op_ragged_tensor.hpp"
|
||||
#include "string_common.h"
|
||||
|
||||
KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) {
|
||||
}
|
|
@ -4,7 +4,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
|
||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||
KernelRaggedTensorToSparse(OrtApi api);
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentSum : BaseKernel {
|
||||
KernelSegmentSum(OrtApi api);
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_concat.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringConcat : BaseKernel {
|
||||
KernelStringConcat(OrtApi api);
|
|
@ -1,80 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_hash.hpp"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include "re2/re2.h"
|
||||
#include "farmhash.h"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
#include "string_hash.hpp"
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
|
||||
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L41
|
||||
uint64_t DecodeFixed32(const char* ptr) {
|
||||
return ((static_cast<uint64_t>(static_cast<unsigned char>(ptr[0]))) |
|
||||
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[1])) << 8) |
|
||||
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[2])) << 16) |
|
||||
(static_cast<uint64_t>(static_cast<unsigned char>(ptr[3])) << 24));
|
||||
}
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/raw_coding.h#L55
|
||||
static uint64_t DecodeFixed64(const char* ptr) {
|
||||
uint64_t lo = DecodeFixed32(ptr);
|
||||
uint64_t hi = DecodeFixed32(ptr + 4);
|
||||
return (hi << 32) | lo;
|
||||
}
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L79
|
||||
uint64_t Hash64(const char* data, size_t n, uint64_t seed) {
|
||||
const uint64_t m = 0xc6a4a7935bd1e995;
|
||||
const int r = 47;
|
||||
|
||||
uint64_t h = seed ^ (n * m);
|
||||
|
||||
while (n >= 8) {
|
||||
uint64_t k = DecodeFixed64(data);
|
||||
data += 8;
|
||||
n -= 8;
|
||||
|
||||
k *= m;
|
||||
k ^= k >> r;
|
||||
k *= m;
|
||||
|
||||
h ^= k;
|
||||
h *= m;
|
||||
}
|
||||
|
||||
switch (n) {
|
||||
case 7:
|
||||
h ^= ByteAs64(data[6]) << 48;
|
||||
case 6:
|
||||
h ^= ByteAs64(data[5]) << 40;
|
||||
case 5:
|
||||
h ^= ByteAs64(data[4]) << 32;
|
||||
case 4:
|
||||
h ^= ByteAs64(data[3]) << 24;
|
||||
case 3:
|
||||
h ^= ByteAs64(data[2]) << 16;
|
||||
case 2:
|
||||
h ^= ByteAs64(data[1]) << 8;
|
||||
case 1:
|
||||
h ^= ByteAs64(data[0]);
|
||||
h *= m;
|
||||
}
|
||||
|
||||
h ^= h >> r;
|
||||
h *= m;
|
||||
h ^= h >> r;
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
uint64_t Hash64Fast(const char* data, size_t n) {
|
||||
return static_cast<int64_t>(util::Fingerprint64(data, n));
|
||||
}
|
||||
|
||||
KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
|
||||
}
|
|
@ -4,15 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
|
||||
uint64_t Hash64(const char* data, size_t n, uint64_t seed);
|
||||
|
||||
inline uint64_t Hash64(const char* data, size_t n) {
|
||||
return Hash64(data, n, 0xDECAFCAFFE);
|
||||
}
|
||||
|
||||
uint64_t Hash64Fast(const char* data, size_t n);
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringHash : BaseKernel {
|
||||
KernelStringHash(OrtApi api);
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_join.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
|
||||
}
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringJoin : BaseKernel {
|
||||
KernelStringJoin(OrtApi api);
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_length.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLength : BaseKernel {
|
||||
KernelStringLength(OrtApi api);
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_lower.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLower : BaseKernel {
|
||||
KernelStringLower(OrtApi api);
|
|
@ -6,7 +6,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include "re2/re2.h"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringRegexReplace : BaseKernel {
|
||||
KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info);
|
|
@ -5,7 +5,7 @@
|
|||
#include "string_regex_split_re.hpp"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
}
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
|
@ -1,137 +1,137 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_split.hpp"
|
||||
#include "string_common.h"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
|
||||
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
|
||||
std::vector<std::string> X, sep;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
|
||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 1)
|
||||
throw std::runtime_error("Only 1D tensor are supported as input.");
|
||||
|
||||
std::vector<std::string> words;
|
||||
std::vector<int64_t> indices;
|
||||
int64_t maxc = 0;
|
||||
int64_t col;
|
||||
std::string delimiter = sep[0];
|
||||
if (delimiter.size() == 0) {
|
||||
char word[2] = "a";
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
if (str.empty())
|
||||
continue;
|
||||
maxc = str.size() > maxc ? str.size() : maxc;
|
||||
for (auto it = str.begin(); it != str.end(); ++it) {
|
||||
word[0] = *it;
|
||||
words.push_back(word);
|
||||
indices.push_back(row);
|
||||
indices.push_back(std::distance(str.begin(), it));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bool keep = !(*skip_empty);
|
||||
std::size_t current, previous = 0;
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
if (str.empty())
|
||||
continue;
|
||||
previous = 0;
|
||||
col = 0;
|
||||
current = str.find_first_of(delimiter);
|
||||
while (current != std::string::npos) {
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
indices.push_back(row);
|
||||
indices.push_back(col);
|
||||
++col;
|
||||
}
|
||||
previous = current + 1;
|
||||
current = str.find_first_of(delimiter, previous);
|
||||
}
|
||||
current = str.size();
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
indices.push_back(row);
|
||||
indices.push_back(col);
|
||||
++col;
|
||||
}
|
||||
maxc = col > maxc ? col : maxc;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
|
||||
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
|
||||
|
||||
std::vector<int64_t> shape_text(1, words.size());
|
||||
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
|
||||
|
||||
std::vector<int64_t> shape_shape(1, 2);
|
||||
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
|
||||
|
||||
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
|
||||
|
||||
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
p_shape[0] = dimensions[0];
|
||||
p_shape[1] = maxc;
|
||||
FillTensorDataString(api_, ort_, context, words, out_text);
|
||||
}
|
||||
|
||||
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringSplit(api);
|
||||
};
|
||||
|
||||
const char* CustomOpStringSplit::GetName() const {
|
||||
return "StringSplit";
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
|
||||
}
|
||||
};
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_split.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
|
||||
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
|
||||
std::vector<std::string> X, sep;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
|
||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 1)
|
||||
throw std::runtime_error("Only 1D tensor are supported as input.");
|
||||
|
||||
std::vector<std::string> words;
|
||||
std::vector<int64_t> indices;
|
||||
int64_t maxc = 0;
|
||||
int64_t col;
|
||||
std::string delimiter = sep[0];
|
||||
if (delimiter.size() == 0) {
|
||||
char word[2] = "a";
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
if (str.empty())
|
||||
continue;
|
||||
maxc = str.size() > maxc ? str.size() : maxc;
|
||||
for (auto it = str.begin(); it != str.end(); ++it) {
|
||||
word[0] = *it;
|
||||
words.push_back(word);
|
||||
indices.push_back(row);
|
||||
indices.push_back(std::distance(str.begin(), it));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bool keep = !(*skip_empty);
|
||||
std::size_t current, previous = 0;
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
if (str.empty())
|
||||
continue;
|
||||
previous = 0;
|
||||
col = 0;
|
||||
current = str.find_first_of(delimiter);
|
||||
while (current != std::string::npos) {
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
indices.push_back(row);
|
||||
indices.push_back(col);
|
||||
++col;
|
||||
}
|
||||
previous = current + 1;
|
||||
current = str.find_first_of(delimiter, previous);
|
||||
}
|
||||
current = str.size();
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
indices.push_back(row);
|
||||
indices.push_back(col);
|
||||
++col;
|
||||
}
|
||||
maxc = col > maxc ? col : maxc;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
|
||||
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
|
||||
|
||||
std::vector<int64_t> shape_text(1, words.size());
|
||||
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
|
||||
|
||||
std::vector<int64_t> shape_shape(1, 2);
|
||||
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
|
||||
|
||||
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
|
||||
|
||||
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
p_shape[0] = dimensions[0];
|
||||
p_shape[1] = maxc;
|
||||
FillTensorDataString(api_, ort_, context, words, out_text);
|
||||
}
|
||||
|
||||
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelStringSplit(api);
|
||||
};
|
||||
|
||||
const char* CustomOpStringSplit::GetName() const {
|
||||
return "StringSplit";
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("[StringSplit] Unexpected output index ", index));
|
||||
}
|
||||
};
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringSplit : BaseKernel {
|
||||
KernelStringSplit(OrtApi api);
|
|
@ -1,8 +1,8 @@
|
|||
#include <charconv>
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_to_vector.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
|
||||
ParseMappingTable(map);
|
|
@ -8,7 +8,7 @@
|
|||
#include <vector>
|
||||
#include "kernels.h"
|
||||
#include "farmhash.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
|
||||
class StringToVectorImpl {
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_upper.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringUpper : BaseKernel {
|
||||
KernelStringUpper(OrtApi api);
|
|
@ -1,8 +1,8 @@
|
|||
#include <charconv>
|
||||
#include "kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
#include "vector_to_string.hpp"
|
||||
#include "string_common.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
VectorToStringImpl::VectorToStringImpl(std::string& map, std::string& unk) : unk_value_(unk) {
|
||||
ParseMappingTable(map);
|
|
@ -8,7 +8,7 @@
|
|||
#include <vector>
|
||||
#include "kernels.h"
|
||||
#include "farmhash.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
namespace std {
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -4,7 +4,7 @@
|
|||
#include "sentencepiece_processor.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_tokenizer.hpp"
|
||||
#include "kernels/string_common.h"
|
||||
#include "string_tensor.h"
|
||||
#include "base64.h"
|
||||
|
||||
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "kernels/kernels.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
struct KernelSentencepieceTokenizer : BaseKernel {
|
|
@ -5,9 +5,10 @@
|
|||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "kernels/kernels.h"
|
||||
#include "kernels/string_common.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "ocos.h"
|
||||
#include "ustring.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
struct KernelWordpieceTokenizer : BaseKernel {
|
||||
KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info);
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include <iostream>
|
||||
#include "ustring.hpp"
|
||||
#include "ustring.h"
|
||||
|
||||
ustring::ustring(): std::u32string() {
|
||||
}
|
|
@ -17,10 +17,10 @@
|
|||
#include <pybind11/functional.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <thread>
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
#include "pykernel.h"
|
||||
#include "kernels/string_hash.hpp"
|
||||
#include "kernels/string_common.h"
|
||||
#include "text/string_hash.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
|
@ -3,22 +3,23 @@
|
|||
|
||||
#include <set>
|
||||
|
||||
#include "kernels/op_equal.hpp"
|
||||
#include "kernels/op_segment_sum.hpp"
|
||||
#include "kernels/op_ragged_tensor.hpp"
|
||||
#include "kernels/string_hash.hpp"
|
||||
#include "kernels/string_join.hpp"
|
||||
#include "kernels/string_lower.hpp"
|
||||
#include "kernels/string_regex_replace.hpp"
|
||||
#include "kernels/string_regex_split.hpp"
|
||||
#include "kernels/string_split.hpp"
|
||||
#include "kernels/string_to_vector.hpp"
|
||||
#include "kernels/string_upper.hpp"
|
||||
#include "kernels/negpos.hpp"
|
||||
#include "kernels/vector_to_string.hpp"
|
||||
#include "kernels/string_length.hpp"
|
||||
#include "kernels/string_concat.hpp"
|
||||
#include "utils/string_utils.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include "text/op_equal.hpp"
|
||||
#include "text/op_segment_sum.hpp"
|
||||
#include "text/op_ragged_tensor.hpp"
|
||||
#include "text/string_hash.hpp"
|
||||
#include "text/string_join.hpp"
|
||||
#include "text/string_lower.hpp"
|
||||
#include "text/string_regex_replace.hpp"
|
||||
#include "text/string_regex_split.hpp"
|
||||
#include "text/string_split.hpp"
|
||||
#include "text/string_to_vector.hpp"
|
||||
#include "text/string_upper.hpp"
|
||||
#include "text/vector_to_string.hpp"
|
||||
#include "text/string_length.hpp"
|
||||
#include "text/string_concat.hpp"
|
||||
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
#include "sentencepiece_tokenizer.hpp"
|
||||
|
@ -37,7 +38,6 @@ CustomOpWordpieceTokenizer c_CustomOpWordpieceTokenizer;
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
CustomOpNegPos c_CustomOpNegPos;
|
||||
CustomOpSegmentSum c_CustomOpSegmentSum;
|
||||
CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense;
|
||||
CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
|
||||
|
@ -67,7 +67,6 @@ OrtCustomOp* operator_lists[] = {
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
&c_CustomOpNegPos,
|
||||
&c_CustomOpRaggedTensorToDense,
|
||||
&c_CustomOpRaggedTensorToSparse,
|
||||
&c_CustomOpSegmentSum,
|
||||
|
@ -88,6 +87,10 @@ OrtCustomOp* operator_lists[] = {
|
|||
#endif
|
||||
nullptr};
|
||||
|
||||
#if ENABLE_MATH
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
|
||||
#endif //ENABLE_MATH
|
||||
|
||||
class ExternalCustomOps {
|
||||
public:
|
||||
ExternalCustomOps() {
|
||||
|
@ -117,11 +120,13 @@ class ExternalCustomOps {
|
|||
std::vector<const OrtCustomOp*> op_array_;
|
||||
};
|
||||
|
||||
extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op) {
|
||||
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
|
||||
ExternalCustomOps::instance().Add(c_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
||||
|
@ -145,27 +150,29 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
|||
}
|
||||
#endif
|
||||
|
||||
OrtCustomOp** ops = operator_lists;
|
||||
while (*ops != nullptr) {
|
||||
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
++ops;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_GPT2_TOKENIZER)
|
||||
const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
|
||||
while (*t_ops != nullptr) {
|
||||
if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
t_ops++;
|
||||
}
|
||||
static std::vector<FxLoadCustomOpFactory> c_factories = {
|
||||
[]() { return const_cast<const OrtCustomOp**>(operator_lists); }
|
||||
#if defined(ENABLE_MATH)
|
||||
,
|
||||
LoadCustomOpClasses_Math
|
||||
#endif
|
||||
#if defined(ENABLE_GPT2_TOKENIZER)
|
||||
,
|
||||
LoadTokenizerSchemaList
|
||||
#endif
|
||||
};
|
||||
|
||||
for (auto fx : c_factories) {
|
||||
auto ops = fx();
|
||||
while (*ops != nullptr) {
|
||||
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
++ops;
|
||||
}
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
#include "test_kernel.hpp"
|
||||
#include "utils/string_utils.h"
|
||||
#include "kernels/string_common.h"
|
||||
#include <filesystem>
|
||||
|
||||
|
||||
const char* GetLibraryPath() {
|
||||
#if defined(_WIN32)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include <filesystem>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
#include "kernels/string_lower.hpp"
|
||||
#include <filesystem>
|
||||
#include "text/string_lower.hpp"
|
||||
|
||||
|
||||
TEST(utils, test_string_lower) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
using namespace dlib;
|
||||
|
||||
TEST(math, matrix_op) {
|
||||
matrix<float> M(3,3);
|
||||
M = 54.2, 7.4, 12.1,
|
||||
1, 2, 3,
|
||||
5.9, 0.05, 1;
|
||||
|
||||
matrix<float,3,1> y;
|
||||
y = 3.5,
|
||||
1.2,
|
||||
7.8;
|
||||
|
||||
matrix<float> x = inv(M)*y;
|
||||
EXPECT_FLOAT_EQ(x(1, 0), -13.909741);
|
||||
}
|
|
@ -2,8 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "kernels/string_regex_split_re.hpp"
|
||||
#include "string_utils.h"
|
||||
#include "text/string_regex_split_re.hpp"
|
||||
|
||||
TEST(strings, regex_split) {
|
||||
std::string input = "hello world";
|
||||
|
|
|
@ -2,9 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "kernels/string_common.h"
|
||||
#include "../tokenizer/wordpiece_tokenizer.hpp"
|
||||
#include "string_utils.h"
|
||||
#include "wordpiece_tokenizer.hpp"
|
||||
|
||||
TEST(tokenizer, bert_word_split) {
|
||||
ustring ind("##");
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
#include "gtest/gtest.h"
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cstring>
|
||||
#include "kernels/ustring.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ustring.h"
|
||||
|
||||
void convert_test(const char* const_str) {
|
||||
std::string string(const_str);
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/string_utils.h"
|
||||
#include "re2/re2.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "string_utils.h"
|
||||
|
||||
|
||||
TEST(utils, make_string) {
|
||||
std::string res = MakeString("a", "b", 0);
|
||||
|
|
Загрузка…
Ссылка в новой задаче