optimize spm tokenizer for long text (#799)
* optimize spm tokenizer for long text * refine the split logic * re-trigger CI pipeline.
This commit is contained in:
Родитель
6f532376c9
Коммит
b8b2ebfb85
|
@ -416,7 +416,53 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
||||||
// Get byte encodings prior to performing BPE
|
// Get byte encodings prior to performing BPE
|
||||||
std::list<std::pair<uint32_t, uint32_t>> byte_list;
|
std::list<std::pair<uint32_t, uint32_t>> byte_list;
|
||||||
|
|
||||||
while (res.size() < max_length && char_pos < ustr.length()) {
|
while (res.size() < max_length && char_pos <= ustr.length()) {
|
||||||
|
bool split_now = false;
|
||||||
|
if (char_pos == ustr.length()) {
|
||||||
|
split_now = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// temporary split logic, will be replaced regex based split after it is implemented
|
||||||
|
if (!split_now && byte_list.size() > 10) {
|
||||||
|
auto is_split_char = [](char32_t ch) {
|
||||||
|
return ch == U' ' || ch == U'\n' || ch == U'\r' || ch == U'▁';
|
||||||
|
};
|
||||||
|
if (!is_split_char(ustr[char_pos - 1]) && is_split_char(ustr[char_pos])) {
|
||||||
|
split_now = true;
|
||||||
|
}
|
||||||
|
// split immediately to avoid too long byte_list for extreme cases, which is slow.
|
||||||
|
if (!split_now && byte_list.size() > 100) {
|
||||||
|
split_now = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (split_now) {
|
||||||
|
// Perform BPE
|
||||||
|
bbpe_tokenizer_->PerformBPE(byte_list);
|
||||||
|
|
||||||
|
// Add output to result
|
||||||
|
for (auto p : byte_list) {
|
||||||
|
if (res.size() >= max_length) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
res.push_back(p.first);
|
||||||
|
|
||||||
|
if (compute_offset_mapping) {
|
||||||
|
offset_mapping.emplace_back(std::make_pair(
|
||||||
|
offset,
|
||||||
|
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
|
||||||
|
offset += ((size_t)p.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
byte_list.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (char_pos == ustr.length()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
auto chr = ustr[char_pos];
|
auto chr = ustr[char_pos];
|
||||||
if (chr == U' ') {
|
if (chr == U' ') {
|
||||||
chr = 0x2581; // UTF-8 string '\xe2\x96\x81'
|
chr = 0x2581; // UTF-8 string '\xe2\x96\x81'
|
||||||
|
@ -436,26 +482,6 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
||||||
|
|
||||||
char_pos++;
|
char_pos++;
|
||||||
}
|
}
|
||||||
{
|
|
||||||
// Perform BPE
|
|
||||||
bbpe_tokenizer_->PerformBPE(byte_list);
|
|
||||||
|
|
||||||
// Add output to result
|
|
||||||
for (auto p : byte_list) {
|
|
||||||
if (res.size() >= max_length) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
res.push_back(p.first);
|
|
||||||
|
|
||||||
if (compute_offset_mapping) {
|
|
||||||
offset_mapping.emplace_back(std::make_pair(
|
|
||||||
offset,
|
|
||||||
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
|
|
||||||
offset += ((size_t)p.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (compute_offset_mapping) {
|
if (compute_offset_mapping) {
|
||||||
// Add offset mappings for input in this instance to list of offset mappings for all inputs
|
// Add offset mappings for input in this instance to list of offset mappings for all inputs
|
||||||
|
@ -463,6 +489,12 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (res.size() > 0 && res.front() == bos_token_id_) {
|
||||||
|
if (add_bos_token_.has_value() && add_bos_token_.value() == false) {
|
||||||
|
res.erase(res.begin());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче