* init

* update

* update

* update

* update

* update

* update

* Modify relative path of generated cmake file.

* update

* udapte

* fix the bug

* update

* fix bugs

Co-authored-by: Ze Tao <zetao@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Mojimi 2021-08-27 04:50:03 +08:00 коммит произвёл GitHub
Родитель 6c3b496e3f
Коммит aef5ef1ef1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
20 изменённых файлов: 29681 добавлений и 20 удалений

Просмотреть файл

@ -159,7 +159,7 @@ endif()
if (OCOS_ENABLE_BERT_TOKENIZER)
# Bert
set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/wordpiece*.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/wordpiece*.*" "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()

Просмотреть файл

@ -22,9 +22,19 @@ struct BaseKernel {
BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
bool HasAttribute(const char* name) const;
template <class T>
bool TryToGetAttribute(const char* name, T& value);
template <class T>
T TryToGetAttributeWithDefault(const char* name, T default_value) {
T& result = default_value;
TryToGetAttribute(name, result);
return result;
}
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_

Просмотреть файл

@ -81,7 +81,7 @@ class StringToVector(CustomOp):
for k_, v_ in attrs.items():
if k_ == 'map' and isinstance(v_, dict):
attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items())
elif k_ == 'unk' and isinstance(v_, list):
elif k_ == 'unk' and isinstance(v_, list):
attr_data[k_] = ' '.join(str(i) for i in v_)
else:
attr_data[k_] = v_
@ -109,6 +109,30 @@ class BlingFireSentenceBreaker(CustomOp):
return attrs_data
class BertTokenizer(CustomOp):
@classmethod
def get_inputs(cls):
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
@classmethod
def get_outputs(cls):
return [cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])]
@classmethod
def serialize_attr(cls, attrs):
attrs_data = {}
for k_, v_ in attrs.items():
if k_ == 'vocab_file':
with open(v_, "r", encoding='utf-8') as model_file:
lines = model_file.readlines()
attrs_data[k_] = '\n'.join(lines)
else:
attrs_data[k_] = v_
return attrs_data
class SentencepieceTokenizer(CustomOp):
@classmethod
def get_inputs(cls):

Просмотреть файл

@ -3,7 +3,6 @@
#include <sstream>
#include "ocos.h"
bool BaseKernel::HasAttribute(const char* name) const {
if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
@ -36,6 +35,14 @@ OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
return error_code;
}
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
for (int i = 0; i < data.size(); i++) {
data_ptr[i] = data[i];
}
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
if (info_ == nullptr) {
@ -76,4 +83,19 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
}
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
}
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
}
int64_t origin_value = 0;
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &origin_value)) != ORT_OK) {
return false;
}
value = origin_value == 1;
return true;
}

Просмотреть файл

