* add depence

* init

* add test

* implement

* finished

* find file path througth pathlib

* rename the op

* update reminder

Co-authored-by: Ze Tao <zetao@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Mojimi 2021-06-24 14:29:16 +08:00 коммит произвёл GitHub
Родитель 800e360ef3
Коммит 3e82549bcb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 219 добавлений и 1 удалений

Просмотреть файл

@ -28,6 +28,7 @@ option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
option(OCOS_ENABLE_BLINGFIRE "Enable the Blingfire building" ON)
option(OCOS_ENABLE_MATH "Enable the math tensor operators building" ON)
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
@ -124,6 +125,12 @@ if (OCOS_ENABLE_BERT_TOKENIZER)
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()
if (OCOS_ENABLE_BLINGFIRE)
file(GLOB blingfire_TARGET_SRC "operators/tokenizer/blingfire*.*")
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
endif()
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
add_library(ocos_operators STATIC ${TARGET_SRC})
@ -171,6 +178,12 @@ if (OCOS_ENABLE_SPM_TOKENIZER)
list(APPEND ocos_libraries sentencepiece-static)
endif()
if (OCOS_ENABLE_BLINGFIRE)
include(blingfire)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BLINGFIRE)
list(APPEND ocos_libraries bingfirtinydll_static)
endif()
if (OCOS_ENABLE_TF_STRING)
target_compile_definitions(ocos_operators PRIVATE
NOMINMAX
@ -190,6 +203,9 @@ if(OCOS_ENABLE_PYTHON)
message(FATAL_ERROR "Python3_FIND_REGISTRY is not NEVER")
endif()
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy)
if (NOT Python3_FOUND)
message(FATAL_ERROR "Python3 or NumPy not found!")
endif()
if (WIN32)
list(APPEND shared_TARGET_SRC "${PROJECT_SOURCE_DIR}/onnxruntime_extensions/ortcustomops.def")
endif()

15
cmake/externals/blingfire.cmake поставляемый Normal file
Просмотреть файл

@ -0,0 +1,15 @@
FetchContent_Declare(
Blingfire
GIT_REPOSITORY https://github.com/microsoft/BlingFire.git
GIT_TAG master
)
FetchContent_GetProperties(Blingfire)
if (NOT blingfire_POPULATED)
FetchContent_Populate(Blingfire)
# enable size optimization build
add_subdirectory(${blingfire_SOURCE_DIR} ${blingfire_BINARY_DIR} EXCLUDE_FROM_ALL)
endif ()

Просмотреть файл

@ -5,7 +5,7 @@
import onnx
from onnx import onnx_pb as onnx_proto
from ._ocos import default_opset_domain
from ._ocos import default_opset_domain, get_library_path # noqa
class CustomOp:
@ -88,6 +88,28 @@ class StringToVector(CustomOp):
return attr_data
class BlingFireSentenceBreaker(CustomOp):
@classmethod
def get_inputs(cls):
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
@classmethod
def get_outputs(cls):
return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
@classmethod
def serialize_attr(cls, attrs):
attrs_data = {}
for k_, v_ in attrs.items():
if k_ == 'model':
with open(v_, "rb") as model_file:
attrs_data[k_] = model_file.read()
else:
attrs_data[k_] = v_
return attrs_data
# TODO: list all custom operators schema here:
# ...
# ...
class SentencepieceTokenizer(CustomOp):
@classmethod
def get_inputs(cls):

Просмотреть файл

