Add initial tiktoken and Phi3SmallTokenizer support (#729)
* add initial tiktoken support * add vector hash and equal for bpe ranks map * change lambda comparator * move phi-3-small files * final changes * move tiktoken files from data2 to data * add unit test * add tokenizer module * merge json and tiktoken impl * fix tiktoken encoding problem * address comments * remove dummy tokens --------- Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com> Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Родитель
46998e96fb
Коммит
7851b51ee3
|
@ -30,7 +30,23 @@ class TokenJsonConfig final {
|
|||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + file_path.string());
|
||||
}
|
||||
|
||||
vocab_path_ = (path(json_path) / "tokenizer.json").string();
|
||||
auto vocab_file_path = path(json_path) / "tokenizer.json";
|
||||
vocab_path_ = vocab_file_path.string();
|
||||
std::ifstream vocab_fs = vocab_file_path.open();
|
||||
if (!vocab_fs.is_open()) {
|
||||
// No tokenizer.json file present; search for tokenizer module file
|
||||
auto module_file_path = path(json_path) / "tokenizer_module.json";
|
||||
module_path_ = module_file_path.string();
|
||||
std::ifstream tok_module_ifs = module_file_path.open();
|
||||
if (!tok_module_ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "No tokenizer.json or tokenizer_module.json file found.");
|
||||
} else {
|
||||
nlohmann::json tok_module_json_config = nlohmann::json::parse(tok_module_ifs);
|
||||
auto tiktoken_path = tok_module_json_config.value("tiktoken_file", "");
|
||||
vocab_file_path = path(json_path) / tiktoken_path.c_str();
|
||||
vocab_path_ = vocab_file_path.string();
|
||||
}
|
||||
}
|
||||
nlohmann::json json_config = nlohmann::json::parse(ifs);
|
||||
add_bos_token_ = json_config.value("add_bos_token", false);
|
||||
add_eos_token_ = json_config.value("add_eos_token", false);
|
||||
|
@ -66,6 +82,10 @@ class TokenJsonConfig final {
|
|||
|
||||
const std::string& GetVocabDataFile() const { return vocab_path_; }
|
||||
|
||||
const std::string& GetTikTokenModuleFile() const {
|
||||
return module_path_;
|
||||
}
|
||||
|
||||
public:
|
||||
bool add_bos_token_{};
|
||||
bool add_eos_token_{};
|
||||
|
@ -80,6 +100,7 @@ class TokenJsonConfig final {
|
|||
|
||||
private:
|
||||
std::string vocab_path_;
|
||||
std::string module_path_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions::bpe
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
#include "bpe_json.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
#include "base64.h"
|
||||
|
||||
#include <optional>
|
||||
#include <limits>
|
||||
|
||||
|
@ -552,6 +554,50 @@ SpmTokenizer::SpmTokenizer()
|
|||
|
||||
JsonFastTokenizer::JsonFastTokenizer() : KernelBpeTokenizer(kGPT2Configuration) {}
|
||||
|
||||
/*
|
||||
Read more here: https://github.com/huggingface/transformers/blob/60bb571e993b7d73257fb64044726b569fef9403/src/transformers/convert_slow_tokenizer.py#L1454
|
||||
|
||||
Note: this is similar to the BPE CreateByteEncoder, however for decoding the .tiktoken bytes
|
||||
we need to store the strings rather than their IDs, and thereby need a separate map.
|
||||
*/
|
||||
void JsonFastTokenizer::CreateUnicodeByteEncoder() {
|
||||
char32_t index = 256;
|
||||
for (char32_t i = 0; i < 256; ++i) {
|
||||
if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) {
|
||||
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(index++);
|
||||
} else {
|
||||
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
|
||||
std::string result;
|
||||
for (auto c : bytes) {
|
||||
result += unicode_byte_encoder_[static_cast<unsigned char>(c)];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Custom hash function for the vector key
|
||||
struct VectorHash {
|
||||
size_t operator()(const std::vector<uint8_t>& v) const {
|
||||
std::hash<uint8_t> hasher;
|
||||
size_t seed = 0;
|
||||
for (uint8_t i : v) {
|
||||
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
// Custom equality function for the vector key
|
||||
struct VectorEqual {
|
||||
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
|
||||
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
|
||||
std::string voc_file = config.GetVocabDataFile();
|
||||
std::ifstream ifs = path(voc_file).open();
|
||||
|
@ -559,6 +605,89 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
|
|||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
|
||||
}
|
||||
|
||||
// consider to use SAX parser for large json file
|
||||
nlohmann::json tok_json;
|
||||
std::ifstream module_ifs;
|
||||
|
||||
// Following vocab and merges only used for tiktoken case but accessed outside scope below
|
||||
std::unordered_map<std::string, uint32_t> vocab;
|
||||
std::vector<std::pair<std::string, std::string>> merges;
|
||||
|
||||
if (tiktoken_){
|
||||
std::string module_file = config.GetTikTokenModuleFile();
|
||||
|
||||
module_ifs = path(module_file).open();
|
||||
if (!module_ifs.is_open()) {
|
||||
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
|
||||
}
|
||||
|
||||
std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;
|
||||
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
if (!line.empty()) {
|
||||
std::istringstream lineStream(line);
|
||||
std::string token;
|
||||
uint32_t rank;
|
||||
while (lineStream >> token >> rank) {
|
||||
// Decode base64 token and convert rank to int
|
||||
std::vector<uint8_t> decoded_token;
|
||||
base64_decode(token, decoded_token);
|
||||
// Store bpe token and rank
|
||||
bpe_ranks[decoded_token] = rank;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> byte_merges;
|
||||
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
JsonFastTokenizer::CreateUnicodeByteEncoder();
|
||||
|
||||
for (const auto& item : bpe_ranks) {
|
||||
std::vector<uint8_t> token = item.first;
|
||||
uint32_t rank = item.second;
|
||||
vocab[JsonFastTokenizer::TokenBytesToString(token)] = rank;
|
||||
|
||||
if (token.size() == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> local;
|
||||
for (size_t index = 1; index < token.size(); index++) {
|
||||
std::vector<uint8_t> piece_l(token.begin(), token.begin() + index);
|
||||
std::vector<uint8_t> piece_r(token.begin() + index, token.end());
|
||||
if (bpe_ranks.count(piece_l) && bpe_ranks.count(piece_r)) {
|
||||
local.emplace_back(piece_l, piece_r, rank);
|
||||
}
|
||||
}
|
||||
|
||||
auto compare_bpe_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
|
||||
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
|
||||
// Compare comparator based on the ranks in bpe_ranks
|
||||
return bpe_ranks[std::get<0>(a)] < bpe_ranks[std::get<0>(b)] ||
|
||||
(bpe_ranks[std::get<0>(a)] == bpe_ranks[std::get<0>(b)] && bpe_ranks[std::get<1>(a)] < bpe_ranks[std::get<1>(b)]);
|
||||
};
|
||||
|
||||
std::sort(local.begin(), local.end(), compare_bpe_tuples);
|
||||
|
||||
byte_merges.insert(byte_merges.end(), local.begin(), local.end());
|
||||
}
|
||||
|
||||
// Custom comparator that compares the third element of the tuples
|
||||
auto compare_merge_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
|
||||
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
|
||||
return std::get<2>(a) < std::get<2>(b);
|
||||
};
|
||||
|
||||
std::sort(byte_merges.begin(), byte_merges.end(), compare_merge_tuples);
|
||||
|
||||
// Populate merges
|
||||
for (auto& val : byte_merges) {
|
||||
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
|
||||
}
|
||||
}
|
||||
|
||||
const char token_sub[] = "Tokenizer";
|
||||
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
|
||||
json_conf_.name_ = model_name_.c_str();
|
||||
|
@ -570,18 +699,27 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
|
|||
// re-bind the configuration object
|
||||
bpe_conf_ = json_conf_;
|
||||
|
||||
// consider to use SAX parser for large json file
|
||||
nlohmann::json tok_json;
|
||||
ifs >> tok_json;
|
||||
auto model_node = tok_json.find("model");
|
||||
if (model_node == tok_json.end()) {
|
||||
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
|
||||
}
|
||||
OrtxStatus status;
|
||||
if (tiktoken_){
|
||||
status = bbpe_tokenizer_->Load(vocab,
|
||||
merges,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
false);
|
||||
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
auto status = bbpe_tokenizer_->Load(*model_node,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
IsSpmModel(ModelName()));
|
||||
module_ifs >> tok_json;
|
||||
} else {
|
||||
ifs >> tok_json;
|
||||
auto model_node = tok_json.find("model");
|
||||
if (model_node == tok_json.end()) {
|
||||
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
|
||||
}
|
||||
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
status = bbpe_tokenizer_->Load(*model_node,
|
||||
bpe_conf_.get().GetSpecialTokens().c_str(),
|
||||
IsSpmModel(ModelName()));
|
||||
}
|
||||
|
||||
|
||||
auto added_tokens = tok_json.find("added_tokens");
|
||||
if (added_tokens != tok_json.end()) {
|
||||
|
@ -640,4 +778,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
}
|
|
@ -107,6 +107,10 @@ struct SpmTokenizer : KernelBpeTokenizer {
|
|||
class JsonFastTokenizer : KernelBpeTokenizer {
|
||||
public:
|
||||
JsonFastTokenizer();
|
||||
bool tiktoken_ = false;
|
||||
std::string unicode_byte_encoder_[256] = {};
|
||||
void CreateUnicodeByteEncoder();
|
||||
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
|
||||
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
|
@ -121,3 +125,23 @@ class JsonFastTokenizer : KernelBpeTokenizer {
|
|||
BpeModelConf json_conf_;
|
||||
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
|
||||
};
|
||||
|
||||
class TikTokenizer : KernelBpeTokenizer {
|
||||
public:
|
||||
TikTokenizer();
|
||||
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
|
||||
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
|
||||
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
|
||||
public:
|
||||
const auto& GetAddedTokens() const { return added_tokens_; }
|
||||
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
|
||||
BpeModelConf json_conf_;
|
||||
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
|
||||
};
|
||||
|
|
|
@ -194,6 +194,47 @@ class BpeModel {
|
|||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Load(std::unordered_map<std::string, uint32_t>& vocab,
|
||||
std::vector<std::pair<std::string, std::string>>& merges,
|
||||
const char* /* special_tokens */,
|
||||
bool spm_converted) {
|
||||
vocab_map_ = vocab;
|
||||
|
||||
if (spm_converted) {
|
||||
UpdateSpmByteToken(vocab_map_);
|
||||
} else {
|
||||
CreateByteEncoder();
|
||||
}
|
||||
|
||||
uint32_t index = 0;
|
||||
for (auto& tuple : merges){
|
||||
std::string w1 = tuple.first;
|
||||
std::string w2 = tuple.second;
|
||||
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
|
||||
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
|
||||
token_length -= 4;
|
||||
}
|
||||
auto iw1 = GetTokenId(w1);
|
||||
auto iw2 = GetTokenId(w2);
|
||||
auto iww = GetTokenId(w1 + w2);
|
||||
BpeNode value{iww, index++, token_length};
|
||||
bpe_rank_[GetRankKey(iw1, iw2)] = value;
|
||||
}
|
||||
|
||||
id2token_map_.resize(vocab_map_.size());
|
||||
for (const auto& [t, i] : vocab_map_) {
|
||||
if (i > static_cast<uint32_t>((std::numeric_limits<int32_t>::max)())) {
|
||||
continue; // safe purpose.
|
||||
}
|
||||
if (i > id2token_map_.size()) {
|
||||
id2token_map_.resize(static_cast<size_t>(i) + 1);
|
||||
}
|
||||
id2token_map_[i] = t;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus LoadAddedTokens(const char* added_tokens) {
|
||||
int id = bpe::kInvalidTokenId;
|
||||
std::istringstream strm_tokens(added_tokens);
|
||||
|
|
|
@ -18,12 +18,24 @@ OrtxStatus TokenizerImpl::Load(const std::string& dir) {
|
|||
return status;
|
||||
}
|
||||
|
||||
auto vocab_file_path = path(dir) / "tokenizer.json";
|
||||
std::ifstream vocab_fs = vocab_file_path.open();
|
||||
|
||||
tokenizer_ = std::make_unique<JsonFastTokenizer>();
|
||||
// load the tokenizer from a config
|
||||
status = tokenizer_->Load(*tok_config_);
|
||||
if (status.IsOk()) {
|
||||
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
|
||||
status = detokenizer_->Load(tok_config_, *tokenizer_);
|
||||
if (!vocab_fs.is_open()) {
|
||||
// No tokenizer.json file present; use TikToken tokenizer
|
||||
tokenizer_->tiktoken_ = true;
|
||||
|
||||
// load the tokenizer from a config
|
||||
status = tokenizer_->Load(*tok_config_);
|
||||
} else {
|
||||
// load the tokenizer from a config
|
||||
status = tokenizer_->Load(*tok_config_);
|
||||
|
||||
if (status.IsOk()) {
|
||||
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
|
||||
status = detokenizer_->Load(tok_config_, *tokenizer_);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
|
@ -34,7 +46,7 @@ OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input
|
|||
for (const auto& s : input) {
|
||||
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
|
||||
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
|
||||
auto status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
|
||||
OrtxStatus status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
|
||||
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
|
|
|
@ -50,6 +50,7 @@ class TokenizerImpl : public OrtxObjectImpl {
|
|||
std::vector<std::vector<extTokenId_t>>& t_ids) const;
|
||||
|
||||
private:
|
||||
bool tiktoken = false;
|
||||
std::string tokenizer_dir_;
|
||||
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
|
||||
std::unique_ptr<JsonFastTokenizer> tokenizer_;
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"added_tokens_decoder": {},
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenization_phi3_small.Phi3SmallTokenizer",
|
||||
"tokenization_phi3_small.Phi3SmallTokenizer"
|
||||
]
|
||||
},
|
||||
"bos_token": "<|endoftext|>",
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"model_max_length": 8192,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "Phi3SmallTokenizer"
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
{
|
||||
"tiktoken_file": "cl100k_base.tiktoken",
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 100257,
|
||||
"content": "<|endoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100258,
|
||||
"content": "<|fim_prefix|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100259,
|
||||
"content": "<|fim_middle|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100260,
|
||||
"content": "<|fim_suffix|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100261,
|
||||
"content": "<|system|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100262,
|
||||
"content": "<|user|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100263,
|
||||
"content": "<|assistant|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 100276,
|
||||
"content": "<|endofprompt|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
}
|
||||
]
|
||||
}
|
|
@ -151,6 +151,32 @@ TEST(OrtxTokenizerTest, Phi3_S_Tokenizer) {
|
|||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, Phi3_Small_Tokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/phi-3-small");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_NE(tokenizer, nullptr);
|
||||
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13};
|
||||
std::vector<std::string_view> input = {
|
||||
"This is a test.",
|
||||
"the second one",
|
||||
"I like walking my cute dog\n and\x17 then",
|
||||
"Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
|
||||
std::vector<std::vector<extTokenId_t>>
|
||||
token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
DumpTokenIds(token_ids);
|
||||
|
||||
EXPECT_EQ(token_ids.size(), input.size());
|
||||
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, GemmaTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/gemma");
|
||||
|
|
Загрузка…
Ссылка в новой задаче