Add TrieTokenizer for RWKV-like LLM models (#509)

* Add TrieTokenizer for RWKV-like LLM models

* add more tests

* fix the windows build

* downloading file instead of check in the vocab file

* a small bug fixing
This commit is contained in:
Wenbing Li 2023-08-08 16:47:38 -07:00 коммит произвёл GitHub
Родитель c8bb9e8abd
Коммит 978ada6d60
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 666 добавлений и 30 удалений

Просмотреть файл

@ -48,6 +48,7 @@ option(OCOS_ENABLE_CPP_EXCEPTIONS "Enable C++ Exception" ON)
option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
option(OCOS_ENABLE_RE2_REGEX "Enable StringRegexReplace and StringRegexSplit" ON)
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
option(OCOS_ENABLE_TRIE_TOKENIZER "Enable the TrieTokenizer building" ON)
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
option(OCOS_ENABLE_WORDPIECE_TOKENIZER "Enable the WordpieceTokenizer building" ON)
option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
@ -72,6 +73,7 @@ function(disable_all_operators)
set(OCOS_ENABLE_TF_STRING OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_WORDPIECE_TOKENIZER OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_GPT2_TOKENIZER OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_TRIE_TOKENIZER OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_BERT_TOKENIZER OFF CACHE INTERNAL "" FORCE)
set(OCOS_ENABLE_BLINGFIRE OFF CACHE INTERNAL "" FORCE)
@ -346,6 +348,13 @@ if(OCOS_ENABLE_GPT2_TOKENIZER)
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
endif()
if(OCOS_ENABLE_TRIE_TOKENIZER)
# Trie Tokenizer
set(_HAS_TOKENIZER ON)
file(GLOB tok_TARGET_SRC "operators/tokenizer/trie_tokenizer.hpp" "operators/tokenizer/unescape.h")
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
endif()
if(OCOS_ENABLE_SPM_TOKENIZER)
# SentencePiece
set(_HAS_TOKENIZER ON)
@ -512,6 +521,10 @@ if(OCOS_ENABLE_GPT2_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
endif()
if(OCOS_ENABLE_TRIE_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TRIE_TOKENIZER)
endif()
if(OCOS_ENABLE_WORDPIECE_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_WORDPIECE_TOKENIZER)
endif()
@ -797,3 +810,30 @@ if(OCOS_ENABLE_CTEST)
add_test(NAME extensions_test COMMAND $<TARGET_FILE:extensions_test>)
endif()
endif()
if(OCOS_ENABLE_AZURE)
add_dependencies(ocos_operators triton)
target_include_directories(ocos_operators PUBLIC ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
target_link_directories(ocos_operators PUBLIC ${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
if (ocos_target_platform STREQUAL "AMD64")
set(vcpkg_target_platform "x86")
else()
set(vcpkg_target_platform ${ocos_target_platform})
endif()
if (WIN32)
target_link_directories(ocos_operators PUBLIC ${VCPKG_SRC}/installed/${vcpkg_target_platform}-windows-static/lib)
target_link_libraries(ocos_operators PUBLIC libcurl httpclient_static ws2_32 crypt32 Wldap32)
else()
find_package(ZLIB REQUIRED)
find_package(OpenSSL REQUIRED)
target_link_libraries(ocos_operators PUBLIC httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
endif() #if (WIN32)
endif()

Просмотреть файл

@ -29,11 +29,16 @@ target_compile_definitions(extensions_pydll PRIVATE
target_link_libraries(extensions_pydll PRIVATE Python3::Module ocos_operators)
if(NOT "${OCOS_EXTENTION_NAME}" STREQUAL "")
if(OCOS_PYTHON_MODULE_PATH)
get_filename_component(OCOS_PYTHON_MODULE_NAME ${OCOS_PYTHON_MODULE_PATH} NAME)
if(NOT WIN32)
set_target_properties(extensions_pydll PROPERTIES
LIBRARY_OUTPUT_NAME ${OCOS_EXTENTION_NAME}
LIBRARY_OUTPUT_NAME ${OCOS_PYTHON_MODULE_NAME}
PREFIX ""
SUFFIX "")
endif()
add_custom_command(TARGET extensions_pydll POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:extensions_pydll> ${OCOS_PYTHON_MODULE_PATH}
COMMENT "Copying $<TARGET_FILE:extensions_pydll> to ${OCOS_PYTHON_MODULE_PATH}")
endif()

Просмотреть файл

@ -350,6 +350,26 @@ class SentencepieceDecoder(CustomOp):
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
class TrieTokenizer(CustomOp):
@classmethod
def get_inputs(cls):
return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]
@classmethod
def get_outputs(cls):
return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
class TrieDetokenizer(CustomOp):
@classmethod
def get_inputs(cls):
return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
@classmethod
def get_outputs(cls):
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
class Inverse(CustomOp):
@classmethod

Просмотреть файл

@ -27,42 +27,44 @@
#include "bert_tokenizer_decoder.hpp"
#endif
const std::vector<const OrtCustomOp*>& TokenizerLoader() {
#ifdef ENABLE_TRIE_TOKENIZER
#include "trie_tokenizer.hpp"
#endif
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_GPT2_TOKENIZER
,
CustomCpuStruct("GPT2Tokenizer", KernelBpeTokenizer),
CustomCpuStruct("CLIPTokenizer", KernelClipBpeTokenizer),
CustomCpuStruct("RobertaTokenizer", KernelRobertaBpeTokenizer),
CustomCpuStruct("BpeDecoder", KernelBpeDecoder)
CustomCpuStruct("BpeDecoder", KernelBpeDecoder),
#endif
#ifdef ENABLE_SPM_TOKENIZER
,
CustomCpuStruct("SentencepieceTokenizer", KernelSentencepieceTokenizer),
CustomCpuStruct("SentencepieceDecoder", KernelSentencepieceDecoder)
CustomCpuStruct("SentencepieceDecoder", KernelSentencepieceDecoder),
#endif
#ifdef ENABLE_TRIE_TOKENIZER
CustomCpuStruct("TrieTokenizer", KernelTrieTokenizer),
CustomCpuStruct("TrieDetokenizer", KernelTrieDetokenizer),
#endif
#ifdef ENABLE_WORDPIECE_TOKENIZER
,
CustomCpuStruct("WordpieceTokenizer", KernelWordpieceTokenizer)
CustomCpuStruct("WordpieceTokenizer", KernelWordpieceTokenizer),
#endif
#ifdef ENABLE_BERT_TOKENIZER
,
CustomCpuStruct("BasicTokenizer", KernelBasicTokenizer),
CustomCpuStruct("BertTokenizer", KernelBertTokenizer),
CustomCpuStruct("BertTokenizerDecoder", KernelBertTokenizerDecoder),
CustomCpuStruct("HfBertTokenizer", KernelHfBertTokenizer)
CustomCpuStruct("HfBertTokenizer", KernelHfBertTokenizer),
#endif
#ifdef ENABLE_BLINGFIRE
,
CustomCpuStruct("BlingFireSentenceBreaker", KernelBlingFireSentenceBreaker)
CustomCpuStruct("BlingFireSentenceBreaker", KernelBlingFireSentenceBreaker),
#endif
);
return op_loader.GetCustomOps();
}
[]() { return nullptr; });
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = TokenizerLoader;
return op_loader.GetCustomOps();
};

Просмотреть файл

@ -0,0 +1,213 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "narrow.h"
#include <vector>
#include <set>
#include <map>
#include <string>
#include <memory>
#include <sstream>
#include <charconv>
#include <optional>
#include "unescape.h"
// This Trie Tree is C++ implementation of
// https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/rwkv_tokenizer.py
// Perf optimized by leveraging C++ features, but the algorithm is the same.
class TrieTree {
public:
static constexpr int kMaxTokenLength_ = 128;
TrieTree(unsigned char ch = 0) : ch_(ch), to_(256) {}
void add(const std::string& key, int idx = 0,
std::optional<int> value = std::optional<int>()) {
if (idx == key.length()) {
if (!value) {
value = key[0];
}
value_ = value;
return;
}
unsigned char ch = static_cast<unsigned char>(key[idx]);
if (to_[ch] == nullptr) {
to_[ch] = std::make_unique<TrieTree>(ch);
}
to_[ch]->add(key, idx + 1, value);
}
int find_longest(const std::string& key, size_t& idx) {
const TrieTree* u = this;
unsigned char ch = key[idx];
int tok_id = 0;
size_t idx_end = idx;
while (u->to_[ch]) {
u = u->to_[ch].get();
idx += 1;
if (u->value_) {
tok_id = *u->value_;
idx_end = idx;
}
if (idx == key.length()) {
break;
}
ch = key[idx];
}
idx = idx_end;
return tok_id;
}
private:
std::vector<std::unique_ptr<TrieTree>> to_;
std::optional<int> value_;
unsigned char ch_;
};
class TrieTokenizer {
private:
std::map<int, std::string> idx2token;
TrieTree root;
public:
TrieTokenizer(const std::string& text_tokens) {
std::istringstream file(text_tokens);
std::string line;
while (std::getline(file, line)) {
auto l_ws = line.find(' ');
auto r_ws = line.rfind(' ');
if (l_ws == std::string::npos || r_ws == std::string::npos || l_ws == r_ws) {
ORTX_CXX_API_THROW(MakeString("[TrieTokenizer] vocab line: ", line), ORT_RUNTIME_EXCEPTION);
}
int idx = 0;
std::from_chars(line.data(), line.data() + line.size(), idx);
if (idx == 0) {
ORTX_CXX_API_THROW(MakeString("[TrieTokenizer] bad index in vocab line: ", line), ORT_RUNTIME_EXCEPTION);
}
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
std::string x;
int key_length = 0;
if (ort_extensions::UnquoteString(raw, x)) {
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
}
if (x.length() != key_length) {
ORTX_CXX_API_THROW(MakeString("[TrieTokenizer] bad len in vocab line: ", line), ORT_RUNTIME_EXCEPTION);
}
idx2token[idx] = x;
}
for (const auto& kv : idx2token) {
root.add(kv.second, 0, kv.first);
}
}
std::vector<int> encodeBytes(const std::string& src) {
size_t idx = 0;
std::vector<int> tokens;
while (idx < src.length()) {
auto result = root.find_longest(src, idx);
tokens.push_back(result);
}
return tokens;
}
std::string decodeBytes(const std::vector<int>& tokens) {
std::string result;
for (const auto& i : tokens) {
result += idx2token[i];
}
return result;
}
};
struct KernelTrieTokenizer : public BaseKernel {
private:
std::shared_ptr<TrieTokenizer> tokenizer;
public:
KernelTrieTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string text_tokens = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
};
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output) const {
std::vector<std::string> str_input{input.Data()};
const auto& input_dim = input.Shape();
size_t max_length = 0;
std::vector<std::vector<int64_t>> tokenize_results;
for (auto& str : str_input) {
auto tokens = tokenizer->encodeBytes(str);
std::vector<int64_t> tokens_int64(tokens.begin(), tokens.end());
max_length = std::max(max_length, tokens_int64.size());
tokenize_results.emplace_back(tokens_int64);
}
std::vector<int64_t> output_dim = input_dim;
output_dim.push_back(max_length);
auto* token = tokenize_output.Allocate(output_dim);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
token[idx] = id;
idx++;
}
for (size_t i = res.size(); i < max_length; i++) {
token[idx] = 0;
idx++;
}
}
for (auto& result : tokenize_results) {
result.resize(max_length, 0);
}
}
};
struct KernelTrieDetokenizer : public BaseKernel {
private:
std::shared_ptr<TrieTokenizer> tokenizer;
public:
KernelTrieDetokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string text_tokens = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
};
void Compute(const ortc::Tensor<int64_t>& tokens, ortc::Tensor<std::string>& text) const {
const int64_t* p_ids = tokens.Data();
const auto& ids_dim = tokens.Shape();
std::vector<int64_t> output_dim = {1};
if (ids_dim.size() > 1) {
output_dim.resize(ids_dim.size() - 1);
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
}
std::vector<std::string> output(output_dim[0]);
for (auto n = 0; n < output_dim[0]; n++) {
std::vector<int> ids;
for (auto i = 0; i < ids_dim[1]; i++) {
ids.push_back(ort_extensions::narrow<int>(p_ids[n * ids_dim[1] + i]));
}
output[n] = tokenizer->decodeBytes(ids);
}
text.SetStringOutput(output, output_dim);
}
};