@ -8,6 +8,7 @@ std::vector<std::string_view> SplitString(const std::string_view& str, const std
std::vector<std::string_view> result;
std::string ::size_type pre_pos = 0;
//TODO: bug fix
while (true) {
auto next_pos = str.find_first_of(seps, pre_pos);
@ -32,6 +33,35 @@ std::vector<std::string_view> SplitString(const std::string_view& str, const std
return result;
}
bool IsCJK(char32_t c) {
return (c >= 0x4E00 && c <= 0x9FFF)
|| (c >= 0x3400 && c <= 0x4DBF)
|| (c >= 0x20000 && c <= 0x2A6DF)
|| (c >= 0x2A700 && c <= 0x2B73F)
|| (c >= 0x2B740 && c <= 0x2B81F)
|| (c >= 0x2B820 && c <= 0x2CEAF)
|| (c >= 0xF900 && c <= 0xFAFF)
|| (c >= 0x2F800 && c <= 0x2FA1F);
}
bool IsAccent(char32_t c)
{
// only support part of accent
// [TODO] support more accent
return c >= 0x300 && c <= 0x36F;
}
char32_t StripAccent(char32_t c)
{
// "ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖ×ØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõö÷øùúûüýþÿ"
const char* tr = "AAAAAAÆCEEEEIIIIÐNOOOOO×ØUUUUYÞßaaaaaaæceeeeiiiiðnooooo÷øuuuuyþy";
if (c < 192 || c > 255) {
return c;
}
return tr[c - 192];
}
#ifdef ENABLE_TF_STRING
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
@ -100,4 +130,4 @@ uint64_t Hash64Fast(const char* data, size_t n) {
return static_cast<int64_t>(util::Fingerprint64(data, n));
}
#endif // ENABLE_TF_STRING
#endif // ENABLE_TF_STRING

Просмотреть файл

@ -49,8 +49,11 @@ std::string MakeString(const Args&... args) {
std::vector<std::string_view> SplitString(const std::string_view& str, const std::string_view& seps, bool remove_empty_entries = false);
void char2unicode(const std::string& src, std::vector<uint32_t>& result);
void unicode2char(const std::vector<uint32_t>& src, std::string& result);
bool IsCJK(char32_t c);
bool IsAccent(char32_t c);
char32_t StripAccent(char32_t c);
uint64_t Hash64(const char* data, size_t n, uint64_t seed);

Просмотреть файл

@ -0,0 +1,128 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_utils.h"
#include "basic_tokenizer.hpp"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
#include <algorithm>
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
do_lower_case_(do_lower_case), tokenize_chinese_chars_(tokenize_chinese_chars), strip_accents_(strip_accents), tokenize_punctuation_(tokenize_punctuation),
remove_control_chars_(remove_control_chars){}
std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
std::vector<ustring> result;
ustring token;
auto push_current_token_and_clear = [&result, &token]() {
if (!token.empty()) {
result.push_back(token);
token.clear();
}
};
auto push_single_char_and_clear = [&result, &token](char32_t c) {
token.push_back(c);
result.push_back(token);
token.clear();
};
// strip accent first
if (strip_accents_) {
for (auto& c : text) {
c = StripAccent(c);
}
}
if (do_lower_case_) {
for (auto& c : text) {
c = ::tolower(c);
}
}
for (auto c : text) {
if (tokenize_chinese_chars_ && IsCJK(c)) {
push_current_token_and_clear();
push_single_char_and_clear(c);
continue;
}
if (strip_accents_ && IsAccent(c)) {
continue;
}
if (tokenize_punctuation_ && ::ispunct(c)) {
push_current_token_and_clear();
push_single_char_and_clear(c);
continue;
}
// split by space
if (::isspace(c)) {
push_current_token_and_clear();
continue;
}
// iscntrl will judge \t\f\n\r as control char
// but it has been filter by isspace(c)
if (remove_control_chars_ && ::iscntrl(c)) {
continue;
}
token.push_back(c);
}
push_current_token_and_clear();
return result;
}
KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
bool remove_control_chars = TryToGetAttributeWithDefault("strip_accents", true);
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
}
void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
OrtTensorDimensions dimensions(ort_, input);
if (dimensions.size() != 1 && dimensions[0] != 1) {
ORT_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
FillTensorDataString(api_, ort_, context, result, output);
}
void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new KernelBasicTokenizer(api, info);
};
const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "string_utils.h"
#include "ustring.h"
class BasicTokenizer {
public:
BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars);
std::vector<ustring> Tokenize(ustring text);
private:
bool do_lower_case_;
bool strip_accents_;
bool tokenize_chinese_chars_;
bool tokenize_punctuation_;
bool remove_control_chars_;
};
struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BasicTokenizer> tokenizer_;
};
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -0,0 +1,237 @@
#include "bert_tokenizer.hpp"
#include <utility>
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(unk_token),
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
auto it = vocab_->find(unk_token);
if (it == vocab_->end()) {
ORT_CXX_API_THROW("[WordpieceTokenizer]: can not find unk_token in vocal", ORT_RUNTIME_EXCEPTION);
}
unk_token_id_ = it->second;
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
std::vector<ustring> result;
ustring token;
for (auto c : text) {
if (c == U' ' && !token.empty()) {
GreedySearch(token, result);
token.clear();
continue;
}
token.push_back(c);
}
if (!token.empty()) {
GreedySearch(token, result);
}
return result;
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
std::vector<ustring> result;
for (const auto& token : tokens) {
GreedySearch(token, result);
}
return result;
}
std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
std::vector<int64_t> ids;
for (const auto& token : tokens) {
auto it = vocab_->find(token);
if (it == vocab_->end()) {
ids.push_back(unk_token_id_);
continue;
}
ids.push_back(it->second);
}
return ids;
}
void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
if (token.size() > max_input_chars_per_word_) {
tokenized_result.push_back(unk_token_);
return;
}
int start = 0;
int end = -1;
ustring substr;
for (; start < token.size();) {
end = token.size();
bool is_found = false;
// try to found longest matched sub-token in vocab
for (; start < end;) {
substr = static_cast<const ustring>(token.substr(start, end - start));
if (start > 0) {
substr = static_cast<const ustring>(suffix_indicator_ + substr);
}
auto it = vocab_->find(substr);
if (it != vocab_->end()) {
is_found = true;
break;
}
end -= 1;
}
// token not found in vocab
if (!is_found) {
tokenized_result.push_back(unk_token_);
break;
}
tokenized_result.push_back(substr);
start = end;
}
}
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
auto tokens = SplitString(vocab, "\n", true);
vocab_ = std::make_shared<std::unordered_map<ustring, int32_t>>();
for (int i = 0; i < tokens.size(); i++) {
(*vocab_)[ustring(tokens[i])] = i;
}
if (do_basic_tokenize) {
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
}
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
unk_token_id_ = FindSpecialToken(unk_token);
sep_token_id_ = FindSpecialToken(sep_token);
pad_token_id_ = FindSpecialToken(pad_token);
cls_token_id_ = FindSpecialToken(cls_token);
mask_token_id_ = FindSpecialToken(mask_token);
}
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
if (do_basic_tokenize_) {
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
}
return wordpiece_tokenizer_->Tokenize(text);
}
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
return wordpiece_tokenizer_->Encode(tokens);
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
std::vector<int64_t> result;
result.reserve(ids.size() + 2);
result.push_back(cls_token_id_);
result.insert(result.end(), ids.begin(), ids.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.push_back(cls_token_id_);
result.insert(result.end(), ids1.begin(), ids1.end());
result.push_back(sep_token_id_);
result.insert(result.end(), ids2.begin(), ids2.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
return std::vector<int64_t>(ids.size() + 2, 0);
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.insert(result.end(), ids1.size() + 2, 0);
result.insert(result.end(), ids2.size() + 1, 1);
return result;
}
int32_t BertTokenizer::FindSpecialToken(ustring token) {
auto it = vocab_->find(token);
if (it == vocab_->end()) {
ORT_CXX_API_THROW("[BertTokenizer]: can not find special tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
}
return it->second;
}
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
ustring(sep_token), ustring(pad_token),ustring(cls_token),
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
}
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
if (input_data.size() != 1 && input_data.size() != 2) {
ORT_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
}
std::vector<int64_t> input_ids;
std::vector<int64_t> token_type_ids;
if (input_data.size() == 1) {
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
input_ids = tokenizer_->AddSpecialToken(encode);
token_type_ids = tokenizer_->GenerateTypeId(encode);
} else {
std::vector<int64_t> encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
std::vector<int64_t> encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
input_ids = tokenizer_->AddSpecialToken(encode1, encode2);
token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2);
}
std::vector<int64_t> attention_mask(input_ids.size(), 1);
std::vector<int64_t> output_dim({static_cast<int64_t>(input_ids.size())});
SetOutput(context, 0, output_dim, input_ids);
SetOutput(context, 1, output_dim, token_type_ids);
SetOutput(context, 2, output_dim, attention_mask);
}
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new KernelBertTokenizer(api, info);
};
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; };
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <unordered_map>
#include <vector>
#include "ocos.h"
#include "ustring.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "basic_tokenizer.hpp"
// TODO: merge with the implementation of word piece tokenizer
class WordpieceTokenizer{
public:
WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
private:
int64_t max_input_chars_per_word_;
ustring suffix_indicator_;
ustring unk_token_;
int64_t unk_token_id_;
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
};
class BertTokenizer {
public:
BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
private:
int32_t unk_token_id_;
int32_t sep_token_id_;
int32_t pad_token_id_;
int32_t cls_token_id_;
int32_t mask_token_id_;
bool do_basic_tokenize_;
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
std::shared_ptr<BasicTokenizer> basic_tokenizer_;
std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
int32_t FindSpecialToken(ustring token);
};
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BertTokenizer> tokenizer_;
};
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -200,4 +200,4 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
default:
throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index));
}
};
};

