Fix the exception on invalid trie-tokenizer input (#575)
* fix the exception on invalid trie-tokenizer input * remove unused import
This commit is contained in:
Родитель
46a37c3902
Коммит
68b9d1dc47
|
@ -59,6 +59,28 @@ class ustring : public std::u32string {
|
|||
return std::string(utf8_buf);
|
||||
}
|
||||
|
||||
static bool ValidateUTF8(const std::string& data) {
|
||||
int cnt = 0;
|
||||
for (auto i = 0; i < data.size(); i++) {
|
||||
int x = data[i];
|
||||
if (!cnt) {
|
||||
if ((x >> 5) == 0b110) {
|
||||
cnt = 1;
|
||||
} else if ((x >> 4) == 0b1110) {
|
||||
cnt = 2;
|
||||
} else if ((x >> 3) == 0b11110) {
|
||||
cnt = 3;
|
||||
} else if ((x >> 7) != 0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if ((x >> 6) != 0b10) return false;
|
||||
cnt--;
|
||||
}
|
||||
}
|
||||
return cnt == 0;
|
||||
}
|
||||
|
||||
private:
|
||||
using u32string = std::u32string;
|
||||
static u32string FromUTF8(const std::string_view& utf8) {
|
||||
|
|
|
@ -38,7 +38,8 @@ from ._ocos import hook_model_op
|
|||
from ._ocos import default_opset_domain
|
||||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
|
||||
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model, ONNXRuntimeError
|
||||
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
|
||||
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException
|
||||
from .cvt import gen_processing_models
|
||||
|
||||
# rename the implementation with a more formal name
|
||||
|
|
|
@ -196,3 +196,4 @@ def optimize_model(model_or_file, output_file):
|
|||
|
||||
|
||||
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
|
||||
ONNXRuntimeException = _ort.capi.onnxruntime_pybind11_state.RuntimeException
|
||||
|
|
|
@ -200,14 +200,24 @@ struct KernelTrieDetokenizer : public BaseKernel {
|
|||
}
|
||||
|
||||
std::vector<std::string> output(output_dim[0]);
|
||||
bool failed = false;
|
||||
for (auto n = 0; n < output_dim[0]; n++) {
|
||||
std::vector<int> ids;
|
||||
for (auto i = 0; i < ids_dim[1]; i++) {
|
||||
ids.push_back(ort_extensions::narrow<int>(p_ids[n * ids_dim[1] + i]));
|
||||
}
|
||||
output[n] = tokenizer->decodeBytes(ids);
|
||||
auto raw_string = tokenizer->decodeBytes(ids);
|
||||
if (ustring::ValidateUTF8(raw_string)) {
|
||||
output[n] = raw_string;
|
||||
} else {
|
||||
output[n] = "\ufffd"; // bad utf-8 string
|
||||
failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
text.SetStringOutput(output, output_dim);
|
||||
if (failed) {
|
||||
ORTX_CXX_API_THROW("[KernelTrieDetokenizer] the input ids cannot be parsed as a valid utf-8 string", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -80,12 +80,12 @@ class TestAutoTokenizer(unittest.TestCase):
|
|||
actual_ids = ort_tok([text])[0]
|
||||
np.testing.assert_array_equal(ids, actual_ids)
|
||||
|
||||
def test_xmlroberta_tokenizer(self):
|
||||
def test_xlm_roberta_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
|
||||
# TODO: if there is <unk> in text, the result is not matched.
|
||||
text = (
|
||||
'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
|
||||
" add words that should not exsist and be tokenized to , such as saoneuhaoesuth")
|
||||
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
|
||||
ids = tokenizer.encode(text, return_tensors="np")
|
||||
|
||||
ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True})
|
||||
|
|
|
@ -16,11 +16,11 @@ from onnxruntime_extensions import (
|
|||
get_library_path as _get_library_path)
|
||||
|
||||
|
||||
_is_tensorflow_avaliable = False
|
||||
_is_tensorflow_available = False
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from tensorflow_text import SentencepieceTokenizer
|
||||
_is_tensorflow_avaliable = True
|
||||
_is_tensorflow_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -282,7 +282,7 @@ def _create_test_model_sentencepiece_fairseq(
|
|||
return model
|
||||
|
||||
|
||||
@unittest.skipIf(not _is_tensorflow_avaliable, "tensorflow/tensorflow-text is unavailable")
|
||||
@unittest.skipIf(not _is_tensorflow_available, "tensorflow/tensorflow-text is unavailable")
|
||||
class TestPythonOpSentencePiece(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -7,8 +7,9 @@ import os
|
|||
import tempfile
|
||||
import requests
|
||||
|
||||
import numpy as np
|
||||
from unittest import TestCase, main as unittest_main
|
||||
from onnxruntime_extensions import OrtPyFunction, util
|
||||
from onnxruntime_extensions import OrtPyFunction, util, ONNXRuntimeException
|
||||
|
||||
|
||||
# to avoid to install rwkv LM package, we copy the tokenizer code here.
|
||||
|
@ -151,6 +152,11 @@ class TestTrieTokenizer(TestCase):
|
|||
detok = OrtPyFunction.from_customop("TrieDetokenizer", vocab=vocab_data, cpu_only=True)
|
||||
self.assertEqual(list(detok(tokens)), ["I love you"])
|
||||
|
||||
def test_invalid_utf8(self):
|
||||
vocab_data = util.read_file(self.vocab_file, 'rb')
|
||||
detok = OrtPyFunction.from_customop("TrieDetokenizer", vocab=vocab_data, cpu_only=True)
|
||||
self.assertRaises(ONNXRuntimeException, detok, np.array([[148]], np.int64))
|
||||
|
||||
def test_parity(self):
|
||||
test_sentences = [
|
||||
"I am a girl",
|
||||
|
|
Загрузка…
Ссылка в новой задаче