Load the tokenizer data from the memory (#836)
This commit is contained in:
Родитель
14f280adf6
Коммит
3da0d3c929
|
@ -13,6 +13,33 @@ typedef OrtxObject OrtxStringArray;
|
|||
typedef OrtxObject OrtxTokenId2DArray;
|
||||
typedef OrtxObject OrtxDetokenizerCache;
|
||||
|
||||
struct OrtxTokenizerBlob {
|
||||
const char* config_json_blob;
|
||||
const char* vocab_json_blob;
|
||||
const char* token_module_blob;
|
||||
const char* raw_model_blob;
|
||||
const char* reserved_blob_1;
|
||||
|
||||
const size_t config_blob_len;
|
||||
const size_t vocab_blob_len;
|
||||
const size_t token_module_blob_len;
|
||||
const size_t raw_model_blob_len;
|
||||
const size_t reserved_blob_1_len;
|
||||
|
||||
#ifdef __cplusplus
|
||||
OrtxTokenizerBlob(const std::string_view& config_json_blob,
|
||||
const std::string_view& vocab_json_blob,
|
||||
const std::string_view& token_module_blob,
|
||||
const std::string_view& raw_model_blob)
|
||||
: config_json_blob(config_json_blob.data()), vocab_json_blob(vocab_json_blob.data()),
|
||||
token_module_blob(token_module_blob.data()), raw_model_blob(raw_model_blob.data()),
|
||||
config_blob_len(config_json_blob.size()),
|
||||
vocab_blob_len(vocab_json_blob.size()), token_module_blob_len(token_module_blob.size()),
|
||||
raw_model_blob_len(raw_model_blob.size()), reserved_blob_1(nullptr),
|
||||
reserved_blob_1_len(0) {}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -26,6 +53,15 @@ extern "C" {
|
|||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const char* tokenizer_path);
|
||||
|
||||
/** \brief Create a tokenizer object with the specified tokenizer blob
|
||||
*
|
||||
* \param tokenizer Pointer to store the created tokenizer object
|
||||
* \param tokenizer_blob Pointer to the tokenizer blob
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, const struct OrtxTokenizerBlob* tokenizer_blob);
|
||||
|
||||
|
||||
/** \brief Tokenize the input using the specified tokenizer
|
||||
*
|
||||
* \param tokenizer Pointer to the tokenizer object
|
||||
|
|
|
@ -650,53 +650,6 @@ std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// Custom hash function for the vector key
|
||||
struct VectorHash {
|
||||
size_t operator()(const std::vector<uint8_t>& v) const {
|
||||
std::hash<uint8_t> hasher;
|
||||
size_t seed = 0;
|
||||
for (uint8_t i : v) {
|
||||
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
// Custom equality function for the vector key
|
||||
struct VectorEqual {
|
||||
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
|
||||
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
|
||||
auto added_tokens = tok_json.find("added_tokens");
|
||||
if (added_tokens != tok_json.end()) {
|
||||
for (const auto& token : *added_tokens) {
|
||||
AddedToken added_token;
|
||||
added_token.id_ = token.value("id", 0);
|
||||
added_token.token_type_ = token.value("__type", "");
|
||||
added_token.content_ = token.value("content", "");
|
||||
added_token.lstrip_ = token.value("lstrip", false);
|
||||
added_token.normalized_ = token.value("normalized", false);
|
||||
added_token.rstrip_ = token.value("rstrip", false);
|
||||
added_token.single_word_ = token.value("single_word", false);
|
||||
added_token.special_ = token.value("special", false);
|
||||
|
||||
added_tokens_.emplace_back(added_token);
|
||||
if (added_token.content_ == config.bos_token_) {
|
||||
bos_token_id_ = added_token.id_;
|
||||
} else if (added_token.content_ == config.eos_token_) {
|
||||
eos_token_id_ = added_token.id_;
|
||||
} else if (added_token.content_ == config.pad_token_) {
|
||||
pad_token_id_ = added_token.id_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
|
||||
}
|
||||
|
||||
// Helper methods (to be added to the class declaration)
|
||||
void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
|
||||
auto decoder_node = tok_json.find("decoder");
|
||||
|
@ -722,7 +675,29 @@ void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
|
|||
}
|
||||
}
|
||||
|
||||
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
|
||||
void JsonFastTokenizer::UpdateTokenizer(const TokenJsonConfig& config, const json& tok_json) {
|
||||
added_tokens_ = config.added_tokens_;
|
||||
auto added_tokens = tok_json.find("added_tokens");
|
||||
if (added_tokens != tok_json.end()) {
|
||||
for (const auto& token : *added_tokens) {
|
||||
added_tokens_.emplace_back(TokenJsonConfig::ParseAddedToken(token));
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& added_token : added_tokens_) {
|
||||
if (added_token.content_ == config.bos_token_) {
|
||||
bos_token_id_ = added_token.id_;
|
||||
} else if (added_token.content_ == config.eos_token_) {
|
||||
eos_token_id_ = added_token.id_;
|
||||
} else if (added_token.content_ == config.pad_token_) {
|
||||
pad_token_id_ = added_token.id_;
|
||||
}
|
||||
}
|
||||
|
||||
bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
|
||||
add_bos_token_ = config.add_bos_token_;
|
||||
add_eos_token_ = config.add_eos_token_;
|
||||
|
||||
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
|
||||
auto post_processor = tok_json.find("post_processor");
|
||||
if (post_processor != tok_json.end()) {
|
||||
|
@ -738,14 +713,14 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
|
|||
}
|
||||
|
||||
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
|
||||
std::string voc_file = config.GetVocabDataFile();
|
||||
std::ifstream ifs = path(voc_file).open();
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
|
||||
std::unique_ptr<std::istream> vocab_stream;
|
||||
auto status = config.OpenVocabFile(vocab_stream);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
nlohmann::json tok_json;
|
||||
ifs >> tok_json;
|
||||
*vocab_stream >> tok_json;
|
||||
|
||||
const char token_sub[] = "Tokenizer";
|
||||
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
|
||||
|
@ -767,30 +742,40 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config
|
|||
}
|
||||
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
OrtxStatus status = bbpe_tokenizer_->Load(*model_node,
|
||||
status = bbpe_tokenizer_->Load(*model_node,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
bpe_conf_.get().spm_model_);
|
||||
if (!status.IsOk()) {
|
||||
if (status.IsOk()) {
|
||||
UpdateTokenizer(config, tok_json);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
status = LoadAddedTokens(tok_json, config);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
// Custom hash function for the vector key
|
||||
struct VectorHash {
|
||||
size_t operator()(const std::vector<uint8_t>& v) const {
|
||||
std::hash<uint8_t> hasher;
|
||||
size_t seed = 0;
|
||||
for (uint8_t i : v) {
|
||||
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
add_bos_token_ = config.add_bos_token_;
|
||||
add_eos_token_ = config.add_eos_token_;
|
||||
UpdateTokenAdditionFlags(tok_json, config);
|
||||
|
||||
return status;
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
// Custom equality function for the vector key
|
||||
struct VectorEqual {
|
||||
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
|
||||
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
|
||||
std::string voc_file = config.GetVocabDataFile();
|
||||
std::ifstream ifs = path(voc_file).open();
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
|
||||
std::unique_ptr<std::istream> vocab_stream;
|
||||
auto status = config.OpenVocabFile(vocab_stream);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, uint32_t> vocab;
|
||||
|
@ -798,7 +783,7 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
|
|||
std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;
|
||||
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
while (std::getline(*vocab_stream, line)) {
|
||||
if (!line.empty()) {
|
||||
std::istringstream lineStream(line);
|
||||
std::string token;
|
||||
|
@ -857,7 +842,8 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
|
|||
|
||||
// Populate merges
|
||||
for (auto& val : byte_merges) {
|
||||
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
|
||||
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)),
|
||||
JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
|
||||
}
|
||||
|
||||
const char token_sub[] = "Tokenizer";
|
||||
|
@ -871,32 +857,12 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
|
|||
// re-bind the configuration object
|
||||
bpe_conf_ = json_conf_;
|
||||
|
||||
OrtxStatus status = bbpe_tokenizer_->Load(vocab,
|
||||
merges,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
false);
|
||||
status = bbpe_tokenizer_->Load(vocab, merges, bpe_conf_.get().GetSpecialTokens().c_str(), false);
|
||||
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
if (status.IsOk()) {
|
||||
UpdateTokenizer(config, json());
|
||||
}
|
||||
|
||||
std::string module_file = config.GetTikTokenModuleFile();
|
||||
std::ifstream module_ifs = path(module_file).open();
|
||||
if (!module_ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
|
||||
}
|
||||
|
||||
nlohmann::json tok_json;
|
||||
module_ifs >> tok_json;
|
||||
status = LoadAddedTokens(tok_json, config);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
add_bos_token_ = config.add_bos_token_;
|
||||
add_eos_token_ = config.add_eos_token_;
|
||||
UpdateTokenAdditionFlags(tok_json, config);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
|
|
@ -128,8 +128,7 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
|
|||
private:
|
||||
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
|
||||
void LoadSpmModelParams(const json& tok_json);
|
||||
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
|
||||
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
|
||||
void UpdateTokenizer(const ort_extensions::TokenJsonConfig& config, const json& tok_json);
|
||||
|
||||
BpeModelConf json_conf_;
|
||||
std::vector<ort_extensions::AddedToken> added_tokens_;
|
||||
|
|
|
@ -258,12 +258,10 @@ class BpeModel {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
|
||||
void LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
|
||||
for (const auto& token : added_tokens) {
|
||||
added_tokens_.Add(ustring(token.content_), 0, token.id_);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::string> BuildDecoder() const { return id2token_map_; }
|
||||
|
|
|
@ -21,61 +21,7 @@ class TokenJsonConfig final {
|
|||
using json_pointer = nlohmann::json_pointer<std::string>;
|
||||
|
||||
public:
|
||||
OrtxStatus Load(const std::string& json_path) {
|
||||
if (json_path.empty()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
|
||||
}
|
||||
|
||||
ortx::path tok_dir(json_path);
|
||||
ortx::path vocab_path(json_path);
|
||||
ortx::path tok_path_obj(json_path);
|
||||
if (tok_path_obj.is_directory()) {
|
||||
vocab_path = tok_dir / kDefaultVocabFile;
|
||||
} else {
|
||||
if (!tok_path_obj.exists()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tok_path_obj.string());
|
||||
}
|
||||
|
||||
tok_dir = ortx::path(tok_path_obj.parent_path());
|
||||
}
|
||||
|
||||
auto config_path = tok_dir / "tokenizer_config.json";
|
||||
std::ifstream ifs = config_path.open();
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + config_path.string());
|
||||
}
|
||||
|
||||
nlohmann::json json_config = nlohmann::json::parse(ifs);
|
||||
auto module_cfg = tok_dir / "tokenizer_module.json";
|
||||
if (module_cfg.exists()) {
|
||||
module_path_ = module_cfg.string();
|
||||
std::ifstream module_ifs = module_cfg.open();
|
||||
nlohmann::json module_config = nlohmann::json::parse(module_ifs);
|
||||
json_config.update(module_config);
|
||||
}
|
||||
|
||||
model_max_length_ = json_config.value("model_max_length", 1e+30);
|
||||
std::string tiktoken_file = json_config.value("tiktoken_file", "");
|
||||
if (!tiktoken_file.empty()) {
|
||||
auto tktok_path = tok_dir / tiktoken_file;
|
||||
if (tktok_path.exists()) {
|
||||
vocab_path_ = tktok_path.string();
|
||||
} else {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tiktoken_file);
|
||||
}
|
||||
} else {
|
||||
if (ortx::path(vocab_path).exists()) {
|
||||
vocab_path_ = vocab_path.string();
|
||||
} else {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + vocab_path.string());
|
||||
}
|
||||
}
|
||||
|
||||
tokenizer_class_ = json_config.value("tokenizer_class", "");
|
||||
if (tokenizer_class_.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus ParseTokensFromConfig(const json& json_config) {
|
||||
add_bos_token_ = json_config.value("add_bos_token", false);
|
||||
add_eos_token_ = json_config.value("add_eos_token", false);
|
||||
clean_up_tokenization_spaces_ = json_config.value("clean_up_tokenization_spaces", false);
|
||||
|
@ -101,9 +47,135 @@ class TokenJsonConfig final {
|
|||
return {};
|
||||
}
|
||||
|
||||
const std::string& GetVocabDataFile() const { return vocab_path_; }
|
||||
OrtxStatus OpenVocabFile(std::unique_ptr<std::istream>& vocab_stream) const {
|
||||
if (blob_ != nullptr) {
|
||||
if (blob_->vocab_blob_len == 0) {
|
||||
if (blob_->raw_model_blob_len == 0) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "vocab_blob_len and raw_model_blob_len are both 0.");
|
||||
}
|
||||
std::string vocab_str(blob_->raw_model_blob, blob_->raw_model_blob_len);
|
||||
vocab_stream = std::make_unique<std::istringstream>(vocab_str);
|
||||
} else {
|
||||
if (blob_->raw_model_blob_len > 0) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "vocab_blob_len and raw_model_blob_len are both non-zero.");
|
||||
}
|
||||
std::string vocab_str(blob_->vocab_json_blob, blob_->vocab_blob_len);
|
||||
vocab_stream = std::make_unique<std::istringstream>(vocab_str);
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto ifs = std::make_unique<std::ifstream>(vocab_path_);
|
||||
if (!ifs->is_open()) {
|
||||
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, vocab_path_ + ": does not exist.");
|
||||
}
|
||||
vocab_stream = std::move(ifs);
|
||||
}
|
||||
|
||||
const std::string& GetTikTokenModuleFile() const { return module_path_; }
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus LoadFromBlob(const OrtxTokenizerBlob& blob) {
|
||||
std::string config_str(blob.config_json_blob, blob.config_blob_len);
|
||||
std::istringstream config_ifs(config_str);
|
||||
json json_config = json::parse(config_ifs, nullptr, false, true);
|
||||
if (json_config.is_discarded()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse config json.");
|
||||
}
|
||||
|
||||
if (blob.token_module_blob_len > 0) {
|
||||
std::string tokenizer_str(blob.token_module_blob, blob.token_module_blob_len);
|
||||
std::istringstream tokenizer_ifs(tokenizer_str);
|
||||
json json_tokenizer = json::parse(tokenizer_ifs, nullptr, false, true);
|
||||
if (json_tokenizer.is_discarded()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse tokenizer json.");
|
||||
}
|
||||
LoadAddedTokens(json_tokenizer);
|
||||
json_config.update(json_tokenizer);
|
||||
}
|
||||
|
||||
blob_ = &blob;
|
||||
model_max_length_ = json_config.value("model_max_length", 1e+30);
|
||||
std::string tiktoken_file = json_config.value("tiktoken_file", "");
|
||||
if (!tiktoken_file.empty()) {
|
||||
if (blob.raw_model_blob_len == 0) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "missing tiktoken file content in blob.raw_model_blob.");
|
||||
}
|
||||
}
|
||||
|
||||
tokenizer_class_ = json_config.value("tokenizer_class", "");
|
||||
if (!tokenizer_class_.empty()) {
|
||||
return ParseTokensFromConfig(json_config);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Load(const std::string& json_path) {
|
||||
if (json_path.empty()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
|
||||
}
|
||||
|
||||
ortx::path tok_dir(json_path);
|
||||
ortx::path vocab_path(json_path);
|
||||
ortx::path tok_path_obj(json_path);
|
||||
if (tok_path_obj.is_directory()) {
|
||||
vocab_path = tok_dir / kDefaultVocabFile;
|
||||
} else {
|
||||
if (!tok_path_obj.exists()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tok_path_obj.string());
|
||||
}
|
||||
|
||||
tok_dir = ortx::path(tok_path_obj.parent_path());
|
||||
}
|
||||
|
||||
auto config_path = tok_dir / "tokenizer_config.json";
|
||||
std::ifstream ifs = config_path.open();
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + config_path.string());
|
||||
}
|
||||
|
||||
json json_config = json::parse(ifs, nullptr, false, true);
|
||||
if (json_config.is_discarded()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse config json.");
|
||||
}
|
||||
|
||||
auto module_cfg = tok_dir / "tokenizer_module.json";
|
||||
if (module_cfg.exists()) {
|
||||
std::ifstream module_ifs = module_cfg.open();
|
||||
json json_module = json::parse(module_ifs, nullptr, false, true);
|
||||
if (json_module.is_discarded()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse tokenizer module json.");
|
||||
}
|
||||
LoadAddedTokens(json_module);
|
||||
json_config.update(json_module);
|
||||
}
|
||||
|
||||
model_max_length_ = json_config.value("model_max_length", 1e+30);
|
||||
std::string tiktoken_file = json_config.value("tiktoken_file", "");
|
||||
if (!tiktoken_file.empty()) {
|
||||
auto tktok_path = tok_dir / tiktoken_file;
|
||||
if (tktok_path.exists()) {
|
||||
vocab_path_ = tktok_path.string();
|
||||
} else {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + tiktoken_file);
|
||||
}
|
||||
} else {
|
||||
if (ortx::path(vocab_path).exists()) {
|
||||
vocab_path_ = vocab_path.string();
|
||||
} else {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Invalid file: " + vocab_path.string());
|
||||
}
|
||||
}
|
||||
|
||||
tokenizer_class_ = json_config.value("tokenizer_class", "");
|
||||
if (!tokenizer_class_.empty()) {
|
||||
return ParseTokensFromConfig(json_config);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
const std::string& GetVocabDataFile() const { return vocab_path_; }
|
||||
|
||||
public:
|
||||
bool add_bos_token_{};
|
||||
|
@ -117,9 +189,34 @@ class TokenJsonConfig final {
|
|||
std::string unk_token_;
|
||||
std::string pad_token_;
|
||||
|
||||
std::vector<ort_extensions::AddedToken> added_tokens_;
|
||||
|
||||
static AddedToken ParseAddedToken(const json& token) {
|
||||
AddedToken added_token;
|
||||
added_token.id_ = token.value("id", 0);
|
||||
added_token.token_type_ = token.value("__type", "");
|
||||
added_token.content_ = token.value("content", "");
|
||||
added_token.lstrip_ = token.value("lstrip", false);
|
||||
added_token.normalized_ = token.value("normalized", false);
|
||||
added_token.rstrip_ = token.value("rstrip", false);
|
||||
added_token.single_word_ = token.value("single_word", false);
|
||||
added_token.special_ = token.value("special", false);
|
||||
return added_token;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
void LoadAddedTokens(const json& tok_json) {
|
||||
auto added_tokens = tok_json.find("added_tokens");
|
||||
if (added_tokens != tok_json.end()) {
|
||||
for (const auto& token : *added_tokens) {
|
||||
added_tokens_.emplace_back(ParseAddedToken(token));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string vocab_path_;
|
||||
std::string module_path_;
|
||||
const OrtxTokenizerBlob* blob_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -113,22 +113,18 @@ struct SpmUgmTokenizer {
|
|||
}
|
||||
|
||||
OrtxStatus Load(const TokenJsonConfig& config) {
|
||||
ortx::path vocab_path(config.GetVocabDataFile());
|
||||
if (!vocab_path.exists()) {
|
||||
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary file does not exist.");
|
||||
std::unique_ptr<std::istream> vocab_stream;
|
||||
auto status = config.OpenVocabFile(vocab_stream);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto ifs = vocab_path.open();
|
||||
if (!ifs.is_open()) {
|
||||
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Failed to open vocabulary file.");
|
||||
}
|
||||
|
||||
nlohmann::json j_vocab = json::parse(ifs, nullptr, false, true);
|
||||
nlohmann::json j_vocab = json::parse(*vocab_stream, nullptr, false, true);
|
||||
if (j_vocab.is_discarded()) {
|
||||
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Failed to parse vocabulary file.");
|
||||
}
|
||||
|
||||
OrtxStatus status = LoadConfig(j_vocab);
|
||||
status = LoadConfig(j_vocab);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
|
|
@ -46,6 +46,24 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch
|
|||
return status.Code();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, const OrtxTokenizerBlob* blob) {
|
||||
// test if the tokenizer_path is a valid directory
|
||||
if (blob == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "The tokenizer blob is null";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
ReturnableStatus status;
|
||||
auto ptr = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
status = ptr->Load(*blob);
|
||||
if (status.IsOk()) {
|
||||
*tokenizer = static_cast<OrtxTokenizer*>(ptr.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size,
|
||||
OrtxTokenId2DArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
|
|
|
@ -21,8 +21,7 @@ std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
|
|||
"CodeLlamaTokenizer",
|
||||
"CodeGenTokenizer",
|
||||
"GPT2Tokenizer",
|
||||
"Qwen2Tokenizer",
|
||||
"T5Tokenizer"
|
||||
"Qwen2Tokenizer"
|
||||
};
|
||||
|
||||
std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
|
||||
|
@ -33,17 +32,11 @@ TokenizerImpl::TokenizerImpl()
|
|||
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
|
||||
TokenizerImpl::~TokenizerImpl() {};
|
||||
|
||||
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
||||
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
|
||||
auto status = tok_config_->Load(tok_path);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
|
||||
if (tok_config_->tokenizer_class_.empty() ||
|
||||
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {
|
||||
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
|
||||
status = tokenizer->Load(*tok_config_);
|
||||
auto status = tokenizer->Load(*tok_config_);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -65,12 +58,21 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
|||
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
|
||||
}
|
||||
|
||||
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
|
||||
auto tokenizer = std::make_unique<JsonFastTokenizer>();
|
||||
auto fx_load = &JsonFastTokenizer::Load;
|
||||
if (blob == nullptr) {
|
||||
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
|
||||
// vocab file is checked in TokenJsonConfig::Load
|
||||
auto fx_load = vocab_file_path.extension() == ".json"?
|
||||
&JsonFastTokenizer::Load: &JsonFastTokenizer::LoadTikTokenBase64;
|
||||
status = (tokenizer.get()->*fx_load)(*tok_config_);
|
||||
if (vocab_file_path.extension() != ".json") {
|
||||
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
|
||||
}
|
||||
} else {
|
||||
if (blob->raw_model_blob_len > 0) {
|
||||
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
|
||||
}
|
||||
}
|
||||
|
||||
auto status = (tokenizer.get()->*fx_load)(*tok_config_);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -86,6 +88,26 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
|||
return status;
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
|
||||
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
|
||||
auto status = tok_config_->LoadFromBlob(blob);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return LoadTokenizer(&blob);
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
|
||||
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
|
||||
auto status = tok_config_->Load(tok_path);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return LoadTokenizer();
|
||||
}
|
||||
|
||||
OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input,
|
||||
std::vector<std::vector<extTokenId_t>>& t_ids) const {
|
||||
for (const auto& s : input) {
|
||||
|
|
|
@ -23,6 +23,7 @@ class TokenizerImpl : public OrtxObjectImpl {
|
|||
|
||||
public:
|
||||
OrtxStatus Load(const std::string& tok_path);
|
||||
OrtxStatus Load(const OrtxTokenizerBlob& blob);
|
||||
|
||||
OrtxStatus Tokenize(const std::vector<std::string_view>& input, std::vector<std::vector<extTokenId_t>>& t_ids) const {
|
||||
return BatchEncode(input, t_ids);
|
||||
|
@ -60,6 +61,8 @@ class TokenizerImpl : public OrtxObjectImpl {
|
|||
std::vector<std::vector<extTokenId_t>>& t_ids) const;
|
||||
|
||||
private:
|
||||
OrtxStatus LoadTokenizer(const OrtxTokenizerBlob* blob = nullptr);
|
||||
|
||||
using bpe_tokenizer_t = std::unique_ptr<JsonFastTokenizer>;
|
||||
using ugm_tokenizer_t = std::unique_ptr<SpmUgmTokenizer>;
|
||||
std::variant<bpe_tokenizer_t, ugm_tokenizer_t> tokenizer_;
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -416,7 +416,7 @@ TEST(OrtxTokenizerTest, CodeGenTokenizer) {
|
|||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(out_text1.size(), 1);
|
||||
std::string out_text_ref = out_text1.back();
|
||||
std::cout << out_text_ref << std::endl;
|
||||
// std::cout << out_text_ref << std::endl;
|
||||
EXPECT_EQ(out_text_ref.substr(out_text_ref.length() - 3, 3), "\ufffd");
|
||||
}
|
||||
|
||||
|
@ -503,8 +503,7 @@ TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
|
|||
|
||||
std::vector<std::string_view> input = {
|
||||
R"(こんにちは。データ分析にはいくつかのステップがあります。まずは目的を明確にします。次に、データを収集し、クリーニングを行います。)"
|
||||
R"(その後、データを構造化し、その後、データを分析します。これらのステップを実行することで、データを有意的に分析することができます。)"
|
||||
};
|
||||
R"(その後、データを構造化し、その後、データを分析します。これらのステップを実行することで、データを有意的に分析することができます。)"};
|
||||
std::vector<std::vector<extTokenId_t>> token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
|
@ -569,8 +568,8 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
|
|||
|
||||
// expected ids was generated using the following command:
|
||||
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
|
||||
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
|
||||
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364,
|
||||
245875, 354, 11716, 2}));
|
||||
|
||||
OrtxObjectPtr<OrtxStringArray> decoded_text;
|
||||
OrtxDetokenize(tokenizer.get(), token_ids.get(), ort_extensions::ptr(decoded_text));
|
||||
|
@ -580,11 +579,91 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
|
|||
OrtxStringArrayGetItem(decoded_text.get(), 0, &text);
|
||||
// because the tokenization remove the character from the string, the decoded text is not the same as the input text.
|
||||
std::string filtered_text(input[0]);
|
||||
filtered_text.erase(std::remove_if(
|
||||
filtered_text.begin(), filtered_text.end(), [](unsigned char chr){ return chr < 0x20; }), filtered_text.end());
|
||||
filtered_text.erase(
|
||||
std::remove_if(filtered_text.begin(), filtered_text.end(), [](unsigned char chr) { return chr < 0x20; }),
|
||||
filtered_text.end());
|
||||
// remove the consecutive spaces
|
||||
filtered_text.erase(std::unique(filtered_text.begin(), filtered_text.end(),
|
||||
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }), filtered_text.end());
|
||||
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }),
|
||||
filtered_text.end());
|
||||
|
||||
EXPECT_STREQ(filtered_text.c_str(), text);
|
||||
}
|
||||
|
||||
static std::string ReadFile(const std::string& filepath) {
|
||||
std::ifstream file(filepath.data(), std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return "";
|
||||
}
|
||||
std::ostringstream ss;
|
||||
ss << file.rdbuf();
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
TEST(OrtxTokenizerTest, Phi3_Small_Tokenizer_Blob) {
|
||||
std::string config_blob = ReadFile("data/tokenizer/phi-3-small/tokenizer_config.json");
|
||||
ASSERT_FALSE(config_blob.empty()) << "Failed to read config blob file, stopping the test.";
|
||||
|
||||
std::string raw_model_blob = ReadFile("data/tokenizer/phi-3-small/cl100k_base.tiktoken");
|
||||
ASSERT_FALSE(raw_model_blob.empty()) << "Failed to read raw model blob file, stopping the test.";
|
||||
|
||||
std::string module_blob = ReadFile("data/tokenizer/phi-3-small/tokenizer_module.json");
|
||||
ASSERT_FALSE(module_blob.empty()) << "Failed to read module blob file, stopping the test.";
|
||||
|
||||
struct OrtxTokenizerBlob blobs(config_blob, "", module_blob, raw_model_blob);
|
||||
|
||||
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizerFromBlob, &blobs);
|
||||
ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test.";
|
||||
|
||||
// validate tokenizer is not null
|
||||
ASSERT_NE(tokenizer.get(), nullptr) << "Tokenizer is null, stopping the test.";
|
||||
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13};
|
||||
const char* input[] = {"This is a test.",
|
||||
"the second one",
|
||||
"I like walking my cute dog\n and\x17 then",
|
||||
"Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
|
||||
|
||||
OrtxObjectPtr<OrtxTokenId2DArray> token_ids;
|
||||
OrtxTokenize(tokenizer.get(), input, 4, ort_extensions::ptr(token_ids));
|
||||
EXPECT_EQ(token_ids.Code(), kOrtxOK);
|
||||
|
||||
size_t length = 0;
|
||||
const extTokenId_t* ids = nullptr;
|
||||
OrtxTokenId2DArrayGetItem(token_ids.get(), 0, &ids, &length);
|
||||
std::vector<extTokenId_t> ids_vec(ids, ids + length);
|
||||
EXPECT_EQ(ids_vec, EXPECTED_IDS_0);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, Phi3TokenizerBlob) {
|
||||
std::string config_blob = ReadFile("data/phi-3/tokenizer_config.json");
|
||||
ASSERT_FALSE(config_blob.empty()) << "Failed to read config blob file, stopping the test.";
|
||||
|
||||
std::string vocab_blob = ReadFile("data/phi-3/tokenizer.json");
|
||||
ASSERT_FALSE(vocab_blob.empty()) << "Failed to read vocab blob file, stopping the test.";
|
||||
|
||||
struct OrtxTokenizerBlob blob(config_blob, vocab_blob, "", "");
|
||||
|
||||
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizerFromBlob, &blob);
|
||||
ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test.";
|
||||
|
||||
// validate tokenizer is not null
|
||||
ASSERT_NE(tokenizer.get(), nullptr) << "Tokenizer is null, stopping the test.";
|
||||
|
||||
const char* input[] = {"I like walking my cute dog\n and\x17 then, 生活的真谛是 \t\t\t\t \n\n61"};
|
||||
OrtxObjectPtr<OrtxTokenId2DArray> token_ids;
|
||||
OrtxTokenize(tokenizer.get(), input, 1, ort_extensions::ptr(token_ids));
|
||||
EXPECT_EQ(token_ids.Code(), kOrtxOK);
|
||||
|
||||
size_t length = 0;
|
||||
const extTokenId_t* ids = nullptr;
|
||||
OrtxTokenId2DArrayGetItem(token_ids.get(), 0, &ids, &length);
|
||||
std::vector<extTokenId_t> ids_vec(ids, ids + length);
|
||||
|
||||
// expected ids was generated using the following command:
|
||||
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
EXPECT_EQ(ids_vec,
|
||||
std::vector<extTokenId_t>({1, 306, 763, 22049, 590, 274, 1082, 11203, 13, 322, 26,
|
||||
769, 29892, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392,
|
||||
259, 12, 12, 12, 12, 29871, 13, 13, 29953, 29896}));
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче