Support 'added_token' attribute for BPE tokenizer and some code refactoring. (#591)
* Fix CodeGenTokenizer issues and the related code refactoring. * refactor the trie-tree * temp check-ins * code complete * correctness fixing * Update _hf_cvt.py * more test cases fixing * more refinement * linux crash fixing * Update test_autotokenizer.py
This commit is contained in:
Родитель
e951e72a85
Коммит
d1148aea4e
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include <vector>
|
||||
#include <string_view>
|
||||
|
||||
|
@ -85,7 +84,7 @@ class ustring : public std::u32string {
|
|||
using u32string = std::u32string;
|
||||
static u32string FromUTF8(const std::string_view& utf8) {
|
||||
u32string ucs32;
|
||||
ucs32.reserve(utf8.length() / 2); // a rough estimation for less memory allocation.
|
||||
ucs32.reserve(utf8.length() / 2); // a rough estimation for less memory allocation.
|
||||
for (size_t i = 0; i < utf8.size();) {
|
||||
char32_t codepoint = 0;
|
||||
if ((utf8[i] & 0x80) == 0) {
|
||||
|
|
|
@ -30,7 +30,7 @@ class API {
|
|||
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
|
||||
public:
|
||||
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
|
||||
static API self(*ort_api);
|
||||
static API self(ort_api);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
@ -54,15 +54,15 @@ class API {
|
|||
return &api_;
|
||||
}
|
||||
|
||||
API(const OrtApi& api) : api_(api) {
|
||||
if (&api == nullptr) {
|
||||
API(const OrtApi* api) : api_(*api) {
|
||||
if (api == nullptr) {
|
||||
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
|
||||
const OrtApi& api_;
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
|
||||
|
@ -107,21 +107,32 @@ inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
|
|||
return API::CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
|
||||
return API::CreateStatus(code, msg.c_str());
|
||||
}
|
||||
|
||||
inline void ReleaseStatus(OrtStatusPtr& status) {
|
||||
API::ReleaseStatus(status);
|
||||
status = nullptr;
|
||||
}
|
||||
|
||||
|
||||
} // namespace OrtW
|
||||
|
||||
#define ORTX_RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
auto _status = (expr); \
|
||||
if (_status != nullptr) { \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
||||
#ifdef USE_CUDA
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// TODO: include the definition from the header file in ONNXRuntime
|
||||
struct CudaContext {};
|
||||
struct CudaContext {};
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
||||
|
|
|
@ -23,36 +23,28 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
attrs = None
|
||||
@staticmethod
|
||||
def convert_bpe_vocab(hf_tokenizer):
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_tokenizer.encoder, separators=(',', ':'))}
|
||||
if hf_tokenizer.added_tokens_encoder:
|
||||
# ids = sorted(hf_tokenizer.added_tokens_encoder.values())
|
||||
# if not ids == list(range(min(ids), max(ids) + 1)):
|
||||
# raise RuntimeError(f"{hf_tokenizer.__name__}: the ids in added_tokens_encoder are not consecutive")
|
||||
token_map = [f"{_k}={_v}" for _k, _v in hf_tokenizer.added_tokens_encoder.items()]
|
||||
attrs.update({"added_token": "\n".join(token_map)})
|
||||
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
|
||||
elif(self.tokenizer.name_or_path.endswith('gpt-4')):
|
||||
# Fill vocab gap for GPT4Tokenizer to create continuous domain
|
||||
vocab_dict = hf_gpt2_tokenizer.encoder
|
||||
partial_values = list(vocab_dict.values())
|
||||
|
||||
max_vocab = partial_values[-1]
|
||||
all_values = np.arange(max_vocab + 1)
|
||||
|
||||
missing_values = set(all_values) - set(partial_values)
|
||||
|
||||
for v in missing_values:
|
||||
vocab_dict[str(uuid.uuid4())] = int(v)
|
||||
|
||||
vocab_dict = dict(sorted(vocab_dict.items(), key=lambda item: item[1]))
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
vocab_dict, separators=(',', ':'))}
|
||||
else:
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
|
||||
sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
sorted_merges = {v_: k_ for k_, v_ in hf_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
return attrs
|
||||
|
||||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: GPT2Tokenizer).')
|
||||
|
||||
attrs = self.convert_bpe_vocab(hf_gpt2_tokenizer)
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
|
@ -101,12 +93,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_clip_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_clip_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs = self.convert_bpe_vocab(hf_clip_tokenizer)
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
|
@ -116,12 +103,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
|||
if type(self.tokenizer).__name__.endswith('Fast'):
|
||||
raise ValueError('Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
|
||||
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_roberta_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_roberta_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs = self.convert_bpe_vocab(hf_roberta_tokenizer)
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
|
|
|
@ -16,19 +16,27 @@
|
|||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
struct KernelBpeDecoder : public BaseKernel {
|
||||
struct KernelBpeDecoder {
|
||||
public:
|
||||
KernelBpeDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "id_vocab");
|
||||
if (vocab.empty()) {
|
||||
ORTX_CXX_API_THROW("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
|
||||
std::string vocab;
|
||||
OrtStatusPtr status = OrtW::GetOpAttribute(info, "id_vocab", vocab);
|
||||
if (status != nullptr || vocab.empty()) {
|
||||
if (status == nullptr) {
|
||||
status = OrtW::CreateStatus("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
BuildIdVocab(vocab);
|
||||
|
||||
std::string byte_decoder = ort_.KernelInfoGetAttribute<std::string>(&info, "byte_decoder");
|
||||
if (byte_decoder.empty()) {
|
||||
ORTX_CXX_API_THROW("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
std::string byte_decoder;
|
||||
status = OrtW::GetOpAttribute(info, "byte_decoder", byte_decoder);
|
||||
if (status != nullptr || byte_decoder.empty()) {
|
||||
if (status == nullptr) {
|
||||
status = OrtW::CreateStatus("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
return status;
|
||||
} else {
|
||||
auto um = ParseId2String(byte_decoder);
|
||||
std::transform(um.begin(), um.end(),
|
||||
|
@ -37,13 +45,15 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
|
||||
}
|
||||
|
||||
std::string added_tokens = TryToGetAttributeWithDefault<std::string>("added_tokens", "");
|
||||
std::string added_tokens;
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens));
|
||||
if (!added_tokens.empty()) {
|
||||
auto um = ParseId2String(added_tokens);
|
||||
added_tokens_ = std::map<int64_t, std::string>(um.begin(), um.end());
|
||||
}
|
||||
|
||||
std::string all_special_ids = TryToGetAttributeWithDefault<std::string>("all_special_ids", "");
|
||||
std::string all_special_ids;
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
|
||||
if (!all_special_ids.empty()) {
|
||||
auto um = ParseId2String(all_special_ids);
|
||||
std::transform(um.begin(), um.end(),
|
||||
|
@ -51,12 +61,14 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
[](const auto& p) { return p.first; });
|
||||
}
|
||||
|
||||
en_normalization_ = TryToGetAttributeWithDefault<int64_t>("en_normalization", 0);
|
||||
skip_special_tokens_ = TryToGetAttributeWithDefault<int64_t>("skip_special_tokens", 0);
|
||||
whitespace_token_ = TryToGetAttributeWithDefault<int64_t>("whitespace_token", 0);
|
||||
bos_token_ = TryToGetAttributeWithDefault("bos_token", std::string("<|endoftext|>"));
|
||||
eos_token_ = TryToGetAttributeWithDefault("eos_token", std::string("<|endoftext|>"));
|
||||
unk_token_ = TryToGetAttributeWithDefault("unk_token", std::string("<|endoftext|>"));
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_));
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_));
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_));
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_));
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_));
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_));
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
std::unordered_map<int64_t, std::string> ParseId2String(const std::string& s_attr) {
|
||||
|
@ -102,8 +114,8 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
arr_vocab_.shrink_to_fit();
|
||||
}
|
||||
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
OrtStatusPtr Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) const {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
const auto& ids_dim = ids.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
|
@ -168,12 +180,13 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
p_ids += seq_len;
|
||||
}
|
||||
output.SetStringOutput(decoded_strings, output_dim);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::string unk_token_;
|
||||
std::string bos_token_{"<|endoftext|>"};
|
||||
std::string eos_token_{"<|endoftext|>"};
|
||||
std::string unk_token_{"<|endoftext|>"};
|
||||
|
||||
// Since ORT API doesn't support boolean type in ONNX node attribute,
|
||||
// all flag attributes here are defined as int64 type to be more explicit.
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
#include <optional>
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
std::string BpeModelConf::GetSpecialTokens() const {
|
||||
std::string special_tokens = unk_token_; // unk_token_ is required
|
||||
auto add_token = [](std::string& sp, const char* tok) {
|
||||
|
@ -87,43 +89,53 @@ ustring RemoveConsecutiveSpaces(const ustring& input) {
|
|||
return result;
|
||||
}
|
||||
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info, const BpeModelConf& conf)
|
||||
: BaseKernel(api, info),
|
||||
bpe_conf_(conf) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf)
|
||||
: bpe_conf_(conf){};
|
||||
|
||||
OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
|
||||
std::string vocab;
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab));
|
||||
if (vocab.empty()) {
|
||||
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
return OrtW::CreateStatus("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
|
||||
std::string merges;
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges));
|
||||
if (merges.empty()) {
|
||||
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
|
||||
padding_length_ = -1;
|
||||
return OrtW::CreateStatus("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_));
|
||||
if (padding_length_ != -1 && padding_length_ <= 0) {
|
||||
ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
|
||||
return OrtW::CreateStatus("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::stringstream vocabu_stream(vocab);
|
||||
std::stringstream merges_stream(merges);
|
||||
bbpe_tokenizer_ = std::make_unique<BpeModel>();
|
||||
bbpe_tokenizer_->Load(vocabu_stream, merges_stream, conf.unk_token_, conf.GetSpecialTokens().c_str());
|
||||
auto status = bbpe_tokenizer_->Load(vocabu_stream, merges_stream, bpe_conf_.unk_token_, bpe_conf_.GetSpecialTokens().c_str());
|
||||
if (status != nullptr) {
|
||||
return status;
|
||||
}
|
||||
|
||||
std::string added_token;
|
||||
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
|
||||
ORTX_RETURN_IF_ERROR(bbpe_tokenizer_->LoadAddedTokens(added_token.c_str()));
|
||||
|
||||
// TODO: need to check if the special token ids are the same as the ones in HFTokenizer
|
||||
unk_token_id_ = bbpe_tokenizer_->GetTokenId(conf.unk_token_);
|
||||
if (conf.bos_token_ != nullptr) {
|
||||
bos_token_id_ = bbpe_tokenizer_->GetTokenId(conf.bos_token_);
|
||||
unk_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.unk_token_);
|
||||
if (bpe_conf_.bos_token_ != nullptr) {
|
||||
bos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.bos_token_);
|
||||
}
|
||||
if (conf.eos_token_ != nullptr) {
|
||||
eos_token_id_ = bbpe_tokenizer_->GetTokenId(conf.eos_token_);
|
||||
if (bpe_conf_.eos_token_ != nullptr) {
|
||||
eos_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.eos_token_);
|
||||
}
|
||||
if (conf.pad_token_ != nullptr) {
|
||||
pad_token_id_ = bbpe_tokenizer_->GetTokenId(conf.pad_token_);
|
||||
if (bpe_conf_.pad_token_ != nullptr) {
|
||||
pad_token_id_ = bbpe_tokenizer_->GetTokenId(bpe_conf_.pad_token_);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
||||
|
@ -171,21 +183,20 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
}
|
||||
|
||||
// Parse input
|
||||
auto special_token_split_res = bbpe_tokenizer_->SplitBySpecialTokens(input);
|
||||
TokenWithRegularExp regcmp;
|
||||
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
|
||||
bpe::TokenWithRegularExp regcmp;
|
||||
|
||||
for (auto& seg_id : special_token_split_res) {
|
||||
if (static_cast<int64_t>(res.size()) >= max_length) break;
|
||||
|
||||
if (seg_id.second != -1) {
|
||||
if (seg_id.second != bpe::kInvalidTokenId) {
|
||||
res.push_back(seg_id.second);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cur_input = std::move(seg_id.first);
|
||||
// Note: keep ptr to make sure the string_view is valid in the following process
|
||||
const char32_t* ptr = cur_input.c_str();
|
||||
regcmp.Set(ptr);
|
||||
std::u32string str(seg_id.first);
|
||||
regcmp.Set(str.c_str());
|
||||
|
||||
size_t offset = 0;
|
||||
OffsetMappingType offset_mapping;
|
||||
|
@ -199,7 +210,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
|
||||
while (static_cast<int64_t>(res.size()) < max_length) {
|
||||
auto [b, tok] = regcmp.GetNextToken();
|
||||
|
||||
|
||||
if (!b) break;
|
||||
|
||||
std::string utf8_token = std::string(ustring(tok));
|
||||
|
@ -271,13 +282,14 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
|
|||
// Add EOS token to result
|
||||
res.push_back(eos_token_id_);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
OrtStatusPtr KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
// Setup inputs
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
std::list<OffsetMappingType> offset_map;
|
||||
|
@ -356,11 +368,13 @@ void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
|||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const auto kGPT2Confinguration = BpeModelConf();
|
||||
GPT2Tokenizer::GPT2Tokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: KernelBpeTokenizer(api, info, kGPT2Confinguration) {}
|
||||
GPT2Tokenizer::GPT2Tokenizer()
|
||||
: KernelBpeTokenizer(kGPT2Confinguration) {}
|
||||
|
||||
static const auto kRobertaConfiguration = BpeModelConf{
|
||||
BpeModelConf::kModel_Roberta, // name
|
||||
|
@ -369,8 +383,8 @@ static const auto kRobertaConfiguration = BpeModelConf{
|
|||
"</s>", // eos_token
|
||||
"<pad>"}; // pad_token
|
||||
|
||||
RobertaTokenizer::RobertaTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: KernelBpeTokenizer(api, info, kRobertaConfiguration) {}
|
||||
RobertaTokenizer::RobertaTokenizer()
|
||||
: KernelBpeTokenizer(kRobertaConfiguration) {}
|
||||
|
||||
static const auto kCLIPConfiguration = BpeModelConf{
|
||||
BpeModelConf::kModel_CLIP, // name
|
||||
|
@ -379,5 +393,5 @@ static const auto kCLIPConfiguration = BpeModelConf{
|
|||
"<|endoftext|>", // eos_token
|
||||
"<|endoftext|>"}; // pad_token
|
||||
|
||||
CLIPTokenizer::CLIPTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: KernelBpeTokenizer(api, info, kCLIPConfiguration) {}
|
||||
CLIPTokenizer::CLIPTokenizer()
|
||||
: KernelBpeTokenizer(kCLIPConfiguration) {}
|
||||
|
|
|
@ -23,15 +23,18 @@ struct BpeModelConf {
|
|||
std::string GetSpecialTokens() const;
|
||||
};
|
||||
|
||||
namespace ort_extensions {
|
||||
class BpeModel;
|
||||
}
|
||||
|
||||
struct KernelBpeTokenizer : BaseKernel {
|
||||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info, const BpeModelConf& conf);
|
||||
struct KernelBpeTokenizer {
|
||||
KernelBpeTokenizer(const BpeModelConf& conf);
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
|
||||
|
||||
const char* ModelName() const { return bpe_conf_.name_; }
|
||||
|
||||
|
@ -41,10 +44,12 @@ struct KernelBpeTokenizer : BaseKernel {
|
|||
int64_t max_length,
|
||||
bool compute_offset_mapping,
|
||||
std::list<OffsetMappingType>& offset_map) const;
|
||||
int64_t padding_length_;
|
||||
std::unique_ptr<BpeModel> bbpe_tokenizer_;
|
||||
const BpeModelConf& bpe_conf_;
|
||||
|
||||
private:
|
||||
const BpeModelConf& bpe_conf_;
|
||||
std::unique_ptr<ort_extensions::BpeModel> bbpe_tokenizer_;
|
||||
|
||||
int64_t padding_length_ = -1;
|
||||
uint32_t unk_token_id_{};
|
||||
uint32_t bos_token_id_{};
|
||||
uint32_t eos_token_id_{};
|
||||
|
@ -52,34 +57,34 @@ struct KernelBpeTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct GPT2Tokenizer : KernelBpeTokenizer {
|
||||
GPT2Tokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
GPT2Tokenizer();
|
||||
// required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler.
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
||||
struct RobertaTokenizer : KernelBpeTokenizer {
|
||||
RobertaTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
RobertaTokenizer();
|
||||
// required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler.
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
||||
struct CLIPTokenizer : KernelBpeTokenizer {
|
||||
CLIPTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
CLIPTokenizer();
|
||||
// required by LiteCustomOp which neede a explicit Compute declaration for non-MSVC compiler.
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
OrtStatusPtr Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
|
||||
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -12,18 +12,23 @@
|
|||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <charconv>
|
||||
#include <limits>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "bpe_utils.hpp"
|
||||
#include "trietree.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
class BpeModel {
|
||||
public:
|
||||
BpeModel() = default;
|
||||
|
||||
void Load(std::istream& vocab_stream,
|
||||
std::istream& merges_stream,
|
||||
const char* unk_token,
|
||||
const char* special_tokens) {
|
||||
OrtStatusPtr Load(std::istream& vocab_stream,
|
||||
std::istream& merges_stream,
|
||||
const char* unk_token,
|
||||
const char* special_tokens) {
|
||||
nlohmann::json tok_json;
|
||||
vocab_stream >> tok_json;
|
||||
vocab_map_ = std::move(tok_json.get<std::unordered_map<std::string, uint32_t>>());
|
||||
|
@ -34,6 +39,7 @@ class BpeModel {
|
|||
} else {
|
||||
auto id = ort_extensions::narrow<uint32_t>(vocab_map_.size());
|
||||
vocab_map_[unk_token] = id;
|
||||
unk_id_ = id;
|
||||
}
|
||||
|
||||
CreateByteEncoder();
|
||||
|
@ -46,7 +52,7 @@ class BpeModel {
|
|||
if ((line[0] == '#') && (index == 0)) continue;
|
||||
auto pos = line.find(' ');
|
||||
if (pos == std::string::npos) {
|
||||
ORTX_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
|
||||
return OrtW::CreateStatus("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
std::string w1 = line.substr(0, pos);
|
||||
std::string w2 = line.substr(pos + 1);
|
||||
|
@ -54,9 +60,9 @@ class BpeModel {
|
|||
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
|
||||
token_length -= 4;
|
||||
}
|
||||
auto iw1 = GetVocabIndex(w1);
|
||||
auto iw2 = GetVocabIndex(w2);
|
||||
auto iww = GetVocabIndex(w1 + w2);
|
||||
auto iw1 = GetTokenId(w1);
|
||||
auto iw2 = GetTokenId(w2);
|
||||
auto iww = GetTokenId(w1 + w2);
|
||||
BpeNode value{iww, index++, token_length};
|
||||
bpe_rank_[GetRankKey(iw1, iw2)] = value;
|
||||
}
|
||||
|
@ -80,8 +86,62 @@ class BpeModel {
|
|||
|
||||
id2token_map_.resize(vocab_map_.size());
|
||||
for (const auto& [t, i] : vocab_map_) {
|
||||
if (i > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
|
||||
continue; // safe purpose.
|
||||
}
|
||||
if (i > id2token_map_.size()) {
|
||||
id2token_map_.resize(i + 1);
|
||||
}
|
||||
id2token_map_[i] = t;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OrtStatusPtr LoadAddedTokens(const char* added_tokens) {
|
||||
int id = bpe::kInvalidTokenId;
|
||||
std::istringstream strm_tokens(added_tokens);
|
||||
std::string line;
|
||||
while (!strm_tokens.eof()) {
|
||||
std::getline(strm_tokens, line);
|
||||
line.erase(std::remove(line.begin(), line.end(), '\r'), line.end());
|
||||
if (line.empty()) continue;
|
||||
// seperate the key and value by =
|
||||
auto pos = line.rfind("=");
|
||||
if (pos == std::string::npos) {
|
||||
return OrtW::CreateStatus("Error on parse a added_token line: " + line, ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
auto token = line.substr(0, pos);
|
||||
auto id_str = line.substr(pos + 1); // 1 is the length of "="
|
||||
auto [ptr, ec] = std::from_chars(id_str.data(), id_str.data() + id_str.length(), id);
|
||||
if (ec != std::errc()) {
|
||||
return OrtW::CreateStatus("Cannot convert to an integer from " + id_str, ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
added_tokens_.Add(ustring(token), 0, std::make_optional(id));
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
|
||||
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
|
||||
// split by added tokens
|
||||
bpe::TokenPairs added_result;
|
||||
bpe::TokenPairs final_result;
|
||||
added_tokens_.Split(input, added_result);
|
||||
for (const auto& [token, id] : added_result) {
|
||||
if (id != bpe::kInvalidTokenId) {
|
||||
final_result.emplace_back(token, id);
|
||||
} else {
|
||||
auto special_result = special_tokens_.SplitBySpecialTokens(token);
|
||||
for (const auto& [token, id] : special_result) {
|
||||
final_result.emplace_back(token, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return final_result;
|
||||
}
|
||||
|
||||
void bpe(std::list<std::pair<uint32_t, uint32_t>>& vals) const {
|
||||
|
@ -94,9 +154,15 @@ class BpeModel {
|
|||
for (auto it = vals.begin(); it != vals.end(); ++it) {
|
||||
auto it2 = it;
|
||||
++it2;
|
||||
if (it2 == vals.end()) break;
|
||||
if (it2 == vals.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto map_it = bpe_rank_.find(GetRankKey(it->first, it2->first));
|
||||
if (map_it == bpe_rank_.end()) continue;
|
||||
if (map_it == bpe_rank_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (minval > map_it->second.value) {
|
||||
ori_id1 = it->first;
|
||||
ori_id2 = it2->first;
|
||||
|
@ -105,7 +171,10 @@ class BpeModel {
|
|||
aim_id = map_it->second.id;
|
||||
}
|
||||
}
|
||||
if (pos_it == vals.end()) break;
|
||||
|
||||
if (pos_it == vals.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
token_length = pos_it->second;
|
||||
pos_it = vals.erase(pos_it);
|
||||
|
@ -129,11 +198,6 @@ class BpeModel {
|
|||
return byte_encoder_;
|
||||
}
|
||||
|
||||
auto SplitBySpecialTokens(const ustring& input) const {
|
||||
return special_tokens_.SplitBySpecialTokens(input);
|
||||
}
|
||||
|
||||
// Returns token if key was found in vocab, and unk_id_ otherwise
|
||||
uint32_t GetTokenId(const std::string& key) {
|
||||
auto it = vocab_map_.find(key);
|
||||
if (it != end(vocab_map_)) {
|
||||
|
@ -163,21 +227,13 @@ class BpeModel {
|
|||
)
|
||||
*/
|
||||
if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) {
|
||||
byte_encoder_[i] = GetVocabIndex(ustring::EncodeUTF8Char(index++));
|
||||
byte_encoder_[i] = GetTokenId(ustring::EncodeUTF8Char(index++));
|
||||
} else {
|
||||
byte_encoder_[i] = GetVocabIndex(ustring::EncodeUTF8Char(i));
|
||||
byte_encoder_[i] = GetTokenId(ustring::EncodeUTF8Char(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t GetVocabIndex(const std::string& str) {
|
||||
auto it = vocab_map_.find(str);
|
||||
if (it == vocab_map_.end()) {
|
||||
ORTX_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<uint64_t, BpeNode> bpe_rank_;
|
||||
|
||||
|
@ -186,5 +242,8 @@ class BpeModel {
|
|||
std::vector<std::string> id2token_map_;
|
||||
|
||||
uint32_t unk_id_ = std::numeric_limits<uint32_t>::max();
|
||||
SpecialTokenMap special_tokens_;
|
||||
bpe::SpecialTokenMap special_tokens_;
|
||||
TrieTree<char32_t> added_tokens_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -6,35 +6,43 @@
|
|||
#include "ocos.h"
|
||||
#include "narrow.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include "ustring.h"
|
||||
|
||||
#include "unicode.h"
|
||||
|
||||
namespace ort_extensions {
|
||||
namespace bpe {
|
||||
|
||||
using TokenPairs = std::vector<std::pair<std::u32string_view, int>>;
|
||||
using u32string_view = std::u32string_view;
|
||||
|
||||
constexpr int kInvalidTokenId = -1;
|
||||
|
||||
class SpecialTokenMap {
|
||||
public:
|
||||
void Add(ustring p_str, int p_id) {
|
||||
auto it = token_map_.find(p_str);
|
||||
if (it != token_map_.end()) {
|
||||
if (it->second != p_id) {
|
||||
ORTX_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
assert(it->second == p_id && "Duplicate special tokens.");
|
||||
} else {
|
||||
token_map_[p_str] = p_id;
|
||||
token_list_.push_back(SpecialTokenInfo(std::move(p_str), p_id));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<ustring, int>> SplitBySpecialTokens(ustring input) const {
|
||||
std::vector<std::pair<ustring, int>> res;
|
||||
res.emplace_back(std::move(input), -1);
|
||||
TokenPairs SplitBySpecialTokens(const std::u32string_view& input) const {
|
||||
TokenPairs res;
|
||||
res.emplace_back(input, kInvalidTokenId);
|
||||
for (const auto& st : token_list_) {
|
||||
std::vector<std::pair<ustring, int>> new_split_res;
|
||||
TokenPairs new_split_res;
|
||||
for (auto& str : res) {
|
||||
if (str.second != -1) {
|
||||
new_split_res.push_back(std::move(str));
|
||||
if (str.second != kInvalidTokenId) {
|
||||
new_split_res.emplace_back(str);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = str.first.begin();
|
||||
size_t search_pos = 0;
|
||||
while (it != str.first.end()) {
|
||||
|
@ -46,21 +54,27 @@ class SpecialTokenMap {
|
|||
std::boyer_moore_searcher(st.str.begin(), st.str.end()));
|
||||
#endif
|
||||
if (search_it == str.first.end()) {
|
||||
new_split_res.emplace_back(str.first.substr(search_pos), -1);
|
||||
new_split_res.emplace_back(u32string_view(
|
||||
str.first.data() + search_pos, str.first.size() - search_pos),
|
||||
kInvalidTokenId);
|
||||
break;
|
||||
}
|
||||
|
||||
auto prefixLen = search_it - it;
|
||||
if (prefixLen != 0) {
|
||||
new_split_res.emplace_back(str.first.substr(search_pos, prefixLen), -1);
|
||||
new_split_res.emplace_back(u32string_view(str.first.data() + search_pos, prefixLen), kInvalidTokenId);
|
||||
search_pos += prefixLen;
|
||||
}
|
||||
new_split_res.emplace_back(str.first.substr(search_pos, st.str.size()), st.id);
|
||||
|
||||
new_split_res.emplace_back(u32string_view(str.first.data() + search_pos, st.str.size()), st.id);
|
||||
it = search_it + st.str.size();
|
||||
search_pos += st.str.size();
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(new_split_res, res);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -101,7 +115,6 @@ class TokenWithRegularExp {
|
|||
|
||||
private:
|
||||
std::u32string_view TryMatch() {
|
||||
|
||||
// python pattern:
|
||||
// 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
|
||||
|
@ -198,8 +211,7 @@ class TokenWithRegularExp {
|
|||
for (; i < m_text.size(); ++i) {
|
||||
if (!IsZ(m_text[i])) break;
|
||||
}
|
||||
if ((i > 1) && (i != m_text.size())) //\s+(?!\S)
|
||||
{
|
||||
if ((i > 1) && (i != m_text.size())) { //\s+(?!\S)
|
||||
i--;
|
||||
std::u32string_view res = m_text.substr(0, i);
|
||||
m_text = m_text.substr(i);
|
||||
|
@ -230,3 +242,6 @@ class TokenWithRegularExp {
|
|||
private:
|
||||
std::u32string_view m_text;
|
||||
};
|
||||
|
||||
} // namespace bpe
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -35,10 +35,10 @@
|
|||
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& {
|
||||
static OrtOpLoader op_loader(
|
||||
#ifdef ENABLE_GPT2_TOKENIZER
|
||||
CustomCpuStruct("GPT2Tokenizer", GPT2Tokenizer),
|
||||
CustomCpuStruct("CLIPTokenizer", CLIPTokenizer),
|
||||
CustomCpuStruct("RobertaTokenizer", RobertaTokenizer),
|
||||
CustomCpuStruct("BpeDecoder", KernelBpeDecoder),
|
||||
CustomCpuStructV2("GPT2Tokenizer", GPT2Tokenizer),
|
||||
CustomCpuStructV2("CLIPTokenizer", CLIPTokenizer),
|
||||
CustomCpuStructV2("RobertaTokenizer", RobertaTokenizer),
|
||||
CustomCpuStructV2("BpeDecoder", KernelBpeDecoder),
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
|
|
|
@ -14,66 +14,32 @@
|
|||
#include <optional>
|
||||
|
||||
#include "unescape.h"
|
||||
#include "trietree.hpp"
|
||||
|
||||
// This Trie Tree is C++ implementation of
|
||||
// https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/rwkv_tokenizer.py
|
||||
// Perf optimized by leveraging C++ features, but the algorithm is the same.
|
||||
class TrieTree {
|
||||
class RWKVTrieTree : public ort_extensions::TrieTree<char> {
|
||||
public:
|
||||
static constexpr int kMaxTokenLength_ = 128;
|
||||
|
||||
TrieTree(unsigned char ch = 0) : ch_(ch), to_(256) {}
|
||||
RWKVTrieTree(char ch = 0) : TrieTree(ch) {}
|
||||
|
||||
// keep the same function for source code understanding.
|
||||
void add(const std::string& key, int idx = 0,
|
||||
std::optional<int> value = std::optional<int>()) {
|
||||
if (idx == key.length()) {
|
||||
if (!value) {
|
||||
value = key[0];
|
||||
}
|
||||
value_ = value;
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned char ch = static_cast<unsigned char>(key[idx]);
|
||||
if (to_[ch] == nullptr) {
|
||||
to_[ch] = std::make_unique<TrieTree>(ch);
|
||||
}
|
||||
to_[ch]->add(key, idx + 1, value);
|
||||
Add(key, idx, value);
|
||||
}
|
||||
|
||||
int find_longest(const std::string& key, size_t& idx) {
|
||||
const TrieTree* u = this;
|
||||
unsigned char ch = key[idx];
|
||||
|
||||
int tok_id = 0;
|
||||
size_t idx_end = idx;
|
||||
while (u->to_[ch]) {
|
||||
u = u->to_[ch].get();
|
||||
idx += 1;
|
||||
if (u->value_) {
|
||||
tok_id = *u->value_;
|
||||
idx_end = idx;
|
||||
}
|
||||
if (idx == key.length()) {
|
||||
break;
|
||||
}
|
||||
ch = key[idx];
|
||||
}
|
||||
|
||||
idx = idx_end;
|
||||
return tok_id;
|
||||
return FindLongest(key, idx);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<TrieTree>> to_;
|
||||
std::optional<int> value_;
|
||||
unsigned char ch_;
|
||||
};
|
||||
|
||||
class TrieTokenizer {
|
||||
private:
|
||||
std::map<int, std::string> idx2token;
|
||||
TrieTree root;
|
||||
RWKVTrieTree root;
|
||||
|
||||
public:
|
||||
TrieTokenizer(const std::string& text_tokens) {
|
||||
|
@ -210,7 +176,7 @@ struct KernelTrieDetokenizer : public BaseKernel {
|
|||
if (ustring::ValidateUTF8(raw_string)) {
|
||||
output[n] = raw_string;
|
||||
} else {
|
||||
output[n] = "\ufffd"; // bad utf-8 string
|
||||
output[n] = "\ufffd"; // bad utf-8 string
|
||||
failed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
#include "ocos.h"
|
||||
#include "narrow.h"
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <optional>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
template <typename CharT, typename ValueT = int>
|
||||
class TrieTree {
|
||||
public:
|
||||
static constexpr int kMaxTokenLength_ = 128;
|
||||
|
||||
TrieTree(CharT ch = 0, ValueT invalid_id = -1) : ch_(ch), invalid_id_(invalid_id) {}
|
||||
|
||||
void Add(const std::basic_string<CharT>& key, int idx = 0,
|
||||
const std::optional<ValueT>& value = std::nullopt) noexcept {
|
||||
if (idx == key.length()) {
|
||||
if (!value) {
|
||||
value_ = std::make_optional(narrow<ValueT>(key[0]));
|
||||
} else {
|
||||
value_ = value;
|
||||
}
|
||||
} else {
|
||||
auto ch = key[idx];
|
||||
if (to_.count(ch) == 0) {
|
||||
to_[ch] = std::make_unique<TrieTree>(ch);
|
||||
}
|
||||
to_[ch]->Add(key, idx + 1, value);
|
||||
}
|
||||
}
|
||||
|
||||
ValueT FindLongest(const std::basic_string<CharT>& key, size_t& idx) const noexcept {
|
||||
const TrieTree* u = this;
|
||||
CharT ch = key[idx];
|
||||
|
||||
ValueT tok_id = invalid_id_;
|
||||
size_t idx_end = idx;
|
||||
while (u->to_.count(ch)) {
|
||||
u = u->to_.at(ch).get();
|
||||
idx += 1;
|
||||
if (u->value_) {
|
||||
tok_id = *u->value_;
|
||||
idx_end = idx;
|
||||
}
|
||||
if (idx == key.length()) {
|
||||
break;
|
||||
}
|
||||
ch = key[idx];
|
||||
}
|
||||
|
||||
idx = idx_end;
|
||||
return tok_id;
|
||||
}
|
||||
|
||||
int Split(const std::basic_string<CharT>& input,
|
||||
std::vector<std::pair<std::basic_string_view<CharT>, ValueT>>& tokens) const noexcept {
|
||||
size_t seg_idx = 0;
|
||||
size_t tok_idx = 0;
|
||||
|
||||
while (tok_idx < input.length()) {
|
||||
// variable u is the tree root.
|
||||
const TrieTree* u = this;
|
||||
auto ch = input[tok_idx];
|
||||
size_t tok_len = 0;
|
||||
size_t idx_end = tok_idx;
|
||||
ValueT tok_id = invalid_id_;
|
||||
|
||||
// try to match a longest token
|
||||
while (u->to_.count(ch)) {
|
||||
tok_len += 1;
|
||||
u = u->to_.at(ch).get();
|
||||
if (u->value_) {
|
||||
tok_id = *u->value_;
|
||||
idx_end = tok_idx + 1;
|
||||
}
|
||||
|
||||
tok_idx += 1;
|
||||
if (tok_idx == input.length()) {
|
||||
break;
|
||||
}
|
||||
ch = input[tok_idx];
|
||||
}
|
||||
|
||||
tok_idx += 1;
|
||||
if (tok_id == invalid_id_) {
|
||||
if (tok_idx < input.length()) {
|
||||
continue;
|
||||
} else {
|
||||
tok_idx += 1; // Assign tok_idx to input.length()
|
||||
idx_end = tok_idx;
|
||||
}
|
||||
}
|
||||
|
||||
auto token_begin_idx = tok_idx - tok_len - 1; // since the tok_idx already moved forward by 1
|
||||
tok_len = idx_end - token_begin_idx;
|
||||
if (token_begin_idx > seg_idx || tok_len == 0) {
|
||||
tokens.emplace_back(std::basic_string_view<CharT>(input.data() + seg_idx, token_begin_idx - seg_idx),
|
||||
invalid_id_);
|
||||
}
|
||||
if (tok_id != invalid_id_) {
|
||||
tokens.emplace_back(std::basic_string_view<CharT>(input.data() + token_begin_idx, tok_len), tok_id);
|
||||
tok_idx = idx_end;
|
||||
}
|
||||
|
||||
// reset state for next match
|
||||
seg_idx = tok_idx;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<CharT, std::unique_ptr<TrieTree>> to_;
|
||||
std::optional<ValueT> value_;
|
||||
const CharT ch_;
|
||||
const ValueT invalid_id_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import pkg_resources
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer
|
||||
|
@ -76,7 +75,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0])
|
||||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids, actual_ids)
|
||||
|
||||
|
||||
def test_gpt2_tokenizer(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("Xenova/gpt-4", use_fast=False)
|
||||
text = "Testing words with apostrophes such as you're, i'm, don't, etc."
|
||||
|
@ -96,7 +95,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
|
||||
ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True})
|
||||
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True})
|
||||
actual_ids, *_ = ort_inference(ort_tok, [text])
|
||||
np.testing.assert_array_equal(ids[0], actual_ids)
|
||||
|
||||
|
@ -124,8 +123,7 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
|
||||
actual_ids, *_ = ort_inference(ort_tok, [code])
|
||||
self.assertEqual(len(ids['input_ids'].shape), len(actual_ids.shape))
|
||||
# TODO: not matched.
|
||||
# np.testing.assert_array_equal(ids['input_ids'], actual_ids)
|
||||
np.testing.assert_array_equal(ids['input_ids'], actual_ids)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Загрузка…
Ссылка в новой задаче