support the merges array in tokenizer.json (#817)

This commit is contained in:
Wenbing Li 2024-09-26 11:01:13 -07:00 коммит произвёл GitHub
Родитель e424838708
Коммит 2c3e936cfc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 25 добавлений и 14 удалений

Просмотреть файл

@ -145,23 +145,32 @@ class BpeModel {
uint32_t index = 0;
auto merge_item = merges_json.begin();
while (merge_item != merges_json.end()) {
std::string line = merge_item.value();
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
if (line.empty()) continue;
if ((line[0] == '#') && (index == 0)) continue;
auto pos = line.find(' ');
if (pos == std::string::npos) {
return {
kOrtxErrorCorruptData,
"Cannot know how to parse line: " + line,
};
std::string w1, w2;
if (merge_item->is_string()) {
std::string line = merge_item.value();
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
if (line.empty()) continue;
if ((line[0] == '#') && (index == 0)) continue;
auto pos = line.find(' ');
if (pos == std::string::npos) {
return {
kOrtxErrorCorruptData,
"Cannot know how to parse line: " + line,
};
}
w1 = line.substr(0, pos);
w2 = line.substr(pos + 1);
} else if (merge_item->is_array()) {
w1 = merge_item->at(0).get<std::string>();
w2 = merge_item->at(1).get<std::string>();
} else {
return {kOrtxErrorCorruptData, "Cannot know how to parse line: " + merge_item->dump()};
}
std::string w1 = line.substr(0, pos);
std::string w2 = line.substr(pos + 1);
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
token_length -= 4;
}
auto iw1 = GetTokenId(w1);
auto iw2 = GetTokenId(w2);
auto iww = GetTokenId(w1 + w2);

Просмотреть файл

@ -18,6 +18,7 @@
#include "op_def_struct.h"
#include "base64.h"
#include "ustring.h"
#include "narrow.h"
#include "nlohmann/json.hpp"
#include "trietree.hpp"
#include "tokenizer_jsconfig.hpp"
@ -485,7 +486,7 @@ class SpmUgmDecoder {
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
}
size_t seq_len = ids_dim.back();
int64_t seq_len = ids_dim.back();
size_t string_batch = ids.NumberOfElement() / seq_len;
std::vector<std::string> decoded_strings;
@ -495,7 +496,7 @@ class SpmUgmDecoder {
std::string text;
for (int64_t i = 0; i < seq_len; ++i) {
std::string token;
Id2Token(p_ids[i], token, nullptr);
Id2Token(ort_extensions::narrow<extTokenId_t>(p_ids[i]), token, nullptr);
if (token.find(spm_escaped_space) == 0) {
token = ws + token.substr(spm_escaped_space.length());
}
@ -508,6 +509,7 @@ class SpmUgmDecoder {
text = text.substr(1);
}
}
decoded_strings.push_back(text);
}