Ignore all streaming output of invalid utf-8 string (#704)
* Ignore all streaming output of invalid utf-8 string * Update bpe_streaming.hpp * add the phi-3 tokenizer test * add a streaming test for phi-3 model * fix the utf-8 validation * fix the utf-8 validation 2 * fix the utf-8 validation 3 * fix the utf-8 validation 4
This commit is contained in:
Родитель
e645cdab8d
Коммит
c58c930739
|
@ -77,25 +77,57 @@ class ustring : public std::u32string {
|
|||
}
|
||||
|
||||
static bool ValidateUTF8(const std::string& data) {
|
||||
int cnt = 0;
|
||||
for (auto i = 0; i < data.size(); i++) {
|
||||
int x = data[i];
|
||||
if (!cnt) {
|
||||
if ((x >> 5) == 0b110) {
|
||||
cnt = 1;
|
||||
} else if ((x >> 4) == 0b1110) {
|
||||
cnt = 2;
|
||||
} else if ((x >> 3) == 0b11110) {
|
||||
cnt = 3;
|
||||
} else if ((x >> 7) != 0) {
|
||||
const unsigned char* s = reinterpret_cast<const unsigned char*>(data.c_str());
|
||||
const unsigned char* s_end = s + data.size();
|
||||
if (*s_end != '\0')
|
||||
return false;
|
||||
|
||||
while (*s) {
|
||||
if (*s < 0x80)
|
||||
/* 0xxxxxxx */
|
||||
s++;
|
||||
else if ((s[0] & 0xe0) == 0xc0) {
|
||||
/* 110XXXXx 10xxxxxx */
|
||||
if (s + 1 >= s_end) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if ((x >> 6) != 0b10) return false;
|
||||
cnt--;
|
||||
}
|
||||
if ((s[1] & 0xc0) != 0x80 ||
|
||||
(s[0] & 0xfe) == 0xc0) /* overlong? */
|
||||
return false;
|
||||
else
|
||||
s += 2;
|
||||
} else if ((s[0] & 0xf0) == 0xe0) {
|
||||
/* 1110XXXX 10Xxxxxx 10xxxxxx */
|
||||
if (s + 2 >= s_end) {
|
||||
return false;
|
||||
}
|
||||
if ((s[1] & 0xc0) != 0x80 ||
|
||||
(s[2] & 0xc0) != 0x80 ||
|
||||
(s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || /* overlong? */
|
||||
(s[0] == 0xed && (s[1] & 0xe0) == 0xa0) || /* surrogate? */
|
||||
(s[0] == 0xef && s[1] == 0xbf &&
|
||||
(s[2] & 0xfe) == 0xbe)) /* U+FFFE or U+FFFF? */
|
||||
return false;
|
||||
else
|
||||
s += 3;
|
||||
} else if ((s[0] & 0xf8) == 0xf0) {
|
||||
/* 11110XXX 10XXxxxx 10xxxxxx 10xxxxxx */
|
||||
if (s + 3 >= s_end) {
|
||||
return false;
|
||||
}
|
||||
if ((s[1] & 0xc0) != 0x80 ||
|
||||
(s[2] & 0xc0) != 0x80 ||
|
||||
(s[3] & 0xc0) != 0x80 ||
|
||||
(s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || /* overlong? */
|
||||
(s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) /* > U+10FFFF? */
|
||||
return false;
|
||||
else
|
||||
s += 4;
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
return cnt == 0;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -342,8 +342,9 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
|||
|
||||
size_t max_length = static_cast<size_t>(max_length_i64);
|
||||
// Parse input
|
||||
bool add_dummy_prefix = false;
|
||||
if (ModelName() == kModel_Llama) {
|
||||
input.insert(input.begin(), 0x2581);
|
||||
add_dummy_prefix = true;
|
||||
}
|
||||
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
|
||||
for (auto& seg_id : special_token_split_res) {
|
||||
|
@ -356,6 +357,10 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
|
|||
|
||||
// Note: keep ptr to make sure the string_view is valid in the following process
|
||||
std::u32string ustr(seg_id.first);
|
||||
if (add_dummy_prefix) {
|
||||
ustr.insert(ustr.begin(), 0x2581); // UTF-8 string '\xe2\x96\x81'
|
||||
add_dummy_prefix = false; // only add dummy prefix once
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
size_t char_pos = 0;
|
||||
|
|
|
@ -176,15 +176,21 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
|
|||
}
|
||||
} // end case of whitespace_token_
|
||||
|
||||
if (!bpe_state->incomplete_utf8_.empty()) {
|
||||
token = bpe_state->incomplete_utf8_ + token;
|
||||
bpe_state->incomplete_utf8_.clear();
|
||||
} else {
|
||||
if (!token.empty() && ustring::UTF8Len(token.front()) > token.size()) {
|
||||
bpe_state->incomplete_utf8_ = token;
|
||||
token = "";
|
||||
bpe_state->incomplete_utf8_ += token;
|
||||
token.clear();
|
||||
std::string& s_utf8 = bpe_state->incomplete_utf8_;
|
||||
size_t utf8_len = 1;
|
||||
size_t utf8_all_len = 0;
|
||||
for (size_t i = 0; i < s_utf8.size(); i += utf8_len) {
|
||||
utf8_len = ustring::UTF8Len(s_utf8[i]);
|
||||
if (utf8_len <= s_utf8.size() - i) {
|
||||
utf8_all_len += utf8_len;
|
||||
auto _t = s_utf8.substr(i, utf8_len);
|
||||
token += ustring::ValidateUTF8(_t) ? _t : "";
|
||||
}
|
||||
}
|
||||
|
||||
s_utf8 = s_utf8.substr(utf8_all_len);
|
||||
}
|
||||
|
||||
bpe_state->f_special_last = f_special;
|
||||
|
|
|
@ -312,7 +312,7 @@ class BpeModel {
|
|||
if (it != vocab_map_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return unk_id_;
|
||||
return bpe::kInvalidTokenId;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,130 @@
|
|||
{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32000": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<|assistant|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<|placeholder1|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32003": {
|
||||
"content": "<|placeholder2|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32004": {
|
||||
"content": "<|placeholder3|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32005": {
|
||||
"content": "<|placeholder4|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32006": {
|
||||
"content": "<|system|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32007": {
|
||||
"content": "<|end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32008": {
|
||||
"content": "<|placeholder5|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32009": {
|
||||
"content": "<|placeholder6|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32010": {
|
||||
"content": "<|user|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<s>",
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"legacy": false,
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"padding_side": "left",
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
"use_default_system_prompt": false
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
|
@ -65,7 +65,6 @@ TEST(CApiTest, StreamApiTest) {
|
|||
OrtxDispose(&tokenizer);
|
||||
}
|
||||
|
||||
|
||||
TEST(OrtxTokenizerTest, ClipTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/clip");
|
||||
|
@ -119,7 +118,6 @@ TEST(OrtxTokenizerTest, TicTokenTokenizer) {
|
|||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
|
||||
TEST(OrtxTokenizerTest, GemmaTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/gemma");
|
||||
|
@ -161,6 +159,41 @@ TEST(OrtxTokenizerTest, GemmaTokenizer) {
|
|||
EXPECT_EQ(out_text[1], input[1]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, Phi3Tokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/phi-3");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string_view> input = {
|
||||
"分析",
|
||||
" こんにちは", // an extra space at the beginning
|
||||
"<|user|>こんにちは。データ分析するにはなにをすればいい?<|end|><|assistant|>"};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_0 = {1, 29871, 30748, 233, 161, 147};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_1 = {1, 259, 30589, 30389, 30353, 30644, 30449};
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_2 = {1, 32010, 29871, 30589, 30389, 30353,
|
||||
30644, 30449, 30267, 30597, 30185, 30369, 30748, 233, 161, 147, 30427, 30332, 30353,
|
||||
30449, 30371, 30353, 30396, 30427, 30553, 31254, 30298, 30298, 30882, 32007, 32001};
|
||||
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(token_ids.size(), input.size());
|
||||
DumpTokenIds(token_ids);
|
||||
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
|
||||
EXPECT_EQ(token_ids[1], EXPECTED_IDS_1);
|
||||
EXPECT_EQ(token_ids[2], EXPECTED_IDS_2);
|
||||
|
||||
std::vector<std::string> out_text;
|
||||
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {
|
||||
EXPECTED_IDS_0, EXPECTED_IDS_1};
|
||||
status = tokenizer->Detokenize(token_ids_span, out_text);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(out_text[0], input[0]);
|
||||
EXPECT_EQ(out_text[1], input[1]);
|
||||
}
|
||||
|
||||
static const char* kPromptText = R"(```python
|
||||
def print_prime(n):
|
||||
"""
|
||||
|
@ -224,7 +257,7 @@ TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
|
|||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
|
||||
token_ids[0] = {921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
|
||||
token_ids[0] = {564, 921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
std::string token;
|
||||
status = tokenizer->Id2Token(token_id, token, decoder_cache);
|
||||
|
@ -269,3 +302,39 @@ TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
|
|||
// std::cout << "\"" << std::endl;
|
||||
EXPECT_EQ(std::string(text), std::string(input[0]) + " "); // from the extra byte token */
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
|
||||
// test the llama2 tokenizer with BPE class, instead of sentencepiece wrapper.
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/phi-3");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_TRUE(tokenizer != nullptr);
|
||||
|
||||
std::vector<std::string_view> input = {
|
||||
R"(こんにちは。データ分析にはいくつかのステップがあります。まずは目的を明確にします。次に、データを収集し、クリーニングを行い ます。その後、データを構造化し、その後、データを分析します。これらのステップを実行することで、データを有意的に分析することができます。)"};
|
||||
std::vector<std::vector<extTokenId_t>>
|
||||
token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// Add an extra byte token for decoding tests
|
||||
token_ids[0].push_back(35); // <0x20>
|
||||
DumpTokenIds(token_ids);
|
||||
|
||||
std::string text;
|
||||
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
|
||||
// std::cout << "\"";
|
||||
for (const auto& token_id : token_ids[0]) {
|
||||
std::string token;
|
||||
auto status = tokenizer->Id2Token(token_id, token, decoder_cache);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
// std::cout << token;
|
||||
text.append(token);
|
||||
}
|
||||
|
||||
// std::cout << "\"" << std::endl;
|
||||
EXPECT_EQ(std::string(text), std::string(input[0]) + " "); // from the extra byte token */
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ include(FetchContent)
|
|||
FetchContent_Declare(
|
||||
ortx
|
||||
GIT_REPOSITORY https://github.com/microsoft/onnxruntime-extensions.git
|
||||
GIT_TAG a7043c56e4f19c4bf11642d390f7b502f80a34ba)
|
||||
GIT_TAG main)
|
||||
|
||||
set(OCOS_BUILD_PRESET token_api_only)
|
||||
FetchContent_MakeAvailable(ortx)
|
||||
|
|
Загрузка…
Ссылка в новой задаче