Просмотреть файл

@ -0,0 +1,190 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <string_view>
#include <vector>
namespace ort_extensions {
inline bool IsDigit(char c) { return c >= '0' && c <= '9'; }
inline bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }
inline unsigned int hex_digit_to_int(char c) {
unsigned int x = static_cast<unsigned char>(c);
if (x > '9') {
x += 9;
}
return x & 0xf;
}
inline bool IsSurrogate(char32_t c) {
return c >= 0xD800 && c <= 0xDFFF;
}
size_t EncodeUTF8Char(char* buffer, char32_t utf8_char) {
if (utf8_char <= 0x7F) {
*buffer = static_cast<char>(utf8_char);
return 1;
} else if (utf8_char <= 0x7FF) {
buffer[1] = static_cast<char>(0x80 | (utf8_char & 0x3F));
utf8_char >>= 6;
buffer[0] = static_cast<char>(0xC0 | utf8_char);
return 2;
} else if (utf8_char <= 0xFFFF) {
buffer[2] = static_cast<char>(0x80 | (utf8_char & 0x3F));
utf8_char >>= 6;
buffer[1] = static_cast<char>(0x80 | (utf8_char & 0x3F));
utf8_char >>= 6;
buffer[0] = static_cast<char>(0xE0 | utf8_char);
return 3;
} else {
buffer[3] = static_cast<char>(0x80 | (utf8_char & 0x3F));
utf8_char >>= 6;
buffer[2] = static_cast<char>(0x80 | (utf8_char & 0x3F));
utf8_char >>= 6;
buffer[1] = static_cast<char>(0x80 | (utf8_char & 0x3F));
utf8_char >>= 6;
buffer[0] = static_cast<char>(0xF0 | utf8_char);
return 4;
}
}
// Unescape a Python escaped string
inline bool Unescape(const std::string_view& source, std::string& unescaped, bool is_bytes) {
// reserve enough space for the worst case, and final size will be calculated at the end.
unescaped.resize(source.length());
char* d = unescaped.data();
const char* p = source.data();
const char* end = p + source.size();
const char* last_byte = end - 1;
while (p == d && p < end && *p != '\\') p++, d++;
while (p < end) {
if (*p != '\\') {
*d++ = *p++;
} else {
if (++p > last_byte) {
return false;
}
switch (*p) {
case 'n':
*d++ = '\n';
break;
case 'r':
*d++ = '\r';
break;
case 't':
*d++ = '\t';
break;
break;
case '\\':
*d++ = '\\';
break;
case '\'':
*d++ = '\'';
break;
case '"':
*d++ = '\"';
break;
case 'x':
case 'X': {
if (p >= last_byte) {
return false;
} else if (!IsHexDigit(static_cast<unsigned char>(p[1]))) {
return false;
}
unsigned int ch = 0;
const char* hex_start = p;
while (p < last_byte &&
IsHexDigit(static_cast<unsigned char>(p[1])))
ch = (ch << 4) + hex_digit_to_int(*++p);
if (ch > 0xFF && !is_bytes) {
return false;
}
if (is_bytes) {
*d++ = static_cast<char>(ch);
} else {
d += EncodeUTF8Char(d, static_cast<char32_t>(ch));
}
break;
}
case 'u': {
char32_t rune = 0;
const char* hex_start = p;
if (p + 4 >= end) {
return false;
}
for (int i = 0; i < 4; ++i) {
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
rune = (rune << 4) + hex_digit_to_int(*++p);
} else {
return false;
}
}
if (IsSurrogate(rune)) {
return false;
}
d += EncodeUTF8Char(d, rune);
break;
}
case 'U': {
char32_t rune = 0;
const char* hex_start = p;
if (p + 8 >= end) {
return false;
}
for (int i = 0; i < 8; ++i) {
if (IsHexDigit(static_cast<unsigned char>(p[1]))) {
uint32_t newrune = (rune << 4) + hex_digit_to_int(*++p);
if (newrune > 0x10FFFF) {
return false;
} else {
rune = newrune;
}
} else {
return false;
}
}
if (IsSurrogate(rune)) {
return false;
}
d += EncodeUTF8Char(d, rune);
break;
}
default: {
return false;
}
}
p++;
}
}
unescaped.resize(d - unescaped.data());
return true;
}
inline bool UnquoteString(const std::string& str, std::string& unquoted) {
bool is_bytes = false;
int idx_0 = 0;
if (str.front() == 'b') {
is_bytes = true;
idx_0 = 1;
}
std::string str_view(str.data() + idx_0, str.length() - idx_0);
if (str_view.length() < 2) {
return false;
}
if ((str_view.front() != '\"' && str_view.front() != '\'') || str_view.back() != str.back()) {
return false;
}
// unescape the string
return Unescape(std::string_view(str_view.data() + 1, str_view.length() - 2), unquoted, is_bytes);
}
} // namespace ort_extensions

