This commit is contained in:
Wenbing Li 2024-09-18 23:08:23 +00:00
Родитель a10cac9e42
Коммит 7459afac24
3 изменённых файлов: 121 добавлений и 1000318 удалений

Просмотреть файл

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// The implementation is inspired by llama.cpp ugm tokenizer and huggingface FastTokenizer
#pragma once
@ -35,10 +36,10 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
auto id = token.value()["id"].get<extTokenId_t>();
bool is_special = token.value()["special"].get<bool>();
if (is_special) {
special_token_ids.insert(id);
special_token_ids_.insert(id);
}
auto word = token.value()["content"].get<std::string>();
user_defined_token_matcher.Add(word, 0, id);
user_defined_token_matcher_.Add(word, 0, id);
}
}
return {};
@ -50,41 +51,61 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
auto iter = normalizer->find("precompiled_charsmap");
if (iter != normalizer->end()) {
auto charsmap = iter->get<std::string>();
if (!base64_decode(charsmap, charsmap_data)) {
if (!base64_decode(charsmap, charsmap_data_)) {
return OrtxStatus(extError_t::kOrtxErrorCorruptData, "Failed to decode charsmap.");
}
// std::cout << "charsmap size: " << charsmap_data.size() << std::endl;
// for (size_t i = 0; i < charsmap_data.size() && i < 100; ++i) {
// std::cout << int(charsmap_data[i]) << " ";
// std::cout << "charsmap size: " << charsmap_data_.size() << std::endl;
// for (size_t i = 0; i < charsmap_data_.size() && i < 100; ++i) {
// std::cout << int(charsmap_data_[i]) << " ";
// }
size_t charsmap_offset = 0;
// First four bytes of precompiled_charsmap contains length of binary
// blob containing XOR-compressed compact double array (XCDA) entries
uint32_t xcda_blob_size = *(const uint32_t*)&charsmap_data[0];
uint32_t xcda_blob_size = *(const uint32_t*)&charsmap_data_[0];
charsmap_offset += sizeof(xcda_blob_size);
if (xcda_blob_size + charsmap_offset >= charsmap_data.size()) {
if (xcda_blob_size + charsmap_offset >= charsmap_data_.size()) {
return OrtxStatus(extError_t::kOrtxErrorCorruptData, "Index out of array bounds in precompiled charsmap!");
}
// Next xcda_blob_size bytes contain entries of XOR-compressed compact
// double array (XCDA). Each entry is bit-packed into a 32-bit integer.
xcda_array = (const uint32_t*)&charsmap_data[charsmap_offset];
xcda_array_size = xcda_blob_size / sizeof(uint32_t);
xcda_array_ = (const uint32_t*)&charsmap_data_[charsmap_offset];
xcda_array_size_ = xcda_blob_size / sizeof(uint32_t);
charsmap_offset += xcda_blob_size;
// Remaining bytes of precompiled charsmap contain null-terminated
// replacement strings for prefixes matched by the XCDA.
prefix_replacements = reinterpret_cast<const char*>(&charsmap_data[charsmap_offset]);
prefix_replacements_size = charsmap_data.size() - charsmap_offset;
prefix_replacements_ = reinterpret_cast<const char*>(&charsmap_data_[charsmap_offset]);
prefix_replacements_size_ = charsmap_data_.size() - charsmap_offset;
}
}
return {};
}
OrtxStatus LoadConfig(const json& config) {
auto pretokenizer_node = config.find("pretokenizer");
if (pretokenizer_node != config.end()) {
auto pretokenizers_node = pretokenizer_node->find("pretokenizers");
if (pretokenizers_node != pretokenizer_node->end()) {
for (const auto& pretokenizer : pretokenizers_node->items()) {
if (pretokenizer.value().contains("type")) {
auto type = pretokenizer.value()["type"].get<std::string>();
if (type == "Metaspace") {
tokenizer_escape_whitespaces_ = true;
}
}
if (pretokenizer.value().contains("add_prefix_space")) {
tokenizer_add_space_prefix_ = pretokenizer.value()["add_prefix_space"].get<bool>();
}
}
}
}
return {};
}
OrtxStatus Load(const bpe::TokenJsonConfig& config) {
ortx::path vocab_path(config.GetVocabDataFile());
@ -102,7 +123,12 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Failed to parse vocabulary file.");
}
OrtxStatus status = LoadCharsMap(j_vocab);
OrtxStatus status = LoadConfig(j_vocab);
if (!status.IsOk()) {
return status;
}
status = LoadCharsMap(j_vocab);
if (!status.IsOk()) {
return status;
}
@ -117,6 +143,11 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Model node not found in vocabulary file.");
}
auto unk_id_iter = model_node->find("unk_id");
if (unk_id_iter != model_node->end()) {
special_unk_id_ = unk_id_iter->get<extTokenId_t>();
}
auto vocab_node = model_node->find("vocab");
if (vocab_node == model_node->end()) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary not found in model node.");
@ -128,62 +159,23 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
auto score = entry.value()[1].get<double>();
vocab_[tkn] = std::make_tuple(id++, score);
}
scores_.resize(id);
double min_score = std::numeric_limits<double>::max();
for (const auto& entry : vocab_) {
scores_[std::get<0>(entry.second)] = std::get<1>(entry.second);
token_matcher.Add(entry.first, 0, std::get<0>(entry.second));
token_matcher_.Add(entry.first, 0, std::get<0>(entry.second));
min_score = std::min<double>(min_score, std::get<1>(entry.second));
}
// auto vocab = config["vocab"];
// if (vocab.precompiled_charsmap.size() > 0) {
// size_t charsmap_offset = 0;
// // First four bytes of precompiled_charsmap contains length of binary
// // blob containing XOR-compressed compact double array (XCDA) entries
// uint32_t xcda_blob_size = *(const uint32_t*)&vocab.precompiled_charsmap[0];
// charsmap_offset += sizeof(xcda_blob_size);
// if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
// throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
// }
// // Next xcda_blob_size bytes contain entries of XOR-compressed compact
// // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
// xcda_array = (const uint32_t*)&vocab.precompiled_charsmap[charsmap_offset];
// xcda_array_size = xcda_blob_size / sizeof(uint32_t);
// charsmap_offset += xcda_blob_size;
// // Remaining bytes of precompiled charsmap contain null-terminated
// // replacement strings for prefixes matched by the XCDA.
// prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
// prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
// }
// for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
// const auto& token_data = vocab.id_to_token[id];
// if (llama_is_normal_token(vocab, id)) {
// min_score = std::min<float>(min_score, token_data.score);
// max_score = std::max<float>(max_score, token_data.score);
// }
// if (llama_is_normal_token(vocab, id) || llama_is_user_defined_token(vocab, id) ||
// llama_is_unused_token(vocab, id)) {
// token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
// }
// if (llama_is_user_defined_token(vocab, id)) {
// user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
// }
// }
// unknown_token_score = min_score - unknown_token_score_penalty;
unknown_token_score_ = min_score - unknown_token_score_penalty_;
return status;
}
extTokenId_t GetTokenId(const std::string& token) const {
auto iter = vocab_.find(token);
if (iter == vocab_.end()) {
return special_unk_id;
return special_unk_id_;
}
return std::get<0>(iter->second);
}
@ -193,7 +185,6 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Input tensor must have rank 1.");
}
// normalize the input first
std::string normalized;
Normalize(input.AsScalar(), &normalized);
size_t input_len = normalized.size();
@ -201,70 +192,55 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
return {};
}
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {0, -FLT_MAX, special_unk_id});
// at the beginning tokenization score is zero
tokenization_results[0] = {0, 0, special_unk_id};
std::vector<struct BestTokenization> tokenization_results(input_len + 1, {0, -DBL_MAX, special_unk_id_});
tokenization_results[0] = {0, 0, special_unk_id_};
for (size_t input_offset = 0; input_offset < input_len;) {
size_t prefix_offset = input_offset;
// calculate how many code units are in the currently processed UTF code point
size_t n_utf8_code_units = std::min<size_t>(ustring::UTF8Len(normalized[input_offset]), input_len - input_offset);
// traverse the token matcher trie to find a matching token
bool single_codepoint_token_found = false;
const struct best_tokenization& current_best = tokenization_results[input_offset];
auto node = token_matcher.Find(normalized[prefix_offset++]);
const struct BestTokenization& current_best = tokenization_results[input_offset];
auto node = token_matcher_.Find(normalized[prefix_offset++]);
while (prefix_offset <= input_len && node != NULL) {
// check if we found valid token in prefix
if (node->HasValue()) {
// check if it corresponds to the whole UTF code point
if (prefix_offset - input_offset == n_utf8_code_units) {
single_codepoint_token_found = true;
}
extTokenId_t token_id = node->Value();
const auto& token_data = scores_[token_id];
// we set the user-defined token scores to 0 to make them more likely to be selected
// (normal token scores are log probabilities, so they are negative)
// score type is double here to make tokenization results exactly
// the same as in the HF tokenizer using SentencePiece
const double token_score = special_token_ids.count(token_id) > 0 ? 0.0 : token_data;
const double token_score = special_token_ids_.count(token_id) > 0 ? 0.0 : token_data;
const double challenger_score = current_best.score_sum + token_score;
struct best_tokenization& current_champ = tokenization_results[prefix_offset];
struct BestTokenization& current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = {input_offset, (float)challenger_score, token_id};
struct BestTokenization challenger = {input_offset, (float)challenger_score, token_id};
current_champ = challenger;
}
}
node = node->Find(normalized[prefix_offset++]);
}
// if we didn't find a valid token corresponding to the whole UTF code point
// then use unknown token as the tokenization of this UTF code point
if (!single_codepoint_token_found) {
const double challenger_score = current_best.score_sum + unknown_token_score;
const double challenger_score = current_best.score_sum + unknown_token_score_;
prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization& current_champ = tokenization_results[prefix_offset];
struct BestTokenization& current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = {input_offset, (float)challenger_score, special_unk_id};
struct BestTokenization challenger = {input_offset, (float)challenger_score, special_unk_id_};
current_champ = challenger;
}
}
// move to the next UTF code point
input_offset += n_utf8_code_units;
}
std::vector<extTokenId_t> output;
output.reserve(input_len);
// now backtrack from the end to gather token ids of the best tokenization
// merge sequences of consecutive unknown tokens into single unknown tokens
bool is_prev_unknown = false;
for (struct best_tokenization& tokenization = tokenization_results[input_len];;
for (struct BestTokenization& tokenization = tokenization_results[input_len];;
tokenization = tokenization_results[tokenization.input_offset]) {
bool is_unknown = tokenization.token_id == special_unk_id;
bool is_unknown = tokenization.token_id == special_unk_id_;
if (!(is_prev_unknown && is_unknown)) {
output.push_back(tokenization.token_id);
}
@ -274,18 +250,24 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
is_prev_unknown = is_unknown;
}
// // reverse the output since we added tokens starting from the end of the input
// std::reverse(output.begin() + output_size, output.end());
bool add_bos = GetTokenId(bos_token_) != special_unk_id_;
bool add_eos = GetTokenId(eos_token_) != special_unk_id_;
auto output_size = static_cast<int64_t>(output.size());
int64_t* id_output = tokenize_output.Allocate({output_size});
int64_t* id_output = tokenize_output.Allocate({output_size + add_bos + add_eos});
if (add_bos) {
*id_output = GetTokenId(bos_token_);
id_output++;
}
std::transform(output.begin(), output.end(), id_output, [](extTokenId_t id) { return static_cast<int64_t>(id); });
std::reverse(id_output, id_output + output_size);
if (add_eos) {
*(id_output + output_size) = GetTokenId(eos_token_);
}
return {};
}
private:
// helper structure for returning normalization results
struct normalization_result {
struct NormalizationResult {
const char* normalized;
size_t normalized_len;
size_t consumed_input;
@ -295,11 +277,11 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
normalized->clear();
normalized->reserve(input.size() * 3);
const std::string space = tokenizer_escape_whitespaces ? escaped_space : " ";
const std::string space = tokenizer_escape_whitespaces_ ? std::string(escaped_space_) : " ";
bool shall_prepend_space = !tokenizer_treat_whitespace_as_suffix && tokenizer_add_space_prefix;
bool shall_append_space = tokenizer_treat_whitespace_as_suffix && tokenizer_add_space_prefix;
bool shall_merge_spaces = tokenizer_remove_extra_whitespaces;
bool shall_prepend_space = !tokenizer_treat_whitespace_as_suffix_ && tokenizer_add_space_prefix_;
bool shall_append_space = tokenizer_treat_whitespace_as_suffix_ && tokenizer_add_space_prefix_;
bool shall_merge_spaces = tokenizer_remove_extra_whitespaces_;
bool is_space_prepended = false;
bool processing_non_ws = false;
@ -307,7 +289,7 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
size_t input_len = input.size();
for (size_t input_offset = 0; input_offset < input_len;) {
auto norm_res = normalize_prefix(input, input_offset);
auto norm_res = NormalizePrefix(input, input_offset);
for (size_t i = 0; i < norm_res.normalized_len; i++) {
char c = norm_res.normalized[i];
if (c != ' ') {
@ -349,7 +331,7 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
struct XcdaArrayView {
public:
XcdaArrayView(const uint32_t* xcda_array, size_t xcda_array_size)
: xcda_array(xcda_array), xcda_array_size(xcda_array_size) {}
: xcda_array_(xcda_array), xcda_array_size_(xcda_array_size) {}
uint32_t GetBase(size_t index) {
uint32_t packed_node = GetNode(index);
return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6);
@ -369,40 +351,34 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
private:
uint32_t GetNode(size_t index) {
if (index > xcda_array_size) {
throw std::runtime_error("Index out of array bounds in XCDA array!");
if (index > xcda_array_size_) {
ORTX_CXX_API_THROW("[UgmTok]Index out of array bounds in XCDA array!", ORT_RUNTIME_EXCEPTION);
}
return xcda_array[index];
return xcda_array_[index];
}
const uint32_t* xcda_array;
size_t xcda_array_size;
const uint32_t* xcda_array_;
size_t xcda_array_size_;
};
struct normalization_result normalize_prefix(const std::string& input, size_t input_offset) const {
struct NormalizationResult NormalizePrefix(const std::string& input, size_t input_offset) const {
if (input_offset == input.size()) {
return {&input[input_offset], 0, 0};
}
std::string prefix = input.substr(input_offset);
size_t prefix_off = 0;
// if input prefix matches some user-defined token return this token as normalization result
auto user_defined_token_match = user_defined_token_matcher.FindLongest(prefix, prefix_off);
if (user_defined_token_match != user_defined_token_matcher.kInvalidId_) {
auto user_defined_token_match = user_defined_token_matcher_.FindLongest(prefix, prefix_off);
if (user_defined_token_match != user_defined_token_matcher_.kInvalidId_) {
return {&input[input_offset], prefix_off + input_offset, prefix_off + input_offset};
}
size_t longest_prefix_length = 0;
size_t longest_prefix_offset = 0;
if (xcda_array_size > 0) {
XcdaArrayView xcda_view(xcda_array, xcda_array_size);
if (xcda_array_size_ > 0) {
XcdaArrayView xcda_view(xcda_array_, xcda_array_size_);
// Find the longest normalized sequence matching the input prefix by walking
// the XOR-compressed compact double array (XCDA) starting from the root node
// We find the index of the next node by calculating BASE[s] ^ c where s is
// the index of the previous node and c is a numerical character value
uint32_t node_index = 0;
// get BASE of the root node
node_index = xcda_view.GetBase(node_index);
for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
unsigned char c = input[prefix_offset];
@ -410,30 +386,23 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
break;
}
node_index ^= c;
// if value of LCHECK is not c it means that this is not a child of
// the previous node, so we stop matching
if (xcda_view.GetLcheck(node_index) != c) {
break;
}
bool is_leaf = xcda_view.IsLeaf(node_index);
// get BASE of the current node
node_index ^= xcda_view.GetBase(node_index);
// if LEAF of the current node is true, it means that its BASE points to the node
// containing index of replacement sequence for currently matched input prefix
if (is_leaf) {
longest_prefix_length = prefix_offset - input_offset + 1;
// get index of replacement sequence for currently matched input prefix
longest_prefix_offset = xcda_view.GetValue(node_index);
}
}
}
if (longest_prefix_length > 0) {
// we have a match, so return the replacement sequence
if (longest_prefix_offset >= prefix_replacements_size) {
throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
if (longest_prefix_offset >= prefix_replacements_size_) {
ORTX_CXX_API_THROW("[UgmTok]Index out of array bounds in precompiled charsmap!", ORT_RUNTIME_EXCEPTION);
}
const char* prefix_replacement = &prefix_replacements[longest_prefix_offset];
const char* prefix_replacement = &prefix_replacements_[longest_prefix_offset];
return {prefix_replacement, strlen(prefix_replacement), longest_prefix_length};
} else {
// if yes, return this sequence unmodified
@ -447,41 +416,41 @@ struct SpmUgmTokenizer : public TokenizerKernelBase {
}
// escaped space symbol - U+2581 (Lower One Eighth Block)
const std::string escaped_space = "\xE2\x96\x81";
static constexpr double unknown_token_score_penalty_ = 10.0;
static constexpr std::string_view escaped_space_ = "\xE2\x96\x81";
std::vector<uint8_t> charsmap_data;
const char* prefix_replacements = NULL;
size_t prefix_replacements_size = 0;
std::vector<uint8_t> charsmap_data_;
const char* prefix_replacements_ = NULL;
size_t prefix_replacements_size_ = 0;
const uint32_t* xcda_array = NULL;
size_t xcda_array_size = 0;
const uint32_t* xcda_array_ = NULL;
size_t xcda_array_size_ = 0;
VocabTrieTree user_defined_token_matcher;
VocabTrieTree user_defined_token_matcher_;
// this structure stores the best tokenization so far at input_offset
struct best_tokenization {
struct BestTokenization {
size_t input_offset;
float score_sum;
double score_sum;
extTokenId_t token_id;
};
float min_score = FLT_MAX;
float max_score = -FLT_MAX;
float unknown_token_score_penalty = 10.0;
float unknown_token_score;
extTokenId_t special_unk_id_ = -1;
double unknown_token_score_;
Vocab vocab_;
std::vector<double> scores_;
std::set<extTokenId_t> special_token_ids;
VocabTrieTree token_matcher;
std::set<extTokenId_t> special_token_ids_;
VocabTrieTree token_matcher_;
public:
bool tokenizer_escape_whitespaces = false;
bool tokenizer_treat_whitespace_as_suffix = false;
bool tokenizer_add_space_prefix = false;
bool tokenizer_remove_extra_whitespaces = false;
extTokenId_t special_unk_id = -1;
bool tokenizer_escape_whitespaces_ = true;
bool tokenizer_treat_whitespace_as_suffix_ = false;
bool tokenizer_add_space_prefix_ = true;
bool tokenizer_remove_extra_whitespaces_ = true;
std::string bos_token_ = "<s>";
std::string eos_token_ = "</s>";
std::string pad_token_ = "<pad>";
std::string unk_token_ = "<unk>";
};
} // namespace ort_extensions

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -450,14 +450,18 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizer, "data/tokenizer/fairseq/xlm-roberta-base");
EXPECT_EQ(tokenizer.Code(), kOrtxOK);
const char* input[] = {"hello world"};
const char* input[] = {"I like walking my cute dog\n and\x17 then, 生活的真谛是 \t\t\t\t \n\n61"};
OrtxObjectPtr<OrtxTokenId2DArray> token_ids;
OrtxTokenize(tokenizer.get(), input, 1, ort_extensions::ptr(token_ids));
EXPECT_EQ(token_ids.Code(), kOrtxOK);
size_t length = 0;
const extTokenId_t* ids = NULL;
const extTokenId_t* ids = nullptr;
OrtxTokenId2DArrayGetItem(token_ids.get(), 0, &ids, &length);
std::vector<extTokenId_t> ids_vec(ids, ids + length);
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({127, 13817, uint32_t(-1), 32554}));
// expected ids was generated using the following command:
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
}