From b8b2ebfb852c13a2a42814d9d9154171982f85de Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:58:40 -0700 Subject: [PATCH] optimize spm tokenizer for long text (#799) * optimize spm tokenizer for long text * refine the split logic * re-trigger CI pipeline. --- operators/tokenizer/bpe_kernels.cc | 74 +++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/operators/tokenizer/bpe_kernels.cc b/operators/tokenizer/bpe_kernels.cc index 09f97377..52a13cf2 100644 --- a/operators/tokenizer/bpe_kernels.cc +++ b/operators/tokenizer/bpe_kernels.cc @@ -416,7 +416,53 @@ std::vector KernelBpeTokenizer::SpmTokenize(ustring& input, // Get byte encodings prior to performing BPE std::list> byte_list; - while (res.size() < max_length && char_pos < ustr.length()) { + while (res.size() < max_length && char_pos <= ustr.length()) { + bool split_now = false; + if (char_pos == ustr.length()) { + split_now = true; + } + + // temporary split logic, will be replaced regex based split after it is implemented + if (!split_now && byte_list.size() > 10) { + auto is_split_char = [](char32_t ch) { + return ch == U' ' || ch == U'\n' || ch == U'\r' || ch == U'▁'; + }; + if (!is_split_char(ustr[char_pos - 1]) && is_split_char(ustr[char_pos])) { + split_now = true; + } + // split immediately to avoid too long byte_list for extreme cases, which is slow. + if (!split_now && byte_list.size() > 100) { + split_now = true; + } + } + + if (split_now) { + // Perform BPE + bbpe_tokenizer_->PerformBPE(byte_list); + + // Add output to result + for (auto p : byte_list) { + if (res.size() >= max_length) { + break; + } + + res.push_back(p.first); + + if (compute_offset_mapping) { + offset_mapping.emplace_back(std::make_pair( + offset, + ort_extensions::narrow(offset + (size_t)p.second))); + offset += ((size_t)p.second); + } + } + + byte_list.clear(); + } + + if (char_pos == ustr.length()) { + break; + } + auto chr = ustr[char_pos]; if (chr == U' ') { chr = 0x2581; // UTF-8 string '\xe2\x96\x81' @@ -436,26 +482,6 @@ std::vector KernelBpeTokenizer::SpmTokenize(ustring& input, char_pos++; } - { - // Perform BPE - bbpe_tokenizer_->PerformBPE(byte_list); - - // Add output to result - for (auto p : byte_list) { - if (res.size() >= max_length) { - break; - } - - res.push_back(p.first); - - if (compute_offset_mapping) { - offset_mapping.emplace_back(std::make_pair( - offset, - ort_extensions::narrow(offset + (size_t)p.second))); - offset += ((size_t)p.second); - } - } - } if (compute_offset_mapping) { // Add offset mappings for input in this instance to list of offset mappings for all inputs @@ -463,6 +489,12 @@ std::vector KernelBpeTokenizer::SpmTokenize(ustring& input, } } + if (res.size() > 0 && res.front() == bos_token_id_) { + if (add_bos_token_.has_value() && add_bos_token_.value() == false) { + res.erase(res.begin()); + } + } + return res; }