@ -42,6 +42,7 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
}
}
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<ustring>& value, OrtValue* output) {
std::vector<std::string> utf8_strings;

Просмотреть файл

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "blingfire_sentencebreaker.hpp"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
#include <algorithm>
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
if (model_data_.empty()) {
throw std::runtime_error("vocabulary shouldn't be empty.");
}
void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
if (model_ptr == nullptr) {
throw std::runtime_error("Invalid model");
}
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
if (HasAttribute("max_sentence")) {
max_sentence = ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence");
}
}
void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
OrtTensorDimensions dimensions(ort_, input);
if (dimensions.Size() != 1 && dimensions[0] != 1) {
throw std::runtime_error("We only support string scalar.");
}
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
std::string& input_string = input_data[0];
int max_length = 2 * input_string.size() + 1;
std::string output_str;
output_str.reserve(max_length);
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
if (output_length < 0) {
throw std::runtime_error(MakeString("splitting input:\"", input_string, "\" failed"));
}
// inline split output_str by newline '\n'
std::vector<char*> output_sentences;
bool head_flag = true;
for (int i = 0; i < output_length; i++) {
if (head_flag) {
output_sentences.push_back(&output_str[i]);
head_flag = false;
}
if (output_str[i] == '\n') {
head_flag = true;
output_str[i] = '\0';
}
}
std::vector<int64_t> output_dimensions(1);
output_dimensions[0] = output_sentences.size();
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
}
void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new KernelBlingFireSentenceBreaker(api, info);
};
const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "string_utils.h"
extern "C" const int TextToSentencesWithOffsetsWithModel(
const char* pInUtf8Str, int InUtf8StrByteCount,
char* pOutUtf8Str, int* pStartOffsets, int* pEndOffsets,
const int MaxOutUtf8StrByteCount, void* hModel);
extern "C" int FreeModel(void* ModelPtr);
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
struct KernelBlingFireSentenceBreaker : BaseKernel {
KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
using ModelPtr = std::shared_ptr<void>;
ModelPtr model_;
std::string model_data_;
int max_sentence;
};
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -29,6 +29,10 @@
#include "wordpiece_tokenizer.hpp"
#endif
#ifdef ENABLE_BLINGFIRE
#include "blingfire_sentencebreaker.hpp"
#endif
#ifdef ENABLE_SPM_TOKENIZER
CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer;
#endif
@ -55,6 +59,7 @@ CustomOpStringUpper c_CustomOpStringUpper;
CustomOpVectorToString c_CustomOpVectorToString;
CustomOpStringLength c_CustomOpStringLength;
CustomOpStringConcat c_CustomOpStringConcat;
CustomOpBlingFireSentenceBreaker c_CustomOpTextToSentences;
#endif
OrtCustomOp* operator_lists[] = {
@ -84,6 +89,7 @@ OrtCustomOp* operator_lists[] = {
&c_CustomOpVectorToString,
&c_CustomOpStringLength,
&c_CustomOpStringConcat,
&c_CustomOpTextToSentences,
#endif
nullptr};

Двоичные данные
test/data/default_sentence_break_model.bin Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,29 @@
from pathlib import Path
import unittest
import numpy as np
from onnxruntime_extensions.eager_op import EagerOp, BlingFireSentenceBreaker
def _get_test_data_file(*sub_dirs):
test_dir = Path(__file__).parent
return str(test_dir.joinpath(*sub_dirs))
def _run_blingfire_sentencebreaker(input, output, model_path):
t2stc = EagerOp.from_customop(BlingFireSentenceBreaker, model=model_path)
result = t2stc(input)
np.testing.assert_array_equal(result, output)
class TestBlingFireSentenceBreaker(unittest.TestCase):
def test_text_to_case1(self):
inputs = np.array([
"This is the Bling-Fire tokenizer. Autophobia, also called monophobia, isolophobia, or eremophobia, is the specific phobia of isolation. 2007年9月日历表_2007年9月农历阳历一览表-万年历. I saw a girl with a telescope. Я увидел девушку с телескопом."])
outputs = np.array(["This is the Bling-Fire tokenizer.",
"Autophobia, also called monophobia, isolophobia, or eremophobia, is the specific phobia of isolation. 2007年9月日历表_2007年9月农历阳历一览表-万年历.",
"I saw a girl with a telescope.",
"Я увидел девушку с телескопом."])
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
if __name__ == "__main__":
unittest.main()