Просмотреть файл

@ -6,3 +6,4 @@ protobuf < 4.0.0
onnxruntime >=1.12.0
transformers >=4.9.2
tensorflow_text >=2.5.0;python_version < '3.11'
requests >=2.26.0

Просмотреть файл

@ -74,13 +74,13 @@ class BuildCMakeExt(_build_ext):
project_dir = pathlib.Path().absolute()
build_temp = pathlib.Path(self.build_temp)
build_temp.mkdir(parents=True, exist_ok=True)
ext_fullpath = pathlib.Path(self.get_ext_fullpath(extension.name))
ext_fullpath = pathlib.Path(self.get_ext_fullpath(extension.name)).absolute()
config = 'RelWithDebInfo' if self.debug else 'Release'
cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(ext_fullpath.parent.absolute()),
'-DOCOS_BUILD_PYTHON=ON',
'-DOCOS_EXTENTION_NAME=' + ext_fullpath.name,
'-DOCOS_PYTHON_MODULE_PATH=' + str(ext_fullpath),
'-DCMAKE_BUILD_TYPE=' + config
]
@ -154,16 +154,6 @@ class BuildCMakeExt(_build_ext):
if not self.dry_run:
self.spawn([cmake_exe, '--build', str(build_temp)] + build_args)
if sys.platform == "win32":
config_dir = '.'
if not (build_temp / 'build.ninja').exists():
config_dir = config
self.copy_file(build_temp / 'bin' / config_dir / 'extensions_pydll.dll', ext_fullpath,
link='hard' if self.debug else None)
else:
self.copy_file(build_temp / 'lib' / ext_fullpath.name, ext_fullpath,
link='sym' if self.debug else None)
class Build(_build):
def initialize_options(self) -> None:

