Add BertTokenizer (#135)
* init * update * update * update * update * update * update * Modify relative path of generated cmake file. * update * udapte * fix the bug * update * fix bugs Co-authored-by: Ze Tao <zetao@microsoft.com> Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Родитель
6c3b496e3f
Коммит
aef5ef1ef1
|
@ -159,7 +159,7 @@ endif()
|
|||
if (OCOS_ENABLE_BERT_TOKENIZER)
|
||||
# Bert
|
||||
set(_HAS_TOKENIZER ON)
|
||||
file(GLOB bert_TARGET_SRC "operators/tokenizer/wordpiece*.*")
|
||||
file(GLOB bert_TARGET_SRC "operators/tokenizer/wordpiece*.*" "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*")
|
||||
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -22,9 +22,19 @@ struct BaseKernel {
|
|||
BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
|
||||
|
||||
bool HasAttribute(const char* name) const;
|
||||
|
||||
template <class T>
|
||||
bool TryToGetAttribute(const char* name, T& value);
|
||||
|
||||
template <class T>
|
||||
T TryToGetAttributeWithDefault(const char* name, T default_value) {
|
||||
T& result = default_value;
|
||||
TryToGetAttribute(name, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
|
||||
|
||||
protected:
|
||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
|
|
|
@ -81,7 +81,7 @@ class StringToVector(CustomOp):
|
|||
for k_, v_ in attrs.items():
|
||||
if k_ == 'map' and isinstance(v_, dict):
|
||||
attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items())
|
||||
elif k_ == 'unk' and isinstance(v_, list):
|
||||
elif k_ == 'unk' and isinstance(v_, list):
|
||||
attr_data[k_] = ' '.join(str(i) for i in v_)
|
||||
else:
|
||||
attr_data[k_] = v_
|
||||
|
@ -109,6 +109,30 @@ class BlingFireSentenceBreaker(CustomOp):
|
|||
return attrs_data
|
||||
|
||||
|
||||
class BertTokenizer(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])]
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
attrs_data = {}
|
||||
for k_, v_ in attrs.items():
|
||||
if k_ == 'vocab_file':
|
||||
with open(v_, "r", encoding='utf-8') as model_file:
|
||||
lines = model_file.readlines()
|
||||
attrs_data[k_] = '\n'.join(lines)
|
||||
else:
|
||||
attrs_data[k_] = v_
|
||||
return attrs_data
|
||||
|
||||
|
||||
class SentencepieceTokenizer(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#include <sstream>
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
bool BaseKernel::HasAttribute(const char* name) const {
|
||||
if (info_ == nullptr) {
|
||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
|
@ -36,6 +35,14 @@ OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
|
|||
return error_code;
|
||||
}
|
||||
|
||||
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
|
||||
int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
data_ptr[i] = data[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
||||
if (info_ == nullptr) {
|
||||
|
@ -76,4 +83,19 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
|||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
|
||||
if (info_ == nullptr) {
|
||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
int64_t origin_value = 0;
|
||||
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &origin_value)) != ORT_OK) {
|
||||
return false;
|
||||
}
|
||||
|
||||
value = origin_value == 1;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ std::vector<std::string_view> SplitString(const std::string_view& str, const std
|
|||
std::vector<std::string_view> result;
|
||||
std::string ::size_type pre_pos = 0;
|
||||
|
||||
//TODO: bug fix
|
||||
while (true) {
|
||||
auto next_pos = str.find_first_of(seps, pre_pos);
|
||||
|
||||
|
@ -32,6 +33,35 @@ std::vector<std::string_view> SplitString(const std::string_view& str, const std
|
|||
return result;
|
||||
}
|
||||
|
||||
bool IsCJK(char32_t c) {
|
||||
return (c >= 0x4E00 && c <= 0x9FFF)
|
||||
|| (c >= 0x3400 && c <= 0x4DBF)
|
||||
|| (c >= 0x20000 && c <= 0x2A6DF)
|
||||
|| (c >= 0x2A700 && c <= 0x2B73F)
|
||||
|| (c >= 0x2B740 && c <= 0x2B81F)
|
||||
|| (c >= 0x2B820 && c <= 0x2CEAF)
|
||||
|| (c >= 0xF900 && c <= 0xFAFF)
|
||||
|| (c >= 0x2F800 && c <= 0x2FA1F);
|
||||
}
|
||||
|
||||
bool IsAccent(char32_t c)
|
||||
{
|
||||
// only support part of accent
|
||||
// [TODO] support more accent
|
||||
return c >= 0x300 && c <= 0x36F;
|
||||
}
|
||||
|
||||
char32_t StripAccent(char32_t c)
|
||||
{
|
||||
// "ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖ×ØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõö÷øùúûüýþÿ"
|
||||
const char* tr = "AAAAAAÆCEEEEIIIIÐNOOOOO×ØUUUUYÞßaaaaaaæceeeeiiiiðnooooo÷øuuuuyþy";
|
||||
if (c < 192 || c > 255) {
|
||||
return c;
|
||||
}
|
||||
|
||||
return tr[c - 192];
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
|
||||
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
|
||||
|
@ -100,4 +130,4 @@ uint64_t Hash64Fast(const char* data, size_t n) {
|
|||
return static_cast<int64_t>(util::Fingerprint64(data, n));
|
||||
}
|
||||
|
||||
#endif // ENABLE_TF_STRING
|
||||
#endif // ENABLE_TF_STRING
|
||||
|
|
|
@ -49,8 +49,11 @@ std::string MakeString(const Args&... args) {
|
|||
|
||||
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries = false);
|
||||
|
||||
void char2unicode(const std::string& src, std::vector<uint32_t>& result);
|
||||
void unicode2char(const std::vector<uint32_t>& src, std::string& result);
|
||||
bool IsCJK(char32_t c);
|
||||
|
||||
bool IsAccent(char32_t c);
|
||||
|
||||
char32_t StripAccent(char32_t c);
|
||||
|
||||
uint64_t Hash64(const char* data, size_t n, uint64_t seed);
|
||||
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_utils.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
|
||||
do_lower_case_(do_lower_case), tokenize_chinese_chars_(tokenize_chinese_chars), strip_accents_(strip_accents), tokenize_punctuation_(tokenize_punctuation),
|
||||
remove_control_chars_(remove_control_chars){}
|
||||
|
||||
std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
||||
std::vector<ustring> result;
|
||||
ustring token;
|
||||
auto push_current_token_and_clear = [&result, &token]() {
|
||||
if (!token.empty()) {
|
||||
result.push_back(token);
|
||||
token.clear();
|
||||
}
|
||||
};
|
||||
|
||||
auto push_single_char_and_clear = [&result, &token](char32_t c) {
|
||||
token.push_back(c);
|
||||
result.push_back(token);
|
||||
token.clear();
|
||||
};
|
||||
|
||||
// strip accent first
|
||||
if (strip_accents_) {
|
||||
for (auto& c : text) {
|
||||
c = StripAccent(c);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_lower_case_) {
|
||||
for (auto& c : text) {
|
||||
c = ::tolower(c);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto c : text) {
|
||||
if (tokenize_chinese_chars_ && IsCJK(c)) {
|
||||
push_current_token_and_clear();
|
||||
push_single_char_and_clear(c);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (strip_accents_ && IsAccent(c)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tokenize_punctuation_ && ::ispunct(c)) {
|
||||
push_current_token_and_clear();
|
||||
push_single_char_and_clear(c);
|
||||
continue;
|
||||
}
|
||||
|
||||
// split by space
|
||||
if (::isspace(c)) {
|
||||
push_current_token_and_clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
// iscntrl will judge \t\f\n\r as control char
|
||||
// but it has been filter by isspace(c)
|
||||
if (remove_control_chars_ && ::iscntrl(c)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
token.push_back(c);
|
||||
}
|
||||
|
||||
push_current_token_and_clear();
|
||||
return result;
|
||||
}
|
||||
|
||||
KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
|
||||
bool remove_control_chars = TryToGetAttributeWithDefault("strip_accents", true);
|
||||
|
||||
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
|
||||
}
|
||||
|
||||
void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
if (dimensions.size() != 1 && dimensions[0] != 1) {
|
||||
ORT_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
|
||||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelBasicTokenizer(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
|
||||
|
||||
size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
#include "ustring.h"
|
||||
|
||||
class BasicTokenizer {
|
||||
public:
|
||||
BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars);
|
||||
std::vector<ustring> Tokenize(ustring text);
|
||||
|
||||
private:
|
||||
bool do_lower_case_;
|
||||
bool strip_accents_;
|
||||
bool tokenize_chinese_chars_;
|
||||
bool tokenize_punctuation_;
|
||||
bool remove_control_chars_;
|
||||
};
|
||||
|
||||
struct KernelBasicTokenizer : BaseKernel {
|
||||
KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
|
@ -0,0 +1,237 @@
|
|||
#include "bert_tokenizer.hpp"
|
||||
|
||||
#include <utility>
|
||||
|
||||
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(unk_token),
|
||||
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
|
||||
auto it = vocab_->find(unk_token);
|
||||
if (it == vocab_->end()) {
|
||||
ORT_CXX_API_THROW("[WordpieceTokenizer]: can not find unk_token in vocal", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
unk_token_id_ = it->second;
|
||||
}
|
||||
|
||||
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
|
||||
std::vector<ustring> result;
|
||||
ustring token;
|
||||
for (auto c : text) {
|
||||
if (c == U' ' && !token.empty()) {
|
||||
GreedySearch(token, result);
|
||||
token.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
token.push_back(c);
|
||||
}
|
||||
|
||||
if (!token.empty()) {
|
||||
GreedySearch(token, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
|
||||
std::vector<ustring> result;
|
||||
for (const auto& token : tokens) {
|
||||
GreedySearch(token, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
|
||||
std::vector<int64_t> ids;
|
||||
for (const auto& token : tokens) {
|
||||
auto it = vocab_->find(token);
|
||||
if (it == vocab_->end()) {
|
||||
ids.push_back(unk_token_id_);
|
||||
continue;
|
||||
}
|
||||
|
||||
ids.push_back(it->second);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
|
||||
if (token.size() > max_input_chars_per_word_) {
|
||||
tokenized_result.push_back(unk_token_);
|
||||
return;
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
int end = -1;
|
||||
ustring substr;
|
||||
for (; start < token.size();) {
|
||||
end = token.size();
|
||||
bool is_found = false;
|
||||
// try to found longest matched sub-token in vocab
|
||||
for (; start < end;) {
|
||||
substr = static_cast<const ustring>(token.substr(start, end - start));
|
||||
if (start > 0) {
|
||||
substr = static_cast<const ustring>(suffix_indicator_ + substr);
|
||||
}
|
||||
auto it = vocab_->find(substr);
|
||||
if (it != vocab_->end()) {
|
||||
is_found = true;
|
||||
break;
|
||||
}
|
||||
end -= 1;
|
||||
}
|
||||
// token not found in vocab
|
||||
if (!is_found) {
|
||||
tokenized_result.push_back(unk_token_);
|
||||
break;
|
||||
}
|
||||
|
||||
tokenized_result.push_back(substr);
|
||||
start = end;
|
||||
}
|
||||
}
|
||||
|
||||
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
|
||||
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
|
||||
auto tokens = SplitString(vocab, "\n", true);
|
||||
|
||||
vocab_ = std::make_shared<std::unordered_map<ustring, int32_t>>();
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
(*vocab_)[ustring(tokens[i])] = i;
|
||||
}
|
||||
|
||||
if (do_basic_tokenize) {
|
||||
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
|
||||
}
|
||||
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
|
||||
|
||||
unk_token_id_ = FindSpecialToken(unk_token);
|
||||
sep_token_id_ = FindSpecialToken(sep_token);
|
||||
pad_token_id_ = FindSpecialToken(pad_token);
|
||||
cls_token_id_ = FindSpecialToken(cls_token);
|
||||
mask_token_id_ = FindSpecialToken(mask_token);
|
||||
}
|
||||
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
|
||||
if (do_basic_tokenize_) {
|
||||
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
|
||||
}
|
||||
return wordpiece_tokenizer_->Tokenize(text);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
|
||||
return wordpiece_tokenizer_->Encode(tokens);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids.size() + 2);
|
||||
result.push_back(cls_token_id_);
|
||||
result.insert(result.end(), ids.begin(), ids.end());
|
||||
result.push_back(sep_token_id_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids1.size() + ids2.size() + 3);
|
||||
result.push_back(cls_token_id_);
|
||||
result.insert(result.end(), ids1.begin(), ids1.end());
|
||||
result.push_back(sep_token_id_);
|
||||
result.insert(result.end(), ids2.begin(), ids2.end());
|
||||
result.push_back(sep_token_id_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
|
||||
return std::vector<int64_t>(ids.size() + 2, 0);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids1.size() + ids2.size() + 3);
|
||||
result.insert(result.end(), ids1.size() + 2, 0);
|
||||
result.insert(result.end(), ids2.size() + 1, 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
int32_t BertTokenizer::FindSpecialToken(ustring token) {
|
||||
auto it = vocab_->find(token);
|
||||
if (it == vocab_->end()) {
|
||||
ORT_CXX_API_THROW("[BertTokenizer]: can not find special tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
||||
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
||||
std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
|
||||
std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
|
||||
std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
|
||||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
||||
|
||||
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
ustring(sep_token), ustring(pad_token),ustring(cls_token),
|
||||
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
|
||||
}
|
||||
|
||||
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
if (input_data.size() != 1 && input_data.size() != 2) {
|
||||
ORT_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
||||
}
|
||||
std::vector<int64_t> input_ids;
|
||||
std::vector<int64_t> token_type_ids;
|
||||
|
||||
if (input_data.size() == 1) {
|
||||
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
|
||||
input_ids = tokenizer_->AddSpecialToken(encode);
|
||||
token_type_ids = tokenizer_->GenerateTypeId(encode);
|
||||
} else {
|
||||
std::vector<int64_t> encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
|
||||
std::vector<int64_t> encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
|
||||
input_ids = tokenizer_->AddSpecialToken(encode1, encode2);
|
||||
token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2);
|
||||
}
|
||||
|
||||
std::vector<int64_t> attention_mask(input_ids.size(), 1);
|
||||
|
||||
std::vector<int64_t> output_dim({static_cast<int64_t>(input_ids.size())});
|
||||
|
||||
SetOutput(context, 0, output_dim, input_ids);
|
||||
SetOutput(context, 1, output_dim, token_type_ids);
|
||||
SetOutput(context, 2, output_dim, attention_mask);
|
||||
}
|
||||
|
||||
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelBertTokenizer(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; };
|
||||
|
||||
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "ocos.h"
|
||||
#include "ustring.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
|
||||
// TODO: merge with the implementation of word piece tokenizer
|
||||
class WordpieceTokenizer{
|
||||
public:
|
||||
WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
private:
|
||||
int64_t max_input_chars_per_word_;
|
||||
ustring suffix_indicator_;
|
||||
ustring unk_token_;
|
||||
int64_t unk_token_id_;
|
||||
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
|
||||
|
||||
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
|
||||
};
|
||||
|
||||
class BertTokenizer {
|
||||
public:
|
||||
BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
|
||||
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
|
||||
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
|
||||
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
|
||||
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
|
||||
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
|
||||
private:
|
||||
int32_t unk_token_id_;
|
||||
int32_t sep_token_id_;
|
||||
int32_t pad_token_id_;
|
||||
int32_t cls_token_id_;
|
||||
int32_t mask_token_id_;
|
||||
bool do_basic_tokenize_;
|
||||
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
|
||||
std::shared_ptr<BasicTokenizer> basic_tokenizer_;
|
||||
std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
|
||||
|
||||
int32_t FindSpecialToken(ustring token);
|
||||
};
|
||||
|
||||
struct KernelBertTokenizer : BaseKernel {
|
||||
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
private:
|
||||
std::shared_ptr<BertTokenizer> tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
|
@ -200,4 +200,4 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
|
|||
default:
|
||||
throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index));
|
||||
}
|
||||
};
|
||||
};
|
|
@ -43,4 +43,4 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
std::vector<int64_t>& rows,
|
||||
const int64_t* existing_rows = nullptr,
|
||||
int64_t n_existing_rows = 0,
|
||||
int64_t max_input_chars_per_word = 200);
|
||||
int64_t max_input_chars_per_word = 200);
|
|
@ -3,7 +3,7 @@
|
|||
#include <iostream>
|
||||
#include "ustring.h"
|
||||
|
||||
ustring::ustring(): std::u32string() {
|
||||
ustring::ustring() : std::u32string() {
|
||||
}
|
||||
|
||||
ustring::ustring(char* str) {
|
||||
|
@ -26,17 +26,27 @@ ustring::ustring(const std::string& str) {
|
|||
assign(str_cvt.from_bytes(str));
|
||||
}
|
||||
|
||||
ustring::ustring(char32_t* str):std::u32string(str) {}
|
||||
ustring::ustring(char32_t* str) : std::u32string(str) {}
|
||||
|
||||
ustring::ustring(const char32_t* str):std::u32string(str) {}
|
||||
ustring::ustring(const char32_t* str) : std::u32string(str) {}
|
||||
|
||||
ustring::ustring(std::u32string& str):std::u32string(str) {}
|
||||
ustring::ustring(std::u32string& str) : std::u32string(str) {}
|
||||
|
||||
ustring::ustring(std::u32string&& str):std::u32string(str) {}
|
||||
ustring::ustring(std::u32string&& str) : std::u32string(str) {}
|
||||
|
||||
ustring::ustring(const std::u32string& str):std::u32string(str) {}
|
||||
ustring::ustring(const std::u32string& str) : std::u32string(str) {}
|
||||
|
||||
ustring::ustring(const std::u32string&& str):std::u32string(str) {}
|
||||
ustring::ustring(const std::u32string&& str) : std::u32string(str) {}
|
||||
|
||||
ustring::ustring(std::string_view& str) {
|
||||
utf8_converter str_cvt;
|
||||
assign(str_cvt.from_bytes(str.data(), str.data() + str.size()));
|
||||
}
|
||||
|
||||
ustring::ustring(const std::string_view& str) {
|
||||
utf8_converter str_cvt;
|
||||
assign(str_cvt.from_bytes(str.data(), str.data() + str.size()));
|
||||
}
|
||||
|
||||
ustring::ustring(std::u32string_view& str):std::u32string(str) {}
|
||||
|
||||
|
@ -55,3 +65,4 @@ ustring::operator std::string() const {
|
|||
utf8_converter str_cvt;
|
||||
return str_cvt.to_bytes(*this);
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ class ustring : public std::u32string {
|
|||
explicit ustring(std::u32string&& str);
|
||||
explicit ustring(const std::u32string& str);
|
||||
explicit ustring(const std::u32string&& str);
|
||||
explicit ustring(std::string_view& str);
|
||||
explicit ustring(const std::string_view& str);
|
||||
explicit ustring(std::u32string_view& str);
|
||||
explicit ustring(std::u32string_view&& str);
|
||||
explicit ustring(const std::u32string_view& str);
|
||||
|
@ -36,8 +38,8 @@ namespace std {
|
|||
template <>
|
||||
struct hash<ustring> {
|
||||
size_t operator()(const ustring& __str) const noexcept {
|
||||
hash<u32string> hash;
|
||||
return hash(__str);
|
||||
hash<u32string> standard_hash;
|
||||
return standard_hash(static_cast<u32string>(__str));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
|
@ -22,6 +22,15 @@
|
|||
#include "text/string_ecmaregex_replace.hpp"
|
||||
#include "text/string_ecmaregex_split.hpp"
|
||||
|
||||
#ifdef ENABLE_BERT_TOKENIZER
|
||||
#include "bert_tokenizer.hpp"
|
||||
#include "basic_tokenizer.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BERT_TOKENIZER
|
||||
CustomOpBasicTokenizer c_CustomOpBasicTokenizer;
|
||||
CustomOpBertTokenizer c_CustomOpBertTokenizer;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
CustomOpSegmentSum c_CustomOpSegmentSum;
|
||||
|
@ -49,6 +58,10 @@ CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets;
|
|||
#endif
|
||||
|
||||
OrtCustomOp* operator_lists[] = {
|
||||
#ifdef ENABLE_BERT_TOKENIZER
|
||||
&c_CustomOpBasicTokenizer,
|
||||
&c_CustomOpBertTokenizer,
|
||||
#endif
|
||||
#ifdef ENABLE_TF_STRING
|
||||
&c_CustomOpRaggedTensorToDense,
|
||||
&c_CustomOpRaggedTensorToSparse,
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -4,6 +4,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "string_utils.h"
|
||||
#include "wordpiece_tokenizer.hpp"
|
||||
#include "bert_tokenizer.hpp"
|
||||
|
||||
TEST(tokenizer, bert_word_split) {
|
||||
ustring ind("##");
|
||||
|
@ -43,6 +44,14 @@ std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
|
|||
return vocab;
|
||||
}
|
||||
|
||||
std::vector<ustring> ustring_vector_convertor(std::vector<std::string> input) {
|
||||
std::vector<ustring> result;
|
||||
for (const auto& str : input) {
|
||||
result.emplace_back(str);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TEST(tokenizer, wordpiece_basic_tokenizer) {
|
||||
auto vocab = get_vocabulary_basic();
|
||||
std::vector<ustring> text = {ustring("UNwant\u00E9d,running")};
|
||||
|
@ -110,10 +119,26 @@ TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
|
|||
std::vector<int64_t> existing_indices({0, 2, 3});
|
||||
std::vector<ustring> text = {ustring("unwanted"), ustring("running"), ustring("running")};
|
||||
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows,
|
||||
existing_indices.data(), existing_indices.size());
|
||||
existing_indices.data(), existing_indices.size());
|
||||
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
|
||||
ustring("runn"), ustring("##ing"),
|
||||
ustring("runn"), ustring("##ing")}));
|
||||
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9, 8, 9}));
|
||||
EXPECT_EQ(rows, std::vector<int64_t>({0, 5, 7}));
|
||||
}
|
||||
|
||||
TEST(tokenizer, basic_tokenizer_chinese) {
|
||||
ustring test_case = ustring("ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~");
|
||||
std::vector<ustring> expect_result = ustring_vector_convertor({"aaaaaaceeeeiiinooooouu", "䗓", "𨖷", "虴", "𨀐", "辘", "𧄋", "脟", "𩑢", "𡗶", "镇", "伢", "𧎼", "䪱", "轚", "榶", "𢑌", "㺽", "𤨡", "!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":", ";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
|
||||
BasicTokenizer tokenizer(true, true, true, true, true);
|
||||
auto result = tokenizer.Tokenize(test_case);
|
||||
EXPECT_EQ(result, expect_result);
|
||||
}
|
||||
|
||||
TEST(tokenizer, basic_tokenizer_russia) {
|
||||
ustring test_case = ustring("A $100,000 price-tag@big>small на русском языке");
|
||||
std::vector<ustring> expect_result = ustring_vector_convertor({"a", "$", "100", ",", "000", "price", "-", "tag", "@", "big", ">", "small", "на", "русском", "языке"});
|
||||
BasicTokenizer tokenizer(true, true, true, true, true);
|
||||
auto result = tokenizer.Tokenize(test_case);
|
||||
EXPECT_EQ(result, expect_result);
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
#include "re2/re2.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "string_utils.h"
|
||||
#include "ustring.h"
|
||||
|
||||
|
||||
TEST(utils, make_string) {
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
from pathlib import Path
|
||||
import unittest
|
||||
import numpy as np
|
||||
import transformers
|
||||
from onnxruntime_extensions import PyOrtFunction, BertTokenizer
|
||||
|
||||
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
|
||||
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
def _run_basic_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
|
||||
result = t2stc([input])
|
||||
expect_result = bert_cased_tokenizer.encode_plus(input)
|
||||
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
|
||||
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
|
||||
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
|
||||
|
||||
def _run_combined_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
|
||||
result = t2stc(input)
|
||||
expect_result = bert_cased_tokenizer.encode_plus(input[0], input[1])
|
||||
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
|
||||
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
|
||||
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
|
||||
|
||||
|
||||
class TestBertTokenizer(unittest.TestCase):
|
||||
|
||||
def test_text_to_case1(self):
|
||||
_run_basic_case(input="Input 'text' must not be empty.",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网 易 云 音 乐",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat is playing toys",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"], vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -61,7 +61,7 @@ if __name__ == '__main__':
|
|||
if len(sys.argv) == 2:
|
||||
print('[onnxruntime-extensions] Generating _selectedoplist.cmake file to folder: ${PROJECT_SOURCE_DIR}/cmake/')
|
||||
current_dir = os.path.dirname(__file__)
|
||||
target_cmake_path = os.path.abspath(os.path.join(current_dir, '../../cmake/_selectedoplist.cmake'))
|
||||
target_cmake_path = os.path.abspath(os.path.join(current_dir, '../cmake/_selectedoplist.cmake'))
|
||||
print('[onnxruntime-extensions] Target cmake file path: ', target_cmake_path)
|
||||
|
||||
gen_cmake_oplist(sys.argv[1], target_cmake_path)
|
||||
|
|
Загрузка…
Ссылка в новой задаче