Reduce bert tokenize memory usage (#156)
* add BertTokenizerVocab * improve format Co-authored-by: Ze Tao <zetao@microsoft.com>
This commit is contained in:
Родитель
d8cdb8e042
Коммит
2d6cf0b4ea
|
@ -2,14 +2,47 @@
|
|||
|
||||
#include <utility>
|
||||
|
||||
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(unk_token),
|
||||
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
|
||||
auto it = vocab_->find(unk_token);
|
||||
if (it == vocab_->end()) {
|
||||
ORT_CXX_API_THROW("[WordpieceTokenizer]: can not find unk_token in vocal", ORT_RUNTIME_EXCEPTION);
|
||||
BertTokenizerVocab::BertTokenizerVocab(std::string vocab) : raw_vocab_(vocab) {
|
||||
auto tokens = SplitString(raw_vocab_, "\n", true);
|
||||
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
(vocab_)[tokens[i]] = i;
|
||||
}
|
||||
unk_token_id_ = it->second;
|
||||
}
|
||||
|
||||
bool BertTokenizerVocab::FindToken(const ustring& token) {
|
||||
auto utf8_token = std::string(token);
|
||||
|
||||
return vocab_.find(utf8_token) != vocab_.end();
|
||||
}
|
||||
|
||||
bool BertTokenizerVocab::FindTokenId(const ustring& token, int32_t& token_id) {
|
||||
auto utf8_token = std::string(token);
|
||||
|
||||
auto it = vocab_.find(utf8_token);
|
||||
if (it == vocab_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
token_id = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
|
||||
auto utf8_token = std::string(token);
|
||||
|
||||
auto it = vocab_.find(utf8_token);
|
||||
if (it == vocab_.end()) {
|
||||
ORT_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(std::move(unk_token)),
|
||||
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
|
||||
unk_token_id_ = vocab_->FindTokenId(unk_token_);
|
||||
}
|
||||
|
||||
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
|
||||
|
@ -44,13 +77,13 @@ std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& to
|
|||
std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
|
||||
std::vector<int64_t> ids;
|
||||
for (const auto& token : tokens) {
|
||||
auto it = vocab_->find(token);
|
||||
if (it == vocab_->end()) {
|
||||
int32_t token_id = -1;
|
||||
if (!vocab_->FindTokenId(token, token_id)) {
|
||||
ids.push_back(unk_token_id_);
|
||||
continue;
|
||||
}
|
||||
|
||||
ids.push_back(it->second);
|
||||
ids.push_back(token_id);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
@ -67,14 +100,13 @@ void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>
|
|||
for (; start < token.size();) {
|
||||
end = token.size();
|
||||
bool is_found = false;
|
||||
// try to found longest matched sub-token in vocab
|
||||
// try to found the longest matched sub-token in vocab
|
||||
for (; start < end;) {
|
||||
substr = static_cast<const ustring>(token.substr(start, end - start));
|
||||
if (start > 0) {
|
||||
substr = static_cast<const ustring>(suffix_indicator_ + substr);
|
||||
}
|
||||
auto it = vocab_->find(substr);
|
||||
if (it != vocab_->end()) {
|
||||
if (vocab_->FindToken(substr)) {
|
||||
is_found = true;
|
||||
break;
|
||||
}
|
||||
|
@ -91,90 +123,6 @@ void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>
|
|||
}
|
||||
}
|
||||
|
||||
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
|
||||
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
|
||||
auto tokens = SplitString(vocab, "\n", true);
|
||||
|
||||
vocab_ = std::make_shared<std::unordered_map<ustring, int32_t>>();
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
(*vocab_)[ustring(tokens[i])] = i;
|
||||
}
|
||||
|
||||
if (do_basic_tokenize) {
|
||||
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
|
||||
}
|
||||
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
|
||||
|
||||
unk_token_id_ = FindSpecialToken(unk_token);
|
||||
sep_token_id_ = FindSpecialToken(sep_token);
|
||||
pad_token_id_ = FindSpecialToken(pad_token);
|
||||
cls_token_id_ = FindSpecialToken(cls_token);
|
||||
mask_token_id_ = FindSpecialToken(mask_token);
|
||||
}
|
||||
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
|
||||
if (do_basic_tokenize_) {
|
||||
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
|
||||
}
|
||||
return wordpiece_tokenizer_->Tokenize(text);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
|
||||
return wordpiece_tokenizer_->Encode(tokens);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids.size() + 2);
|
||||
result.push_back(cls_token_id_);
|
||||
result.insert(result.end(), ids.begin(), ids.end());
|
||||
result.push_back(sep_token_id_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids1.size() + ids2.size() + 3);
|
||||
result.push_back(cls_token_id_);
|
||||
result.insert(result.end(), ids1.begin(), ids1.end());
|
||||
result.push_back(sep_token_id_);
|
||||
result.insert(result.end(), ids2.begin(), ids2.end());
|
||||
result.push_back(sep_token_id_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
|
||||
return std::vector<int64_t>(ids.size() + 2, 0);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids1.size() + ids2.size() + 3);
|
||||
result.insert(result.end(), ids1.size() + 2, 0);
|
||||
result.insert(result.end(), ids2.size() + 1, 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
int32_t BertTokenizer::FindSpecialToken(ustring token) {
|
||||
auto it = vocab_->find(token);
|
||||
if (it == vocab_->end()) {
|
||||
ORT_CXX_API_THROW("[BertTokenizer]: can not find special tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
|
||||
if (strategy_name == "longest_first") {
|
||||
strategy_ = TruncateStrategyType::LONGEST_FIRST;
|
||||
} else if (strategy_name == "only_first") {
|
||||
strategy_ = TruncateStrategyType::ONLY_FIRST;
|
||||
} else if (strategy_name == "only_second") {
|
||||
strategy_ = TruncateStrategyType::ONLY_SECOND;
|
||||
} else if (strategy_name == "longest_from_back") {
|
||||
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
|
||||
}
|
||||
}
|
||||
|
||||
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
|
||||
if (max_len < 0 || max_len >= ids.size()) {
|
||||
return;
|
||||
|
@ -184,7 +132,6 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
|
|||
}
|
||||
|
||||
void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len) {
|
||||
|
||||
if (max_len < 0 || (input1.size() + input2.size() <= max_len)) {
|
||||
return;
|
||||
}
|
||||
|
@ -224,6 +171,77 @@ void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_
|
|||
}
|
||||
}
|
||||
|
||||
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
|
||||
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
|
||||
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
|
||||
|
||||
if (do_basic_tokenize) {
|
||||
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
|
||||
}
|
||||
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
|
||||
|
||||
unk_token_id_ = vocab_->FindTokenId(unk_token);
|
||||
sep_token_id_ = vocab_->FindTokenId(sep_token);
|
||||
pad_token_id_ = vocab_->FindTokenId(pad_token);
|
||||
cls_token_id_ = vocab_->FindTokenId(cls_token);
|
||||
mask_token_id_ = vocab_->FindTokenId(mask_token);
|
||||
}
|
||||
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
|
||||
if (do_basic_tokenize_) {
|
||||
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
|
||||
}
|
||||
return wordpiece_tokenizer_->Tokenize(text);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
|
||||
return wordpiece_tokenizer_->Encode(tokens);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids.size() + 2);
|
||||
result.push_back(cls_token_id_);
|
||||
result.insert(result.end(), ids.begin(), ids.end());
|
||||
result.push_back(sep_token_id_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids1.size() + ids2.size() + 3);
|
||||
result.push_back(cls_token_id_);
|
||||
result.insert(result.end(), ids1.begin(), ids1.end());
|
||||
result.push_back(sep_token_id_);
|
||||
result.insert(result.end(), ids2.begin(), ids2.end());
|
||||
result.push_back(sep_token_id_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
|
||||
return std::vector<int64_t>(ids.size() + 2, 0);
|
||||
}
|
||||
|
||||
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(ids1.size() + ids2.size() + 3);
|
||||
result.insert(result.end(), ids1.size() + 2, 0);
|
||||
result.insert(result.end(), ids2.size() + 1, 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
|
||||
if (strategy_name == "longest_first") {
|
||||
strategy_ = TruncateStrategyType::LONGEST_FIRST;
|
||||
} else if (strategy_name == "only_first") {
|
||||
strategy_ = TruncateStrategyType::ONLY_FIRST;
|
||||
} else if (strategy_name == "only_second") {
|
||||
strategy_ = TruncateStrategyType::ONLY_SECOND;
|
||||
} else if (strategy_name == "longest_from_back") {
|
||||
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
|
||||
}
|
||||
}
|
||||
|
||||
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
|
@ -239,9 +257,8 @@ KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info)
|
|||
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
|
||||
max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1));
|
||||
|
||||
|
||||
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
ustring(sep_token), ustring(pad_token),ustring(cls_token),
|
||||
ustring(sep_token), ustring(pad_token), ustring(cls_token),
|
||||
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
|
||||
|
||||
truncate_ = std::make_shared<TruncateStrategy>(truncation_strategy_name);
|
||||
|
@ -307,5 +324,3 @@ size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -11,51 +11,18 @@
|
|||
#include "string_tensor.h"
|
||||
#include "basic_tokenizer.hpp"
|
||||
|
||||
// TODO: merge with the implementation of word piece tokenizer
|
||||
class WordpieceTokenizer{
|
||||
class BertTokenizerVocab {
|
||||
public:
|
||||
WordpieceTokenizer(std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
explicit BertTokenizerVocab(std::string vocab);
|
||||
bool FindToken(const ustring& token);
|
||||
bool FindTokenId(const ustring& token, int32_t& token_id);
|
||||
int32_t FindTokenId(const ustring& token);
|
||||
|
||||
private:
|
||||
int64_t max_input_chars_per_word_;
|
||||
ustring suffix_indicator_;
|
||||
ustring unk_token_;
|
||||
int64_t unk_token_id_;
|
||||
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
|
||||
|
||||
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
|
||||
std::string raw_vocab_;
|
||||
std::unordered_map<std::string_view, int32_t> vocab_;
|
||||
};
|
||||
|
||||
class BertTokenizer {
|
||||
public:
|
||||
BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
|
||||
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
|
||||
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
|
||||
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
|
||||
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
|
||||
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
|
||||
private:
|
||||
int32_t unk_token_id_;
|
||||
int32_t sep_token_id_;
|
||||
int32_t pad_token_id_;
|
||||
int32_t cls_token_id_;
|
||||
int32_t mask_token_id_;
|
||||
bool do_basic_tokenize_;
|
||||
std::shared_ptr<std::unordered_map<ustring, int32_t>> vocab_;
|
||||
std::shared_ptr<BasicTokenizer> basic_tokenizer_;
|
||||
std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
|
||||
|
||||
int32_t FindSpecialToken(ustring token);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TruncateStrategy {
|
||||
public:
|
||||
explicit TruncateStrategy(std::string strategy_name);
|
||||
|
@ -63,18 +30,61 @@ class TruncateStrategy {
|
|||
void Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len);
|
||||
|
||||
private:
|
||||
enum TruncateStrategyType{
|
||||
enum TruncateStrategyType {
|
||||
LONGEST_FIRST,
|
||||
ONLY_FIRST,
|
||||
ONLY_SECOND,
|
||||
LONGEST_FROM_BACK
|
||||
}strategy_;
|
||||
} strategy_;
|
||||
};
|
||||
|
||||
// TODO: merge with the implementation of word piece tokenizer
|
||||
class WordpieceTokenizer {
|
||||
public:
|
||||
WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
|
||||
private:
|
||||
int64_t max_input_chars_per_word_;
|
||||
ustring suffix_indicator_;
|
||||
ustring unk_token_;
|
||||
int32_t unk_token_id_;
|
||||
std::shared_ptr<BertTokenizerVocab> vocab_;
|
||||
|
||||
void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
|
||||
};
|
||||
|
||||
class BertTokenizer {
|
||||
public:
|
||||
BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
|
||||
ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
|
||||
ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
|
||||
ustring suffix_indicator);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
|
||||
std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
|
||||
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
|
||||
std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
|
||||
|
||||
private:
|
||||
int32_t unk_token_id_;
|
||||
int32_t sep_token_id_;
|
||||
int32_t pad_token_id_;
|
||||
int32_t cls_token_id_;
|
||||
int32_t mask_token_id_;
|
||||
bool do_basic_tokenize_;
|
||||
std::shared_ptr<BertTokenizerVocab> vocab_;
|
||||
std::shared_ptr<BasicTokenizer> basic_tokenizer_;
|
||||
std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
|
||||
};
|
||||
|
||||
struct KernelBertTokenizer : BaseKernel {
|
||||
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BertTokenizer> tokenizer_;
|
||||
std::shared_ptr<TruncateStrategy> truncate_;
|
||||
|
|
Загрузка…
Ссылка в новой задаче