Unit test being compatible with ONNXRuntime-GPU package, and some clean-ups. (#457)
This commit is contained in:
Родитель
70411fdd96
Коммит
93f239c143
|
@ -1,35 +0,0 @@
|
|||
import inspect
|
||||
from ._ocos import default_opset_domain
|
||||
from . import _cuops
|
||||
|
||||
|
||||
ALL_CUSTOM_OPS = {_name: _obj for _name, _obj in inspect.getmembers(_cuops)
|
||||
if (inspect.isclass(_obj) and issubclass(_obj, _cuops.CustomOp))}
|
||||
|
||||
|
||||
OPMAP_TO_CMAKE_FLAGS = {'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
|
||||
'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE'
|
||||
}
|
||||
|
||||
|
||||
def gen_cmake_oplist(opconfig_file, oplist_cmake_file = '_selectedoplist.cmake'):
|
||||
|
||||
ext_domain = default_opset_domain()
|
||||
with open(oplist_cmake_file, 'w') as f:
|
||||
print("# Auto-Generated File, not edited!!!", file=f)
|
||||
with open(opconfig_file, 'r') as opfile:
|
||||
for _ln in opfile:
|
||||
if _ln.startswith(ext_domain):
|
||||
items = _ln.strip().split(';')
|
||||
if len(items) < 3:
|
||||
raise RuntimeError("The malformated operator config file.")
|
||||
for _op in items[2].split(','):
|
||||
if not _op:
|
||||
continue # is None or ""
|
||||
if _op not in OPMAP_TO_CMAKE_FLAGS:
|
||||
raise RuntimeError("Cannot find the custom operator({})\'s build flags, "
|
||||
+ "Please update the OPMAP_TO_CMAKE_FLAGS dictionary.".format(_op))
|
||||
print("set({} ON CACHE INTERNAL \"\")".format(OPMAP_TO_CMAKE_FLAGS[_op]), file=f)
|
||||
print("# End of Building the Operator CMake variables", file=f)
|
||||
|
||||
print('The cmake tool file has been generated successfully.')
|
|
@ -12,6 +12,7 @@ from onnxruntime_extensions import (
|
|||
PyOrtFunction)
|
||||
from onnxruntime_extensions.cvt import HFTokenizerConverter
|
||||
|
||||
|
||||
def _get_file_content(path):
|
||||
with open(path, "rb") as file:
|
||||
return file.read()
|
||||
|
@ -34,7 +35,8 @@ def _create_test_model(**kwargs):
|
|||
if kwargs["attention_mask"]:
|
||||
if kwargs["offset_map"]:
|
||||
node = [helper.make_node(
|
||||
'CLIPTokenizer', ['string_input'], ['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
|
||||
'CLIPTokenizer', ['string_input'],
|
||||
['input_ids', 'attention_mask', 'offset_mapping'], vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
|
||||
domain='ai.onnx.contrib')]
|
||||
|
||||
|
@ -73,10 +75,11 @@ class TestCLIPTokenizer(unittest.TestCase):
|
|||
cls.tokenizer_cvt = HFTokenizerConverter(cls.slow_tokenizer)
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True, offset_map=True)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges,
|
||||
max_length=padding_length, attention_mask=True, offset_map=True)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so, providers=["CPUExecutionProvider"])
|
||||
input_text = np.array(test_sentence)
|
||||
input_ids, attention_mask, offset_mapping = sess.run(None, {'string_input': input_text})
|
||||
print("\nTest Sentence: " + str(test_sentence))
|
||||
|
@ -111,7 +114,9 @@ class TestCLIPTokenizer(unittest.TestCase):
|
|||
self._run_tokenizer(["One Microsoft Way, Redmond, WA"])
|
||||
|
||||
def test_converter(self):
|
||||
fn_tokenizer = PyOrtFunction.from_customop("CLIPTokenizer", cvt=(self.tokenizer_cvt).clip_tokenizer)
|
||||
fn_tokenizer = PyOrtFunction.from_customop("CLIPTokenizer",
|
||||
cvt=(self.tokenizer_cvt).clip_tokenizer,
|
||||
cpu_only=True)
|
||||
test_str = "I can feel the magic, can you?"
|
||||
fn_out = fn_tokenizer([test_str])
|
||||
clip_out = self.tokenizer(test_str, return_offsets_mapping=True)
|
||||
|
@ -120,16 +125,20 @@ class TestCLIPTokenizer(unittest.TestCase):
|
|||
expect_offset_mapping = clip_out['offset_mapping']
|
||||
np.testing.assert_array_equal(fn_out[0].reshape((fn_out[0].size,)), expect_input_ids)
|
||||
np.testing.assert_array_equal(fn_out[1].reshape((fn_out[1].size,)), expect_attention_mask)
|
||||
np.testing.assert_array_equal(fn_out[2].reshape((fn_out[2].shape[1], fn_out[2].shape[2])), expect_offset_mapping)
|
||||
np.testing.assert_array_equal(fn_out[2].reshape((fn_out[2].shape[1], fn_out[2].shape[2])),
|
||||
expect_offset_mapping)
|
||||
|
||||
def test_optional_outputs(self):
|
||||
# Test for models without offset mapping and without both attention mask and offset mapping (input id output is always required)
|
||||
model1 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=True, offset_map=False)
|
||||
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False, offset_map=False)
|
||||
# Test for models without offset mapping and without both attention mask and offset mapping
|
||||
# (input id output is always required)
|
||||
model1 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges,
|
||||
max_length=-1, attention_mask=True, offset_map=False)
|
||||
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges,
|
||||
max_length=-1, attention_mask=False, offset_map=False)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess1 = _ort.InferenceSession(model1.SerializeToString(), so)
|
||||
sess2 = _ort.InferenceSession(model2.SerializeToString(), so)
|
||||
sess1 = _ort.InferenceSession(model1.SerializeToString(), so, providers=["CPUExecutionProvider"])
|
||||
sess2 = _ort.InferenceSession(model2.SerializeToString(), so, providers=["CPUExecutionProvider"])
|
||||
input_text = np.array(["Hello World"])
|
||||
outputs1 = sess1.run(None, {'string_input': input_text})
|
||||
outputs2 = sess2.run(None, {'string_input': input_text})
|
||||
|
@ -142,10 +151,9 @@ class TestCLIPTokenizer(unittest.TestCase):
|
|||
clip_out = self.tokenizer(["Hello World"], return_offsets_mapping=True)
|
||||
expect_input_ids = clip_out['input_ids']
|
||||
expect_attention_mask = clip_out['attention_mask']
|
||||
expect_offset_mapping = clip_out['offset_mapping']
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs1[0])
|
||||
np.testing.assert_array_equal(expect_attention_mask, outputs1[1])
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs2[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from onnxruntime_extensions import cmake_helper
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
class TestCMakeHelper(unittest.TestCase):
|
||||
def test_cmake_file_gen(self):
|
||||
cfgfile = _get_test_data_file('data', 'test.op.config')
|
||||
cfile = '_selectedoplist.cmake'
|
||||
cmake_helper.gen_cmake_oplist(cfgfile, cfile)
|
||||
found = False
|
||||
with open(cfile, 'r') as f:
|
||||
for _ln in f:
|
||||
if _ln.strip() == "set(OCOS_ENABLE_GPT2_TOKENIZER ON CACHE INTERNAL \"\")":
|
||||
found = True
|
||||
break
|
||||
|
||||
os.remove(cfile)
|
||||
self.assertTrue(found)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -18,7 +18,7 @@ class TestOpenCV(unittest.TestCase):
|
|||
try:
|
||||
rdr = OrtPyFunction.from_customop("ImageReader")
|
||||
img_nhwc = rdr([img_file])
|
||||
except ONNXRuntimeError as e:
|
||||
except ONNXRuntimeError:
|
||||
pass
|
||||
|
||||
if img_nhwc is not None:
|
||||
|
@ -59,9 +59,9 @@ class TestOpenCV(unittest.TestCase):
|
|||
expected = np.asarray(expected, dtype=np.uint8).copy()
|
||||
|
||||
# Convert the image to BGR format since cv2 is default BGR format.
|
||||
red = expected[:,:,0].copy()
|
||||
expected[:,:,0] = expected[:,:,2].copy()
|
||||
expected[:,:,2] = red
|
||||
red = expected[:, :, 0].copy()
|
||||
expected[:, :, 0] = expected[:, :, 2].copy()
|
||||
expected[:, :, 2] = red
|
||||
|
||||
self.assertEqual(actual.shape[0], expected.shape[0])
|
||||
self.assertEqual(actual.shape[1], expected.shape[1])
|
||||
|
|
|
@ -90,10 +90,11 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
return super().tearDown()
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True)
|
||||
model = _create_test_model(vocab_file=self.tokjson,
|
||||
merges_file=self.merges, max_length=padding_length, attention_mask=True)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_text = np.array(test_sentence)
|
||||
input_ids, attention_mask = sess.run(None, {'string_input': input_text})
|
||||
expect_input_ids, expect_attention_mask = self.tokenizer.tokenizer_sentence(test_sentence, padding_length)
|
||||
|
@ -118,10 +119,11 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
enable_py_op(False)
|
||||
|
||||
# Test for model without attention mask (input id output is always required)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False)
|
||||
model = _create_test_model(vocab_file=self.tokjson,
|
||||
merges_file=self.merges, max_length=-1, attention_mask=False)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_text = np.array(["Hello World"])
|
||||
outputs = sess.run(None, {'string_input': input_text})
|
||||
|
||||
|
@ -133,7 +135,6 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
expect_input_ids = gpt2_out[0]
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs[0])
|
||||
|
||||
|
||||
def test_tokenizer_pyop(self):
|
||||
self._run_tokenizer(["I can feel the magic, can you?"])
|
||||
self._run_tokenizer(["Hey Cortana"])
|
||||
|
|
|
@ -52,7 +52,7 @@ class TestMathOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_segment_sum("")
|
||||
self.assertIn('op_type: "SegmentSum"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], dtype=np.float32)
|
||||
segment_ids = np.array([0, 0, 1], dtype=np.int64)
|
||||
exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
|
||||
|
@ -65,7 +65,7 @@ class TestMathOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_segment_sum("Py")
|
||||
self.assertIn('op_type: "PySegmentSum"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], dtype=np.float32)
|
||||
segment_ids = np.array([0, 0, 1], dtype=np.int64)
|
||||
exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
|
||||
|
|
|
@ -154,7 +154,7 @@ class TestPythonOp(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model()
|
||||
self.assertIn('op_type: "PyReverseMatrix"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array(
|
||||
[1, 2, 3, 4, 5, 6]).astype(np.float32).reshape([3, 2])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
|
@ -165,7 +165,7 @@ class TestPythonOp(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_double('Py')
|
||||
self.assertIn('op_type: "PyAddEpsilon"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([[0., 1., 1.5], [7., 8., -5.5]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
diff = txout[0] - input_1 - 1e-3
|
||||
|
@ -176,7 +176,7 @@ class TestPythonOp(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_2outputs('Py')
|
||||
self.assertIn('op_type: "PyNegPos"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32)
|
||||
neg, pos = sess.run(None, {'x': x})
|
||||
diff = x - (neg + pos)
|
||||
|
@ -187,7 +187,7 @@ class TestPythonOp(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_2outputs("")
|
||||
self.assertIn('op_type: "NegPos"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32)
|
||||
neg, pos = sess.run(None, {'x': x})
|
||||
diff = x - (neg + pos)
|
||||
|
@ -210,7 +210,7 @@ class TestPythonOp(unittest.TestCase):
|
|||
onnx_content = _create_test_model_test()
|
||||
self.assertIn('op_type: "CustomOpOne"', str(onnx_content))
|
||||
ser = onnx_content.SerializeToString()
|
||||
sess0 = _ort.InferenceSession(ser, so)
|
||||
sess0 = _ort.InferenceSession(ser, so, providers=['CPUExecutionProvider'])
|
||||
res = sess0.run(None, {
|
||||
'input_1': np.random.rand(3, 5).astype(np.float32),
|
||||
'input_2': np.random.rand(3, 5).astype(np.float32)})
|
||||
|
@ -221,7 +221,7 @@ class TestPythonOp(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_join()
|
||||
self.assertIn('op_type: "PyOpJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
arr = np.array([["a", "b"]], dtype=object)
|
||||
txout = sess.run(None, {'input_1': arr})
|
||||
exp = np.array(["a;b"], dtype=object)
|
||||
|
|
|
@ -73,10 +73,11 @@ class TestRobertaTokenizer(unittest.TestCase):
|
|||
cls.tokenizer_cvt = HFTokenizerConverter(cls.slow_tokenizer)
|
||||
|
||||
def _run_tokenizer(self, test_sentence, padding_length=-1):
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length, attention_mask=True, offset_map=True)
|
||||
model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges,
|
||||
max_length=padding_length, attention_mask=True, offset_map=True)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_text = np.array(test_sentence)
|
||||
input_ids, attention_mask, offset_mapping = sess.run(None, {'string_input': input_text})
|
||||
print("\nTest Sentence: " + str(test_sentence))
|
||||
|
@ -128,8 +129,8 @@ class TestRobertaTokenizer(unittest.TestCase):
|
|||
model2 = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=-1, attention_mask=False, offset_map=False)
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess1 = _ort.InferenceSession(model1.SerializeToString(), so)
|
||||
sess2 = _ort.InferenceSession(model2.SerializeToString(), so)
|
||||
sess1 = _ort.InferenceSession(model1.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
sess2 = _ort.InferenceSession(model2.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_text = np.array(["Hello World"])
|
||||
outputs1 = sess1.run(None, {'string_input': input_text})
|
||||
outputs2 = sess2.run(None, {'string_input': input_text})
|
||||
|
@ -142,10 +143,9 @@ class TestRobertaTokenizer(unittest.TestCase):
|
|||
roberta_out = self.tokenizer(["Hello World"], return_offsets_mapping=True)
|
||||
expect_input_ids = roberta_out['input_ids']
|
||||
expect_attention_mask = roberta_out['attention_mask']
|
||||
expect_offset_mapping = roberta_out['offset_mapping']
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs1[0])
|
||||
np.testing.assert_array_equal(expect_attention_mask, outputs1[1])
|
||||
np.testing.assert_array_equal(expect_input_ids, outputs2[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -29,7 +29,7 @@ def _run_string_concat(input1, input2):
|
|||
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
result = sess.run(None, {'input_1': input1, 'input_2': input2})
|
||||
|
||||
# verify
|
||||
|
|
|
@ -83,7 +83,7 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace("")
|
||||
self.assertIn('op_type: "StringECMARegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):"])
|
||||
rewrite = np.array([r"static PyObject* py_$1(void) {"])
|
||||
text = np.array([["def myfunc():"], ["def dummy():"]])
|
||||
|
@ -99,7 +99,7 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace("", global_replace=False)
|
||||
self.assertIn('op_type: "StringECMARegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):"])
|
||||
rewrite = np.array([r"static PyObject* py_$1(void) {"])
|
||||
text = np.array([["def myfunc():def myfunc():"], ["def dummy():def dummy():"]])
|
||||
|
@ -115,7 +115,7 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace("")
|
||||
self.assertIn('op_type: "StringECMARegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):"])
|
||||
rewrite = np.array([r"static PyObject* py_$1(void) {"])
|
||||
text = np.array([["def myfunc():"], ["def dummy():" * 2]])
|
||||
|
@ -132,7 +132,7 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_replace(
|
||||
"", "ai.onnx.contrib", True, True
|
||||
)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
|
||||
pattern = np.array(
|
||||
[
|
||||
|
@ -157,7 +157,7 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_regex_split("")
|
||||
self.assertIn('op_type: "StringECMARegexSplitWithOffsets"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input = np.array(["hello there", "hello there"])
|
||||
pattern = np.array(["(\\s)"])
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ def _run_string_length(input):
|
|||
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
result = sess.run(None, {'input_1': input})
|
||||
|
||||
# verify
|
||||
|
|
|
@ -441,7 +441,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["Abc"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
|
||||
|
@ -451,7 +451,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_lower('')
|
||||
self.assertIn('op_type: "StringLower"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["Abc"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["abc"]]).tolist())
|
||||
|
@ -461,7 +461,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["R"], ["Abcé"], ["ABC"], ["A"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(
|
||||
|
@ -473,7 +473,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_lower('')
|
||||
self.assertIn('op_type: "StringLower"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["R"], ["Abce"], ["ABC"], ["A"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(
|
||||
|
@ -497,7 +497,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('Py')
|
||||
self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["Abc"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
|
||||
|
@ -507,7 +507,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_lower('Py')
|
||||
self.assertIn('op_type: "PyStringLower"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["Abc"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["abc"]]).tolist())
|
||||
|
@ -517,7 +517,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_upper('Py')
|
||||
self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["Abcé"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(),
|
||||
|
@ -528,7 +528,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_lower('Py')
|
||||
self.assertIn('op_type: "PyStringLower"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input_1 = np.array([["Abcé"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(),
|
||||
|
@ -539,7 +539,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('Py')
|
||||
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])])
|
||||
self.assertEqual(text.shape, (2, 3))
|
||||
|
@ -560,7 +560,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('Py')
|
||||
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
|
||||
sep = np.array([";"])
|
||||
|
@ -575,7 +575,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('Py')
|
||||
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.array(["a", "b", "cc"])
|
||||
sep = np.array([";"])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
|
@ -589,7 +589,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])])
|
||||
sep = np.array([";"])
|
||||
|
@ -607,7 +607,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.array(["a", "b", "cc"])
|
||||
sep = np.array([";"])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
|
@ -620,7 +620,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.array([""])
|
||||
sep = np.array([" "])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
|
@ -633,7 +633,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.array("a scalar string")
|
||||
sep = np.array([" "])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
|
@ -646,7 +646,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
|
||||
2, 2, 2))
|
||||
sep = np.array([";"])
|
||||
|
@ -671,7 +671,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('')
|
||||
self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject* py_\1(void) {'])
|
||||
text = np.array([['def myfunc():'], ['def dummy():']])
|
||||
|
@ -687,7 +687,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_replace(
|
||||
'', global_replace=False)
|
||||
self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject* py_\1(void) {'])
|
||||
text = np.array([['def myfunc():def myfunc():'],
|
||||
|
@ -703,7 +703,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('')
|
||||
self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject* py_\1(void) {'])
|
||||
text = np.array([['def myfunc():'], ['def dummy():' * 2]])
|
||||
|
@ -718,7 +718,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('Py')
|
||||
self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
|
||||
text = np.array([['def myfunc():'], ['def dummy():']])
|
||||
|
@ -733,7 +733,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('Py')
|
||||
self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
|
||||
text = np.array([['def myfunc():'], ['def dummy():' * 2]])
|
||||
|
@ -748,7 +748,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_to_hash('Py', kind='crc32')
|
||||
self.assertIn('op_type: "PyStringToCRC32"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
text = np.array([["abc", "abcdé"], ["$$^l!%*ù", ""]])
|
||||
num_buckets = np.array([44], dtype=np.uint32)
|
||||
res = self._string_to_crc32(text, num_buckets)
|
||||
|
@ -765,7 +765,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_to_hash(
|
||||
'', kind='hash_bucket')
|
||||
self.assertIn('op_type: "StringToHashBucket"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
|
||||
text = np.array(raw).reshape((3, 2))
|
||||
num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
|
||||
|
@ -791,7 +791,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_to_hash(
|
||||
'', kind='hash_bucket_fast')
|
||||
self.assertIn('op_type: "StringToHashBucketFast"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
|
||||
text = np.array(raw).reshape((3, 2))
|
||||
num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
|
||||
|
@ -817,7 +817,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_to_hash(
|
||||
'Py', kind='hash_bucket')
|
||||
self.assertIn('op_type: "PyStringToHashBucket"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
|
||||
text = np.array(raw).reshape((3, 2))
|
||||
num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
|
||||
|
@ -850,7 +850,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_equal('Py')
|
||||
self.assertIn('op_type: "PyStringEqual"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
|
||||
for x, y in self.enumerate_matrix_couples():
|
||||
txout = sess.run(None, {'x': x, 'y': y})
|
||||
|
@ -863,7 +863,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_equal('')
|
||||
self.assertIn('op_type: "StringEqual"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
|
||||
for x, y in self.enumerate_matrix_couples():
|
||||
txout = sess.run(None, {'x': x, 'y': y})
|
||||
|
@ -876,7 +876,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_split('Py')
|
||||
self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
|
||||
delimiter = np.array([","])
|
||||
|
||||
|
@ -908,7 +908,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_split('')
|
||||
self.assertIn('op_type: "StringSplit"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
|
||||
delimiter = np.array([","])
|
||||
|
||||
|
@ -956,7 +956,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_split('')
|
||||
self.assertIn('op_type: "StringSplit"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input = np.array(["a*b", "a,*b", "aa,b,,c", 'z', "dddddd,", "**"])
|
||||
delimiter = np.array([",*"])
|
||||
|
||||
|
@ -1009,7 +1009,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_split('')
|
||||
self.assertIn('op_type: "StringSplit"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input = np.array(["a*b", "a,*b"])
|
||||
delimiter = np.array([""])
|
||||
|
||||
|
@ -1051,7 +1051,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_regex_split('')
|
||||
self.assertIn('op_type: "StringRegexSplitWithOffsets"',
|
||||
str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
input = np.array(["hello there", "hello there"])
|
||||
pattern = np.array(["(\\s)"])
|
||||
|
||||
|
@ -1114,7 +1114,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
so.register_custom_ops_library(_get_library_path())
|
||||
cc_onnx_model = _create_test_model_wordpiece('')
|
||||
self.assertIn('op_type: "WordpieceTokenizer"', str(cc_onnx_model))
|
||||
cc_sess = _ort.InferenceSession(cc_onnx_model.SerializeToString(), so)
|
||||
cc_sess = _ort.InferenceSession(cc_onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
|
||||
|
||||
inputs = dict(text=np.array(["unwanted running",
|
||||
"unwantedX running"], dtype=object))
|
||||
|
@ -1149,7 +1149,6 @@ class TestPythonOpString(unittest.TestCase):
|
|||
value_dtype=tf.int64)
|
||||
res = tf.lookup.StaticVocabularyTable(init, num_oov, lookup_key_dtype=tf.string)
|
||||
res.__len__ = lambda self: len(vocab)
|
||||
|
||||
|
||||
vocab_table = _CreateTable(["want", "##want", "##ed", "wa", "un", "runn", "##ing"])
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@ tools_dir = os.path.join(ort_ext_root, "tools")
|
|||
test_data_dir = os.path.join(ort_ext_root, "test", "data")
|
||||
sys.path.append(tools_dir)
|
||||
|
||||
import gen_customop_template
|
||||
import gen_customop_template # noqa: E402
|
||||
|
||||
|
||||
# create generic custom op models with some basic math ops for testing purposes
|
||||
def _create_test_model_1():
|
||||
|
@ -34,6 +35,7 @@ def _create_test_model_1():
|
|||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
|
||||
def _create_test_model_2(prefix=""):
|
||||
nodes = [
|
||||
helper.make_node("Identity", ["data"], ["id1"]),
|
||||
|
@ -51,8 +53,19 @@ def _create_test_model_2(prefix=""):
|
|||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
|
||||
class TestCustomOpTemplate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
# remove generated files
|
||||
template_output_path = os.path.join(test_data_dir, "generated")
|
||||
if os.path.exists(template_output_path):
|
||||
for file in os.listdir(template_output_path):
|
||||
os.remove(os.path.join(template_output_path, file))
|
||||
os.rmdir(template_output_path)
|
||||
return super().tearDownClass()
|
||||
|
||||
# check input and output type count of models extracted by template generator
|
||||
def check_io_count(self, model_name, output_path, expected_input_count, expected_output_count):
|
||||
model_path = os.path.join(test_data_dir, "generated", model_name)
|
||||
|
@ -63,14 +76,19 @@ class TestCustomOpTemplate(unittest.TestCase):
|
|||
def test_template(self):
|
||||
template_output_path = os.path.join(test_data_dir, "generated")
|
||||
os.mkdir(template_output_path)
|
||||
|
||||
|
||||
onnx.save(_create_test_model_1(), os.path.join(template_output_path, "test_model_1.onnx"))
|
||||
test1_template_output_path = os.path.join(template_output_path, "custom_op_template_test1.hpp")
|
||||
self.check_io_count(model_name = "test_model_1.onnx", output_path = test1_template_output_path, expected_input_count = 1, expected_output_count = 1)
|
||||
|
||||
self.check_io_count(model_name="test_model_1.onnx",
|
||||
output_path=test1_template_output_path,
|
||||
expected_input_count=1, expected_output_count=1)
|
||||
|
||||
onnx.save(_create_test_model_2(), os.path.join(template_output_path, "test_model_2.onnx"))
|
||||
test2_template_output_path = os.path.join(template_output_path, "custom_op_template_test2.hpp")
|
||||
self.check_io_count(model_name = "test_model_2.onnx", output_path = test2_template_output_path, expected_input_count = 2, expected_output_count = 1)
|
||||
self.check_io_count(model_name="test_model_2.onnx",
|
||||
output_path=test2_template_output_path,
|
||||
expected_input_count=2, expected_output_count=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче