Add TextToSentenceTokenizer (#113)
* add depence * init * add test * implement * finished * find file path througth pathlib * rename the op * update reminder Co-authored-by: Ze Tao <zetao@microsoft.com> Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Родитель
800e360ef3
Коммит
3e82549bcb
|
@ -28,6 +28,7 @@ option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
|
|||
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
|
||||
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
|
||||
option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
|
||||
option(OCOS_ENABLE_BLINGFIRE "Enable the Blingfire building" ON)
|
||||
option(OCOS_ENABLE_MATH "Enable the math tensor operators building" ON)
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
|
||||
|
@ -124,6 +125,12 @@ if (OCOS_ENABLE_BERT_TOKENIZER)
|
|||
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_BLINGFIRE)
|
||||
file(GLOB blingfire_TARGET_SRC "operators/tokenizer/blingfire*.*")
|
||||
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
|
||||
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
|
||||
add_library(ocos_operators STATIC ${TARGET_SRC})
|
||||
|
@ -171,6 +178,12 @@ if (OCOS_ENABLE_SPM_TOKENIZER)
|
|||
list(APPEND ocos_libraries sentencepiece-static)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_BLINGFIRE)
|
||||
include(blingfire)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BLINGFIRE)
|
||||
list(APPEND ocos_libraries bingfirtinydll_static)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
target_compile_definitions(ocos_operators PRIVATE
|
||||
NOMINMAX
|
||||
|
@ -190,6 +203,9 @@ if(OCOS_ENABLE_PYTHON)
|
|||
message(FATAL_ERROR "Python3_FIND_REGISTRY is not NEVER")
|
||||
endif()
|
||||
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy)
|
||||
if (NOT Python3_FOUND)
|
||||
message(FATAL_ERROR "Python3 or NumPy not found!")
|
||||
endif()
|
||||
if (WIN32)
|
||||
list(APPEND shared_TARGET_SRC "${PROJECT_SOURCE_DIR}/onnxruntime_extensions/ortcustomops.def")
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
FetchContent_Declare(
|
||||
Blingfire
|
||||
GIT_REPOSITORY https://github.com/microsoft/BlingFire.git
|
||||
GIT_TAG master
|
||||
)
|
||||
|
||||
|
||||
FetchContent_GetProperties(Blingfire)
|
||||
|
||||
if (NOT blingfire_POPULATED)
|
||||
FetchContent_Populate(Blingfire)
|
||||
|
||||
# enable size optimization build
|
||||
add_subdirectory(${blingfire_SOURCE_DIR} ${blingfire_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif ()
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import onnx
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
from ._ocos import default_opset_domain
|
||||
from ._ocos import default_opset_domain, get_library_path # noqa
|
||||
|
||||
|
||||
class CustomOp:
|
||||
|
@ -88,6 +88,28 @@ class StringToVector(CustomOp):
|
|||
return attr_data
|
||||
|
||||
|
||||
class BlingFireSentenceBreaker(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
attrs_data = {}
|
||||
for k_, v_ in attrs.items():
|
||||
if k_ == 'model':
|
||||
with open(v_, "rb") as model_file:
|
||||
attrs_data[k_] = model_file.read()
|
||||
else:
|
||||
attrs_data[k_] = v_
|
||||
return attrs_data
|
||||
# TODO: list all custom operators schema here:
|
||||
# ...
|
||||
# ...
|
||||
class SentencepieceTokenizer(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
|
|
|
@ -42,6 +42,7 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const std::vector<ustring>& value, OrtValue* output) {
|
||||
std::vector<std::string> utf8_strings;
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "blingfire_sentencebreaker.hpp"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
if (model_data_.empty()) {
|
||||
throw std::runtime_error("vocabulary shouldn't be empty.");
|
||||
}
|
||||
|
||||
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
|
||||
|
||||
if (model_ptr == nullptr) {
|
||||
throw std::runtime_error("Invalid model");
|
||||
}
|
||||
|
||||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||
|
||||
if (HasAttribute("max_sentence")) {
|
||||
max_sentence = ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence");
|
||||
}
|
||||
}
|
||||
|
||||
void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
||||
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
||||
throw std::runtime_error("We only support string scalar.");
|
||||
}
|
||||
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
std::string& input_string = input_data[0];
|
||||
int max_length = 2 * input_string.size() + 1;
|
||||
std::string output_str;
|
||||
output_str.reserve(max_length);
|
||||
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
|
||||
if (output_length < 0) {
|
||||
throw std::runtime_error(MakeString("splitting input:\"", input_string, "\" failed"));
|
||||
}
|
||||
|
||||
// inline split output_str by newline '\n'
|
||||
std::vector<char*> output_sentences;
|
||||
bool head_flag = true;
|
||||
for (int i = 0; i < output_length; i++) {
|
||||
if (head_flag) {
|
||||
output_sentences.push_back(&output_str[i]);
|
||||
head_flag = false;
|
||||
}
|
||||
|
||||
if (output_str[i] == '\n') {
|
||||
head_flag = true;
|
||||
output_str[i] = '\0';
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dimensions(1);
|
||||
output_dimensions[0] = output_sentences.size();
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
||||
}
|
||||
|
||||
void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelBlingFireSentenceBreaker(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
|
||||
|
||||
size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
extern "C" const int TextToSentencesWithOffsetsWithModel(
|
||||
const char* pInUtf8Str, int InUtf8StrByteCount,
|
||||
char* pOutUtf8Str, int* pStartOffsets, int* pEndOffsets,
|
||||
const int MaxOutUtf8StrByteCount, void* hModel);
|
||||
|
||||
extern "C" int FreeModel(void* ModelPtr);
|
||||
|
||||
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
||||
|
||||
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||
KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
using ModelPtr = std::shared_ptr<void>;
|
||||
ModelPtr model_;
|
||||
std::string model_data_;
|
||||
int max_sentence;
|
||||
};
|
||||
|
||||
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
|
@ -29,6 +29,10 @@
|
|||
#include "wordpiece_tokenizer.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BLINGFIRE
|
||||
#include "blingfire_sentencebreaker.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer;
|
||||
#endif
|
||||
|
@ -55,6 +59,7 @@ CustomOpStringUpper c_CustomOpStringUpper;
|
|||
CustomOpVectorToString c_CustomOpVectorToString;
|
||||
CustomOpStringLength c_CustomOpStringLength;
|
||||
CustomOpStringConcat c_CustomOpStringConcat;
|
||||
CustomOpBlingFireSentenceBreaker c_CustomOpTextToSentences;
|
||||
#endif
|
||||
|
||||
OrtCustomOp* operator_lists[] = {
|
||||
|
@ -84,6 +89,7 @@ OrtCustomOp* operator_lists[] = {
|
|||
&c_CustomOpVectorToString,
|
||||
&c_CustomOpStringLength,
|
||||
&c_CustomOpStringConcat,
|
||||
&c_CustomOpTextToSentences,
|
||||
#endif
|
||||
nullptr};
|
||||
|
||||
|
|
Двоичный файл не отображается.
|
@ -0,0 +1,29 @@
|
|||
from pathlib import Path
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnxruntime_extensions.eager_op import EagerOp, BlingFireSentenceBreaker
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
def _run_blingfire_sentencebreaker(input, output, model_path):
|
||||
t2stc = EagerOp.from_customop(BlingFireSentenceBreaker, model=model_path)
|
||||
result = t2stc(input)
|
||||
np.testing.assert_array_equal(result, output)
|
||||
|
||||
|
||||
class TestBlingFireSentenceBreaker(unittest.TestCase):
|
||||
|
||||
def test_text_to_case1(self):
|
||||
inputs = np.array([
|
||||
"This is the Bling-Fire tokenizer. Autophobia, also called monophobia, isolophobia, or eremophobia, is the specific phobia of isolation. 2007年9月日历表_2007年9月农历阳历一览表-万年历. I saw a girl with a telescope. Я увидел девушку с телескопом."])
|
||||
outputs = np.array(["This is the Bling-Fire tokenizer.",
|
||||
"Autophobia, also called monophobia, isolophobia, or eremophobia, is the specific phobia of isolation. 2007年9月日历表_2007年9月农历阳历一览表-万年历.",
|
||||
"I saw a girl with a telescope.",
|
||||
"Я увидел девушку с телескопом."])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче