Ignore all streaming output of invalid utf-8 string (#704)

* Ignore all streaming output of invalid utf-8 string

* Update bpe_streaming.hpp

* add the phi-3 tokenizer test

* add a streaming test for phi-3 model

* fix the utf-8 validation

* fix the utf-8 validation 2

* fix the utf-8 validation 3

* fix the utf-8 validation 4
This commit is contained in:
Wenbing Li 2024-05-06 16:46:55 -07:00 коммит произвёл GitHub
Родитель e645cdab8d
Коммит c58c930739
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 93734 добавлений и 30 удалений

Просмотреть файл

@ -77,25 +77,57 @@ class ustring : public std::u32string {
}
static bool ValidateUTF8(const std::string& data) {
int cnt = 0;
for (auto i = 0; i < data.size(); i++) {
int x = data[i];
if (!cnt) {
if ((x >> 5) == 0b110) {
cnt = 1;
} else if ((x >> 4) == 0b1110) {
cnt = 2;
} else if ((x >> 3) == 0b11110) {
cnt = 3;
} else if ((x >> 7) != 0) {
const unsigned char* s = reinterpret_cast<const unsigned char*>(data.c_str());
const unsigned char* s_end = s + data.size();
if (*s_end != '\0')
return false;
while (*s) {
if (*s < 0x80)
/* 0xxxxxxx */
s++;
else if ((s[0] & 0xe0) == 0xc0) {
/* 110XXXXx 10xxxxxx */
if (s + 1 >= s_end) {
return false;
}
} else {
if ((x >> 6) != 0b10) return false;
cnt--;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[0] & 0xfe) == 0xc0) /* overlong? */
return false;
else
s += 2;
} else if ((s[0] & 0xf0) == 0xe0) {
/* 1110XXXX 10Xxxxxx 10xxxxxx */
if (s + 2 >= s_end) {
return false;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[2] & 0xc0) != 0x80 ||
(s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || /* overlong? */
(s[0] == 0xed && (s[1] & 0xe0) == 0xa0) || /* surrogate? */
(s[0] == 0xef && s[1] == 0xbf &&
(s[2] & 0xfe) == 0xbe)) /* U+FFFE or U+FFFF? */
return false;
else
s += 3;
} else if ((s[0] & 0xf8) == 0xf0) {
/* 11110XXX 10XXxxxx 10xxxxxx 10xxxxxx */
if (s + 3 >= s_end) {
return false;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[2] & 0xc0) != 0x80 ||
(s[3] & 0xc0) != 0x80 ||
(s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || /* overlong? */
(s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) /* > U+10FFFF? */
return false;
else
s += 4;
} else
return false;
}
return cnt == 0;
return true;
}
private:

Просмотреть файл

@ -342,8 +342,9 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
size_t max_length = static_cast<size_t>(max_length_i64);
// Parse input
bool add_dummy_prefix = false;
if (ModelName() == kModel_Llama) {
input.insert(input.begin(), 0x2581);
add_dummy_prefix = true;
}
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
for (auto& seg_id : special_token_split_res) {
@ -356,6 +357,10 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
// Note: keep ptr to make sure the string_view is valid in the following process
std::u32string ustr(seg_id.first);
if (add_dummy_prefix) {
ustr.insert(ustr.begin(), 0x2581); // UTF-8 string '\xe2\x96\x81'
add_dummy_prefix = false; // only add dummy prefix once
}
size_t offset = 0;
size_t char_pos = 0;

Просмотреть файл

@ -176,15 +176,21 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
}
} // end case of whitespace_token_
if (!bpe_state->incomplete_utf8_.empty()) {
token = bpe_state->incomplete_utf8_ + token;
bpe_state->incomplete_utf8_.clear();
} else {
if (!token.empty() && ustring::UTF8Len(token.front()) > token.size()) {
bpe_state->incomplete_utf8_ = token;
token = "";
bpe_state->incomplete_utf8_ += token;
token.clear();
std::string& s_utf8 = bpe_state->incomplete_utf8_;
size_t utf8_len = 1;
size_t utf8_all_len = 0;
for (size_t i = 0; i < s_utf8.size(); i += utf8_len) {
utf8_len = ustring::UTF8Len(s_utf8[i]);
if (utf8_len <= s_utf8.size() - i) {
utf8_all_len += utf8_len;
auto _t = s_utf8.substr(i, utf8_len);
token += ustring::ValidateUTF8(_t) ? _t : "";
}
}
s_utf8 = s_utf8.substr(utf8_all_len);
}
bpe_state->f_special_last = f_special;

Просмотреть файл

@ -312,7 +312,7 @@ class BpeModel {
if (it != vocab_map_.end()) {
return it->second;
} else {
return unk_id_;
return bpe::kInvalidTokenId;
}
}

93462
test/data/phi-3/tokenizer.json Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,130 @@
{
"add_bos_token": true,
"add_eos_token": false,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32001": {
"content": "<|assistant|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32002": {
"content": "<|placeholder1|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<|placeholder2|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32004": {
"content": "<|placeholder3|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32005": {
"content": "<|placeholder4|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32006": {
"content": "<|system|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32007": {
"content": "<|end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32008": {
"content": "<|placeholder5|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32009": {
"content": "<|placeholder6|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32010": {
"content": "<|user|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<s>",
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"clean_up_tokenization_spaces": false,
"eos_token": "<|endoftext|>",
"legacy": false,
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"padding_side": "left",
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
"use_default_system_prompt": false
}

Просмотреть файл

@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <filesystem>
@ -65,7 +65,6 @@ TEST(CApiTest, StreamApiTest) {
OrtxDispose(&tokenizer);
}
TEST(OrtxTokenizerTest, ClipTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/clip");
@ -119,7 +118,6 @@ TEST(OrtxTokenizerTest, TicTokenTokenizer) {
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerTest, GemmaTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/gemma");
@ -161,6 +159,41 @@ TEST(OrtxTokenizerTest, GemmaTokenizer) {
EXPECT_EQ(out_text[1], input[1]);
}
TEST(OrtxTokenizerTest, Phi3Tokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/phi-3");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
std::vector<std::string_view> input = {
"分析",
" こんにちは", // an extra space at the beginning
"<|user|>こんにちは。データ分析するにはなにをすればいい?<|end|><|assistant|>"};
std::vector<extTokenId_t> EXPECTED_IDS_0 = {1, 29871, 30748, 233, 161, 147};
std::vector<extTokenId_t> EXPECTED_IDS_1 = {1, 259, 30589, 30389, 30353, 30644, 30449};
std::vector<extTokenId_t> EXPECTED_IDS_2 = {1, 32010, 29871, 30589, 30389, 30353,
30644, 30449, 30267, 30597, 30185, 30369, 30748, 233, 161, 147, 30427, 30332, 30353,
30449, 30371, 30353, 30396, 30427, 30553, 31254, 30298, 30298, 30882, 32007, 32001};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(token_ids.size(), input.size());
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
EXPECT_EQ(token_ids[1], EXPECTED_IDS_1);
EXPECT_EQ(token_ids[2], EXPECTED_IDS_2);
std::vector<std::string> out_text;
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {
EXPECTED_IDS_0, EXPECTED_IDS_1};
status = tokenizer->Detokenize(token_ids_span, out_text);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(out_text[0], input[0]);
EXPECT_EQ(out_text[1], input[1]);
}
static const char* kPromptText = R"(```python
def print_prime(n):
"""
@ -224,7 +257,7 @@ TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
token_ids[0] = {921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
token_ids[0] = {564, 921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
for (const auto& token_id : token_ids[0]) {
std::string token;
status = tokenizer->Id2Token(token_id, token, decoder_cache);
@ -269,3 +302,39 @@ TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
// std::cout << "\"" << std::endl;
EXPECT_EQ(std::string(text), std::string(input[0]) + " "); // from the extra byte token */
}
TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
// test the llama2 tokenizer with BPE class, instead of sentencepiece wrapper.
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/phi-3");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_TRUE(tokenizer != nullptr);
std::vector<std::string_view> input = {
R"(こんにちは。データ分析にはいくつかのステップがあります。まずは目的を明確にします。次に、データを収集し、クリーニングを行い ます。その後、データを構造化し、その後、データを分析します。これらのステップを実行することで、データを有意的に分析することができます。)"};
std::vector<std::vector<extTokenId_t>>
token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
// Add an extra byte token for decoding tests
token_ids[0].push_back(35); // <0x20>
DumpTokenIds(token_ids);
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
// std::cout << "\"";
for (const auto& token_id : token_ids[0]) {
std::string token;
auto status = tokenizer->Id2Token(token_id, token, decoder_cache);
EXPECT_TRUE(status.IsOk());
// std::cout << token;
text.append(token);
}
// std::cout << "\"" << std::endl;
EXPECT_EQ(std::string(text), std::string(input[0]) + " "); // from the extra byte token */
}

Просмотреть файл

@ -8,7 +8,7 @@ include(FetchContent)
FetchContent_Declare(
ortx
GIT_REPOSITORY https://github.com/microsoft/onnxruntime-extensions.git
GIT_TAG a7043c56e4f19c4bf11642d390f7b502f80a34ba)
GIT_TAG main)
set(OCOS_BUILD_PRESET token_api_only)
FetchContent_MakeAvailable(ortx)