Add a bbpe tokenizer decoder for Whisper model (#376)
* initial PR * add the attributes for op * cmake update * add the missing symbol * add a unit test case * fix the unit test * fix some corner case. * format Python code with autopep8
This commit is contained in:
Родитель
6b88f4e31f
Коммит
3b0bd66e9e
|
@ -362,16 +362,13 @@ if(OCOS_ENABLE_BLINGFIRE)
|
|||
endif()
|
||||
|
||||
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||
if(NOT TARGET nlohmann_json)
|
||||
set(JSON_BuildTests OFF CACHE INTERNAL "")
|
||||
message(STATUS "Fetch json")
|
||||
include(json)
|
||||
endif()
|
||||
message(STATUS "Fetch json")
|
||||
include(json)
|
||||
endif()
|
||||
|
||||
if(_HAS_TOKENIZER)
|
||||
message(STATUS "Tokenizer needed.")
|
||||
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*")
|
||||
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*" "operators/tokenizer/*.hpp")
|
||||
list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
|
@ -412,8 +409,7 @@ target_include_directories(ocos_operators PUBLIC
|
|||
${ONNXRUNTIME_INCLUDE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/base
|
||||
${PROJECT_SOURCE_DIR}/operators
|
||||
${PROJECT_SOURCE_DIR}/operators/tokenizer)
|
||||
${PROJECT_SOURCE_DIR}/operators)
|
||||
|
||||
set(ocos_libraries)
|
||||
set(OCOS_COMPILE_DEFINITIONS)
|
||||
|
@ -424,6 +420,8 @@ endif()
|
|||
|
||||
if(_HAS_TOKENIZER)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER)
|
||||
target_include_directories(ocos_operators PUBLIC
|
||||
${PROJECT_SOURCE_DIR}/operators/tokenizer)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_TF_STRING)
|
||||
|
@ -491,7 +489,7 @@ if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
|||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||
endif()
|
||||
|
||||
target_include_directories(ocos_operators PRIVATE ${GSL_INCLUDE_DIR})
|
||||
target_include_directories(noexcep_operators PUBLIC ${GSL_INCLUDE_DIR})
|
||||
list(APPEND ocos_libraries Microsoft.GSL::GSL)
|
||||
|
||||
list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS)
|
||||
|
|
32
base/ocos.cc
32
base/ocos.cc
|
@ -2,26 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
#include <sstream>
|
||||
#include "ocos.h"
|
||||
|
||||
bool BaseKernel::HasAttribute(const char* name) const noexcept {
|
||||
size_t size;
|
||||
std::string out;
|
||||
// Crashes here.
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
|
||||
auto r = api_.GetErrorCode(status);
|
||||
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
|
||||
if (has) {
|
||||
api_.ReleaseStatus(status);
|
||||
return has;
|
||||
}
|
||||
const char* error = api_.GetErrorMessage(status);
|
||||
if (strstr(error, "No attribute") == error) {
|
||||
api_.ReleaseStatus(status);
|
||||
return false;
|
||||
}
|
||||
api_.ReleaseStatus(status);
|
||||
return true;
|
||||
}
|
||||
#include "narrow.h"
|
||||
|
||||
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
|
||||
if (status == nullptr) {
|
||||
|
@ -72,6 +53,17 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcep
|
|||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, int& value) const noexcept {
|
||||
int64_t origin_value = 0;
|
||||
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
|
||||
return false;
|
||||
}
|
||||
|
||||
value = ort_extensions::narrow<int>(origin_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
|
||||
int64_t origin_value = 0;
|
||||
|
|
|
@ -7,6 +7,9 @@ for /f "tokens=* USEBACKQ" %%i in (
|
|||
|
||||
IF NOT DEFINED VSINSTALLDIR GOTO :NOT_FOUND
|
||||
|
||||
IF "%1" == "-A" GOTO :VSDEV_CMD
|
||||
set GEN_PLATFORM=-A x64
|
||||
|
||||
:VSDEV_CMD
|
||||
set GENERATOR="Visual Studio 16 2019"
|
||||
IF "%VisualStudioVersion:~0,2%" == "16" GOTO :START_BUILD
|
||||
|
@ -15,7 +18,7 @@ set GENERATOR="Visual Studio 17 2022"
|
|||
:START_BUILD
|
||||
set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe"
|
||||
mkdir .\out\Windows\ 2>NUL
|
||||
%cmake_exe% -G %GENERATOR% -A x64 %* -B out\Windows -S .
|
||||
%cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S .
|
||||
IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL%
|
||||
%cmake_exe% --build out\Windows --config RelWithDebInfo
|
||||
IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL%
|
||||
|
|
|
@ -2,6 +2,8 @@ FetchContent_Declare(json
|
|||
GIT_REPOSITORY https://github.com/nlohmann/json.git
|
||||
GIT_TAG v3.10.5)
|
||||
|
||||
set(JSON_BuildTests OFF CACHE INTERNAL "")
|
||||
|
||||
FetchContent_GetProperties(json)
|
||||
if(NOT json_POPULATED)
|
||||
FetchContent_Populate(json)
|
||||
|
|
|
@ -111,6 +111,13 @@ if(IOS)
|
|||
set(CPU_BASELINE DETECT)
|
||||
endif()
|
||||
|
||||
if (MSVC AND CMAKE_GENERATOR_PLATFORM)
|
||||
string(TOLOWER ${CMAKE_GENERATOR_PLATFORM} _GEN_PLATFORM)
|
||||
if (${_GEN_PLATFORM} MATCHES "arm|arm64")
|
||||
set(OPENCV_SKIP_SYSTEM_PROCESSOR_DETECTION ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
opencv
|
||||
GIT_REPOSITORY https://github.com/opencv/opencv.git
|
||||
|
@ -150,4 +157,3 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|||
set_target_properties(${p} PROPERTIES FOLDER "externals/opencv")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -23,14 +23,12 @@ struct BaseKernel {
|
|||
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept : api_(api), info_(info), ort_(api_) {
|
||||
}
|
||||
|
||||
bool HasAttribute(const char* name) const noexcept;
|
||||
|
||||
template <class T>
|
||||
bool TryToGetAttribute(const char* name, T& value) const noexcept;
|
||||
|
||||
template <class T>
|
||||
T TryToGetAttributeWithDefault(const char* name, T default_value) const noexcept {
|
||||
T& result = default_value;
|
||||
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
|
||||
T result = default_value;
|
||||
TryToGetAttribute(name, result);
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -55,6 +55,18 @@ class GPT2Tokenizer(CustomOp):
|
|||
]
|
||||
|
||||
|
||||
class BpeDecoder(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [
|
||||
cls.io_def("ids", onnx.TensorProto.INT64, [None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
|
||||
|
||||
|
||||
class VectorToString(CustomOp):
|
||||
|
||||
@classmethod
|
||||
|
@ -387,3 +399,7 @@ Opdef.create(_argsort_op,
|
|||
op_type='ArgSort',
|
||||
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
|
||||
outputs=[PyCustomOpDef.dt_int64])
|
||||
|
||||
|
||||
class CustomOpConverter:
|
||||
pass
|
||||
|
|
|
@ -18,7 +18,9 @@ def get_opset_version_from_ort():
|
|||
"1.9": 15,
|
||||
"1.10": 15,
|
||||
"1.11": 16,
|
||||
"1.12": 17
|
||||
"1.12": 17,
|
||||
"1.13": 17,
|
||||
"1.14": 18,
|
||||
}
|
||||
|
||||
ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
|
||||
|
@ -32,7 +34,8 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
|
|||
) else onnx.helper.make_model
|
||||
model = fn_mm(graph, opset_imports=[
|
||||
onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
|
||||
model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
|
||||
model.opset_import.extend(
|
||||
[onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
|
||||
return model
|
||||
|
||||
|
||||
|
@ -54,14 +57,23 @@ class OrtPyFunction:
|
|||
self.default_inputs = {}
|
||||
|
||||
def create_from_customop(self, op_type, *args, **kwargs):
|
||||
graph = SingleOpGraph.build_my_graph(op_type, *args, **kwargs)
|
||||
cvt = kwargs.get('cvt', None)
|
||||
if cvt is None:
|
||||
cvt = args[0] if len(args) > 0 and isinstance(
|
||||
args[0], CustomOpConverter) else None
|
||||
args = args[1:]
|
||||
else:
|
||||
del kwargs['cvt']
|
||||
|
||||
new_kwargs = kwargs if cvt is None else cvt(**kwargs)
|
||||
graph = SingleOpGraph.build_my_graph(op_type, *args, **new_kwargs)
|
||||
self._bind(make_onnx_model(graph))
|
||||
return self
|
||||
|
||||
def add_default_input(self, **kwargs):
|
||||
inputs = {
|
||||
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else \
|
||||
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
|
||||
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else
|
||||
np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
|
||||
}
|
||||
|
||||
self.default_inputs.update(inputs)
|
||||
|
@ -87,7 +99,8 @@ class OrtPyFunction:
|
|||
|
||||
def _ensure_ort_session(self):
|
||||
if self.ort_session is None:
|
||||
sess = _ort.InferenceSession(self.onnx_model.SerializeToString(), self.get_ort_session_options())
|
||||
sess = _ort.InferenceSession(
|
||||
self.onnx_model.SerializeToString(), self.get_ort_session_options())
|
||||
self.ort_session = sess
|
||||
|
||||
return self.ort_session
|
||||
|
@ -113,7 +126,8 @@ class OrtPyFunction:
|
|||
# an annoying bug is numpy by default is int32, while pytorch is int64.
|
||||
# so cast the input here automatically.
|
||||
feed[i_.name] = \
|
||||
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||
ts_x.astype(
|
||||
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||
idx += 1
|
||||
|
||||
# feed.update(kwargs)
|
||||
|
@ -121,7 +135,8 @@ class OrtPyFunction:
|
|||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._ensure_ort_session()
|
||||
outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
|
||||
outputs = self.ort_session.run(
|
||||
None, self._argument_map(*args, **kwargs))
|
||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
import json
|
||||
from ._cuops import CustomOpConverter
|
||||
|
||||
|
||||
class HFTokenizerConverter(CustomOpConverter):
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def bpe_tokenizer(self, **kwargs):
|
||||
hf_gpt2_tokenizer = self.tokenizer
|
||||
attrs = {'vocab': json.dumps(
|
||||
hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_,
|
||||
v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(
|
||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
attrs.update(**kwargs)
|
||||
return attrs
|
||||
|
||||
def bpe_decoder(self, **kwargs):
|
||||
decoder = self.tokenizer.decoder
|
||||
id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
|
||||
# with open("id_vocab.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(id_vocab)
|
||||
byte_decoder = self.tokenizer.byte_decoder
|
||||
str_byte_decoder = "\n".join(["{}\t{}".format(
|
||||
ord(_c), str(byte_decoder[_c])) for _c in byte_decoder])
|
||||
# with open("byte_decoder.txt", "w", encoding="utf-8") as f:
|
||||
# f.write(str_byte_decoder)
|
||||
all_special_ids = self.tokenizer.all_special_ids
|
||||
added_tokens = self.tokenizer.added_tokens_decoder
|
||||
str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids])
|
||||
str_added_tokens = "\n".join(
|
||||
["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens])
|
||||
kwargs.update({
|
||||
"id_vocab": id_vocab,
|
||||
"byte_decoder": str_byte_decoder,
|
||||
"added_tokens": str_added_tokens,
|
||||
"all_special_ids": str_all_special_ids,
|
||||
"skip_special_tokens": kwargs.get("skip_special_tokens", False)
|
||||
})
|
||||
|
||||
return kwargs
|
|
@ -81,7 +81,7 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
|
|||
|
||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CommonRaggedTensorToDense(api, info) {
|
||||
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(&info, "missing_value") : -1;
|
||||
missing_value_ = TryToGetAttributeWithDefault("missing_value", -1) ;
|
||||
}
|
||||
|
||||
void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(&info_, "global_replace") : 1;
|
||||
global_replace_ = TryToGetAttributeWithDefault("global_replace",1);
|
||||
}
|
||||
|
||||
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -24,9 +24,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
|
||||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||
|
||||
if (HasAttribute("max_sentence")) {
|
||||
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(&info, "max_sentence"));
|
||||
}
|
||||
max_sentence = TryToGetAttributeWithDefault("max_sentence", -1);
|
||||
}
|
||||
|
||||
void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "ustring.h"
|
||||
#include "narrow.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
struct KernelBpeDecoder : public BaseKernel {
|
||||
public:
|
||||
KernelBpeDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "id_vocab");
|
||||
if (vocab.empty()) {
|
||||
ORTX_CXX_API_THROW("[BPEDecoder]id vocab text cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
BuildIdVocab(vocab);
|
||||
|
||||
std::string byte_decoder = ort_.KernelInfoGetAttribute<std::string>(&info, "byte_decoder");
|
||||
if (byte_decoder.empty()) {
|
||||
ORTX_CXX_API_THROW("[BPEDecoder]byte_decoder cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
} else {
|
||||
auto um = ParseId2String(byte_decoder);
|
||||
std::transform(um.begin(), um.end(),
|
||||
std::inserter(byte_decoder_, byte_decoder_.end()),
|
||||
[](const auto& p) { return std::make_pair(static_cast<char32_t>(p.first),
|
||||
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
|
||||
}
|
||||
|
||||
std::string added_tokens = TryToGetAttributeWithDefault<std::string>("added_tokens", "");
|
||||
if (!added_tokens.empty()) {
|
||||
auto um = ParseId2String(added_tokens);
|
||||
added_tokens_ = std::map<int64_t, std::string>(um.begin(), um.end());
|
||||
}
|
||||
|
||||
std::string all_special_ids = TryToGetAttributeWithDefault<std::string>("all_special_ids", "");
|
||||
if (!all_special_ids.empty()) {
|
||||
auto um = ParseId2String(all_special_ids);
|
||||
std::transform(um.begin(), um.end(),
|
||||
std::inserter(all_special_ids_, all_special_ids_.end()),
|
||||
[](const auto& p) { return p.first; });
|
||||
}
|
||||
|
||||
en_normalization_ = TryToGetAttributeWithDefault<int64_t>("en_normalization", 0);
|
||||
skip_special_tokens_ = TryToGetAttributeWithDefault<int64_t>("skip_special_tokens", 0);
|
||||
whitespace_token_ = TryToGetAttributeWithDefault<int64_t>("whitespace_token", 0);
|
||||
bos_token_ = TryToGetAttributeWithDefault("bos_token", std::string("<|endoftext|>"));
|
||||
eos_token_ = TryToGetAttributeWithDefault("eos_token", std::string("<|endoftext|>"));
|
||||
unk_token_ = TryToGetAttributeWithDefault("unk_token", std::string("<|endoftext|>"));
|
||||
}
|
||||
|
||||
std::unordered_map<int64_t, std::string> ParseId2String(const std::string& s_attr) {
|
||||
std::unordered_map<int64_t, std::string> result;
|
||||
result.reserve(s_attr.size() / 4);
|
||||
std::stringstream ss(s_attr);
|
||||
|
||||
std::string line;
|
||||
std::string token;
|
||||
while (std::getline(ss, line, '\n')) {
|
||||
size_t pos_end = 0;
|
||||
int64_t v = std::stoll(line, &pos_end);
|
||||
if (pos_end >= line.size() || line[pos_end] != '\t') {
|
||||
token.clear();
|
||||
} else {
|
||||
token = line.substr(pos_end + 1);
|
||||
}
|
||||
result.emplace(v, token);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void BuildIdVocab(const std::string& vocab) {
|
||||
arr_vocab_.reserve(vocab.size() / 2); // give a rough estimation.
|
||||
|
||||
std::u32string u_vocab = ustring(vocab);
|
||||
std::u32string_view uv_vocab(u_vocab);
|
||||
size_t last_pos = 0;
|
||||
|
||||
auto ccount = uv_vocab.size();
|
||||
for (size_t n = 0; n < ccount; ++n) {
|
||||
if (uv_vocab[n] == char32_t('\n')) {
|
||||
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos);
|
||||
arr_vocab_.emplace_back(ustring(s_tok));
|
||||
last_pos = n + 1;
|
||||
} else if (n == ccount - 1) {
|
||||
std::u32string_view s_tok = uv_vocab.substr(last_pos, n - last_pos + 1);
|
||||
arr_vocab_.emplace_back(ustring(s_tok));
|
||||
}
|
||||
}
|
||||
|
||||
arr_vocab_.shrink_to_fit();
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
|
||||
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
|
||||
OrtTensorDimensions ids_dim(ort_, ids);
|
||||
|
||||
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
||||
ORTX_CXX_API_THROW("[BpeDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
std::string text;
|
||||
bool f_special_last = false;
|
||||
bool f_special = false;
|
||||
auto count = static_cast<size_t>(ids_dim.Size());
|
||||
|
||||
for (size_t tok_idx = 0; tok_idx < count; ++tok_idx) {
|
||||
const auto token = *(p_ids + tok_idx);
|
||||
std::string decoded_token;
|
||||
f_special = all_special_ids_.count(token) ? true : false;
|
||||
if (skip_special_tokens_ && f_special) {
|
||||
f_special_last = f_special;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (added_tokens_.count(token)) {
|
||||
const std::string ws = added_tokens_.at(token);
|
||||
decoded_token = (std::string)ws;
|
||||
} else {
|
||||
const auto str = arr_vocab_[token];
|
||||
for (auto wchr : str) {
|
||||
unsigned char uchr = byte_decoder_.at(wchr);
|
||||
decoded_token.push_back(uchr);
|
||||
}
|
||||
}
|
||||
|
||||
if (whitespace_token_ &&
|
||||
f_special && (tok_idx > 0 && !f_special_last)) {
|
||||
text.push_back(' ');
|
||||
}
|
||||
|
||||
text.append(decoded_token);
|
||||
|
||||
if (whitespace_token_ &&
|
||||
f_special && tok_idx != count - 1) {
|
||||
text.push_back(' ');
|
||||
}
|
||||
|
||||
f_special_last = f_special;
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
std::vector<std::string> result = {text};
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::string unk_token_;
|
||||
|
||||
// Since ORT API doesn't support boolean type in ONNX node attribute,
|
||||
// all flag attributes here are defined as int64 type to be more explicit.
|
||||
int64_t en_normalization_ = 0;
|
||||
int64_t skip_special_tokens_ = 0;
|
||||
int64_t whitespace_token_ = 0;
|
||||
std::vector<ustring> arr_vocab_;
|
||||
std::map<char32_t, unsigned char> byte_decoder_;
|
||||
std::map<int64_t, std::string> added_tokens_;
|
||||
std::set<int64_t> all_special_ids_;
|
||||
};
|
||||
|
||||
struct CustomOpBpeDecoder : OrtW::CustomOpBase<CustomOpBpeDecoder, KernelBpeDecoder> {
|
||||
const char* GetName() const {
|
||||
return "BpeDecoder";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
};
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "bpetokenizer.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
struct KernelClipBpeTokenizer : BaseKernel {
|
||||
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "bpetokenizer.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
struct KernelBpeTokenizer : BaseKernel {
|
||||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "bpetokenizer.hpp"
|
||||
#include "bpe_tokenizer.hpp"
|
||||
|
||||
struct KernelRobertaBpeTokenizer : BaseKernel {
|
||||
KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
|
|
@ -31,11 +31,10 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
auto count = ids_dim[0];
|
||||
std::string decoded_string;
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
std::vector<int> tids;
|
||||
std::transform(p_ids, p_ids + count,
|
||||
std::transform(p_ids, p_ids + ids_dim.Size(),
|
||||
std::back_inserter(tids),
|
||||
[](auto _id) { return static_cast<int>(_id); });
|
||||
auto status = tokenizer_.Decode(tids, &decoded_string);
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "gpt2_tokenizer.hpp"
|
||||
#include "clip_tokenizer.hpp"
|
||||
#include "roberta_tokenizer.hpp"
|
||||
#include "bpe_decoder.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
|
@ -33,6 +34,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = LoadCustomOpClasses<
|
|||
, CustomOpBpeTokenizer
|
||||
, CustomOpClipBpeTokenizer
|
||||
, CustomOpRobertaBpeTokenizer
|
||||
, CustomOpBpeDecoder
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
|
|
|
@ -11,9 +11,7 @@ KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtK
|
|||
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(&info, "suffix_indicator");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unknown_token");
|
||||
max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word")
|
||||
? ort_.KernelInfoGetAttribute<int64_t>(&info, "max_input_chars_per_word")
|
||||
: 200;
|
||||
max_input_chars_per_word_ = TryToGetAttributeWithDefault("max_input_chars_per_word", 200);
|
||||
suffix_indicator_ = ustring(suffix_indicator);
|
||||
unk_token_ = ustring(unk);
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from onnxruntime_extensions import PyOrtFunction
|
||||
from onnxruntime_extensions.cvt import HFTokenizerConverter
|
||||
|
||||
|
||||
class TestBpeTokenizer(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.hf_processor = AutoProcessor.from_pretrained(
|
||||
"openai/whisper-tiny.en")
|
||||
cls.tokenizer_cvt = HFTokenizerConverter(cls.hf_processor.tokenizer)
|
||||
return super().setUpClass()
|
||||
|
||||
def test_bpe_tokenizer(self):
|
||||
fn_tokenizer = PyOrtFunction.from_customop(
|
||||
"GPT2Tokenizer",
|
||||
cvt=(self.tokenizer_cvt).bpe_tokenizer)
|
||||
test_str = " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles"
|
||||
test_ids = self.hf_processor.tokenizer.encode(test_str)
|
||||
self.assertTrue(fn_tokenizer(test_ids), test_str)
|
||||
|
||||
def test_en_decoder(self):
|
||||
special_tokens = False
|
||||
fn_decoder = PyOrtFunction.from_customop(
|
||||
"BpeDecoder",
|
||||
cvt=self.tokenizer_cvt.bpe_decoder,
|
||||
skip_special_tokens=not special_tokens)
|
||||
test_str = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||
test_token_ids = self.hf_processor.tokenizer.encode(test_str)
|
||||
expected_str = self.hf_processor.tokenizer.decode(
|
||||
test_token_ids, skip_special_tokens=not special_tokens)
|
||||
self.assertEqual(fn_decoder(np.asarray(test_token_ids)), expected_str)
|
||||
|
||||
def test_en_decoder_with_special(self):
|
||||
special_tokens = True
|
||||
fn_decoder = PyOrtFunction.from_customop(
|
||||
"BpeDecoder",
|
||||
cvt=self.tokenizer_cvt.bpe_decoder,
|
||||
skip_special_tokens=not special_tokens)
|
||||
test_str = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||
test_token_ids = self.hf_processor.tokenizer.encode(test_str)
|
||||
expected_str = self.hf_processor.tokenizer.decode(
|
||||
test_token_ids, skip_special_tokens=not special_tokens)
|
||||
actual_str = fn_decoder(np.asarray(test_token_ids))
|
||||
self.assertEqual(actual_str[0], expected_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче