Load the tokenizer data from the memory (#836)

This commit is contained in:
Wenbing Li 2024-11-09 10:15:21 -08:00 коммит произвёл GitHub
Родитель 14f280adf6
Коммит 3da0d3c929
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
11 изменённых файлов: 404 добавлений и 334741 удалений

Просмотреть файл

@ -13,6 +13,33 @@ typedef OrtxObject OrtxStringArray;
typedef OrtxObject OrtxTokenId2DArray;
typedef OrtxObject OrtxDetokenizerCache;
struct OrtxTokenizerBlob {
const char* config_json_blob;
const char* vocab_json_blob;
const char* token_module_blob;
const char* raw_model_blob;
const char* reserved_blob_1;
const size_t config_blob_len;
const size_t vocab_blob_len;
const size_t token_module_blob_len;
const size_t raw_model_blob_len;
const size_t reserved_blob_1_len;
#ifdef __cplusplus
OrtxTokenizerBlob(const std::string_view& config_json_blob,
const std::string_view& vocab_json_blob,
const std::string_view& token_module_blob,
const std::string_view& raw_model_blob)
: config_json_blob(config_json_blob.data()), vocab_json_blob(vocab_json_blob.data()),
token_module_blob(token_module_blob.data()), raw_model_blob(raw_model_blob.data()),
config_blob_len(config_json_blob.size()),
vocab_blob_len(vocab_json_blob.size()), token_module_blob_len(token_module_blob.size()),
raw_model_blob_len(raw_model_blob.size()), reserved_blob_1(nullptr),
reserved_blob_1_len(0) {}
#endif
};
#ifdef __cplusplus
extern "C" {
@ -26,6 +53,15 @@ extern "C" {
*/
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const char* tokenizer_path);
/** \brief Create a tokenizer object with the specified tokenizer blob
*
* \param tokenizer Pointer to store the created tokenizer object
* \param tokenizer_blob Pointer to the tokenizer blob
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, const struct OrtxTokenizerBlob* tokenizer_blob);
/** \brief Tokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object

Просмотреть файл

@ -650,53 +650,6 @@ std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
return result;
}
// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
AddedToken added_token;
added_token.id_ = token.value("id", 0);
added_token.token_type_ = token.value("__type", "");
added_token.content_ = token.value("content", "");
added_token.lstrip_ = token.value("lstrip", false);
added_token.normalized_ = token.value("normalized", false);
added_token.rstrip_ = token.value("rstrip", false);
added_token.single_word_ = token.value("single_word", false);
added_token.special_ = token.value("special", false);
added_tokens_.emplace_back(added_token);
if (added_token.content_ == config.bos_token_) {
bos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.eos_token_) {
eos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.pad_token_) {
pad_token_id_ = added_token.id_;
}
}
}
return bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
}
// Helper methods (to be added to the class declaration)
void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
auto decoder_node = tok_json.find("decoder");
@ -722,7 +675,29 @@ void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
}
}
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
void JsonFastTokenizer::UpdateTokenizer(const TokenJsonConfig& config, const json& tok_json) {
added_tokens_ = config.added_tokens_;
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
added_tokens_.emplace_back(TokenJsonConfig::ParseAddedToken(token));
}
}
for (const auto& added_token : added_tokens_) {
if (added_token.content_ == config.bos_token_) {
bos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.eos_token_) {
eos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.pad_token_) {
pad_token_id_ = added_token.id_;
}
}
bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
auto post_processor = tok_json.find("post_processor");
if (post_processor != tok_json.end()) {
@ -738,14 +713,14 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
}
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
std::unique_ptr<std::istream> vocab_stream;
auto status = config.OpenVocabFile(vocab_stream);
if (!status.IsOk()) {
return status;
}
nlohmann::json tok_json;
ifs >> tok_json;
*vocab_stream >> tok_json;
const char token_sub[] = "Tokenizer";
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
@ -767,30 +742,40 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config
}
bbpe_tokenizer_ = std::make_unique<BpeModel>();
OrtxStatus status = bbpe_tokenizer_->Load(*model_node,
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
if (!status.IsOk()) {
if (status.IsOk()) {
UpdateTokenizer(config, tok_json);
}
return status;
}
status = LoadAddedTokens(tok_json, config);
if (!status.IsOk()) {
return status;
// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;
UpdateTokenAdditionFlags(tok_json, config);
return status;
return seed;
}
};
// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
std::unique_ptr<std::istream> vocab_stream;
auto status = config.OpenVocabFile(vocab_stream);
if (!status.IsOk()) {
return status;
}
std::unordered_map<std::string, uint32_t> vocab;
@ -798,7 +783,7 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;
std::string line;
while (std::getline(ifs, line)) {
while (std::getline(*vocab_stream, line)) {
if (!line.empty()) {
std::istringstream lineStream(line);
std::string token;
@ -857,7 +842,8 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
// Populate merges
for (auto& val : byte_merges) {
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)),
JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
}
const char token_sub[] = "Tokenizer";
@ -871,32 +857,12 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
// re-bind the configuration object
bpe_conf_ = json_conf_;
OrtxStatus status = bbpe_tokenizer_->Load(vocab,
merges,
bpe_conf_.get().GetSpecialTokens().c_str(),
false);
status = bbpe_tokenizer_->Load(vocab, merges, bpe_conf_.get().GetSpecialTokens().c_str(), false);
if (!status.IsOk()) {
return status;
if (status.IsOk()) {
UpdateTokenizer(config, json());
}
std::string module_file = config.GetTikTokenModuleFile();
std::ifstream module_ifs = path(module_file).open();
if (!module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
}
nlohmann::json tok_json;
module_ifs >> tok_json;
status = LoadAddedTokens(tok_json, config);
if (!status.IsOk()) {
return status;
}
add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;
UpdateTokenAdditionFlags(tok_json, config);
return status;
}

Просмотреть файл

@ -128,8 +128,7 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
private:
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
void LoadSpmModelParams(const json& tok_json);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
void UpdateTokenizer(const ort_extensions::TokenJsonConfig& config, const json& tok_json);
BpeModelConf json_conf_;
std::vector<ort_extensions::AddedToken> added_tokens_;

Просмотреть файл

@ -258,12 +258,10 @@ class BpeModel {
return {};
}
OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
void LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
for (const auto& token : added_tokens) {
added_tokens_.Add(ustring(token.content_), 0, token.id_);
}
return {};
}
std::vector<std::string> BuildDecoder() const { return id2token_map_; }

Просмотреть файл

@ -21,61 +21,7 @@ class TokenJsonConfig final {
using json_pointer = nlohmann::json_pointer<std::string>;
public:
OrtxStatus Load(const std::string& json_path) {
if (json_path.empty()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
}
ortx::path tok_dir(json_path);
ortx::path vocab_path(json_path);
ortx::path tok_path_obj(json_path);
if (tok_path_obj.is_directory()) {
vocab_path = tok_dir / kDefaultVocabFile;
} else {
if (!tok_path_obj.exists()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tok_path_obj.string());
}
tok_dir = ortx::path(tok_path_obj.parent_path());
}
auto config_path = tok_dir / "tokenizer_config.json";
std::ifstream ifs = config_path.open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + config_path.string());
}
nlohmann::json json_config = nlohmann::json::parse(ifs);
auto module_cfg = tok_dir / "tokenizer_module.json";
if (module_cfg.exists()) {
module_path_ = module_cfg.string();
std::ifstream module_ifs = module_cfg.open();
nlohmann::json module_config = nlohmann::json::parse(module_ifs);
json_config.update(module_config);
}
model_max_length_ = json_config.value("model_max_length", 1e+30);
std::string tiktoken_file = json_config.value("tiktoken_file", "");
if (!tiktoken_file.empty()) {
auto tktok_path = tok_dir / tiktoken_file;
if (tktok_path.exists()) {
vocab_path_ = tktok_path.string();
} else {
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tiktoken_file);
}
} else {
if (ortx::path(vocab_path).exists()) {
vocab_path_ = vocab_path.string();
} else {
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + vocab_path.string());
}
}
tokenizer_class_ = json_config.value("tokenizer_class", "");
if (tokenizer_class_.empty()) {
return {};
}
OrtxStatus ParseTokensFromConfig(const json& json_config) {
add_bos_token_ = json_config.value("add_bos_token", false);
add_eos_token_ = json_config.value("add_eos_token", false);
clean_up_tokenization_spaces_ = json_config.value("clean_up_tokenization_spaces", false);
@ -101,9 +47,135 @@ class TokenJsonConfig final {
return {};
}
const std::string& GetVocabDataFile() const { return vocab_path_; }
OrtxStatus OpenVocabFile(std::unique_ptr<std::istream>& vocab_stream) const {
if (blob_ != nullptr) {
if (blob_->vocab_blob_len == 0) {
if (blob_->raw_model_blob_len == 0) {
return OrtxStatus(kOrtxErrorInvalidArgument, "vocab_blob_len and raw_model_blob_len are both 0.");
}
std::string vocab_str(blob_->raw_model_blob, blob_->raw_model_blob_len);
vocab_stream = std::make_unique<std::istringstream>(vocab_str);
} else {
if (blob_->raw_model_blob_len > 0) {
return OrtxStatus(kOrtxErrorInvalidArgument, "vocab_blob_len and raw_model_blob_len are both non-zero.");
}
std::string vocab_str(blob_->vocab_json_blob, blob_->vocab_blob_len);
vocab_stream = std::make_unique<std::istringstream>(vocab_str);
}
}
else {
auto ifs = std::make_unique<std::ifstream>(vocab_path_);
if (!ifs->is_open()) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, vocab_path_ + ": does not exist.");
}
vocab_stream = std::move(ifs);
}
const std::string& GetTikTokenModuleFile() const { return module_path_; }
return {};
}
OrtxStatus LoadFromBlob(const OrtxTokenizerBlob& blob) {
std::string config_str(blob.config_json_blob, blob.config_blob_len);
std::istringstream config_ifs(config_str);
json json_config = json::parse(config_ifs, nullptr, false, true);
if (json_config.is_discarded()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse config json.");
}
if (blob.token_module_blob_len > 0) {
std::string tokenizer_str(blob.token_module_blob, blob.token_module_blob_len);
std::istringstream tokenizer_ifs(tokenizer_str);
json json_tokenizer = json::parse(tokenizer_ifs, nullptr, false, true);
if (json_tokenizer.is_discarded()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse tokenizer json.");
}
LoadAddedTokens(json_tokenizer);
json_config.update(json_tokenizer);
}
blob_ = &blob;
model_max_length_ = json_config.value("model_max_length", 1e+30);
std::string tiktoken_file = json_config.value("tiktoken_file", "");
if (!tiktoken_file.empty()) {
if (blob.raw_model_blob_len == 0) {
return OrtxStatus(kOrtxErrorInvalidArgument, "missing tiktoken file content in blob.raw_model_blob.");
}
}
tokenizer_class_ = json_config.value("tokenizer_class", "");
if (!tokenizer_class_.empty()) {
return ParseTokensFromConfig(json_config);
}
return {};
}
OrtxStatus Load(const std::string& json_path) {
if (json_path.empty()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
}
ortx::path tok_dir(json_path);
ortx::path vocab_path(json_path);
ortx::path tok_path_obj(json_path);
if (tok_path_obj.is_directory()) {
vocab_path = tok_dir / kDefaultVocabFile;
} else {
if (!tok_path_obj.exists()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tok_path_obj.string());
}
tok_dir = ortx::path(tok_path_obj.parent_path());
}
auto config_path = tok_dir / "tokenizer_config.json";
std::ifstream ifs = config_path.open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + config_path.string());
}
json json_config = json::parse(ifs, nullptr, false, true);
if (json_config.is_discarded()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse config json.");
}
auto module_cfg = tok_dir / "tokenizer_module.json";
if (module_cfg.exists()) {
std::ifstream module_ifs = module_cfg.open();
json json_module = json::parse(module_ifs, nullptr, false, true);
if (json_module.is_discarded()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse tokenizer module json.");
}
LoadAddedTokens(json_module);
json_config.update(json_module);
}
model_max_length_ = json_config.value("model_max_length", 1e+30);
std::string tiktoken_file = json_config.value("tiktoken_file", "");
if (!tiktoken_file.empty()) {
auto tktok_path = tok_dir / tiktoken_file;
if (tktok_path.exists()) {
vocab_path_ = tktok_path.string();
} else {
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tiktoken_file);
}
} else {
if (ortx::path(vocab_path).exists()) {
vocab_path_ = vocab_path.string();
} else {
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + vocab_path.string());
}
}
tokenizer_class_ = json_config.value("tokenizer_class", "");
if (!tokenizer_class_.empty()) {
return ParseTokensFromConfig(json_config);
}
return {};
}
const std::string& GetVocabDataFile() const { return vocab_path_; }
public:
bool add_bos_token_{};
@ -117,9 +189,34 @@ class TokenJsonConfig final {
std::string unk_token_;
std::string pad_token_;
std::vector<ort_extensions::AddedToken> added_tokens_;
static AddedToken ParseAddedToken(const json& token) {
AddedToken added_token;
added_token.id_ = token.value("id", 0);
added_token.token_type_ = token.value("__type", "");
added_token.content_ = token.value("content", "");
added_token.lstrip_ = token.value("lstrip", false);
added_token.normalized_ = token.value("normalized", false);
added_token.rstrip_ = token.value("rstrip", false);
added_token.single_word_ = token.value("single_word", false);
added_token.special_ = token.value("special", false);
return added_token;
}
private:
void LoadAddedTokens(const json& tok_json) {
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
added_tokens_.emplace_back(ParseAddedToken(token));
}
}
}
std::string vocab_path_;
std::string module_path_;
const OrtxTokenizerBlob* blob_{nullptr};
};
} // namespace ort_extensions

Просмотреть файл

@ -113,22 +113,18 @@ struct SpmUgmTokenizer {
}
OrtxStatus Load(const TokenJsonConfig& config) {
ortx::path vocab_path(config.GetVocabDataFile());
if (!vocab_path.exists()) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary file does not exist.");
std::unique_ptr<std::istream> vocab_stream;
auto status = config.OpenVocabFile(vocab_stream);
if (!status.IsOk()) {
return status;
}
auto ifs = vocab_path.open();
if (!ifs.is_open()) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Failed to open vocabulary file.");
}
nlohmann::json j_vocab = json::parse(ifs, nullptr, false, true);
nlohmann::json j_vocab = json::parse(*vocab_stream, nullptr, false, true);
if (j_vocab.is_discarded()) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Failed to parse vocabulary file.");
}
OrtxStatus status = LoadConfig(j_vocab);
status = LoadConfig(j_vocab);
if (!status.IsOk()) {
return status;
}

Просмотреть файл

@ -46,6 +46,24 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch
return status.Code();
}
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, const OrtxTokenizerBlob* blob) {
// test if the tokenizer_path is a valid directory
if (blob == nullptr) {
ReturnableStatus::last_error_message_ = "The tokenizer blob is null";
return kOrtxErrorInvalidArgument;
}
ReturnableStatus status;
auto ptr = std::make_unique<ort_extensions::TokenizerImpl>();
status = ptr->Load(*blob);
if (status.IsOk()) {
*tokenizer = static_cast<OrtxTokenizer*>(ptr.release());
return extError_t();
}
return status.Code();
}
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size,
OrtxTokenId2DArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {

Просмотреть файл

@ -21,8 +21,7 @@ std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
"CodeLlamaTokenizer",
"CodeGenTokenizer",
"GPT2Tokenizer",
"Qwen2Tokenizer",
"T5Tokenizer"
"Qwen2Tokenizer"
};
std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
@ -33,17 +32,11 @@ TokenizerImpl::TokenizerImpl()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
TokenizerImpl::~TokenizerImpl() {};
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
auto status = tok_config_->Load(tok_path);
if (!status.IsOk()) {
return status;
}
OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
if (tok_config_->tokenizer_class_.empty() ||
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
status = tokenizer->Load(*tok_config_);
auto status = tokenizer->Load(*tok_config_);
if (!status.IsOk()) {
return status;
}
@ -65,12 +58,21 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
auto fx_load = vocab_file_path.extension() == ".json"?
&JsonFastTokenizer::Load: &JsonFastTokenizer::LoadTikTokenBase64;
status = (tokenizer.get()->*fx_load)(*tok_config_);
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
}
auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
@ -86,6 +88,26 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
return status;
}
OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
auto status = tok_config_->LoadFromBlob(blob);
if (!status.IsOk()) {
return status;
}
return LoadTokenizer(&blob);
}
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
auto status = tok_config_->Load(tok_path);
if (!status.IsOk()) {
return status;
}
return LoadTokenizer();
}
OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
for (const auto& s : input) {

Просмотреть файл

@ -23,6 +23,7 @@ class TokenizerImpl : public OrtxObjectImpl {
public:
OrtxStatus Load(const std::string& tok_path);
OrtxStatus Load(const OrtxTokenizerBlob& blob);
OrtxStatus Tokenize(const std::vector<std::string_view>& input, std::vector<std::vector<extTokenId_t>>& t_ids) const {
return BatchEncode(input, t_ids);
@ -60,6 +61,8 @@ class TokenizerImpl : public OrtxObjectImpl {
std::vector<std::vector<extTokenId_t>>& t_ids) const;
private:
OrtxStatus LoadTokenizer(const OrtxTokenizerBlob* blob = nullptr);
using bpe_tokenizer_t = std::unique_ptr<JsonFastTokenizer>;
using ugm_tokenizer_t = std::unique_ptr<SpmUgmTokenizer>;
std::variant<bpe_tokenizer_t, ugm_tokenizer_t> tokenizer_;

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -416,7 +416,7 @@ TEST(OrtxTokenizerTest, CodeGenTokenizer) {
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(out_text1.size(), 1);
std::string out_text_ref = out_text1.back();
std::cout << out_text_ref << std::endl;
// std::cout << out_text_ref << std::endl;
EXPECT_EQ(out_text_ref.substr(out_text_ref.length() - 3, 3), "\ufffd");
}
@ -503,8 +503,7 @@ TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
std::vector<std::string_view> input = {
R"(こんにちは。データ分析にはいくつかのステップがあります。まずは目的を明確にします。次に、データを収集し、クリーニングを行います。)"
R"(その後、データを構造化し、その後、データを分析します。これらのステップを実行することで、データを有意的に分析することができます。)"
};
R"(その後、データを構造化し、その後、データを分析します。これらのステップを実行することで、データを有意的に分析することができます。)"};
std::vector<std::vector<extTokenId_t>> token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
@ -569,8 +568,8 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
// expected ids was generated using the following command:
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364,
245875, 354, 11716, 2}));
OrtxObjectPtr<OrtxStringArray> decoded_text;
OrtxDetokenize(tokenizer.get(), token_ids.get(), ort_extensions::ptr(decoded_text));
@ -580,11 +579,91 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
OrtxStringArrayGetItem(decoded_text.get(), 0, &text);
// because the tokenization remove the character from the string, the decoded text is not the same as the input text.
std::string filtered_text(input[0]);
filtered_text.erase(std::remove_if(
filtered_text.begin(), filtered_text.end(), [](unsigned char chr){ return chr < 0x20; }), filtered_text.end());
filtered_text.erase(
std::remove_if(filtered_text.begin(), filtered_text.end(), [](unsigned char chr) { return chr < 0x20; }),
filtered_text.end());
// remove the consecutive spaces
filtered_text.erase(std::unique(filtered_text.begin(), filtered_text.end(),
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }), filtered_text.end());
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }),
filtered_text.end());
EXPECT_STREQ(filtered_text.c_str(), text);
}
static std::string ReadFile(const std::string& filepath) {
std::ifstream file(filepath.data(), std::ios::binary);
if (!file.is_open()) {
return "";
}
std::ostringstream ss;
ss << file.rdbuf();
return ss.str();
};
TEST(OrtxTokenizerTest, Phi3_Small_Tokenizer_Blob) {
std::string config_blob = ReadFile("data/tokenizer/phi-3-small/tokenizer_config.json");
ASSERT_FALSE(config_blob.empty()) << "Failed to read config blob file, stopping the test.";
std::string raw_model_blob = ReadFile("data/tokenizer/phi-3-small/cl100k_base.tiktoken");
ASSERT_FALSE(raw_model_blob.empty()) << "Failed to read raw model blob file, stopping the test.";
std::string module_blob = ReadFile("data/tokenizer/phi-3-small/tokenizer_module.json");
ASSERT_FALSE(module_blob.empty()) << "Failed to read module blob file, stopping the test.";
struct OrtxTokenizerBlob blobs(config_blob, "", module_blob, raw_model_blob);
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizerFromBlob, &blobs);
ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test.";
// validate tokenizer is not null
ASSERT_NE(tokenizer.get(), nullptr) << "Tokenizer is null, stopping the test.";
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13};
const char* input[] = {"This is a test.",
"the second one",
"I like walking my cute dog\n and\x17 then",
"Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
OrtxObjectPtr<OrtxTokenId2DArray> token_ids;
OrtxTokenize(tokenizer.get(), input, 4, ort_extensions::ptr(token_ids));
EXPECT_EQ(token_ids.Code(), kOrtxOK);
size_t length = 0;
const extTokenId_t* ids = nullptr;
OrtxTokenId2DArrayGetItem(token_ids.get(), 0, &ids, &length);
std::vector<extTokenId_t> ids_vec(ids, ids + length);
EXPECT_EQ(ids_vec, EXPECTED_IDS_0);
}
TEST(OrtxTokenizerTest, Phi3TokenizerBlob) {
std::string config_blob = ReadFile("data/phi-3/tokenizer_config.json");
ASSERT_FALSE(config_blob.empty()) << "Failed to read config blob file, stopping the test.";
std::string vocab_blob = ReadFile("data/phi-3/tokenizer.json");
ASSERT_FALSE(vocab_blob.empty()) << "Failed to read vocab blob file, stopping the test.";
struct OrtxTokenizerBlob blob(config_blob, vocab_blob, "", "");
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizerFromBlob, &blob);
ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test.";
// validate tokenizer is not null
ASSERT_NE(tokenizer.get(), nullptr) << "Tokenizer is null, stopping the test.";
const char* input[] = {"I like walking my cute dog\n and\x17 then, 生活的真谛是 \t\t\t\t \n\n61"};
OrtxObjectPtr<OrtxTokenId2DArray> token_ids;
OrtxTokenize(tokenizer.get(), input, 1, ort_extensions::ptr(token_ids));
EXPECT_EQ(token_ids.Code(), kOrtxOK);
size_t length = 0;
const extTokenId_t* ids = nullptr;
OrtxTokenId2DArrayGetItem(token_ids.get(), 0, &ids, &length);
std::vector<extTokenId_t> ids_vec(ids, ids + length);
// expected ids was generated using the following command:
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
EXPECT_EQ(ids_vec,
std::vector<extTokenId_t>({1, 306, 763, 22049, 590, 274, 1082, 11203, 13, 322, 26,
769, 29892, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392,
259, 12, 12, 12, 12, 29871, 13, 13, 29953, 29896}));
}