Fix the exception on invalid trie-tokenizer input (#575)

* fix the exception on invalid trie-tokenizer input

* remove unused import
This commit is contained in:
Wenbing Li 2023-10-16 17:03:02 -07:00 коммит произвёл GitHub
Родитель 46a37c3902
Коммит 68b9d1dc47
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 48 добавлений и 8 удалений

Просмотреть файл

@ -59,6 +59,28 @@ class ustring : public std::u32string {
return std::string(utf8_buf);
}
static bool ValidateUTF8(const std::string& data) {
int cnt = 0;
for (auto i = 0; i < data.size(); i++) {
int x = data[i];
if (!cnt) {
if ((x >> 5) == 0b110) {
cnt = 1;
} else if ((x >> 4) == 0b1110) {
cnt = 2;
} else if ((x >> 3) == 0b11110) {
cnt = 3;
} else if ((x >> 7) != 0) {
return false;
}
} else {
if ((x >> 6) != 0b10) return false;
cnt--;
}
}
return cnt == 0;
}
private:
using u32string = std::u32string;
static u32string FromUTF8(const std::string_view& utf8) {

Просмотреть файл

@ -38,7 +38,8 @@ from ._ocos import hook_model_op
from ._ocos import default_opset_domain
from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model, ONNXRuntimeError
from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException
from .cvt import gen_processing_models
# rename the implementation with a more formal name

Просмотреть файл

@ -196,3 +196,4 @@ def optimize_model(model_or_file, output_file):
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
ONNXRuntimeException = _ort.capi.onnxruntime_pybind11_state.RuntimeException

Просмотреть файл

@ -200,14 +200,24 @@ struct KernelTrieDetokenizer : public BaseKernel {
}
std::vector<std::string> output(output_dim[0]);
bool failed = false;
for (auto n = 0; n < output_dim[0]; n++) {
std::vector<int> ids;
for (auto i = 0; i < ids_dim[1]; i++) {
ids.push_back(ort_extensions::narrow<int>(p_ids[n * ids_dim[1] + i]));
}
output[n] = tokenizer->decodeBytes(ids);
auto raw_string = tokenizer->decodeBytes(ids);
if (ustring::ValidateUTF8(raw_string)) {
output[n] = raw_string;
} else {
output[n] = "\ufffd"; // bad utf-8 string
failed = true;
}
}
text.SetStringOutput(output, output_dim);
if (failed) {
ORTX_CXX_API_THROW("[KernelTrieDetokenizer] the input ids cannot be parsed as a valid utf-8 string", ORT_RUNTIME_EXCEPTION);
}
}
};

Просмотреть файл

@ -80,12 +80,12 @@ class TestAutoTokenizer(unittest.TestCase):
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids, actual_ids)
def test_xmlroberta_tokenizer(self):
def test_xlm_roberta_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
# TODO: if there is <unk> in text, the result is not matched.
text = (
'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
" add words that should not exsist and be tokenized to , such as saoneuhaoesuth")
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
ids = tokenizer.encode(text, return_tensors="np")
ort_tok, _ = gen_processing_models(tokenizer,pre_kwargs={"WITH_DEFAULT_INPUTS": True})

Просмотреть файл

@ -16,11 +16,11 @@ from onnxruntime_extensions import (
get_library_path as _get_library_path)
_is_tensorflow_avaliable = False
_is_tensorflow_available = False
try:
import tensorflow as tf
from tensorflow_text import SentencepieceTokenizer
_is_tensorflow_avaliable = True
_is_tensorflow_available = True
except ImportError:
pass
@ -282,7 +282,7 @@ def _create_test_model_sentencepiece_fairseq(
return model
@unittest.skipIf(not _is_tensorflow_avaliable, "tensorflow/tensorflow-text is unavailable")
@unittest.skipIf(not _is_tensorflow_available, "tensorflow/tensorflow-text is unavailable")
class TestPythonOpSentencePiece(unittest.TestCase):
@classmethod

Просмотреть файл

@ -7,8 +7,9 @@ import os
import tempfile
import requests
import numpy as np
from unittest import TestCase, main as unittest_main
from onnxruntime_extensions import OrtPyFunction, util
from onnxruntime_extensions import OrtPyFunction, util, ONNXRuntimeException
# to avoid to install rwkv LM package, we copy the tokenizer code here.
@ -151,6 +152,11 @@ class TestTrieTokenizer(TestCase):
detok = OrtPyFunction.from_customop("TrieDetokenizer", vocab=vocab_data, cpu_only=True)
self.assertEqual(list(detok(tokens)), ["I love you"])
def test_invalid_utf8(self):
vocab_data = util.read_file(self.vocab_file, 'rb')
detok = OrtPyFunction.from_customop("TrieDetokenizer", vocab=vocab_data, cpu_only=True)
self.assertRaises(ONNXRuntimeException, detok, np.array([[148]], np.int64))
def test_parity(self):
test_sentences = [
"I am a girl",