optimize spm tokenizer for long text (#799)

* optimize spm tokenizer for long text

* refine the split logic

* re-trigger CI pipeline.
This commit is contained in:
Wenbing Li 2024-08-30 14:58:40 -07:00 коммит произвёл GitHub
Родитель 6f532376c9
Коммит b8b2ebfb85
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 53 добавлений и 21 удалений

Просмотреть файл

@ -416,7 +416,53 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
// Get byte encodings prior to performing BPE
std::list<std::pair<uint32_t, uint32_t>> byte_list;
while (res.size() < max_length && char_pos < ustr.length()) {
while (res.size() < max_length && char_pos <= ustr.length()) {
bool split_now = false;
if (char_pos == ustr.length()) {
split_now = true;
}
// temporary split logic, will be replaced regex based split after it is implemented
if (!split_now && byte_list.size() > 10) {
auto is_split_char = [](char32_t ch) {
return ch == U' ' || ch == U'\n' || ch == U'\r' || ch == U'';
};
if (!is_split_char(ustr[char_pos - 1]) && is_split_char(ustr[char_pos])) {
split_now = true;
}
// split immediately to avoid too long byte_list for extreme cases, which is slow.
if (!split_now && byte_list.size() > 100) {
split_now = true;
}
}
if (split_now) {
// Perform BPE
bbpe_tokenizer_->PerformBPE(byte_list);
// Add output to result
for (auto p : byte_list) {
if (res.size() >= max_length) {
break;
}
res.push_back(p.first);
if (compute_offset_mapping) {
offset_mapping.emplace_back(std::make_pair(
offset,
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
offset += ((size_t)p.second);
}
}
byte_list.clear();
}
if (char_pos == ustr.length()) {
break;
}
auto chr = ustr[char_pos];
if (chr == U' ') {
chr = 0x2581; // UTF-8 string '\xe2\x96\x81'
@ -436,26 +482,6 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
char_pos++;
}
{
// Perform BPE
bbpe_tokenizer_->PerformBPE(byte_list);
// Add output to result
for (auto p : byte_list) {
if (res.size() >= max_length) {
break;
}
res.push_back(p.first);
if (compute_offset_mapping) {
offset_mapping.emplace_back(std::make_pair(
offset,
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
offset += ((size_t)p.second);
}
}
}
if (compute_offset_mapping) {
// Add offset mappings for input in this instance to list of offset mappings for all inputs
@ -463,6 +489,12 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
}
}
if (res.size() > 0 && res.front() == bos_token_id_) {
if (add_bos_token_.has_value() && add_bos_token_.value() == false) {
res.erase(res.begin());
}
}
return res;
}