Add a decoder for Unigram tokenizer and unify some classes among tokenizers (#816)

* rename and formalize the file names

* add the decoder impl

* fix a typo
This commit is contained in:
Wenbing Li 2024-09-25 10:25:06 -07:00 коммит произвёл GitHub
Родитель 6b94f4d7a5
Коммит f204a4c791
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
21 изменённых файлов: 357 добавлений и 236 удалений

Просмотреть файл

@ -15,7 +15,7 @@
#include "ustring.h"
#include "narrow.h"
#include "tokjson_types.h"
#include "tokenizer_common.h"
struct KernelBpeDecoder {
public:

Просмотреть файл

@ -1,16 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "file_sys.h"
#include "bpe_kernels.h"
#include "bpe_jsoncfg.hpp"
#include "bpe_tokenizer.hpp"
#include <limits>
#include <optional>
#include "base64.h"
#include <optional>
#include <limits>
#include "file_sys.h"
#include "bpe_kernels.h"
#include "tokenizer_jsconfig.hpp"
#include "bpe_tokenizer_model.hpp"
using namespace ort_extensions;
@ -673,11 +671,11 @@ struct VectorEqual {
}
};
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
bpe::AddedToken added_token;
AddedToken added_token;
added_token.id_ = token.value("id", 0);
added_token.token_type_ = token.value("__type", "");
added_token.content_ = token.value("content", "");
@ -721,7 +719,7 @@ bool JsonFastTokenizer::CheckForSpmModel(const json& tok_json) {
return false;
}
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
auto post_processor = tok_json.find("post_processor");
if (post_processor != tok_json.end()) {
@ -736,7 +734,7 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
}
}
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
@ -785,7 +783,7 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
return status;
}
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config) {
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {

Просмотреть файл

@ -8,12 +8,7 @@
#include <vector>
#include <functional>
#include "ortx_tokenizer.h"
#include "ext_status.h"
#include "op_def_struct.h"
#include "nlohmann/json_fwd.hpp"
#include "tokjson_types.h"
#include "ustring.h"
#include "tokenizer_common.h"
struct BpeModelConf {
@ -116,8 +111,8 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : public KernelBpeTokenizer {
public:
JsonFastTokenizer();
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Load(const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
@ -133,9 +128,9 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
// template functions to avoid including the huge json header file
bool CheckForSpmModel(const json& tok_json);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
std::vector<ort_extensions::AddedToken> added_tokens_;
};

Просмотреть файл

@ -5,26 +5,19 @@
#include "bpe_kernels.h"
#include "bpe_decoder.hpp"
#include "bpe_jsoncfg.hpp"
#include "bpe_tokenizer.hpp"
namespace ort_extensions {
struct BPEDecoderState {
bool f_special_last{};
std::string incomplete_utf8_;
};
} // namespace ort_extensions
#include "tokenizer_jsconfig.hpp"
#include "bpe_tokenizer_model.hpp"
class BpeStreamingDecoder : public KernelBpeDecoder {
public:
BpeStreamingDecoder() = default;
~BpeStreamingDecoder() override = default;
using BPEDecoderState = ort_extensions::BPEDecoderState;
using BPEDecoderState = ort_extensions::TokenizerDecodingState;
// shared the data between the encoder and decoder
OrtxStatus Load(
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> ptr_config,
std::shared_ptr<ort_extensions::TokenJsonConfig const> ptr_config,
const JsonFastTokenizer& encoder) {
const auto& tok_config = *ptr_config;
bos_token_ = tok_config.bos_token_;
@ -258,5 +251,5 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
extTokenId_t eos_token_id_{0};
bool add_dummy_prefix_ = false;
bool spm_model_{};
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
std::shared_ptr<ort_extensions::TokenJsonConfig const> tok_config_;
};

Просмотреть файл

@ -18,7 +18,7 @@
#include "nlohmann/json.hpp"
#include "bpe_utils.hpp"
#include "trietree.hpp"
#include "tokjson_types.h"
#include "tokenizer_common.h"
namespace ort_extensions {
@ -249,7 +249,7 @@ class BpeModel {
return {};
}
OrtxStatus LoadAddedTokens(const std::vector<bpe::AddedToken>& added_tokens) {
OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
for (const auto& token : added_tokens) {
added_tokens_.Add(ustring(token.content_), 0, token.id_);
}

Просмотреть файл

@ -4,7 +4,7 @@
#include "sentencepiece_processor.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_tokenizer.hpp"
#include "sentencepiece_tokenizer.h"
#include "string_tensor.h"
#include "base64.h"
#include "narrow.h"

Просмотреть файл

@ -5,11 +5,17 @@
#include <string>
#include "ortx_tokenizer.h"
#include "ext_status.h"
#include "op_def_struct.h"
#include "nlohmann/json_fwd.hpp"
#include "ustring.h"
namespace ort_extensions {
class BpeModel;
namespace bpe {
struct AddedToken final {
uint32_t id_{};
std::string token_type_;
@ -23,7 +29,10 @@ struct AddedToken final {
class TokenJsonConfig; // forward declaration
} // namespace bpe
struct TokenizerDecodingState {
bool f_special_last{};
std::string incomplete_utf8_;
};
constexpr std::string_view spm_escaped_space = "\xE2\x96\x81";
} // namespace ort_extensions

Просмотреть файл

@ -3,14 +3,14 @@
#pragma once
#include "ocos.h"
#include "file_sys.h"
#include "nlohmann/json.hpp"
#include "tokjson_types.h"
#include "tokenizer_common.h"
namespace ort_extensions::bpe {
namespace ort_extensions {
// TokenJsonConfig: Handles loading and parsing of JSON configuration files for tokenizers
class TokenJsonConfig final {
public:
static constexpr const char* kDefaultVocabFile = "tokenizer.json";
@ -26,6 +26,7 @@ class TokenJsonConfig final {
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
}
ortx::path tok_dir(json_path);
ortx::path vocab_path(json_path);
ortx::path tok_path_obj(json_path);
@ -122,4 +123,4 @@ class TokenJsonConfig final {
std::string module_path_;
};
} // namespace ort_extensions::bpe
} // namespace ort_extensions

Просмотреть файл

@ -4,18 +4,18 @@
#include "ocos.h"
#ifdef ENABLE_GPT2_TOKENIZER
#include "bpe_tokenizer.hpp"
#include "bpe_kernels.h"
#include "bpe_tokenizer_model.hpp"
#include "bpe_decoder.hpp"
#endif
#ifdef ENABLE_SPM_TOKENIZER
#include "sentencepiece_tokenizer.hpp"
#include "sentencepiece_tokenizer.h"
#include "sentencepiece_decoder.hpp"
#endif
#ifdef ENABLE_WORDPIECE_TOKENIZER
#include "wordpiece_tokenizer.hpp"
#include "wordpiece_tokenizer.h"
#endif
#ifdef ENABLE_BLINGFIRE

Просмотреть файл

@ -13,7 +13,7 @@
#include <charconv>
#include <optional>
#include "unescape.h"
#include "unescape.hpp"
#include "trietree.hpp"
// This Trie Tree is C++ implementation of
@ -40,6 +40,7 @@ class TrieTokenizer {
private:
std::map<int, std::string> idx2token;
RWKVTrieTree root;
using UnescapeUtils = ort_extensions::UnescapeUtils;
public:
TrieTokenizer(const std::string& text_tokens) {
@ -62,7 +63,7 @@ class TrieTokenizer {
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
std::string x;
int key_length = 0;
if (ort_extensions::UnquoteString(raw, x)) {
if (UnescapeUtils::UnquoteString(raw, x)) {
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
}
if (x.length() != key_length) {

Просмотреть файл

@ -20,10 +20,12 @@
#include "ustring.h"
#include "nlohmann/json.hpp"
#include "trietree.hpp"
#include "bpe_jsoncfg.hpp"
#include "tokenizer_jsconfig.hpp"
namespace ort_extensions {
class SpmUgmDecoder; // forward declaration
struct SpmUgmTokenizer {
using json = nlohmann::json;
using VocabTrieTree = ort_extensions::TrieTree<char, extTokenId_t, -1>;
@ -109,7 +111,7 @@ struct SpmUgmTokenizer {
return {};
}
OrtxStatus Load(const bpe::TokenJsonConfig& config) {
OrtxStatus Load(const TokenJsonConfig& config) {
ortx::path vocab_path(config.GetVocabDataFile());
if (!vocab_path.exists()) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Vocabulary file does not exist.");
@ -417,6 +419,7 @@ struct SpmUgmTokenizer {
}
}
friend class SpmUgmDecoder;
// escaped space symbol - U+2581 (Lower One Eighth Block)
static constexpr double unknown_token_score_penalty_ = 10.0;
@ -454,4 +457,84 @@ struct SpmUgmTokenizer {
std::string unk_token_ = "<unk>";
};
class SpmUgmDecoder {
public:
SpmUgmDecoder() {
}
OrtxStatus Load(const TokenJsonConfig& config, const SpmUgmTokenizer& tokenizer) {
auto vocab_size = tokenizer.vocab_.size();
vocab_.resize(vocab_size);
for (auto iter = tokenizer.vocab_.begin(); iter != tokenizer.vocab_.end(); ++iter) {
vocab_[std::get<0>(iter->second)] = iter->first;
}
unknown_token_ = tokenizer.unk_token_;
special_token_ids_ = tokenizer.special_token_ids_;
tokenizer_add_space_prefix_ = tokenizer.tokenizer_add_space_prefix_;
return {};
}
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids, ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
if (ids_dim.size() > 1) {
output_dim.resize(ids_dim.size() - 1);
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
}
size_t seq_len = ids_dim.back();
size_t string_batch = ids.NumberOfElement() / seq_len;
std::vector<std::string> decoded_strings;
decoded_strings.reserve(string_batch);
const std::string ws = " ";
for (auto n = string_batch; n > 0; n--) {
std::string text;
for (int64_t i = 0; i < seq_len; ++i) {
std::string token;
Id2Token(p_ids[i], token, nullptr);
if (token.find(spm_escaped_space) == 0) {
token = ws + token.substr(spm_escaped_space.length());
}
text += token;
}
if (tokenizer_add_space_prefix_) {
if (text.length() > 0 && text[0] == ' ') {
text = text.substr(1);
}
}
decoded_strings.push_back(text);
}
output.SetStringOutput(decoded_strings, output_dim);
return {};
}
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** /* state */) const {
if (special_token_ids_.count(id)) {
token = "";
return {};
}
if (id >= vocab_.size()) {
token = unknown_token_;
return {};
}
token = vocab_[id];
return {};
}
private:
bool tokenizer_add_space_prefix_ = true;
std::vector<std::string> vocab_;
std::string unknown_token_ = "<unk>";
std::set<extTokenId_t> special_token_ids_;
};
} // namespace ort_extensions

Просмотреть файл

@ -1,161 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ustring.h"
#include <vector>
namespace ort_extensions {
inline bool IsDigit(char c) { return c >= '0' && c <= '9'; }
inline bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
inline unsigned int hex_digit_to_int(char c) {
unsigned int x = static_cast<unsigned char>(c);
if (x > '9') {
x += 9;
}
return x & 0xf;
}
inline bool IsSurrogate(char32_t c) {
return c >= 0xD800 && c <= 0xDFFF;
}
// Unescape a Python escaped string
inline bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
// reserve enough space for the worst case, and final size will be calculated at the end.
unescaped.resize(source.length());
char* d = unescaped.data();
const char* p = source.data();
const char* end = p + source.size();
const char* last_byte = end - 1;
while (p == d && p < end && *p != '\\') p++, d++;
while (p < end) {
if (*p != '\\') {
*d++ = *p++;
} else {
if (++p > last_byte) {
return false;
}
switch (*p) {
case 'n':
*d++ = '\n';
break;
case 'r':
*d++ = '\r';
break;
case 't':
*d++ = '\t';
break;
break;
case '\\':
*d++ = '\\';
break;
case '\'':
*d++ = '\'';
break;
case '"':
*d++ = '\"';
break;
case 'x':
case 'X': {
if (p >= last_byte) {
return false;
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
return false;
}
unsigned int ch = 0;
const char* hex_start = p;
while (p < last_byte &&
IsHexDigit(static_cast<unsigned char>(p[1])))
ch = (ch << 4) + hex_digit_to_int(*++p);
if (ch > 0xFF && !is_bytes) {
return false;
}
if (is_bytes) {
*d++ = static_cast<char>(ch);
} else {
d += ustring::EncodeUTF8Char(d, static_cast<char32_t>(ch));
}
break;
}
case 'u': {
char32_t rune = 0;
const char* hex_start = p;
if (p + 4 >= end) {
return false;
}
for (int i = 0; i < 4; ++i) {
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
rune = (rune << 4) + hex_digit_to_int(*++p);
} else {
return false;
}
}
if (IsSurrogate(rune)) {
return false;
}
d += ustring::EncodeUTF8Char(d, rune);
break;
}
case 'U': {
char32_t rune = 0;
const char* hex_start = p;
if (p + 8 >= end) {
return false;
}
for (int i = 0; i < 8; ++i) {
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
if (newrune > 0x10FFFF) {
return false;
} else {
rune = newrune;
}
} else {
return false;
}
}
if (IsSurrogate(rune)) {
return false;
}
d += ustring::EncodeUTF8Char(d, rune);
break;
}
default: {
return false;
}
}
p++;
}
}
unescaped.resize(d - unescaped.data());
return true;
}
inline bool UnquoteString(const std::string& str, std::string& unquoted) {
bool is_bytes = false;
int idx_0 = 0;
if (str.front() == 'b') {
is_bytes = true;
idx_0 = 1;
}
std::string str_view(str.data() + idx_0, str.length() - idx_0);
if (str_view.length() < 2) {
return false;
}
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
return false;
}
// unescape the string
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
}
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,168 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <string>
#include <string_view>
#include "ustring.h"
namespace ort_extensions {
class UnescapeUtils {
public:
static bool IsDigit(char c) { return c >= '0' && c <= '9'; }
static bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
static unsigned int hex_digit_to_int(char c) {
unsigned int x = static_cast<unsigned char>(c);
if (x > '9') {
x += 9;
}
return x & 0xf;
}
static bool IsSurrogate(char32_t c) {
return c >= 0xD800 && c <= 0xDFFF;
}
// Unescape a Python escaped string
static bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
// reserve enough space for the worst case, and final size will be calculated at the end.
unescaped.resize(source.length());
char* d = unescaped.data();
const char* p = source.data();
const char* end = p + source.size();
const char* last_byte = end - 1;
while (p == d && p < end && *p != '\\') p++, d++;
while (p < end) {
if (*p != '\\') {
*d++ = *p++;
} else {
if (++p > last_byte) {
return false;
}
switch (*p) {
case 'n':
*d++ = '\n';
break;
case 'r':
*d++ = '\r';
break;
case 't':
*d++ = '\t';
break;
break;
case '\\':
*d++ = '\\';
break;
case '\'':
*d++ = '\'';
break;
case '"':
*d++ = '\"';
break;
case 'x':
case 'X': {
if (p >= last_byte) {
return false;
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
return false;
}
unsigned int ch = 0;
const char* hex_start = p;
while (p < last_byte &&
IsHexDigit(static_cast<unsigned char>(p[1])))
ch = (ch << 4) + hex_digit_to_int(*++p);
if (ch > 0xFF && !is_bytes) {
return false;
}
if (is_bytes) {
*d++ = static_cast<char>(ch);
} else {
d += ustring::EncodeUTF8Char(d, static_cast<char32_t>(ch));
}
break;
}
case 'u': {
char32_t rune = 0;
const char* hex_start = p;
if (p + 4 >= end) {
return false;
}
for (int i = 0; i < 4; ++i) {
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
rune = (rune << 4) + hex_digit_to_int(*++p);
} else {
return false;
}
}
if (IsSurrogate(rune)) {
return false;
}
d += ustring::EncodeUTF8Char(d, rune);
break;
}
case 'U': {
char32_t rune = 0;
const char* hex_start = p;
if (p + 8 >= end) {
return false;
}
for (int i = 0; i < 8; ++i) {
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
if (newrune > 0x10FFFF) {
return false;
} else {
rune = newrune;
}
} else {
return false;
}
}
if (IsSurrogate(rune)) {
return false;
}
d += ustring::EncodeUTF8Char(d, rune);
break;
}
default: {
return false;
}
}
p++;
}
}
unescaped.resize(d - unescaped.data());
return true;
}
static bool UnquoteString(const std::string& str, std::string& unquoted) {
bool is_bytes = false;
int idx_0 = 0;
if (str.front() == 'b') {
is_bytes = true;
idx_0 = 1;
}
std::string str_view(str.data() + idx_0, str.length() - idx_0);
if (str_view.length() < 2) {
return false;
}
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
return false;
}
// unescape the string
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
}
};
} // namespace ort_extensions

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "wordpiece_tokenizer.hpp"
#include "wordpiece_tokenizer.h"
#include "nlohmann/json.hpp"
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)

Просмотреть файл

@ -13,7 +13,7 @@ class DetokenizerCache : public OrtxObjectImpl {
DetokenizerCache() : OrtxObjectImpl(extObjectKind_t::kOrtxKindDetokenizerCache) {}
~DetokenizerCache() override = default;
std::unique_ptr<BPEDecoderState> decoder_state_{};
std::unique_ptr<TokenizerDecodingState> decoder_state_{};
std::string last_text_{}; // last detokenized text
};

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "bpe_kernels.h"
#include "bpe_tokenizer.hpp"
#include "bpe_tokenizer_model.hpp"
#include "bpe_decoder.hpp"
#include "ugm_kernels.hpp"
@ -16,7 +16,7 @@ TokenizerImpl::TokenizerImpl()
TokenizerImpl::~TokenizerImpl() {};
OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
tok_config_ = std::make_shared<ort_extensions::bpe::TokenJsonConfig>();
tok_config_ = std::make_shared<ort_extensions::TokenJsonConfig>();
auto status = tok_config_->Load(tok_path);
if (!status.IsOk()) {
return status;
@ -25,8 +25,18 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
if (tok_config_->tokenizer_class_.empty()) {
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
status = tokenizer->Load(*tok_config_);
if (!status.IsOk()) {
return status;
}
auto detok = std::make_unique<SpmUgmDecoder>();
if (status.IsOk()) {
status = detok->Load(*tok_config_, *tokenizer);
}
if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}
return status;
@ -38,14 +48,16 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
auto fx_load = vocab_file_path.extension() == ".json"?
&JsonFastTokenizer::Load: &JsonFastTokenizer::LoadTikTokenBase64;
status = (tokenizer.get()->*fx_load)(*tok_config_);
if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer);
if (!status.IsOk()) {
return status;
}
auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);
if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}
return status;
@ -81,7 +93,9 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
std::transform(s.begin(), s.end(), ids.begin(), [](extTokenId_t v) { return static_cast<int64_t>(v); });
ortc::Tensor<int64_t> ts_input(std::vector<int64_t>{1, static_cast<int64_t>(ids.size())}, (void*)ids.data());
ortc::Tensor<std::string> ts_output;
OrtxStatus status = detokenizer_->Compute(ts_input, ts_output);
OrtxStatus status = std::visit([&](auto& detokenizer) {
return detokenizer->Compute(ts_input, ts_output); }, detokenizer_);
if (!status.IsOk()) {
return status;
}
@ -90,8 +104,9 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
return {};
}
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
return detokenizer_->Id2Token(id, token, state);
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const {
return std::visit([&](auto& detokenizer) {
return detokenizer->Id2Token(id, token, state); }, detokenizer_);
}
static std::map<std::string, std::string> LANGUAGES = {

Просмотреть файл

@ -7,7 +7,7 @@
#include "bpe_kernels.h"
#include "ugm_kernels.hpp"
#include "bpe_jsoncfg.hpp"
#include "tokenizer_jsconfig.hpp"
#include "bpe_streaming.hpp"
#include "c_api_utils.hpp"
@ -34,8 +34,8 @@ class TokenizerImpl : public OrtxObjectImpl {
return {};
}
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
BPEDecoderState* state_ptr = cache.get();
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<TokenizerDecodingState>& cache) const {
TokenizerDecodingState* state_ptr = cache.get();
OrtxStatus status = Id2Token(id, token, &state_ptr);
if (status.IsOk()) {
if (state_ptr != cache.get()) {
@ -51,7 +51,7 @@ class TokenizerImpl : public OrtxObjectImpl {
OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const;
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const;
OrtxStatus GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
std::vector<std::vector<extTokenId_t>>& t_ids) const;
@ -61,8 +61,11 @@ class TokenizerImpl : public OrtxObjectImpl {
using ugm_tokenizer_t = std::unique_ptr<SpmUgmTokenizer>;
std::variant<bpe_tokenizer_t, ugm_tokenizer_t> tokenizer_;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
std::unique_ptr<BpeStreamingDecoder> detokenizer_;
using bpe_decoder_t = std::unique_ptr<BpeStreamingDecoder>;
using ugm_decoder_t = std::unique_ptr<SpmUgmDecoder>;
std::variant<bpe_decoder_t, ugm_decoder_t> detokenizer_;
std::shared_ptr<ort_extensions::TokenJsonConfig> tok_config_;
};
} // namespace ort_extensions

Просмотреть файл

@ -335,7 +335,7 @@ TEST(OrtxTokenizerStreamTest, CodeGenTokenizer) {
EXPECT_EQ(token_ids.size(), 1);
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
// token_ids[0].insert(token_ids[0].begin() + 2, 607); // <0x20>
token_ids[0] = {564, 921, 765, 2130, 588, 262, 6123, 447, 251, 2130, 588, 262};
for (const auto& token_id : token_ids[0]) {
@ -369,7 +369,7 @@ TEST(OrtxTokenizerStreamTest, Llama2Tokenizer) {
DumpTokenIds(token_ids);
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
// std::cout << "\"";
for (const auto& token_id : token_ids[0]) {
std::string token;
@ -406,7 +406,7 @@ TEST(OrtxTokenizerStreamTest, Phi3Tokenizer) {
DumpTokenIds(token_ids);
std::string text;
std::unique_ptr<ort_extensions::BPEDecoderState> decoder_cache;
std::unique_ptr<ort_extensions::TokenizerDecodingState> decoder_cache;
// std::cout << "\"";
for (const auto& token_id : token_ids[0]) {
std::string token;
@ -464,4 +464,20 @@ TEST(OrtxTokenizerTest, SpmUgmTokenizer) {
// AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
EXPECT_EQ(ids_vec, std::vector<extTokenId_t>({
0, 87, 1884, 122395, 759, 99942, 10269, 136, 7068, 4, 6, 62668, 5364, 245875, 354, 11716, 2}));
OrtxObjectPtr<OrtxStringArray> decoded_text;
OrtxDetokenize(tokenizer.get(), token_ids.get(), ort_extensions::ptr(decoded_text));
EXPECT_EQ(decoded_text.Code(), kOrtxOK);
const char* text = nullptr;
OrtxStringArrayGetItem(decoded_text.get(), 0, &text);
// because the tokenization remove the character from the string, the decoded text is not the same as the input text.
std::string filtered_text(input[0]);
filtered_text.erase(std::remove_if(
filtered_text.begin(), filtered_text.end(), [](unsigned char chr){ return chr < 0x20; }), filtered_text.end());
// remove the consecutive spaces
filtered_text.erase(std::unique(filtered_text.begin(), filtered_text.end(),
[](char lhs, char rhs) { return lhs == ' ' && rhs == ' '; }), filtered_text.end());
EXPECT_STREQ(filtered_text.c_str(), text);
}

Просмотреть файл

@ -3,7 +3,7 @@
#include "gtest/gtest.h"
#include "string_utils.h"
#include "wordpiece_tokenizer.hpp"
#include "wordpiece_tokenizer.h"
#include "bert_tokenizer.hpp"
#include <clocale>