Add a bbpe tokenizer decoder for Whisper model (#376)

* initial PR

* add the attributes for op

* cmake update

* add the missing symbol

* add a unit test case

* fix the unit test

* fix some corner case.

* format Python code with autopep8
This commit is contained in:
Wenbing Li 2023-03-08 15:00:01 -08:00 коммит произвёл GitHub
Родитель 6b88f4e31f
Коммит 3b0bd66e9e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
21 изменённых файлов: 372 добавлений и 56 удалений

Просмотреть файл

@ -362,16 +362,13 @@ if(OCOS_ENABLE_BLINGFIRE)
endif()
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
if(NOT TARGET nlohmann_json)
set(JSON_BuildTests OFF CACHE INTERNAL "")
message(STATUS "Fetch json")
include(json)
endif()
message(STATUS "Fetch json")
include(json)
endif()
if(_HAS_TOKENIZER)
message(STATUS "Tokenizer needed.")
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*")
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*" "operators/tokenizer/*.hpp")
list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC})
endif()
@ -412,8 +409,7 @@ target_include_directories(ocos_operators PUBLIC
${ONNXRUNTIME_INCLUDE_DIR}
${PROJECT_SOURCE_DIR}/includes
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators
${PROJECT_SOURCE_DIR}/operators/tokenizer)
${PROJECT_SOURCE_DIR}/operators)
set(ocos_libraries)
set(OCOS_COMPILE_DEFINITIONS)
@ -424,6 +420,8 @@ endif()
if(_HAS_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER)
target_include_directories(ocos_operators PUBLIC
${PROJECT_SOURCE_DIR}/operators/tokenizer)
endif()
if(OCOS_ENABLE_TF_STRING)
@ -491,7 +489,7 @@ if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
endif()
target_include_directories(ocos_operators PRIVATE ${GSL_INCLUDE_DIR})
target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR})
list(APPEND ocos_libraries Microsoft.GSL::GSL)
list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS)

Просмотреть файл

@ -2,26 +2,7 @@
// Licensed under the MIT License.
#include <sstream>
#include "ocos.h"
bool BaseKernel::HasAttribute(const char* name) const noexcept {
size_t size;
std::string out;
// Crashes here.
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
auto r = api_.GetErrorCode(status);
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
if (has) {
api_.ReleaseStatus(status);
return has;
}
const char* error = api_.GetErrorMessage(status);
if (strstr(error, "No attribute") == error) {
api_.ReleaseStatus(status);
return false;
}
api_.ReleaseStatus(status);
return true;
}
#include "narrow.h"
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
if (status == nullptr) {
@ -72,6 +53,17 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcep
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, int& value) const noexcept {
int64_t origin_value = 0;
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
return false;
}
value = ort_extensions::narrow<int>(origin_value);
return true;
}
template <>
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
int64_t origin_value = 0;

Просмотреть файл

@ -7,6 +7,9 @@ for /f "tokens=* USEBACKQ" %%i in (
IF NOT DEFINED VSINSTALLDIR GOTO :NOT_FOUND
IF "%1" == "-A" GOTO :VSDEV_CMD
set GEN_PLATFORM=-A x64
:VSDEV_CMD
set GENERATOR="Visual Studio 16 2019"
IF "%VisualStudioVersion:~0,2%" == "16" GOTO :START_BUILD
@ -15,7 +18,7 @@ set GENERATOR="Visual Studio 17 2022"
:START_BUILD
set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe"
mkdir .\out\Windows\ 2>NUL
%cmake_exe% -G %GENERATOR% -A x64 %* -B out\Windows -S .
%cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S .
IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL%
%cmake_exe% --build out\Windows --config RelWithDebInfo
IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL%

2
cmake/externals/json.cmake поставляемый
Просмотреть файл

@ -2,6 +2,8 @@ FetchContent_Declare(json
GIT_REPOSITORY https://github.com/nlohmann/json.git
GIT_TAG v3.10.5)
set(JSON_BuildTests OFF CACHE INTERNAL "")
FetchContent_GetProperties(json)
if(NOT json_POPULATED)
FetchContent_Populate(json)

8
cmake/externals/opencv.cmake поставляемый
Просмотреть файл

@ -111,6 +111,13 @@ if(IOS)
set(CPU_BASELINE DETECT)
endif()
if (MSVC AND CMAKE_GENERATOR_PLATFORM)
string(TOLOWER ${CMAKE_GENERATOR_PLATFORM} _GEN_PLATFORM)
if (${_GEN_PLATFORM} MATCHES "arm|arm64")
set(OPENCV_SKIP_SYSTEM_PROCESSOR_DETECTION ON)
endif()
endif()
FetchContent_Declare(
opencv
GIT_REPOSITORY https://github.com/opencv/opencv.git
@ -150,4 +157,3 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set_target_properties(${p} PROPERTIES FOLDER "externals/opencv")
endforeach()
endif()

Просмотреть файл

@ -23,14 +23,12 @@ struct BaseKernel {
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept : api_(api), info_(info), ort_(api_) {
}
bool HasAttribute(const char* name) const noexcept;
template <class T>
bool TryToGetAttribute(const char* name, T& value) const noexcept;
template <class T>
T TryToGetAttributeWithDefault(const char* name, T default_value) const noexcept {
T& result = default_value;
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
T result = default_value;
TryToGetAttribute(name, result);
return result;
}

Просмотреть файл

@ -55,6 +55,18 @@ class GPT2Tokenizer(CustomOp):
]
class BpeDecoder(CustomOp):
@classmethod
def get_inputs(cls):
return [
cls.io_def("ids", onnx.TensorProto.INT64, [None])
]
@classmethod
def get_outputs(cls):
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
class VectorToString(CustomOp):
@classmethod
@ -387,3 +399,7 @@ Opdef.create(_argsort_op,
op_type='ArgSort',
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
outputs=[PyCustomOpDef.dt_int64])
class CustomOpConverter:
pass

Просмотреть файл

@ -18,7 +18,9 @@ def get_opset_version_from_ort():
"1.9": 15,
"1.10": 15,
"1.11": 16,
"1.12": 17
"1.12": 17,
"1.13": 17,
"1.14": 18,
}
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
@ -32,7 +34,8 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
) else onnx.helper.make_model
model = fn_mm(graph, opset_imports=[
onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
model.opset_import.extend(
[onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
return model
@ -54,14 +57,23 @@ class OrtPyFunction:
self.default_inputs = {}
def create_from_customop(self, op_type, *args, **kwargs):
graph = SingleOpGraph.build_my_graph(op_type, *args, **kwargs)
cvt = kwargs.get('cvt', None)
if cvt is None:
cvt = args[0] if len(args) > 0 and isinstance(
args[0], CustomOpConverter) else None
args = args[1:]
else:
del kwargs['cvt']
new_kwargs = kwargs if cvt is None else cvt(**kwargs)
graph = SingleOpGraph.build_my_graph(op_type, *args, **new_kwargs)
self._bind(make_onnx_model(graph))
return self
def add_default_input(self, **kwargs):
inputs = {
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else \
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
}
self.default_inputs.update(inputs)
@ -87,7 +99,8 @@ class OrtPyFunction:
def _ensure_ort_session(self):
if self.ort_session is None:
sess = _ort.InferenceSession(self.onnx_model.SerializeToString(), self.get_ort_session_options())
sess = _ort.InferenceSession(
self.onnx_model.SerializeToString(), self.get_ort_session_options())
self.ort_session = sess
return self.ort_session
@ -113,7 +126,8 @@ class OrtPyFunction:
# an annoying bug is numpy by default is int32, while pytorch is int64.
# so cast the input here automatically.
feed[i_.name] = \
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
ts_x.astype(
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
idx += 1
# feed.update(kwargs)
@ -121,7 +135,8 @@ class OrtPyFunction:
def __call__(self, *args, **kwargs):
self._ensure_ort_session()
outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
outputs = self.ort_session.run(
None, self._argument_map(*args, **kwargs))
return outputs[0] if len(outputs) == 1 else tuple(outputs)

Просмотреть файл

@ -0,0 +1,43 @@
import json
from ._cuops import CustomOpConverter
class HFTokenizerConverter(CustomOpConverter):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def bpe_tokenizer(self, **kwargs):
hf_gpt2_tokenizer = self.tokenizer
attrs = {'vocab': json.dumps(
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
sorted_merges = {v_: k_ for k_,
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
attrs['merges'] = '\n'.join("{} {}".format(
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
attrs.update(**kwargs)
return attrs
def bpe_decoder(self, **kwargs):
decoder = self.tokenizer.decoder
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
# with open("id_vocab.txt", "w", encoding="utf-8") as f:
# f.write(id_vocab)
byte_decoder = self.tokenizer.byte_decoder
str_byte_decoder = "\n".join(["{}\t{}".format(
ord(_c), str(byte_decoder[_c])) for _c in byte_decoder])
# with open("byte_decoder.txt", "w", encoding="utf-8") as f:
# f.write(str_byte_decoder)
all_special_ids = self.tokenizer.all_special_ids
added_tokens = self.tokenizer.added_tokens_decoder
str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids])
str_added_tokens = "\n".join(
["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens])
kwargs.update({
"id_vocab": id_vocab,
"byte_decoder": str_byte_decoder,
"added_tokens": str_added_tokens,
"all_special_ids": str_all_special_ids,
"skip_special_tokens": kwargs.get("skip_special_tokens", False)
})
return kwargs

Просмотреть файл

@ -81,7 +81,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
: CommonRaggedTensorToDense(api, info) {
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(&info, "missing_value") : -1;
missing_value_ = TryToGetAttributeWithDefault("missing_value", -1) ;
}
void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -10,7 +10,7 @@
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(&info_, "global_replace") : 1;
global_replace_ = TryToGetAttributeWithDefault("global_replace",1);
}
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -24,9 +24,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
if (HasAttribute("max_sentence")) {
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(&info, "max_sentence"));
}
max_sentence = TryToGetAttributeWithDefault("max_sentence", -1);
}
void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {

Просмотреть файл

@ -0,0 +1,193 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "ustring.h"
#include "narrow.h"
#include <string>
#include <vector>
#include <locale>
#include <codecvt>
#include <set>
#include <map>
#include <unordered_map>
struct KernelBpeDecoder : public BaseKernel {
public:
KernelBpeDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "id_vocab");
if (vocab.empty()) {
ORTX_CXX_API_THROW("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT);
}
BuildIdVocab(vocab);
std::string byte_decoder = ort_.KernelInfoGetAttribute<std::string>(&info, "byte_decoder");
if (byte_decoder.empty()) {
ORTX_CXX_API_THROW("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT);
} else {
auto um = ParseId2String(byte_decoder);
std::transform(um.begin(), um.end(),
std::inserter(byte_decoder_, byte_decoder_.end()),
[](const auto& p) { return std::make_pair(static_cast<char32_t>(p.first),
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
}
std::string added_tokens = TryToGetAttributeWithDefault<std::string>("added_tokens", "");
if (!added_tokens.empty()) {
auto um = ParseId2String(added_tokens);
added_tokens_ = std::map<int64_t, std::string>(um.begin(), um.end());
}
std::string all_special_ids = TryToGetAttributeWithDefault<std::string>("all_special_ids", "");
if (!all_special_ids.empty()) {
auto um = ParseId2String(all_special_ids);
std::transform(um.begin(), um.end(),
std::inserter(all_special_ids_, all_special_ids_.end()),
[](const auto& p) { return p.first; });
}
en_normalization_ = TryToGetAttributeWithDefault<int64_t>("en_normalization", 0);
skip_special_tokens_ = TryToGetAttributeWithDefault<int64_t>("skip_special_tokens", 0);
whitespace_token_ = TryToGetAttributeWithDefault<int64_t>("whitespace_token", 0);
bos_token_ = TryToGetAttributeWithDefault("bos_token", std::string("<|endoftext|>"));
eos_token_ = TryToGetAttributeWithDefault("eos_token", std::string("<|endoftext|>"));
unk_token_ = TryToGetAttributeWithDefault("unk_token", std::string("<|endoftext|>"));
}
std::unordered_map<int64_t, std::string> ParseId2String(const std::string& s_attr) {
std::unordered_map<int64_t, std::string> result;
result.reserve(s_attr.size() / 4);
std::stringstream ss(s_attr);
std::string line;
std::string token;
while (std::getline(ss, line, '\n')) {
size_t pos_end = 0;
int64_t v = std::stoll(line, &pos_end);
if (pos_end >= line.size() || line[pos_end] != '\t') {
token.clear();
} else {
token = line.substr(pos_end + 1);
}
result.emplace(v, token);
}
return result;
}
void BuildIdVocab(const std::string& vocab) {
arr_vocab_.reserve(vocab.size() / 2); // give a rough estimation.
std::u32string u_vocab = ustring(vocab);
std::u32string_view uv_vocab(u_vocab);
size_t last_pos = 0;
auto ccount = uv_vocab.size();
for (size_t n = 0; n < ccount; ++n) {
if (uv_vocab[n] == char32_t('\n')) {
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos);
arr_vocab_.emplace_back(ustring(s_tok));
last_pos = n + 1;
} else if (n == ccount - 1) {
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos + 1);
arr_vocab_.emplace_back(ustring(s_tok));
}
}
arr_vocab_.shrink_to_fit();
}
void Compute(OrtKernelContext* context) {
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
OrtTensorDimensions ids_dim(ort_, ids);
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
ORTX_CXX_API_THROW("[BpeDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
std::string text;
bool f_special_last = false;
bool f_special = false;
auto count = static_cast<size_t>(ids_dim.Size());
for (size_t tok_idx = 0; tok_idx < count; ++tok_idx) {
const auto token = *(p_ids + tok_idx);
std::string decoded_token;
f_special = all_special_ids_.count(token) ? true : false;
if (skip_special_tokens_ && f_special) {
f_special_last = f_special;
continue;
}
if (added_tokens_.count(token)) {
const std::string ws = added_tokens_.at(token);
decoded_token = (std::string)ws;
} else {
const auto str = arr_vocab_[token];
for (auto wchr : str) {
unsigned char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
}
}
if (whitespace_token_ &&
f_special && (tok_idx > 0 && !f_special_last)) {
text.push_back(' ');
}
text.append(decoded_token);
if (whitespace_token_ &&
f_special && tok_idx != count - 1) {
text.push_back(' ');
}
f_special_last = f_special;
}
std::vector<int64_t> output_dim = {1};
std::vector<std::string> result = {text};
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
FillTensorDataString(api_, ort_, context, result, output);
}
private:
std::string bos_token_;
std::string eos_token_;
std::string unk_token_;
// Since ORT API doesn't support boolean type in ONNX node attribute,
// all flag attributes here are defined as int64 type to be more explicit.
int64_t en_normalization_ = 0;
int64_t skip_special_tokens_ = 0;
int64_t whitespace_token_ = 0;
std::vector<ustring> arr_vocab_;
std::map<char32_t, unsigned char> byte_decoder_;
std::map<int64_t, std::string> added_tokens_;
std::set<int64_t> all_special_ids_;
};
struct CustomOpBpeDecoder : OrtW::CustomOpBase<CustomOpBpeDecoder, KernelBpeDecoder> {
const char* GetName() const {
return "BpeDecoder";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
};

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include "bpetokenizer.hpp"
#include "bpe_tokenizer.hpp"
struct KernelClipBpeTokenizer : BaseKernel {
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include "bpetokenizer.hpp"
#include "bpe_tokenizer.hpp"
struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include "bpetokenizer.hpp"
#include "bpe_tokenizer.hpp"
struct KernelRobertaBpeTokenizer : BaseKernel {
KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);

Просмотреть файл

@ -31,11 +31,10 @@ struct KernelSentencepieceDecoder : BaseKernel {
ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
auto count = ids_dim[0];
std::string decoded_string;
std::vector<int64_t> output_dim = {1};
std::vector<int> tids;
std::transform(p_ids, p_ids + count,
std::transform(p_ids, p_ids + ids_dim.Size(),
std::back_inserter(tids),
[](auto _id) { return static_cast<int>(_id); });
auto status = tokenizer_.Decode(tids, &decoded_string);

Просмотреть файл

@ -5,6 +5,7 @@
#include "gpt2_tokenizer.hpp"
#include "clip_tokenizer.hpp"
#include "roberta_tokenizer.hpp"
#include "bpe_decoder.hpp"
#endif
#ifdef ENABLE_SPM_TOKENIZER
@ -33,6 +34,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = LoadCustomOpClasses<
, CustomOpBpeTokenizer
, CustomOpClipBpeTokenizer
, CustomOpRobertaBpeTokenizer
, CustomOpBpeDecoder
#endif
#ifdef ENABLE_SPM_TOKENIZER

Просмотреть файл

@ -11,9 +11,7 @@ KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtK
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(&info, "suffix_indicator");
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unknown_token");
max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word")
? ort_.KernelInfoGetAttribute<int64_t>(&info, "max_input_chars_per_word")
: 200;
max_input_chars_per_word_ = TryToGetAttributeWithDefault("max_input_chars_per_word", 200);
suffix_indicator_ = ustring(suffix_indicator);
unk_token_ = ustring(unk);

Просмотреть файл

@ -0,0 +1,53 @@
import unittest
import numpy as np
from transformers import AutoProcessor
from onnxruntime_extensions import PyOrtFunction
from onnxruntime_extensions.cvt import HFTokenizerConverter
class TestBpeTokenizer(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.hf_processor = AutoProcessor.from_pretrained(
"openai/whisper-tiny.en")
cls.tokenizer_cvt = HFTokenizerConverter(cls.hf_processor.tokenizer)
return super().setUpClass()
def test_bpe_tokenizer(self):
fn_tokenizer = PyOrtFunction.from_customop(
"GPT2Tokenizer",
cvt=(self.tokenizer_cvt).bpe_tokenizer)
test_str = " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles"
test_ids = self.hf_processor.tokenizer.encode(test_str)
self.assertTrue(fn_tokenizer(test_ids), test_str)
def test_en_decoder(self):
special_tokens = False
fn_decoder = PyOrtFunction.from_customop(
"BpeDecoder",
cvt=self.tokenizer_cvt.bpe_decoder,
skip_special_tokens=not special_tokens)
test_str = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
test_token_ids = self.hf_processor.tokenizer.encode(test_str)
expected_str = self.hf_processor.tokenizer.decode(
test_token_ids, skip_special_tokens=not special_tokens)
self.assertEqual(fn_decoder(np.asarray(test_token_ids)), expected_str)
def test_en_decoder_with_special(self):
special_tokens = True
fn_decoder = PyOrtFunction.from_customop(
"BpeDecoder",
cvt=self.tokenizer_cvt.bpe_decoder,
skip_special_tokens=not special_tokens)
test_str = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
test_token_ids = self.hf_processor.tokenizer.encode(test_str)
expected_str = self.hf_processor.tokenizer.decode(
test_token_ids, skip_special_tokens=not special_tokens)
actual_str = fn_decoder(np.asarray(test_token_ids))
self.assertEqual(actual_str[0], expected_str)
if __name__ == "__main__":
unittest.main()