optimize spm tokenizer for long text (#799)
* optimize spm tokenizer for long text * refine the split logic * re-trigger CI pipeline.
This commit is contained in:
Родитель
6f532376c9
Коммит
b8b2ebfb85
|
@ -416,7 +416,53 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
|||
// Get byte encodings prior to performing BPE
|
||||
std::list<std::pair<uint32_t, uint32_t>> byte_list;
|
||||
|
||||
while (res.size() < max_length && char_pos < ustr.length()) {
|
||||
while (res.size() < max_length && char_pos <= ustr.length()) {
|
||||
bool split_now = false;
|
||||
if (char_pos == ustr.length()) {
|
||||
split_now = true;
|
||||
}
|
||||
|
||||
// temporary split logic, will be replaced regex based split after it is implemented
|
||||
if (!split_now && byte_list.size() > 10) {
|
||||
auto is_split_char = [](char32_t ch) {
|
||||
return ch == U' ' || ch == U'\n' || ch == U'\r' || ch == U'▁';
|
||||
};
|
||||
if (!is_split_char(ustr[char_pos - 1]) && is_split_char(ustr[char_pos])) {
|
||||
split_now = true;
|
||||
}
|
||||
// split immediately to avoid too long byte_list for extreme cases, which is slow.
|
||||
if (!split_now && byte_list.size() > 100) {
|
||||
split_now = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (split_now) {
|
||||
// Perform BPE
|
||||
bbpe_tokenizer_->PerformBPE(byte_list);
|
||||
|
||||
// Add output to result
|
||||
for (auto p : byte_list) {
|
||||
if (res.size() >= max_length) {
|
||||
break;
|
||||
}
|
||||
|
||||
res.push_back(p.first);
|
||||
|
||||
if (compute_offset_mapping) {
|
||||
offset_mapping.emplace_back(std::make_pair(
|
||||
offset,
|
||||
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
|
||||
offset += ((size_t)p.second);
|
||||
}
|
||||
}
|
||||
|
||||
byte_list.clear();
|
||||
}
|
||||
|
||||
if (char_pos == ustr.length()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto chr = ustr[char_pos];
|
||||
if (chr == U' ') {
|
||||
chr = 0x2581; // UTF-8 string '\xe2\x96\x81'
|
||||
|
@ -436,26 +482,6 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
|||
|
||||
char_pos++;
|
||||
}
|
||||
{
|
||||
// Perform BPE
|
||||
bbpe_tokenizer_->PerformBPE(byte_list);
|
||||
|
||||
// Add output to result
|
||||
for (auto p : byte_list) {
|
||||
if (res.size() >= max_length) {
|
||||
break;
|
||||
}
|
||||
|
||||
res.push_back(p.first);
|
||||
|
||||
if (compute_offset_mapping) {
|
||||
offset_mapping.emplace_back(std::make_pair(
|
||||
offset,
|
||||
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
|
||||
offset += ((size_t)p.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (compute_offset_mapping) {
|
||||
// Add offset mappings for input in this instance to list of offset mappings for all inputs
|
||||
|
@ -463,6 +489,12 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
|||
}
|
||||
}
|
||||
|
||||
if (res.size() > 0 && res.front() == bos_token_id_) {
|
||||
if (add_bos_token_.has_value() && add_bos_token_.value() == false) {
|
||||
res.erase(res.begin());
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче