support the merges array in tokenizer.json (#817)
This commit is contained in:
Родитель
e424838708
Коммит
2c3e936cfc
|
@ -145,23 +145,32 @@ class BpeModel {
|
|||
uint32_t index = 0;
|
||||
auto merge_item = merges_json.begin();
|
||||
while (merge_item != merges_json.end()) {
|
||||
std::string line = merge_item.value();
|
||||
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
|
||||
if (line.empty()) continue;
|
||||
if ((line[0] == '#') && (index == 0)) continue;
|
||||
auto pos = line.find(' ');
|
||||
if (pos == std::string::npos) {
|
||||
return {
|
||||
kOrtxErrorCorruptData,
|
||||
"Cannot know how to parse line: " + line,
|
||||
};
|
||||
std::string w1, w2;
|
||||
if (merge_item->is_string()) {
|
||||
std::string line = merge_item.value();
|
||||
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
|
||||
if (line.empty()) continue;
|
||||
if ((line[0] == '#') && (index == 0)) continue;
|
||||
auto pos = line.find(' ');
|
||||
if (pos == std::string::npos) {
|
||||
return {
|
||||
kOrtxErrorCorruptData,
|
||||
"Cannot know how to parse line: " + line,
|
||||
};
|
||||
}
|
||||
w1 = line.substr(0, pos);
|
||||
w2 = line.substr(pos + 1);
|
||||
} else if (merge_item->is_array()) {
|
||||
w1 = merge_item->at(0).get<std::string>();
|
||||
w2 = merge_item->at(1).get<std::string>();
|
||||
} else {
|
||||
return {kOrtxErrorCorruptData, "Cannot know how to parse line: " + merge_item->dump()};
|
||||
}
|
||||
std::string w1 = line.substr(0, pos);
|
||||
std::string w2 = line.substr(pos + 1);
|
||||
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
|
||||
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
|
||||
token_length -= 4;
|
||||
}
|
||||
|
||||
auto iw1 = GetTokenId(w1);
|
||||
auto iw2 = GetTokenId(w2);
|
||||
auto iww = GetTokenId(w1 + w2);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "op_def_struct.h"
|
||||
#include "base64.h"
|
||||
#include "ustring.h"
|
||||
#include "narrow.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "trietree.hpp"
|
||||
#include "tokenizer_jsconfig.hpp"
|
||||
|
@ -485,7 +486,7 @@ class SpmUgmDecoder {
|
|||
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
|
||||
}
|
||||
|
||||
size_t seq_len = ids_dim.back();
|
||||
int64_t seq_len = ids_dim.back();
|
||||
size_t string_batch = ids.NumberOfElement() / seq_len;
|
||||
|
||||
std::vector<std::string> decoded_strings;
|
||||
|
@ -495,7 +496,7 @@ class SpmUgmDecoder {
|
|||
std::string text;
|
||||
for (int64_t i = 0; i < seq_len; ++i) {
|
||||
std::string token;
|
||||
Id2Token(p_ids[i], token, nullptr);
|
||||
Id2Token(ort_extensions::narrow<extTokenId_t>(p_ids[i]), token, nullptr);
|
||||
if (token.find(spm_escaped_space) == 0) {
|
||||
token = ws + token.substr(spm_escaped_space.length());
|
||||
}
|
||||
|
@ -508,6 +509,7 @@ class SpmUgmDecoder {
|
|||
text = text.substr(1);
|
||||
}
|
||||
}
|
||||
|
||||
decoded_strings.push_back(text);
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче