Add initial tiktoken and Phi3SmallTokenizer support (#729)

* add initial tiktoken support

* add vector hash and equal for bpe ranks map

* change lambda comparator

* move phi-3-small files

* final changes

* move tiktoken files from data2 to data

* add unit test

* add tokenizer module

* merge json and tiktoken impl

* fix tiktoken encoding problem

* address comments

* remove dummy tokens

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Sayan Shaw 2024-08-02 10:24:02 -07:00 коммит произвёл GitHub
Родитель 46998e96fb
Коммит 7851b51ee3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
10 изменённых файлов: 100631 добавлений и 19 удалений

Просмотреть файл

@ -30,7 +30,23 @@ class TokenJsonConfig final {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + file_path.string());
}
vocab_path_ = (path(json_path) / "tokenizer.json").string();
auto vocab_file_path = path(json_path) / "tokenizer.json";
vocab_path_ = vocab_file_path.string();
std::ifstream vocab_fs = vocab_file_path.open();
if (!vocab_fs.is_open()) {
// No tokenizer.json file present; search for tokenizer module file
auto module_file_path = path(json_path) / "tokenizer_module.json";
module_path_ = module_file_path.string();
std::ifstream tok_module_ifs = module_file_path.open();
if (!tok_module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "No tokenizer.json or tokenizer_module.json file found.");
} else {
nlohmann::json tok_module_json_config = nlohmann::json::parse(tok_module_ifs);
auto tiktoken_path = tok_module_json_config.value("tiktoken_file", "");
vocab_file_path = path(json_path) / tiktoken_path.c_str();
vocab_path_ = vocab_file_path.string();
}
}
nlohmann::json json_config = nlohmann::json::parse(ifs);
add_bos_token_ = json_config.value("add_bos_token", false);
add_eos_token_ = json_config.value("add_eos_token", false);
@ -66,6 +82,10 @@ class TokenJsonConfig final {
const std::string& GetVocabDataFile() const { return vocab_path_; }
const std::string& GetTikTokenModuleFile() const {
return module_path_;
}
public:
bool add_bos_token_{};
bool add_eos_token_{};
@ -80,6 +100,7 @@ class TokenJsonConfig final {
private:
std::string vocab_path_;
std::string module_path_;
};
} // namespace ort_extensions::bpe

Просмотреть файл

@ -8,6 +8,8 @@
#include "bpe_json.hpp"
#include "bpe_tokenizer.hpp"
#include "base64.h"
#include <optional>
#include <limits>
@ -552,6 +554,50 @@ SpmTokenizer::SpmTokenizer()
JsonFastTokenizer::JsonFastTokenizer() : KernelBpeTokenizer(kGPT2Configuration) {}
/*
Read more here: https://github.com/huggingface/transformers/blob/60bb571e993b7d73257fb64044726b569fef9403/src/transformers/convert_slow_tokenizer.py#L1454
Note: this is similar to the BPE CreateByteEncoder, however for decoding the .tiktoken bytes
we need to store the strings rather than their IDs, and thereby need a separate map.
*/
void JsonFastTokenizer::CreateUnicodeByteEncoder() {
char32_t index = 256;
for (char32_t i = 0; i < 256; ++i) {
if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) {
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(index++);
} else {
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(i);
}
}
}
std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
std::string result;
for (auto c : bytes) {
result += unicode_byte_encoder_[static_cast<unsigned char>(c)];
}
return result;
}
// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
@ -559,6 +605,89 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
}
// consider to use SAX parser for large json file
nlohmann::json tok_json;
std::ifstream module_ifs;
// Following vocab and merges only used for tiktoken case but accessed outside scope below
std::unordered_map<std::string, uint32_t> vocab;
std::vector<std::pair<std::string, std::string>> merges;
if (tiktoken_){
std::string module_file = config.GetTikTokenModuleFile();
module_ifs = path(module_file).open();
if (!module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
}
std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;
std::string line;
while (std::getline(ifs, line)) {
if (!line.empty()) {
std::istringstream lineStream(line);
std::string token;
uint32_t rank;
while (lineStream >> token >> rank) {
// Decode base64 token and convert rank to int
std::vector<uint8_t> decoded_token;
base64_decode(token, decoded_token);
// Store bpe token and rank
bpe_ranks[decoded_token] = rank;
}
}
}
std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> byte_merges;
bbpe_tokenizer_ = std::make_unique<BpeModel>();
JsonFastTokenizer::CreateUnicodeByteEncoder();
for (const auto& item : bpe_ranks) {
std::vector<uint8_t> token = item.first;
uint32_t rank = item.second;
vocab[JsonFastTokenizer::TokenBytesToString(token)] = rank;
if (token.size() == 1) {
continue;
}
std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> local;
for (size_t index = 1; index < token.size(); index++) {
std::vector<uint8_t> piece_l(token.begin(), token.begin() + index);
std::vector<uint8_t> piece_r(token.begin() + index, token.end());
if (bpe_ranks.count(piece_l) && bpe_ranks.count(piece_r)) {
local.emplace_back(piece_l, piece_r, rank);
}
}
auto compare_bpe_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
// Compare comparator based on the ranks in bpe_ranks
return bpe_ranks[std::get<0>(a)] < bpe_ranks[std::get<0>(b)] ||
(bpe_ranks[std::get<0>(a)] == bpe_ranks[std::get<0>(b)] && bpe_ranks[std::get<1>(a)] < bpe_ranks[std::get<1>(b)]);
};
std::sort(local.begin(), local.end(), compare_bpe_tuples);
byte_merges.insert(byte_merges.end(), local.begin(), local.end());
}
// Custom comparator that compares the third element of the tuples
auto compare_merge_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
return std::get<2>(a) < std::get<2>(b);
};
std::sort(byte_merges.begin(), byte_merges.end(), compare_merge_tuples);
// Populate merges
for (auto& val : byte_merges) {
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
}
}
const char token_sub[] = "Tokenizer";
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
json_conf_.name_ = model_name_.c_str();
@ -570,18 +699,27 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
// re-bind the configuration object
bpe_conf_ = json_conf_;
// consider to use SAX parser for large json file
nlohmann::json tok_json;
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}
OrtxStatus status;
if (tiktoken_){
status = bbpe_tokenizer_->Load(vocab,
merges,
bpe_conf_.get().GetSpecialTokens().c_str(),
false);
bbpe_tokenizer_ = std::make_unique<BpeModel>();
auto status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
module_ifs >> tok_json;
} else {
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}
bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
}
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
@ -640,4 +778,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
}