175
test/test_trie_tokenizer.py Normal file
Просмотреть файл

@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###########################################################################
import os
import tempfile
import requests
from unittest import TestCase, main as unittest_main
from onnxruntime_extensions import OrtPyFunction, util
# to avoid to install rwkv LM package, we copy the tokenizer code here.
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
class TRIE:
__slots__ = tuple("ch,to,values,front".split(","))
to: list
values: set
def __init__(self, front=None, ch=None):
self.ch = ch
self.to = [None for ch in range(256)]
self.values = set()
self.front = front
def __repr__(self):
fr = self
ret = []
while (fr != None):
if (fr.ch != None):
ret.append(fr.ch)
fr = fr.front
return "<TRIE %s %s>" % (ret[::-1], self.values)
def add(self, key: bytes, idx: int = 0, val=None):
if (idx == len(key)):
if (val is None):
val = key
self.values.add(val)
return self
ch = key[idx]
if (self.to[ch] is None):
self.to[ch] = TRIE(front=self, ch=ch)
return self.to[ch].add(key, idx=idx + 1, val=val)
def find_longest(self, key: bytes, idx: int = 0):
u: TRIE = self
ch: int = key[idx]
while (u.to[ch] is not None):
u = u.to[ch]
idx += 1
if (u.values):
ret = idx, u, u.values
if (idx == len(key)):
break
ch = key[idx]
return ret
class TRIE_TOKENIZER():
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
with open(file_name, "r", encoding="utf-8") as f:
lines = f.readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
self.root = TRIE()
for t, i in self.token2idx.items():
_ = self.root.add(t, val=(t, i))
def encodeBytes(self, src: bytes):
idx: int = 0
tokens = []
while (idx < len(src)):
_idx: int = idx
idx, _, values = self.root.find_longest(src, idx)
assert (idx != _idx)
_, token = next(iter(values))
tokens.append(token)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
try:
return self.decodeBytes(tokens).decode('utf-8')
except:
return '\ufffd' # bad utf-8
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
print()
########################################################################################################
class TestTrieTokenizer(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
url = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.txt"
# Create a temporary directory and file path
temp_dir = tempfile.mkdtemp()
file_name = os.path.basename(url) # Gets the file name from the URL
cls.vocab_file = os.path.join(temp_dir, file_name)
response = requests.get(url)
with open(cls.vocab_file, "wb") as f:
f.write(response.content)
def test_trie_tokenizer(self):
tokr = TRIE_TOKENIZER(self.vocab_file)
src = "I love you"
tokens = tokr.encode(src)
self.assertEqual(tokens, [74, 31337, 22799])
self.assertEqual(tokr.decode(tokens), src)
def test_ort_trie_tokenizer(self):
vocab_data = util.read_file(self.vocab_file, 'rb')
tokr = OrtPyFunction.from_customop("TrieTokenizer", vocab=vocab_data)
tokens = tokr(["I love you"])
self.assertEqual(list(tokens[0]), [74, 31337, 22799])
detok = OrtPyFunction.from_customop("TrieDetokenizer", vocab=vocab_data)
self.assertEqual(list(detok(tokens)), ["I love you"])
def test_parity(self):
test_sentences = [
"I am a girl",
"我是个女孩",
"私は女の子です",
"广东人爱吃云吞面,还有腌面、竹升面,车仔面、油渣面、普宁面线、伊面等各种圆扁粗细,加碱水,不加碱水的面",
"我是个人类",
"I am a human",
"that dog is so cute",
"私はねこむすめです、にゃん♪",
"宇宙级特大事件!号外号外!"
]
tokr = TRIE_TOKENIZER(self.vocab_file)
ortx_tokr = OrtPyFunction.from_customop("TrieTokenizer", vocab=util.read_file(self.vocab_file, 'rb'))
for s in test_sentences:
self.assertEqual(tokr.encode(s), list(ortx_tokr([s])[0]))
if __name__ == "__main__":
unittest_main()