Reduce bert tokenize memory usage (#156)

* add BertTokenizerVocab

* improve format

Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Mojimi 2021-09-28 02:19:57 +08:00 коммит произвёл GitHub
Родитель d8cdb8e042
Коммит 2d6cf0b4ea
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 171 добавлений и 146 удалений

Просмотреть файл

@ -2,14 +2,47 @@
#include <utility>
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(unk_token),
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
auto it = vocab_->find(unk_token);
if (it == vocab_->end()) {
ORT_CXX_API_THROW("[WordpieceTokenizer]: can not find unk_token in vocal", ORT_RUNTIME_EXCEPTION);
BertTokenizerVocab::BertTokenizerVocab(std::string vocab) : raw_vocab_(vocab) {
auto tokens = SplitString(raw_vocab_, "\n", true);
for (int i = 0; i < tokens.size(); i++) {
(vocab_)[tokens[i]] = i;
}
unk_token_id_ = it->second;
}
bool BertTokenizerVocab::FindToken(const ustring& token) {
auto utf8_token = std::string(token);
return vocab_.find(utf8_token) != vocab_.end();
}
bool BertTokenizerVocab::FindTokenId(const ustring& token, int32_t& token_id) {
auto utf8_token = std::string(token);
auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) {
return false;
}
token_id = it->second;
return true;
}
int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
auto utf8_token = std::string(token);
auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) {
ORT_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
}
return it->second;
}
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(std::move(unk_token)),
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
unk_token_id_ = vocab_->FindTokenId(unk_token_);
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
@ -44,13 +77,13 @@ std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& to
std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
std::vector<int64_t> ids;
for (const auto& token : tokens) {
auto it = vocab_->find(token);
if (it == vocab_->end()) {
int32_t token_id = -1;
if (!vocab_->FindTokenId(token, token_id)) {
ids.push_back(unk_token_id_);
continue;
}
ids.push_back(it->second);
ids.push_back(token_id);
}
return ids;
}
@ -67,14 +100,13 @@ void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>
for (; start < token.size();) {
end = token.size();
bool is_found = false;
// try to found longest matched sub-token in vocab
// try to found the longest matched sub-token in vocab
for (; start < end;) {
substr = static_cast<const ustring>(token.substr(start, end - start));
if (start > 0) {
substr = static_cast<const ustring>(suffix_indicator_ + substr);
}
auto it = vocab_->find(substr);
if (it != vocab_->end()) {
if (vocab_->FindToken(substr)) {
is_found = true;
break;
}
@ -91,90 +123,6 @@ void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>
}
}
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
auto tokens = SplitString(vocab, "\n", true);
vocab_ = std::make_shared<std::unordered_map<ustring, int32_t>>();
for (int i = 0; i < tokens.size(); i++) {
(*vocab_)[ustring(tokens[i])] = i;
}
if (do_basic_tokenize) {
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
}
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
unk_token_id_ = FindSpecialToken(unk_token);
sep_token_id_ = FindSpecialToken(sep_token);
pad_token_id_ = FindSpecialToken(pad_token);
cls_token_id_ = FindSpecialToken(cls_token);
mask_token_id_ = FindSpecialToken(mask_token);
}
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
if (do_basic_tokenize_) {
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
}
return wordpiece_tokenizer_->Tokenize(text);
}
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
return wordpiece_tokenizer_->Encode(tokens);
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
std::vector<int64_t> result;
result.reserve(ids.size() + 2);
result.push_back(cls_token_id_);
result.insert(result.end(), ids.begin(), ids.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.push_back(cls_token_id_);
result.insert(result.end(), ids1.begin(), ids1.end());
result.push_back(sep_token_id_);
result.insert(result.end(), ids2.begin(), ids2.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
return std::vector<int64_t>(ids.size() + 2, 0);
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.insert(result.end(), ids1.size() + 2, 0);
result.insert(result.end(), ids2.size() + 1, 1);
return result;
}
int32_t BertTokenizer::FindSpecialToken(ustring token) {
auto it = vocab_->find(token);
if (it == vocab_->end()) {
ORT_CXX_API_THROW("[BertTokenizer]: can not find special tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
}
return it->second;
}
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
if (strategy_name == "longest_first") {
strategy_ = TruncateStrategyType::LONGEST_FIRST;
} else if (strategy_name == "only_first") {
strategy_ = TruncateStrategyType::ONLY_FIRST;
} else if (strategy_name == "only_second") {
strategy_ = TruncateStrategyType::ONLY_SECOND;
} else if (strategy_name == "longest_from_back") {
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
}
}
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
if (max_len < 0 || max_len >= ids.size()) {
return;
@ -184,7 +132,6 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
}
void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len) {
if (max_len < 0 || (input1.size() + input2.size() <= max_len)) {
return;
}
@ -224,6 +171,77 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_
}
}
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
if (do_basic_tokenize) {
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
}
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
unk_token_id_ = vocab_->FindTokenId(unk_token);
sep_token_id_ = vocab_->FindTokenId(sep_token);
pad_token_id_ = vocab_->FindTokenId(pad_token);
cls_token_id_ = vocab_->FindTokenId(cls_token);
mask_token_id_ = vocab_->FindTokenId(mask_token);
}
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
if (do_basic_tokenize_) {
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
}
return wordpiece_tokenizer_->Tokenize(text);
}
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
return wordpiece_tokenizer_->Encode(tokens);
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
std::vector<int64_t> result;
result.reserve(ids.size() + 2);
result.push_back(cls_token_id_);
result.insert(result.end(), ids.begin(), ids.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.push_back(cls_token_id_);
result.insert(result.end(), ids1.begin(), ids1.end());
result.push_back(sep_token_id_);
result.insert(result.end(), ids2.begin(), ids2.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
return std::vector<int64_t>(ids.size() + 2, 0);
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.insert(result.end(), ids1.size() + 2, 0);
result.insert(result.end(), ids2.size() + 1, 1);
return result;
}
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
if (strategy_name == "longest_first") {
strategy_ = TruncateStrategyType::LONGEST_FIRST;
} else if (strategy_name == "only_first") {
strategy_ = TruncateStrategyType::ONLY_FIRST;
} else if (strategy_name == "only_second") {
strategy_ = TruncateStrategyType::ONLY_SECOND;
} else if (strategy_name == "longest_from_back") {
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
}
}
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
@ -239,9 +257,8 @@ KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info)
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1));
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
ustring(sep_token), ustring(pad_token),ustring(cls_token),
ustring(sep_token), ustring(pad_token), ustring(cls_token),
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
truncate_ = std::make_shared<TruncateStrategy>(truncation_strategy_name);
@ -307,5 +324,3 @@ size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -11,51 +11,18 @@
#include "string_tensor.h"
#include "basic_tokenizer.hpp"
// TODO: merge with the implementation of word piece tokenizer
class WordpieceTokenizer{
class BertTokenizerVocab {
public:
WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
explicit BertTokenizerVocab(std::string vocab);
bool FindToken(const ustring& token);
bool FindTokenId(const ustring& token, int32_t& token_id);
int32_t FindTokenId(const ustring& token);
private:
int64_t max_input_chars_per_word_;
ustring suffix_indicator_;
ustring unk_token_;
int64_t unk_token_id_;
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
std::string raw_vocab_;
std::unordered_map<std::string_view, int32_t> vocab_;
};
class BertTokenizer {
public:
BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
private:
int32_t unk_token_id_;
int32_t sep_token_id_;
int32_t pad_token_id_;
int32_t cls_token_id_;
int32_t mask_token_id_;
bool do_basic_tokenize_;
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
std::shared_ptr<BasicTokenizer> basic_tokenizer_;
std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
int32_t FindSpecialToken(ustring token);
};
class TruncateStrategy {
public:
explicit TruncateStrategy(std::string strategy_name);
@ -63,18 +30,61 @@ class TruncateStrategy {
void Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len);
private:
enum TruncateStrategyType{
enum TruncateStrategyType {
LONGEST_FIRST,
ONLY_FIRST,
ONLY_SECOND,
LONGEST_FROM_BACK
}strategy_;
} strategy_;
};
// TODO: merge with the implementation of word piece tokenizer
class WordpieceTokenizer {
public:
WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
private:
int64_t max_input_chars_per_word_;
ustring suffix_indicator_;
ustring unk_token_;
int32_t unk_token_id_;
std::shared_ptr<BertTokenizerVocab> vocab_;
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
};
class BertTokenizer {
public:
BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator);
std::vector<ustring> Tokenize(const ustring& text);
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
private:
int32_t unk_token_id_;
int32_t sep_token_id_;
int32_t pad_token_id_;
int32_t cls_token_id_;
int32_t mask_token_id_;
bool do_basic_tokenize_;
std::shared_ptr<BertTokenizerVocab> vocab_;
std::shared_ptr<BasicTokenizer> basic_tokenizer_;
std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
};
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
std::shared_ptr<BertTokenizer> tokenizer_;
std::shared_ptr<TruncateStrategy> truncate_;