Просмотреть файл

@ -107,6 +107,10 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : KernelBpeTokenizer {
public:
JsonFastTokenizer();
bool tiktoken_ = false;
std::string unicode_byte_encoder_[256] = {};
void CreateUnicodeByteEncoder();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
@ -121,3 +125,23 @@ class JsonFastTokenizer : KernelBpeTokenizer {
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};
class TikTokenizer : KernelBpeTokenizer {
public:
TikTokenizer();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
private:
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};

Просмотреть файл

@ -194,6 +194,47 @@ class BpeModel {
return {};
}
OrtxStatus Load(std::unordered_map<std::string, uint32_t>& vocab,
std::vector<std::pair<std::string, std::string>>& merges,
const char* /* special_tokens */,
bool spm_converted) {
vocab_map_ = vocab;
if (spm_converted) {
UpdateSpmByteToken(vocab_map_);
} else {
CreateByteEncoder();
}
uint32_t index = 0;
for (auto& tuple : merges){
std::string w1 = tuple.first;
std::string w2 = tuple.second;
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
token_length -= 4;
}
auto iw1 = GetTokenId(w1);
auto iw2 = GetTokenId(w2);
auto iww = GetTokenId(w1 + w2);
BpeNode value{iww, index++, token_length};
bpe_rank_[GetRankKey(iw1, iw2)] = value;
}
id2token_map_.resize(vocab_map_.size());
for (const auto& [t, i] : vocab_map_) {
if (i > static_cast<uint32_t>((std::numeric_limits<int32_t>::max)())) {
continue; // safe purpose.
}
if (i > id2token_map_.size()) {
id2token_map_.resize(static_cast<size_t>(i) + 1);
}
id2token_map_[i] = t;
}
return {};
}
OrtxStatus LoadAddedTokens(const char* added_tokens) {
int id = bpe::kInvalidTokenId;
std::istringstream strm_tokens(added_tokens);

Просмотреть файл

@ -18,12 +18,24 @@ OrtxStatus TokenizerImpl::Load(const std::string& dir) {
return status;
}
auto vocab_file_path = path(dir) / "tokenizer.json";
std::ifstream vocab_fs = vocab_file_path.open();
tokenizer_ = std::make_unique<JsonFastTokenizer>();
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
if (!vocab_fs.is_open()) {
// No tokenizer.json file present; use TikToken tokenizer
tokenizer_->tiktoken_ = true;
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
} else {
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
}
}
return status;
@ -34,7 +46,7 @@ OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input
for (const auto& s : input) {
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
auto status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
OrtxStatus status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
if (!status.IsOk()) {
return status;

Просмотреть файл

@ -50,6 +50,7 @@ class TokenizerImpl : public OrtxObjectImpl {
std::vector<std::vector<extTokenId_t>>& t_ids) const;
private:
bool tiktoken = false;
std::string tokenizer_dir_;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
std::unique_ptr<JsonFastTokenizer> tokenizer_;

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,16 @@
{
"added_tokens_decoder": {},
"auto_map": {
"AutoTokenizer": [
"tokenization_phi3_small.Phi3SmallTokenizer",
"tokenization_phi3_small.Phi3SmallTokenizer"
]
},
"bos_token": "<|endoftext|>",
"chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
"clean_up_tokenization_spaces": true,
"eos_token": "<|endoftext|>",
"model_max_length": 8192,
"pad_token": "<|endoftext|>",
"tokenizer_class": "Phi3SmallTokenizer"
}

Просмотреть файл

@ -0,0 +1,77 @@
{
"tiktoken_file": "cl100k_base.tiktoken",
"added_tokens": [
{
"id": 100257,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100258,
"content": "<|fim_prefix|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100259,
"content": "<|fim_middle|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100260,
"content": "<|fim_suffix|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100261,
"content": "<|system|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100262,
"content": "<|user|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100263,
"content": "<|assistant|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 100276,
"content": "<|endofprompt|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
]
}

Просмотреть файл

@ -151,6 +151,32 @@ TEST(OrtxTokenizerTest, Phi3_S_Tokenizer) {
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerTest, Phi3_Small_Tokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/phi-3-small");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_NE(tokenizer, nullptr);
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13};
std::vector<std::string_view> input = {
"This is a test.",
"the second one",
"I like walking my cute dog\n and\x17 then",
"Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
std::vector<std::vector<extTokenId_t>>
token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids.size(), input.size());
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
}
TEST(OrtxTokenizerTest, GemmaTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/gemma");