Просмотреть файл

@ -43,4 +43,4 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
std::vector<int64_t>& rows,
const int64_t* existing_rows = nullptr,
int64_t n_existing_rows = 0,
int64_t max_input_chars_per_word = 200);
int64_t max_input_chars_per_word = 200);

Просмотреть файл

@ -3,7 +3,7 @@
#include <iostream>
#include "ustring.h"
ustring::ustring(): std::u32string() {
ustring::ustring() : std::u32string() {
}
ustring::ustring(char* str) {
@ -26,17 +26,27 @@ ustring::ustring(const std::string& str) {
assign(str_cvt.from_bytes(str));
}
ustring::ustring(char32_t* str):std::u32string(str) {}
ustring::ustring(char32_t* str) : std::u32string(str) {}
ustring::ustring(const char32_t* str):std::u32string(str) {}
ustring::ustring(const char32_t* str) : std::u32string(str) {}
ustring::ustring(std::u32string& str):std::u32string(str) {}
ustring::ustring(std::u32string& str) : std::u32string(str) {}
ustring::ustring(std::u32string&& str):std::u32string(str) {}
ustring::ustring(std::u32string&& str) : std::u32string(str) {}
ustring::ustring(const std::u32string& str):std::u32string(str) {}
ustring::ustring(const std::u32string& str) : std::u32string(str) {}
ustring::ustring(const std::u32string&& str):std::u32string(str) {}
ustring::ustring(const std::u32string&& str) : std::u32string(str) {}
ustring::ustring(std::string_view& str) {
utf8_converter str_cvt;
assign(str_cvt.from_bytes(str.data(), str.data() + str.size()));
}
ustring::ustring(const std::string_view& str) {
utf8_converter str_cvt;
assign(str_cvt.from_bytes(str.data(), str.data() + str.size()));
}
ustring::ustring(std::u32string_view& str):std::u32string(str) {}
@ -55,3 +65,4 @@ ustring::operator std::string() const {
utf8_converter str_cvt;
return str_cvt.to_bytes(*this);
}

Просмотреть файл

@ -20,6 +20,8 @@ class ustring : public std::u32string {
explicit ustring(std::u32string&& str);
explicit ustring(const std::u32string& str);
explicit ustring(const std::u32string&& str);
explicit ustring(std::string_view& str);
explicit ustring(const std::string_view& str);
explicit ustring(std::u32string_view& str);
explicit ustring(std::u32string_view&& str);
explicit ustring(const std::u32string_view& str);
@ -36,8 +38,8 @@ namespace std {
template <>
struct hash<ustring> {
size_t operator()(const ustring& __str) const noexcept {
hash<u32string> hash;
return hash(__str);
hash<u32string> standard_hash;
return standard_hash(static_cast<u32string>(__str));
}
};
} // namespace std

Просмотреть файл

@ -22,6 +22,15 @@
#include "text/string_ecmaregex_replace.hpp"
#include "text/string_ecmaregex_split.hpp"
#ifdef ENABLE_BERT_TOKENIZER
#include "bert_tokenizer.hpp"
#include "basic_tokenizer.hpp"
#endif
#ifdef ENABLE_BERT_TOKENIZER
CustomOpBasicTokenizer c_CustomOpBasicTokenizer;
CustomOpBertTokenizer c_CustomOpBertTokenizer;
#endif
#ifdef ENABLE_TF_STRING
CustomOpSegmentSum c_CustomOpSegmentSum;
@ -49,6 +58,10 @@ CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets;
#endif
OrtCustomOp* operator_lists[] = {
#ifdef ENABLE_BERT_TOKENIZER
&c_CustomOpBasicTokenizer,
&c_CustomOpBertTokenizer,
#endif
#ifdef ENABLE_TF_STRING
&c_CustomOpRaggedTensorToDense,
&c_CustomOpRaggedTensorToSparse,

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "string_utils.h"
#include "wordpiece_tokenizer.hpp"
#include "bert_tokenizer.hpp"
TEST(tokenizer, bert_word_split) {
ustring ind("##");
@ -43,6 +44,14 @@ std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
return vocab;
}
std::vector<ustring> ustring_vector_convertor(std::vector<std::string> input) {
std::vector<ustring> result;
for (const auto& str : input) {
result.emplace_back(str);
}
return result;
}
TEST(tokenizer, wordpiece_basic_tokenizer) {
auto vocab = get_vocabulary_basic();
std::vector<ustring> text = {ustring("UNwant\u00E9d,running")};
@ -110,10 +119,26 @@ TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
std::vector<int64_t> existing_indices({0, 2, 3});
std::vector<ustring> text = {ustring("unwanted"), ustring("running"), ustring("running")};
KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows,
existing_indices.data(), existing_indices.size());
existing_indices.data(), existing_indices.size());
EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
ustring("runn"), ustring("##ing"),
ustring("runn"), ustring("##ing")}));
EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9, 8, 9}));
EXPECT_EQ(rows, std::vector<int64_t>({0, 5, 7}));
}
TEST(tokenizer, basic_tokenizer_chinese) {
ustring test_case = ustring("ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~");
std::vector<ustring> expect_result = ustring_vector_convertor({"aaaaaaceeeeiiinooooouu", "", "𨖷", "", "𨀐", "", "𧄋", "", "𩑢", "𡗶", "", "", "𧎼", "", "", "", "𢑌", "", "𤨡", "!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":", ";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
BasicTokenizer tokenizer(true, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}
TEST(tokenizer, basic_tokenizer_russia) {
ustring test_case = ustring("A $100,000 price-tag@big>small на русском языке");
std::vector<ustring> expect_result = ustring_vector_convertor({"a", "$", "100", ",", "000", "price", "-", "tag", "@", "big", ">", "small", "на", "русском", "языке"});
BasicTokenizer tokenizer(true, true, true, true, true);
auto result = tokenizer.Tokenize(test_case);
EXPECT_EQ(result, expect_result);
}

Просмотреть файл

@ -5,6 +5,7 @@
#include "re2/re2.h"
#include "nlohmann/json.hpp"
#include "string_utils.h"
#include "ustring.h"
TEST(utils, make_string) {

Просмотреть файл

@ -0,0 +1,51 @@
from pathlib import Path
import unittest
import numpy as np
import transformers
from onnxruntime_extensions import PyOrtFunction, BertTokenizer
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
def _get_test_data_file(*sub_dirs):
test_dir = Path(__file__).parent
return str(test_dir.joinpath(*sub_dirs))
def _run_basic_case(input, vocab_path):
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
result = t2stc([input])
expect_result = bert_cased_tokenizer.encode_plus(input)
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
def _run_combined_case(input, vocab_path):
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0)
result = t2stc(input)
expect_result = bert_cased_tokenizer.encode_plus(input[0], input[1])
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
class TestBertTokenizer(unittest.TestCase):
def test_text_to_case1(self):
_run_basic_case(input="Input 'text' must not be empty.",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="网 易 云 音 乐",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="cat is playing toys",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="cat isnot playing toyssss",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_basic_case(input="cat isnot playing toyssss",
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"], vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -61,7 +61,7 @@ if __name__ == '__main__':
if len(sys.argv) == 2:
print('[onnxruntime-extensions] Generating _selectedoplist.cmake file to folder: ${PROJECT_SOURCE_DIR}/cmake/')
current_dir = os.path.dirname(__file__)
target_cmake_path = os.path.abspath(os.path.join(current_dir, '../../cmake/_selectedoplist.cmake'))
target_cmake_path = os.path.abspath(os.path.join(current_dir, '../cmake/_selectedoplist.cmake'))
print('[onnxruntime-extensions] Target cmake file path: ', target_cmake_path)
gen_cmake_oplist(sys.argv[1], target_cmake_path)