Support 'added_token' attribute for BPE tokenizer and some code refactoring. (#591)

* Fix CodeGenTokenizer issues and the related code refactoring.

* refactor the trie-tree

* temp check-ins

* code complete

* correctness fixing

* Update _hf_cvt.py

* more test cases fixing

* more refinement

* linux crash fixing

* Update test_autotokenizer.py
This commit is contained in:
Wenbing Li 2023-11-04 22:56:26 -07:00 коммит произвёл GitHub
Родитель e951e72a85
Коммит d1148aea4e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 411 добавлений и 223 удалений

Просмотреть файл

@ -2,7 +2,6 @@
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <vector>
#include <string_view>
@ -85,7 +84,7 @@ class ustring : public std::u32string {
using u32string = std::u32string;
static u32string FromUTF8(const std::string_view& utf8) {
u32string ucs32;
ucs32.reserve(utf8.length() / 2); // a rough estimation for less memory allocation.
ucs32.reserve(utf8.length() / 2); // a rough estimation for less memory allocation.
for (size_t i = 0; i < utf8.size();) {
char32_t codepoint = 0;
if ((utf8[i] & 0x80) == 0) {

Просмотреть файл

@ -30,7 +30,7 @@ class API {
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
public:
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
static API self(*ort_api);
static API self(ort_api);
return self;
}
@ -54,15 +54,15 @@ class API {
return &api_;
}
API(const OrtApi& api) : api_(api) {
if (&api == nullptr) {
API(const OrtApi* api) : api_(*api) {
if (api == nullptr) {
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
}
}
const OrtApi& api_;
};
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
@ -107,21 +107,32 @@ inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
return API::CreateStatus(code, msg);
}
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
return API::CreateStatus(code, msg.c_str());
}
inline void ReleaseStatus(OrtStatusPtr& status) {
API::ReleaseStatus(status);
status = nullptr;
}
} // namespace OrtW
#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
return _status; \
} \
} while (0)
namespace Ort {
namespace Custom {
#ifdef USE_CUDA
///////////////////////////////////////////////////////////////////////////
// TODO: include the definition from the header file in ONNXRuntime
struct CudaContext {};
struct CudaContext {};
#endif // USE_CUDA

Просмотреть файл

@ -23,36 +23,28 @@ class HFTokenizerConverter(CustomOpConverter):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def bpe_tokenizer(self, **kwargs):
hf_gpt2_tokenizer = self.tokenizer
attrs = None
@staticmethod
def convert_bpe_vocab(hf_tokenizer):
attrs = {'vocab': json.dumps(
hf_tokenizer.encoder, separators=(',', ':'))}
if hf_tokenizer.added_tokens_encoder:
# ids = sorted(hf_tokenizer.added_tokens_encoder.values())
# if not ids == list(range(min(ids), max(ids) + 1)):
# raise RuntimeError(f"{hf_tokenizer.__name__}: the ids in added_tokens_encoder are not consecutive")
token_map = [f"{_k}={_v}" for _k, _v in hf_tokenizer.added_tokens_encoder.items()]
attrs.update({"added_token": "\n".join(token_map)})
if type(self.tokenizer).__name__.endswith('Fast'):
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
elif(self.tokenizer.name_or_path.endswith('gpt-4')):
# Fill vocab gap for GPT4Tokenizer to create continuous domain
vocab_dict = hf_gpt2_tokenizer.encoder
partial_values = list(vocab_dict.values())
max_vocab = partial_values[-1]
all_values = np.arange(max_vocab + 1)
missing_values = set(all_values) - set(partial_values)
for v in missing_values:
vocab_dict[str(uuid.uuid4())] = int(v)
vocab_dict = dict(sorted(vocab_dict.items(), key=lambda item: item[1]))
attrs = {'vocab': json.dumps(
vocab_dict, separators=(',', ':'))}
else:
attrs = {'vocab': json.dumps(
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
sorted_merges = {v_: k_ for k_, v_ in hf_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
return attrs
def bpe_tokenizer(self, **kwargs):
hf_gpt2_tokenizer = self.tokenizer
if type(self.tokenizer).__name__.endswith('Fast'):
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
attrs = self.convert_bpe_vocab(hf_gpt2_tokenizer)
attrs.update(**kwargs)
return attrs
@ -101,12 +93,7 @@ class HFTokenizerConverter(CustomOpConverter):
if type(self.tokenizer).__name__.endswith('Fast'):
raise ValueError('Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
attrs = {'vocab': json.dumps(
hf_clip_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
v_ in hf_clip_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
attrs = self.convert_bpe_vocab(hf_clip_tokenizer)
attrs.update(**kwargs)
return attrs
@ -116,12 +103,7 @@ class HFTokenizerConverter(CustomOpConverter):
if type(self.tokenizer).__name__.endswith('Fast'):
raise ValueError('Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
attrs = {'vocab': json.dumps(
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
attrs = self.convert_bpe_vocab(hf_roberta_tokenizer)
attrs.update(**kwargs)
return attrs

Просмотреть файл

@ -16,19 +16,27 @@
#include <algorithm>
#include <sstream>
struct KernelBpeDecoder : public BaseKernel {
struct KernelBpeDecoder {
public:
KernelBpeDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "id_vocab");
if (vocab.empty()) {
ORTX_CXX_API_THROW("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT);
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
std::string vocab;
OrtStatusPtr status = OrtW::GetOpAttribute(info, "id_vocab", vocab);
if (status != nullptr || vocab.empty()) {
if (status == nullptr) {
status = OrtW::CreateStatus("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT);
}
return status;
}
BuildIdVocab(vocab);
std::string byte_decoder = ort_.KernelInfoGetAttribute<std::string>(&info, "byte_decoder");
if (byte_decoder.empty()) {
ORTX_CXX_API_THROW("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT);
std::string byte_decoder;
status = OrtW::GetOpAttribute(info, "byte_decoder", byte_decoder);
if (status != nullptr || byte_decoder.empty()) {
if (status == nullptr) {
status = OrtW::CreateStatus("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT);
}
return status;
} else {
auto um = ParseId2String(byte_decoder);
std::transform(um.begin(), um.end(),
@ -37,13 +45,15 @@ struct KernelBpeDecoder : public BaseKernel {
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
}
std::string added_tokens = TryToGetAttributeWithDefault<std::string>("added_tokens", "");
std::string added_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens));
if (!added_tokens.empty()) {
auto um = ParseId2String(added_tokens);
added_tokens_ = std::map<int64_t, std::string>(um.begin(), um.end());
}
std::string all_special_ids = TryToGetAttributeWithDefault<std::string>("all_special_ids", "");
std::string all_special_ids;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
if (!all_special_ids.empty()) {
auto um = ParseId2String(all_special_ids);
std::transform(um.begin(), um.end(),
@ -51,12 +61,14 @@ struct KernelBpeDecoder : public BaseKernel {
[](const auto& p) { return p.first; });
}
en_normalization_ = TryToGetAttributeWithDefault<int64_t>("en_normalization", 0);
skip_special_tokens_ = TryToGetAttributeWithDefault<int64_t>("skip_special_tokens", 0);
whitespace_token_ = TryToGetAttributeWithDefault<int64_t>("whitespace_token", 0);
bos_token_ = TryToGetAttributeWithDefault("bos_token", std::string("<|endoftext|>"));
eos_token_ = TryToGetAttributeWithDefault("eos_token", std::string("<|endoftext|>"));
unk_token_ = TryToGetAttributeWithDefault("unk_token", std::string("<|endoftext|>"));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_));
return status;
}
std::unordered_map<int64_t, std::string> ParseId2String(const std::string& s_attr) {
@ -102,8 +114,8 @@ struct KernelBpeDecoder : public BaseKernel {
arr_vocab_.shrink_to_fit();
}
void Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
OrtStatusPtr Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) const {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
@ -168,12 +180,13 @@ struct KernelBpeDecoder : public BaseKernel {
p_ids += seq_len;
}
output.SetStringOutput(decoded_strings, output_dim);
return nullptr;
}
private:
std::string bos_token_;
std::string eos_token_;
std::string unk_token_;
std::string bos_token_{"<|endoftext|>"};
std::string eos_token_{"<|endoftext|>"};
std::string unk_token_{"<|endoftext|>"};
// Since ORT API doesn't support boolean type in ONNX node attribute,
// all flag attributes here are defined as int64 type to be more explicit.

Просмотреть файл

@ -6,6 +6,8 @@
#include <optional>
using namespace ort_extensions;
std::string BpeModelConf::GetSpecialTokens() const {
std::string special_tokens = unk_token_; // unk_token_ is required
auto add_token = [](std::string& sp, const char* tok) {
@ -87,43 +89,53 @@ ustring RemoveConsecutiveSpaces(const ustring& input) {
return result;
}
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info, const BpeModelConf& conf)
: BaseKernel(api, info),
bpe_conf_(conf) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf)
: bpe_conf_(conf){};
OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
std::string vocab;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab));
if (vocab.empty()) {
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
return OrtW::CreateStatus("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
std::string merges;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges));
if (merges.empty()) {
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
padding_length_ = -1;
return OrtW::CreateStatus("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_));
if (padding_length_ != -1 && padding_length_ <= 0) {
ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
return OrtW::CreateStatus("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
}
std::stringstream vocabu_stream(vocab);
std::stringstream merges_stream(merges);
bbpe_tokenizer_ = std::make_unique<BpeModel>();
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, conf.unk_token_, conf.GetSpecialTokens().c_str());
auto status = bbpe_tokenizer_->Load(vocabu_stream, merges_stream, bpe_conf_.unk_token_, bpe_conf_.GetSpecialTokens().c_str());
if (status != nullptr) {
return status;
}
std::string added_token;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
ORTX_RETURN_IF_ERROR(bbpe_tokenizer_->LoadAddedTokens(added_token.c_str()));
// TODO: need to check if the special token ids are the same as the ones in HFTokenizer
unk_token_id_ = bbpe_tokenizer_->GetTokenId(conf.unk_token_);
if (conf.bos_token_ != nullptr) {
bos_token_id_ = bbpe_tokenizer_->GetTokenId(conf.bos_token_);
unk_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.unk_token_);
if (bpe_conf_.bos_token_ != nullptr) {
bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.bos_token_);
}
if (conf.eos_token_ != nullptr) {
eos_token_id_ = bbpe_tokenizer_->GetTokenId(conf.eos_token_);
if (bpe_conf_.eos_token_ != nullptr) {
eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.eos_token_);
}
if (conf.pad_token_ != nullptr) {
pad_token_id_ = bbpe_tokenizer_->GetTokenId(conf.pad_token_);
if (bpe_conf_.pad_token_ != nullptr) {
pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.pad_token_);
}
return nullptr;
}
std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
@ -171,21 +183,20 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}
// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitBySpecialTokens(input);
TokenWithRegularExp regcmp;
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
bpe::TokenWithRegularExp regcmp;
for (auto& seg_id : special_token_split_res) {
if (static_cast<int64_t>(res.size()) >= max_length) break;
if (seg_id.second != -1) {
if (seg_id.second != bpe::kInvalidTokenId) {
res.push_back(seg_id.second);
continue;
}
auto cur_input = std::move(seg_id.first);
// Note: keep ptr to make sure the string_view is valid in the following process
const char32_t* ptr = cur_input.c_str();
regcmp.Set(ptr);
std::u32string str(seg_id.first);
regcmp.Set(str.c_str());
size_t offset = 0;
OffsetMappingType offset_mapping;
@ -199,7 +210,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
while (static_cast<int64_t>(res.size()) < max_length) {
auto [b, tok] = regcmp.GetNextToken();
if (!b) break;
std::string utf8_token = std::string(ustring(tok));
@ -271,13 +282,14 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
// Add EOS token to result
res.push_back(eos_token_id_);
}
return res;
}
void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
// Setup inputs
std::vector<std::string> str_input{input.Data()};
std::list<OffsetMappingType> offset_map;
@ -356,11 +368,13 @@ void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
idx++;
}
}
return nullptr;
}
static const auto kGPT2Confinguration = BpeModelConf();
GPT2Tokenizer::GPT2Tokenizer(const OrtApi& api, const OrtKernelInfo& info)
: KernelBpeTokenizer(api, info, kGPT2Confinguration) {}
GPT2Tokenizer::GPT2Tokenizer()
: KernelBpeTokenizer(kGPT2Confinguration) {}
static const auto kRobertaConfiguration = BpeModelConf{
BpeModelConf::kModel_Roberta, // name
@ -369,8 +383,8 @@ static const auto kRobertaConfiguration = BpeModelConf{
"</s>", // eos_token
"<pad>"}; // pad_token
RobertaTokenizer::RobertaTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: KernelBpeTokenizer(api, info, kRobertaConfiguration) {}
RobertaTokenizer::RobertaTokenizer()
: KernelBpeTokenizer(kRobertaConfiguration) {}
static const auto kCLIPConfiguration = BpeModelConf{
BpeModelConf::kModel_CLIP, // name
@ -379,5 +393,5 @@ static const auto kCLIPConfiguration = BpeModelConf{
"<|endoftext|>", // eos_token
"<|endoftext|>"}; // pad_token
CLIPTokenizer::CLIPTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: KernelBpeTokenizer(api, info, kCLIPConfiguration) {}
CLIPTokenizer::CLIPTokenizer()
: KernelBpeTokenizer(kCLIPConfiguration) {}

Просмотреть файл

@ -23,15 +23,18 @@ struct BpeModelConf {
std::string GetSpecialTokens() const;
};
namespace ort_extensions {
class BpeModel;
}
struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info, const BpeModelConf& conf);
struct KernelBpeTokenizer {
KernelBpeTokenizer(const BpeModelConf& conf);
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
const char* ModelName() const { return bpe_conf_.name_; }
@ -41,10 +44,12 @@ struct KernelBpeTokenizer : BaseKernel {
int64_t max_length,
bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const;
int64_t padding_length_;
std::unique_ptr<BpeModel> bbpe_tokenizer_;
const BpeModelConf& bpe_conf_;
private:
const BpeModelConf& bpe_conf_;
std::unique_ptr<ort_extensions::BpeModel> bbpe_tokenizer_;
int64_t padding_length_ = -1;
uint32_t unk_token_id_{};
uint32_t bos_token_id_{};
uint32_t eos_token_id_{};
@ -52,34 +57,34 @@ struct KernelBpeTokenizer : BaseKernel {
};
struct GPT2Tokenizer : KernelBpeTokenizer {
GPT2Tokenizer(const OrtApi& api, const OrtKernelInfo& info);
GPT2Tokenizer();
// required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler.
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};
struct RobertaTokenizer : KernelBpeTokenizer {
RobertaTokenizer(const OrtApi& api, const OrtKernelInfo& info);
RobertaTokenizer();
// required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler.
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};
struct CLIPTokenizer : KernelBpeTokenizer {
CLIPTokenizer(const OrtApi& api, const OrtKernelInfo& info);
CLIPTokenizer();
// required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler.
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
};

Просмотреть файл

@ -12,18 +12,23 @@
#include <unordered_map>
#include <iostream>
#include <utility>
#include <charconv>
#include <limits>
#include "nlohmann/json.hpp"
#include "bpe_utils.hpp"
#include "trietree.hpp"
namespace ort_extensions {
class BpeModel {
public:
BpeModel() = default;
void Load(std::istream& vocab_stream,
std::istream& merges_stream,
const char* unk_token,
const char* special_tokens) {
OrtStatusPtr Load(std::istream& vocab_stream,
std::istream& merges_stream,
const char* unk_token,
const char* special_tokens) {
nlohmann::json tok_json;
vocab_stream >> tok_json;
vocab_map_ = std::move(tok_json.get<std::unordered_map<std::string, uint32_t>>());
@ -34,6 +39,7 @@ class BpeModel {
} else {
auto id = ort_extensions::narrow<uint32_t>(vocab_map_.size());
vocab_map_[unk_token] = id;
unk_id_ = id;
}
CreateByteEncoder();
@ -46,7 +52,7 @@ class BpeModel {
if ((line[0] == '#') && (index == 0)) continue;
auto pos = line.find(' ');
if (pos == std::string::npos) {
ORTX_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
return OrtW::CreateStatus("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
}
std::string w1 = line.substr(0, pos);
std::string w2 = line.substr(pos + 1);
@ -54,9 +60,9 @@ class BpeModel {
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
token_length -= 4;
}
auto iw1 = GetVocabIndex(w1);
auto iw2 = GetVocabIndex(w2);
auto iww = GetVocabIndex(w1 + w2);
auto iw1 = GetTokenId(w1);
auto iw2 = GetTokenId(w2);
auto iww = GetTokenId(w1 + w2);
BpeNode value{iww, index++, token_length};
bpe_rank_[GetRankKey(iw1, iw2)] = value;
}
@ -80,8 +86,62 @@ class BpeModel {
id2token_map_.resize(vocab_map_.size());
for (const auto& [t, i] : vocab_map_) {
if (i > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
continue; // safe purpose.
}
if (i > id2token_map_.size()) {
id2token_map_.resize(i + 1);
}
id2token_map_[i] = t;
}
return nullptr;
}
OrtStatusPtr LoadAddedTokens(const char* added_tokens) {
int id = bpe::kInvalidTokenId;
std::istringstream strm_tokens(added_tokens);
std::string line;
while (!strm_tokens.eof()) {
std::getline(strm_tokens, line);
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
if (line.empty()) continue;
// seperate the key and value by =
auto pos = line.rfind("=");
if (pos == std::string::npos) {
return OrtW::CreateStatus("Error on parse a added_token line: " + line, ORT_INVALID_ARGUMENT);
}
auto token = line.substr(0, pos);
auto id_str = line.substr(pos + 1); // 1 is the length of "="
auto [ptr, ec] = std::from_chars(id_str.data(), id_str.data() + id_str.length(), id);
if (ec != std::errc()) {
return OrtW::CreateStatus("Cannot convert to an integer from " + id_str, ORT_INVALID_ARGUMENT);
}
added_tokens_.Add(ustring(token), 0, std::make_optional(id));
}
return nullptr;
}
// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
// split by added tokens
bpe::TokenPairs added_result;
bpe::TokenPairs final_result;
added_tokens_.Split(input, added_result);
for (const auto& [token, id] : added_result) {
if (id != bpe::kInvalidTokenId) {
final_result.emplace_back(token, id);
} else {
auto special_result = special_tokens_.SplitBySpecialTokens(token);
for (const auto& [token, id] : special_result) {
final_result.emplace_back(token, id);
}
}
}
return final_result;
}
void bpe(std::list<std::pair<uint32_t, uint32_t>>& vals) const {
@ -94,9 +154,15 @@ class BpeModel {
for (auto it = vals.begin(); it != vals.end(); ++it) {
auto it2 = it;
++it2;
if (it2 == vals.end()) break;
if (it2 == vals.end()) {
break;
}
auto map_it = bpe_rank_.find(GetRankKey(it->first, it2->first));
if (map_it == bpe_rank_.end()) continue;
if (map_it == bpe_rank_.end()) {
continue;
}
if (minval > map_it->second.value) {
ori_id1 = it->first;
ori_id2 = it2->first;
@ -105,7 +171,10 @@ class BpeModel {
aim_id = map_it->second.id;
}
}
if (pos_it == vals.end()) break;
if (pos_it == vals.end()) {
break;
}
token_length = pos_it->second;
pos_it = vals.erase(pos_it);
@ -129,11 +198,6 @@ class BpeModel {
return byte_encoder_;
}
auto SplitBySpecialTokens(const ustring& input) const {
return special_tokens_.SplitBySpecialTokens(input);
}
// Returns token if key was found in vocab, and unk_id_ otherwise
uint32_t GetTokenId(const std::string& key) {
auto it = vocab_map_.find(key);
if (it != end(vocab_map_)) {
@ -163,21 +227,13 @@ class BpeModel {
)
*/
if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) {
byte_encoder_[i] = GetVocabIndex(ustring::EncodeUTF8Char(index++));
byte_encoder_[i] = GetTokenId(ustring::EncodeUTF8Char(index++));
} else {
byte_encoder_[i] = GetVocabIndex(ustring::EncodeUTF8Char(i));
byte_encoder_[i] = GetTokenId(ustring::EncodeUTF8Char(i));
}
}
}
uint32_t GetVocabIndex(const std::string& str) {
auto it = vocab_map_.find(str);
if (it == vocab_map_.end()) {
ORTX_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT);
}
return it->second;
}
private:
std::map<uint64_t, BpeNode> bpe_rank_;
@ -186,5 +242,8 @@ class BpeModel {
std::vector<std::string> id2token_map_;
uint32_t unk_id_ = std::numeric_limits<uint32_t>::max();
SpecialTokenMap special_tokens_;
bpe::SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
};
} // namespace ort_extensions

Просмотреть файл

@ -6,35 +6,43 @@
#include "ocos.h"
#include "narrow.h"
#include <cassert>
#include <algorithm>
#include "ustring.h"
#include "unicode.h"
namespace ort_extensions {
namespace bpe {
using TokenPairs = std::vector<std::pair<std::u32string_view, int>>;
using u32string_view = std::u32string_view;
constexpr int kInvalidTokenId = -1;
class SpecialTokenMap {
public:
void Add(ustring p_str, int p_id) {
auto it = token_map_.find(p_str);
if (it != token_map_.end()) {
if (it->second != p_id) {
ORTX_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT);
}
assert(it->second == p_id && "Duplicate special tokens.");
} else {
token_map_[p_str] = p_id;
token_list_.push_back(SpecialTokenInfo(std::move(p_str), p_id));
}
}
std::vector<std::pair<ustring, int>> SplitBySpecialTokens(ustring input) const {
std::vector<std::pair<ustring, int>> res;
res.emplace_back(std::move(input), -1);
TokenPairs SplitBySpecialTokens(const std::u32string_view& input) const {
TokenPairs res;
res.emplace_back(input, kInvalidTokenId);
for (const auto& st : token_list_) {
std::vector<std::pair<ustring, int>> new_split_res;
TokenPairs new_split_res;
for (auto& str : res) {
if (str.second != -1) {
new_split_res.push_back(std::move(str));
if (str.second != kInvalidTokenId) {
new_split_res.emplace_back(str);
continue;
}
auto it = str.first.begin();
size_t search_pos = 0;
while (it != str.first.end()) {
@ -46,21 +54,27 @@ class SpecialTokenMap {
std::boyer_moore_searcher(st.str.begin(), st.str.end()));
#endif
if (search_it == str.first.end()) {
new_split_res.emplace_back(str.first.substr(search_pos), -1);
new_split_res.emplace_back(u32string_view(
str.first.data() + search_pos, str.first.size() - search_pos),
kInvalidTokenId);
break;
}
auto prefixLen = search_it - it;
if (prefixLen != 0) {
new_split_res.emplace_back(str.first.substr(search_pos, prefixLen), -1);
new_split_res.emplace_back(u32string_view(str.first.data() + search_pos, prefixLen), kInvalidTokenId);
search_pos += prefixLen;
}
new_split_res.emplace_back(str.first.substr(search_pos, st.str.size()), st.id);
new_split_res.emplace_back(u32string_view(str.first.data() + search_pos, st.str.size()), st.id);
it = search_it + st.str.size();
search_pos += st.str.size();
}
}
std::swap(new_split_res, res);
}
return res;
}
@ -101,7 +115,6 @@ class TokenWithRegularExp {
private:
std::u32string_view TryMatch() {
// python pattern:
// 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
@ -198,8 +211,7 @@ class TokenWithRegularExp {
for (; i < m_text.size(); ++i) {
if (!IsZ(m_text[i])) break;
}
if ((i > 1) && (i != m_text.size())) //\s+(?!\S)
{
if ((i > 1) && (i != m_text.size())) { //\s+(?!\S)
i--;
std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
@ -230,3 +242,6 @@ class TokenWithRegularExp {
private:
std::u32string_view m_text;
};
} // namespace bpe
} // namespace ort_extensions

Просмотреть файл

@ -35,10 +35,10 @@
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
#ifdef ENABLE_GPT2_TOKENIZER
CustomCpuStruct("GPT2Tokenizer", GPT2Tokenizer),
CustomCpuStruct("CLIPTokenizer", CLIPTokenizer),
CustomCpuStruct("RobertaTokenizer", RobertaTokenizer),
CustomCpuStruct("BpeDecoder", KernelBpeDecoder),
CustomCpuStructV2("GPT2Tokenizer", GPT2Tokenizer),
CustomCpuStructV2("CLIPTokenizer", CLIPTokenizer),
CustomCpuStructV2("RobertaTokenizer", RobertaTokenizer),
CustomCpuStructV2("BpeDecoder", KernelBpeDecoder),
#endif
#ifdef ENABLE_SPM_TOKENIZER

Просмотреть файл

@ -14,66 +14,32 @@
#include <optional>
#include "unescape.h"
#include "trietree.hpp"
// This Trie Tree is C++ implementation of
// https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/rwkv_tokenizer.py
// Perf optimized by leveraging C++ features, but the algorithm is the same.
class TrieTree {
class RWKVTrieTree : public ort_extensions::TrieTree<char> {
public:
static constexpr int kMaxTokenLength_ = 128;
TrieTree(unsigned char ch = 0) : ch_(ch), to_(256) {}
RWKVTrieTree(char ch = 0) : TrieTree(ch) {}
// keep the same function for source code understanding.
void add(const std::string& key, int idx = 0,
std::optional<int> value = std::optional<int>()) {
if (idx == key.length()) {
if (!value) {
value = key[0];
}
value_ = value;
return;
}
unsigned char ch = static_cast<unsigned char>(key[idx]);
if (to_[ch] == nullptr) {
to_[ch] = std::make_unique<TrieTree>(ch);
}
to_[ch]->add(key, idx + 1, value);
Add(key, idx, value);
}
int find_longest(const std::string& key, size_t& idx) {
const TrieTree* u = this;
unsigned char ch = key[idx];
int tok_id = 0;
size_t idx_end = idx;
while (u->to_[ch]) {
u = u->to_[ch].get();
idx += 1;
if (u->value_) {
tok_id = *u->value_;
idx_end = idx;
}
if (idx == key.length()) {
break;
}
ch = key[idx];
}
idx = idx_end;
return tok_id;
return FindLongest(key, idx);
}
private:
std::vector<std::unique_ptr<TrieTree>> to_;
std::optional<int> value_;
unsigned char ch_;
};
class TrieTokenizer {
private:
std::map<int, std::string> idx2token;
TrieTree root;
RWKVTrieTree root;
public:
TrieTokenizer(const std::string& text_tokens) {
@ -210,7 +176,7 @@ struct KernelTrieDetokenizer : public BaseKernel {
if (ustring::ValidateUTF8(raw_string)) {
output[n] = raw_string;
} else {
output[n] = "\ufffd"; // bad utf-8 string
output[n] = "\ufffd"; // bad utf-8 string
failed = true;
}
}

Просмотреть файл

@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "narrow.h"
#include <vector>
#include <set>
#include <map>
#include <string>
#include <optional>
namespace ort_extensions {
template <typename CharT, typename ValueT = int>
class TrieTree {
public:
static constexpr int kMaxTokenLength_ = 128;
TrieTree(CharT ch = 0, ValueT invalid_id = -1) : ch_(ch), invalid_id_(invalid_id) {}
void Add(const std::basic_string<CharT>& key, int idx = 0,
const std::optional<ValueT>& value = std::nullopt) noexcept {
if (idx == key.length()) {
if (!value) {
value_ = std::make_optional(narrow<ValueT>(key[0]));
} else {
value_ = value;
}
} else {
auto ch = key[idx];
if (to_.count(ch) == 0) {
to_[ch] = std::make_unique<TrieTree>(ch);
}
to_[ch]->Add(key, idx + 1, value);
}
}
ValueT FindLongest(const std::basic_string<CharT>& key, size_t& idx) const noexcept {
const TrieTree* u = this;
CharT ch = key[idx];
ValueT tok_id = invalid_id_;
size_t idx_end = idx;
while (u->to_.count(ch)) {
u = u->to_.at(ch).get();
idx += 1;
if (u->value_) {
tok_id = *u->value_;
idx_end = idx;
}
if (idx == key.length()) {
break;
}
ch = key[idx];
}
idx = idx_end;
return tok_id;
}
int Split(const std::basic_string<CharT>& input,
std::vector<std::pair<std::basic_string_view<CharT>, ValueT>>& tokens) const noexcept {
size_t seg_idx = 0;
size_t tok_idx = 0;
while (tok_idx < input.length()) {
// variable u is the tree root.
const TrieTree* u = this;
auto ch = input[tok_idx];
size_t tok_len = 0;
size_t idx_end = tok_idx;
ValueT tok_id = invalid_id_;
// try to match a longest token
while (u->to_.count(ch)) {
tok_len += 1;
u = u->to_.at(ch).get();
if (u->value_) {
tok_id = *u->value_;
idx_end = tok_idx + 1;
}
tok_idx += 1;
if (tok_idx == input.length()) {
break;
}
ch = input[tok_idx];
}
tok_idx += 1;
if (tok_id == invalid_id_) {
if (tok_idx < input.length()) {
continue;
} else {
tok_idx += 1; // Assign tok_idx to input.length()
idx_end = tok_idx;
}
}
auto token_begin_idx = tok_idx - tok_len - 1; // since the tok_idx already moved forward by 1
tok_len = idx_end - token_begin_idx;
if (token_begin_idx > seg_idx || tok_len == 0) {
tokens.emplace_back(std::basic_string_view<CharT>(input.data() + seg_idx, token_begin_idx - seg_idx),
invalid_id_);
}
if (tok_id != invalid_id_) {
tokens.emplace_back(std::basic_string_view<CharT>(input.data() + token_begin_idx, tok_len), tok_id);
tok_idx = idx_end;
}
// reset state for next match
seg_idx = tok_idx;
}
return 0;
}
private:
std::map<CharT, std::unique_ptr<TrieTree>> to_;
std::optional<ValueT> value_;
const CharT ch_;
const ValueT invalid_id_;
};
} // namespace ort_extensions

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import pkg_resources
import numpy as np
from transformers import AutoTokenizer, GPT2Tokenizer
@ -76,7 +75,7 @@ class TestAutoTokenizer(unittest.TestCase):
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids, actual_ids)
def test_gpt2_tokenizer(self):
tokenizer = GPT2Tokenizer.from_pretrained("Xenova/gpt-4", use_fast=False)
text = "Testing words with apostrophes such as you're, i'm, don't, etc."
@ -96,7 +95,7 @@ class TestAutoTokenizer(unittest.TestCase):
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
ids = tokenizer.encode(text, return_tensors="np")
ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True})
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True})
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids)
@ -124,8 +123,7 @@ class TestAutoTokenizer(unittest.TestCase):
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [code])
self.assertEqual(len(ids['input_ids'].shape), len(actual_ids.shape))
# TODO: not matched.
# np.testing.assert_array_equal(ids['input_ids'], actual_ids)
np.testing.assert_array_equal(ids['input_ids'], actual_ids)
if __name__ == '__main__':