reformat test code and verify the pipeline (#251)
* reformat test code and verify the pipeline * upgrade googletest version * fix the merge issue * more formating
This commit is contained in:
Родитель
1a04abdf3e
Коммит
292a0297b4
|
@ -9,8 +9,6 @@ The entry point to onnxruntime custom op library
|
|||
|
||||
__author__ = "Microsoft"
|
||||
|
||||
import pathlib
|
||||
import inspect
|
||||
from ._version import __version__
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
||||
|
@ -27,9 +25,3 @@ from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntim
|
|||
onnx_op = Opdef.declare
|
||||
PyOp = PyCustomOpDef
|
||||
|
||||
|
||||
# do a favour for the unit test.
|
||||
def get_test_data_file(*sub_dirs):
|
||||
case_file = inspect.currentframe().f_back.f_code.co_filename
|
||||
test_dir = pathlib.Path(case_file).parent
|
||||
return str(test_dir.joinpath(*sub_dirs).resolve())
|
||||
|
|
|
@ -10,6 +10,7 @@ from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
|
|||
|
||||
|
||||
class CustomOp:
|
||||
|
||||
@classmethod
|
||||
def op_type(cls):
|
||||
rcls = cls
|
||||
|
@ -18,10 +19,12 @@ class CustomOp:
|
|||
return rcls.__name__
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls): return None
|
||||
def get_inputs(cls):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls): return None
|
||||
def get_outputs(cls):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
|
@ -37,17 +40,23 @@ class CustomOp:
|
|||
|
||||
|
||||
class GPT2Tokenizer(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])]
|
||||
return [
|
||||
cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
|
||||
cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])]
|
||||
return [
|
||||
cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
|
||||
cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])
|
||||
]
|
||||
|
||||
|
||||
class VectorToString(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
|
||||
|
@ -61,7 +70,9 @@ class VectorToString(CustomOp):
|
|||
attr_data = {}
|
||||
for k_, v_ in attrs.items():
|
||||
if k_ == 'map' and isinstance(v_, dict):
|
||||
attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items())
|
||||
attr_data[k_] = '\n'.join(k + "\t" +
|
||||
" ".join([str(i) for i in v])
|
||||
for k, v in v_.items())
|
||||
elif k_ == 'map' and isinstance(v_, str):
|
||||
attr_data[k_] = v_
|
||||
else:
|
||||
|
@ -70,6 +81,7 @@ class VectorToString(CustomOp):
|
|||
|
||||
|
||||
class StringMapping(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("input", onnx.TensorProto.STRING, [])]
|
||||
|
@ -92,10 +104,13 @@ class StringMapping(CustomOp):
|
|||
|
||||
|
||||
class MaskedFill(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("value", onnx.TensorProto.STRING, [None]),
|
||||
cls.io_def("mask", onnx.TensorProto.BOOL, [None])]
|
||||
return [
|
||||
cls.io_def("value", onnx.TensorProto.STRING, [None]),
|
||||
cls.io_def("mask", onnx.TensorProto.BOOL, [None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
|
@ -103,6 +118,7 @@ class MaskedFill(CustomOp):
|
|||
|
||||
|
||||
class StringToVector(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
|
||||
|
@ -116,7 +132,9 @@ class StringToVector(CustomOp):
|
|||
attr_data = {}
|
||||
for k_, v_ in attrs.items():
|
||||
if k_ == 'map' and isinstance(v_, dict):
|
||||
attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items())
|
||||
attr_data[k_] = '\n'.join(k + "\t" +
|
||||
" ".join([str(i) for i in v])
|
||||
for k, v in v_.items())
|
||||
elif k_ == 'map' and isinstance(v_, str):
|
||||
attr_data[k_] = v_
|
||||
elif k_ == 'unk' and isinstance(v_, list):
|
||||
|
@ -127,6 +145,7 @@ class StringToVector(CustomOp):
|
|||
|
||||
|
||||
class BlingFireSentenceBreaker(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
|
||||
|
@ -148,26 +167,32 @@ class BlingFireSentenceBreaker(CustomOp):
|
|||
|
||||
|
||||
class SegmentExtraction(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
|
||||
cls.io_def('value', onnx_proto.TensorProto.INT64, [None])]
|
||||
return [
|
||||
cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
|
||||
cls.io_def('value', onnx_proto.TensorProto.INT64, [None])
|
||||
]
|
||||
|
||||
|
||||
class BertTokenizer(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])]
|
||||
return [
|
||||
cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
|
||||
cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def serialize_attr(cls, attrs):
|
||||
|
@ -183,11 +208,14 @@ class BertTokenizer(CustomOp):
|
|||
|
||||
|
||||
class StringECMARegexReplace(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("input", onnx.TensorProto.STRING, [None]),
|
||||
cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
|
||||
cls.io_def("rewrite", onnx.TensorProto.STRING, [None])]
|
||||
return [
|
||||
cls.io_def("input", onnx.TensorProto.STRING, [None]),
|
||||
cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
|
||||
cls.io_def("rewrite", onnx.TensorProto.STRING, [None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
|
@ -195,10 +223,13 @@ class StringECMARegexReplace(CustomOp):
|
|||
|
||||
|
||||
class BertTokenizerDecoder(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [cls.io_def("ids", onnx.TensorProto.INT64, [None]),
|
||||
cls.io_def("position", onnx.TensorProto.INT64, [None, None])]
|
||||
return [
|
||||
cls.io_def("ids", onnx.TensorProto.INT64, [None]),
|
||||
cls.io_def("position", onnx.TensorProto.INT64, [None, None])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
|
@ -218,6 +249,7 @@ class BertTokenizerDecoder(CustomOp):
|
|||
|
||||
|
||||
class SentencepieceTokenizer(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [
|
||||
|
@ -236,6 +268,7 @@ class SentencepieceTokenizer(CustomOp):
|
|||
cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
|
||||
]
|
||||
|
||||
|
||||
class Inverse(CustomOp):
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
|
@ -279,7 +312,9 @@ class GaussianBlur(CustomOp):
|
|||
cls.io_def('gb_nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None])
|
||||
]
|
||||
|
||||
|
||||
class SingleOpGraph:
|
||||
|
||||
@classmethod
|
||||
def get_next_id(cls):
|
||||
if not hasattr(cls, '_id_counter'):
|
||||
|
@ -296,16 +331,14 @@ class SingleOpGraph:
|
|||
inputs = op_class.get_inputs()
|
||||
outputs = op_class.get_outputs()
|
||||
attrs = op_class.serialize_attr(kwargs)
|
||||
cuop = onnx.helper.make_node(op_type,
|
||||
[i_.name for i_ in inputs],
|
||||
cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
|
||||
[o_.name for o_ in outputs],
|
||||
"{}_{}".format(op_type, cls.get_next_id()),
|
||||
"{}_{}".format(op_type,
|
||||
cls.get_next_id()),
|
||||
**attrs,
|
||||
domain=default_opset_domain())
|
||||
graph = onnx.helper.make_graph([cuop],
|
||||
"og_{}_{}".format(op_type, cls.get_next_id()),
|
||||
inputs,
|
||||
outputs)
|
||||
graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
|
||||
op_type, cls.get_next_id()), inputs, outputs)
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
|
@ -319,6 +352,7 @@ def _argsort_op(x, dim):
|
|||
return d[:, ::-1]
|
||||
|
||||
|
||||
Opdef.create(_argsort_op, op_type='ArgSort',
|
||||
Opdef.create(_argsort_op,
|
||||
op_type='ArgSort',
|
||||
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
|
||||
outputs=[PyCustomOpDef.dt_int64])
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
import pathlib
|
||||
import inspect
|
||||
|
||||
|
||||
# some util function for testing and tools
|
||||
|
||||
def get_test_data_file(*sub_dirs):
|
||||
case_file = inspect.currentframe().f_back.f_code.co_filename
|
||||
test_dir = pathlib.Path(case_file).parent
|
||||
return str(test_dir.joinpath(*sub_dirs).resolve())
|
||||
|
||||
|
||||
def read_file(path):
|
||||
with open(str(path)) as file_content:
|
||||
return file_content.read()
|
|
@ -1,55 +1,74 @@
|
|||
from pathlib import Path
|
||||
# coding: utf-8
|
||||
import unittest
|
||||
import numpy as np
|
||||
import transformers
|
||||
from onnxruntime_extensions import PyOrtFunction, BertTokenizer
|
||||
from onnxruntime_extensions import PyOrtFunction, BertTokenizer, util
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
bert_cased_tokenizer = transformers.BertTokenizer(_get_test_data_file('data', 'bert_basic_cased_vocab.txt'), False,
|
||||
strip_accents=True)
|
||||
bert_cased_tokenizer = transformers.BertTokenizer(
|
||||
util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
False,
|
||||
strip_accents=True,
|
||||
)
|
||||
|
||||
|
||||
def _run_basic_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1)
|
||||
t2stc = PyOrtFunction.from_customop(
|
||||
BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1
|
||||
)
|
||||
result = t2stc([input])
|
||||
expect_result = bert_cased_tokenizer.encode_plus(input)
|
||||
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
|
||||
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
|
||||
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
|
||||
np.testing.assert_array_equal(result[0], expect_result["input_ids"])
|
||||
np.testing.assert_array_equal(result[1], expect_result["token_type_ids"])
|
||||
np.testing.assert_array_equal(result[2], expect_result["attention_mask"])
|
||||
|
||||
|
||||
def _run_combined_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1)
|
||||
t2stc = PyOrtFunction.from_customop(
|
||||
BertTokenizer, vocab_file=vocab_path, do_lower_case=0, strip_accents=1
|
||||
)
|
||||
result = t2stc(input)
|
||||
expect_result = bert_cased_tokenizer.encode_plus(input[0], input[1])
|
||||
np.testing.assert_array_equal(result[0], expect_result['input_ids'])
|
||||
np.testing.assert_array_equal(result[1], expect_result['token_type_ids'])
|
||||
np.testing.assert_array_equal(result[2], expect_result['attention_mask'])
|
||||
np.testing.assert_array_equal(result[0], expect_result["input_ids"])
|
||||
np.testing.assert_array_equal(result[1], expect_result["token_type_ids"])
|
||||
np.testing.assert_array_equal(result[2], expect_result["attention_mask"])
|
||||
|
||||
|
||||
class TestBertTokenizer(unittest.TestCase):
|
||||
|
||||
def test_text_to_case1(self):
|
||||
_run_basic_case(input="Input 'text' must not be empty.",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(
|
||||
input="ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="本想好好的伤感 想放任 但是没泪痕", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网 易 云 音 乐",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat is playing toys",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_combined_case(["网 易 云 音 乐", "cat isnot playing toyssss"],
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
input="Input 'text' must not be empty.",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_basic_case(
|
||||
input="ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t"
|
||||
+ "䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_basic_case(
|
||||
input="网易云音乐",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_basic_case(
|
||||
input="本想好好的伤感 想放任 但是没泪痕",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_basic_case(
|
||||
input="网 易 云 音 乐",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_basic_case(
|
||||
input="cat is playing toys",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_basic_case(
|
||||
input="cat isnot playing toyssss",
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
_run_combined_case(
|
||||
["网 易 云 音 乐", "cat isnot playing toyssss"],
|
||||
vocab_path=util.get_test_data_file("data", "bert_basic_cased_vocab.txt"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,18 +1,14 @@
|
|||
from pathlib import Path
|
||||
# coding: utf-8
|
||||
import unittest
|
||||
import numpy as np
|
||||
import transformers
|
||||
from onnxruntime_extensions import util
|
||||
from onnxruntime_extensions import PyOrtFunction, BertTokenizerDecoder
|
||||
|
||||
bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
|
||||
bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
def _run_basic_case(input, vocab_path):
|
||||
t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path)
|
||||
ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
|
||||
|
@ -45,20 +41,20 @@ class TestBertTokenizerDecoder(unittest.TestCase):
|
|||
|
||||
def test_text_to_case1(self):
|
||||
_run_basic_case(input="Input 'text' must not be empty.",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网易云音乐", vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="网 易 云 音 乐",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat is playing toys",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
_run_basic_case(input="cat isnot playing toyssss",
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
_run_indices_case(input="cat isnot playing toyssss", indices=[[]],
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
_run_indices_case(input="cat isnot playing toyssss", indices=[[1, 2], [3, 5]],
|
||||
vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
vocab_path=util.get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
from pathlib import Path
|
||||
# coding: utf-8
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import util
|
||||
from onnxruntime_extensions import PyOrtFunction, BlingFireSentenceBreaker
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
def _run_blingfire_sentencebreaker(input, output, model_path):
|
||||
t2stc = PyOrtFunction.from_customop(BlingFireSentenceBreaker, model=model_path)
|
||||
result = t2stc(input)
|
||||
|
@ -16,28 +12,51 @@ def _run_blingfire_sentencebreaker(input, output, model_path):
|
|||
|
||||
|
||||
class TestBlingFireSentenceBreaker(unittest.TestCase):
|
||||
|
||||
def test_text_to_case1(self):
|
||||
inputs = np.array([
|
||||
"This is the Bling-Fire tokenizer. Autophobia, also called monophobia, isolophobia, or eremophobia, is the specific phobia of isolation. 2007年9月日历表_2007年9月农历阳历一览表-万年历. I saw a girl with a telescope. Я увидел девушку с телескопом."])
|
||||
outputs = np.array(["This is the Bling-Fire tokenizer.",
|
||||
"Autophobia, also called monophobia, isolophobia, or eremophobia, is the specific phobia of isolation. 2007年9月日历表_2007年9月农历阳历一览表-万年历.",
|
||||
"I saw a girl with a telescope.",
|
||||
"Я увидел девушку с телескопом."])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
inputs = np.array(
|
||||
[
|
||||
"This is the Bling-Fire tokenizer. Autophobia, also called monophobia, isolophobia, "
|
||||
+ "or eremophobia, is the specific phobia of isolation."
|
||||
+ " 2007年9月日历表_2007年9月农历阳历一览表-万年历. "
|
||||
+ "I saw a girl with a telescope. Я увидел девушку с телескопом."
|
||||
]
|
||||
)
|
||||
outputs = np.array(
|
||||
[
|
||||
"This is the Bling-Fire tokenizer.",
|
||||
"Autophobia, also called monophobia, isolophobia, or eremophobia, "
|
||||
+ "is the specific phobia of isolation."
|
||||
+ " 2007年9月日历表_2007年9月农历阳历一览表-万年历.",
|
||||
"I saw a girl with a telescope.",
|
||||
"Я увидел девушку с телескопом.",
|
||||
]
|
||||
)
|
||||
_run_blingfire_sentencebreaker(
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
model_path=util.get_test_data_file("data", "default_sentence_break_model.bin"),
|
||||
)
|
||||
|
||||
def test_text_to_case2(self):
|
||||
# input is empty
|
||||
inputs = np.array([""])
|
||||
outputs = np.array([""])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
_run_blingfire_sentencebreaker(
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
model_path=util.get_test_data_file("data", "default_sentence_break_model.bin"),
|
||||
)
|
||||
|
||||
def test_text_to_case3(self):
|
||||
# input is whitespace
|
||||
inputs = np.array([" "])
|
||||
# output of blingfire sbd.bin model
|
||||
outputs = np.array([""])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
_run_blingfire_sentencebreaker(
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
model_path=util.get_test_data_file("data", "default_sentence_break_model.bin"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from onnxruntime_extensions import get_test_data_file, OrtPyFunction, ONNXRuntimeError
|
||||
from onnxruntime_extensions import OrtPyFunction, ONNXRuntimeError, util
|
||||
|
||||
|
||||
class TestOpenCV(unittest.TestCase):
|
||||
|
@ -10,7 +10,7 @@ class TestOpenCV(unittest.TestCase):
|
|||
pass
|
||||
|
||||
def test_image_reader(self):
|
||||
img_file = get_test_data_file('data', 'pineapple.jpg')
|
||||
img_file = util.get_test_data_file('data', 'pineapple.jpg')
|
||||
|
||||
img_nhwc = None
|
||||
# since the ImageReader is not included the offical release due to code compliance issue,
|
||||
|
@ -34,7 +34,7 @@ class TestOpenCV(unittest.TestCase):
|
|||
np.testing.assert_array_equal(actual, expected)
|
||||
|
||||
def test_gaussian_blur(self):
|
||||
img_file = get_test_data_file('data', 'pineapple.jpg')
|
||||
img_file = util.get_test_data_file('data', 'pineapple.jpg')
|
||||
img = Image.open(img_file).convert('RGB')
|
||||
img_arr = np.asarray(img, dtype=np.float32) / 255.
|
||||
img_arr = np.expand_dims(img_arr, 0)
|
||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from onnx import helper, onnx_pb as onnx_proto
|
||||
from transformers import GPT2Tokenizer
|
||||
from onnxruntime_extensions import (
|
||||
util,
|
||||
make_onnx_model,
|
||||
enable_py_op,
|
||||
get_library_path as _get_library_path)
|
||||
|
@ -65,8 +66,8 @@ class MyGPT2Tokenizer:
|
|||
class TestGPT2Tokenizer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokjson = _get_test_data_file('data', 'gpt2.vocab')
|
||||
cls.merges = _get_test_data_file('data', 'gpt2.merges.txt')
|
||||
cls.tokjson = util.get_test_data_file('data', 'gpt2.vocab')
|
||||
cls.merges = util.get_test_data_file('data', 'gpt2.merges.txt')
|
||||
cls.tokenizer = MyGPT2Tokenizer(cls.tokjson, cls.merges)
|
||||
|
||||
# @onnx_op(op_type="GPT2Tokenizer",
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
from pathlib import Path
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import PyOrtFunction, MaskedFill
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
def read_file(path):
|
||||
with open(path) as file_content:
|
||||
return file_content.read()
|
||||
|
|
|
@ -5,42 +5,43 @@ from onnx import helper, onnx_pb as onnx_proto
|
|||
import onnxruntime as _ort
|
||||
from onnxruntime_extensions import (
|
||||
OrtPyFunction,
|
||||
onnx_op, PyCustomOpDef, make_onnx_model,
|
||||
get_library_path as _get_library_path)
|
||||
PyCustomOpDef,
|
||||
onnx_op,
|
||||
make_onnx_model,
|
||||
get_library_path as _get_library_path,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_model_segment_sum(prefix, domain='ai.onnx.contrib'):
|
||||
nodes = []
|
||||
nodes.append(helper.make_node('Identity', ['data'], ['id1']))
|
||||
nodes.append(helper.make_node('Identity', ['segment_ids'], ['id2']))
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
'%sSegmentSum' % prefix, ['id1', 'id2'], ['z'], domain=domain))
|
||||
def _create_test_model_segment_sum(prefix, domain="ai.onnx.contrib"):
|
||||
nodes = [
|
||||
helper.make_node("Identity", ["data"], ["id1"]),
|
||||
helper.make_node("Identity", ["segment_ids"], ["id2"]),
|
||||
helper.make_node("%sSegmentSum" % prefix, ["id1", "id2"], ["z"], domain=domain),
|
||||
]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'data', onnx_proto.TensorProto.FLOAT, [])
|
||||
input0 = helper.make_tensor_value_info("data", onnx_proto.TensorProto.FLOAT, [])
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'segment_ids', onnx_proto.TensorProto.INT64, [])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'z', onnx_proto.TensorProto.FLOAT, [])
|
||||
"segment_ids", onnx_proto.TensorProto.INT64, []
|
||||
)
|
||||
output0 = helper.make_tensor_value_info("z", onnx_proto.TensorProto.FLOAT, [])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
|
||||
graph = helper.make_graph(nodes, "test0", [input0, input1], [output0])
|
||||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
|
||||
class TestMathOpString(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@onnx_op(op_type="PySegmentSum",
|
||||
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
|
||||
outputs=[PyCustomOpDef.dt_float])
|
||||
@onnx_op(
|
||||
op_type="PySegmentSum",
|
||||
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
|
||||
outputs=[PyCustomOpDef.dt_float],
|
||||
)
|
||||
def segment_sum(data, segment_ids):
|
||||
# segment_ids is sorted
|
||||
nb_seg = segment_ids[-1] + 1
|
||||
sh = (nb_seg, ) + data.shape[1:]
|
||||
sh = (nb_seg,) + data.shape[1:]
|
||||
res = np.zeros(sh, dtype=data.dtype)
|
||||
for seg, row in zip(segment_ids, data):
|
||||
res[seg] += row
|
||||
|
@ -49,33 +50,32 @@ class TestMathOpString(unittest.TestCase):
|
|||
def test_segment_sum_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_segment_sum('')
|
||||
onnx_model = _create_test_model_segment_sum("")
|
||||
self.assertIn('op_type: "SegmentSum"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
|
||||
dtype=np.float32)
|
||||
data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], dtype=np.float32)
|
||||
segment_ids = np.array([0, 0, 1], dtype=np.int64)
|
||||
exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
|
||||
txout = sess.run(None, {'data': data, 'segment_ids': segment_ids})
|
||||
txout = sess.run(None, {"data": data, "segment_ids": segment_ids})
|
||||
self.assertEqual(exp.shape, txout[0].shape)
|
||||
self.assertEqual(exp.tolist(), txout[0].tolist())
|
||||
|
||||
def test_segment_sum_python(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_segment_sum('Py')
|
||||
onnx_model = _create_test_model_segment_sum("Py")
|
||||
self.assertIn('op_type: "PySegmentSum"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
|
||||
dtype=np.float32)
|
||||
data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], dtype=np.float32)
|
||||
segment_ids = np.array([0, 0, 1], dtype=np.int64)
|
||||
exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
|
||||
txout = sess.run(None, {'data': data, 'segment_ids': segment_ids})
|
||||
txout = sess.run(None, {"data": data, "segment_ids": segment_ids})
|
||||
self.assertEqual(exp.shape, txout[0].shape)
|
||||
self.assertEqual(exp.tolist(), txout[0].tolist())
|
||||
|
||||
try:
|
||||
from tensorflow.raw_ops import SegmentSum
|
||||
|
||||
dotf = True
|
||||
except ImportError:
|
||||
dotf = False
|
||||
|
|
|
@ -6,20 +6,20 @@ from typing import List, Tuple
|
|||
from PIL import Image
|
||||
from distutils.version import LooseVersion
|
||||
from onnxruntime_extensions import OrtPyFunction
|
||||
from onnxruntime_extensions import pnp, get_test_data_file
|
||||
from onnxruntime_extensions import pnp, util
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
class _GPT2LMHeadModel(GPT2LMHeadModel):
|
||||
""" Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.
|
||||
"""
|
||||
"""Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = super(_GPT2LMHeadModel, self).forward(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False)
|
||||
result = super(_GPT2LMHeadModel, self).forward(
|
||||
input_ids, attention_mask=attention_mask, return_dict=False
|
||||
)
|
||||
# drop the past states
|
||||
return result[0]
|
||||
|
||||
|
@ -38,56 +38,76 @@ class _MobileNetProcessingModule(pnp.ProcessingScriptModule):
|
|||
def __init__(self, oxml):
|
||||
super(_MobileNetProcessingModule, self).__init__()
|
||||
self.model_function_id = pnp.create_model_function(oxml)
|
||||
self.pre_proc = torch.jit.trace(pnp.PreMobileNet(224), torch.zeros(224, 224, 3, dtype=torch.float32))
|
||||
self.post_proc = torch.jit.trace(pnp.ImageNetPostProcessing(), torch.zeros(1, 1000, dtype=torch.float32))
|
||||
self.pre_proc = torch.jit.trace(
|
||||
pnp.PreMobileNet(224), torch.zeros(224, 224, 3, dtype=torch.float32)
|
||||
)
|
||||
self.post_proc = torch.jit.trace(
|
||||
pnp.ImageNetPostProcessing(), torch.zeros(1, 1000, dtype=torch.float32)
|
||||
)
|
||||
|
||||
def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
proc_input = self.pre_proc(img)
|
||||
return self.post_proc.forward(pnp.invoke_onnx_model1(self.model_function_id, proc_input))
|
||||
return self.post_proc.forward(
|
||||
pnp.invoke_onnx_model1(self.model_function_id, proc_input)
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Not works with older PyTorch')
|
||||
@unittest.skipIf(
|
||||
LooseVersion(torch.__version__) < LooseVersion("1.9"),
|
||||
"Not works with older PyTorch",
|
||||
)
|
||||
class TestPreprocessing(unittest.TestCase):
|
||||
def test_imagenet_preprocessing(self):
|
||||
mnv2 = onnx.load_model(get_test_data_file('data', 'mobilev2.onnx'))
|
||||
mnv2 = onnx.load_model(util.get_test_data_file("data", "mobilev2.onnx"))
|
||||
|
||||
# load an image
|
||||
img = Image.open(get_test_data_file('data', 'pineapple.jpg'))
|
||||
img = torch.from_numpy(numpy.asarray(img.convert('RGB')))
|
||||
img = Image.open(util.get_test_data_file("data", "pineapple.jpg"))
|
||||
img = torch.from_numpy(numpy.asarray(img.convert("RGB")))
|
||||
|
||||
full_models = pnp.SequentialProcessingModule(pnp.PreMobileNet(224),
|
||||
mnv2,
|
||||
pnp.PostMobileNet())
|
||||
full_models = pnp.SequentialProcessingModule(
|
||||
pnp.PreMobileNet(224), mnv2, pnp.PostMobileNet()
|
||||
)
|
||||
ids, probabilities = full_models.forward(img)
|
||||
name_i = 'image'
|
||||
name_i = "image"
|
||||
full_model_func = OrtPyFunction.from_model(
|
||||
pnp.export(full_models,
|
||||
img,
|
||||
opset_version=11,
|
||||
output_path='temp_imagenet.onnx',
|
||||
input_names=[name_i],
|
||||
dynamic_axes={name_i: [0, 1]}))
|
||||
pnp.export(
|
||||
full_models,
|
||||
img,
|
||||
opset_version=11,
|
||||
output_path="temp_imagenet.onnx",
|
||||
input_names=[name_i],
|
||||
dynamic_axes={name_i: [0, 1]},
|
||||
)
|
||||
)
|
||||
actual_ids, actual_result = full_model_func(img.numpy())
|
||||
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
|
||||
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
|
||||
self.assertEqual(
|
||||
ids[0, 0].item(), 953
|
||||
) # 953 is pineapple class id in the imagenet dataset
|
||||
|
||||
def test_gpt2_preprocessing(self):
|
||||
cfg = GPT2Config(n_layer=3)
|
||||
gpt2_m = _GPT2LMHeadModel(cfg)
|
||||
gpt2_m.eval().to('cpu')
|
||||
gpt2_m.eval().to("cpu")
|
||||
|
||||
test_sentence = ["Test a sentence"]
|
||||
tok = pnp.PreHuggingFaceGPT2(vocab_file=get_test_data_file('data', 'gpt2.vocab'),
|
||||
merges_file=get_test_data_file('data', 'gpt2.merges.txt'))
|
||||
tok = pnp.PreHuggingFaceGPT2(
|
||||
vocab_file=util.get_test_data_file("data", "gpt2.vocab"),
|
||||
merges_file=util.get_test_data_file("data", "gpt2.merges.txt"),
|
||||
)
|
||||
inputs = tok.forward(test_sentence)
|
||||
pnp.export(tok, test_sentence, opset_version=12, output_path='temp_tok2.onnx')
|
||||
pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx")
|
||||
|
||||
with open('temp_gpt2lmh.onnx', 'wb') as f:
|
||||
torch.onnx.export(gpt2_m, inputs, f, opset_version=12, do_constant_folding=False)
|
||||
with open("temp_gpt2lmh.onnx", "wb") as f:
|
||||
torch.onnx.export(
|
||||
gpt2_m, inputs, f, opset_version=12, do_constant_folding=False
|
||||
)
|
||||
pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False)
|
||||
full_model = pnp.SequentialProcessingModule(tok, gpt2_m)
|
||||
expected = full_model.forward(test_sentence)
|
||||
model = pnp.export(full_model, test_sentence, opset_version=12, do_constant_folding=False)
|
||||
model = pnp.export(
|
||||
full_model, test_sentence, opset_version=12, do_constant_folding=False
|
||||
)
|
||||
mfunc = OrtPyFunction.from_model(model)
|
||||
actuals = mfunc(test_sentence)
|
||||
# the random weight may generate a large diff in result, test the shape only.
|
||||
|
@ -95,40 +115,53 @@ class TestPreprocessing(unittest.TestCase):
|
|||
|
||||
def test_sequence_tensor(self):
|
||||
seq_m = _SequenceTensorModel()
|
||||
test_input = [torch.from_numpy(_i) for _i in [
|
||||
numpy.array([1]).astype(numpy.int64),
|
||||
numpy.array([3, 4]).astype(numpy.int64),
|
||||
numpy.array([5, 6]).astype(numpy.int64)]]
|
||||
test_input = [
|
||||
torch.from_numpy(_i)
|
||||
for _i in [
|
||||
numpy.array([1]).astype(numpy.int64),
|
||||
numpy.array([3, 4]).astype(numpy.int64),
|
||||
numpy.array([5, 6]).astype(numpy.int64),
|
||||
]
|
||||
]
|
||||
res = seq_m.forward(test_input)
|
||||
numpy.testing.assert_allclose(res, numpy.array([4, 5]))
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
|
||||
# The fixing for the sequence tensor support is only released in 1.11 and the above.
|
||||
oxml = pnp.export(seq_m,
|
||||
test_input,
|
||||
opset_version=12,
|
||||
output_path='temp_seqtest.onnx')
|
||||
oxml = pnp.export(
|
||||
seq_m, test_input, opset_version=12, output_path="temp_seqtest.onnx"
|
||||
)
|
||||
# TODO: ORT doesn't accept the default empty element type of a sequence type.
|
||||
oxml.graph.input[0].type.sequence_type.elem_type.CopyFrom(
|
||||
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT64, []))
|
||||
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT64, [])
|
||||
)
|
||||
mfunc = OrtPyFunction.from_model(oxml)
|
||||
o_res = mfunc([_i.numpy() for _i in test_input])
|
||||
numpy.testing.assert_allclose(res, o_res)
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.11"),
|
||||
'PythonOp bug fixing on Pytorch 1.11')
|
||||
@unittest.skipIf(
|
||||
LooseVersion(torch.__version__) < LooseVersion("1.11"),
|
||||
"PythonOp bug fixing on Pytorch 1.11",
|
||||
)
|
||||
def test_functional_processing(self):
|
||||
# load an image
|
||||
img = Image.open(get_test_data_file('data', 'pineapple.jpg')).convert('RGB')
|
||||
img = Image.open(util.get_test_data_file("data", "pineapple.jpg")).convert(
|
||||
"RGB"
|
||||
)
|
||||
img = torch.from_numpy(numpy.asarray(img))
|
||||
|
||||
pipeline = _MobileNetProcessingModule(onnx.load_model(get_test_data_file('data', 'mobilev2.onnx')))
|
||||
pipeline = _MobileNetProcessingModule(
|
||||
onnx.load_model(util.get_test_data_file("data", "mobilev2.onnx"))
|
||||
)
|
||||
ids, probabilities = pipeline.forward(img)
|
||||
|
||||
full_model_func = OrtPyFunction.from_model(
|
||||
pnp.export(pipeline, img, opset_version=11, output_path='temp_func.onnx'))
|
||||
pnp.export(pipeline, img, opset_version=11, output_path="temp_func.onnx")
|
||||
)
|
||||
actual_ids, actual_result = full_model_func(img.numpy())
|
||||
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
|
||||
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
|
||||
self.assertEqual(
|
||||
ids[0, 0].item(), 953
|
||||
) # 953 is pineapple class id in the imagenet dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
from pathlib import Path
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import PyOrtFunction, SegmentExtraction
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
def _run_segment_extraction(input, expect_position, expect_value):
|
||||
t2stc = PyOrtFunction.from_customop(SegmentExtraction)
|
||||
position, value = t2stc(input)
|
||||
|
|
|
@ -47,7 +47,7 @@ def _create_test_model_sentencepiece(
|
|||
],
|
||||
outputs=['out0', 'out1'],
|
||||
name='SentencepieceTokenizeOpName',
|
||||
domain='ai.onnx.contrib',
|
||||
domain=domain,
|
||||
))
|
||||
inputs = [
|
||||
mkv('model', onnx_proto.TensorProto.UINT8, [None]),
|
||||
|
@ -91,6 +91,7 @@ def _create_test_model_sentencepiece(
|
|||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
|
||||
def _create_test_model_ragged_to_sparse(
|
||||
prefix, model_b64, domain='ai.onnx.contrib'):
|
||||
nodes = []
|
||||
|
@ -109,7 +110,7 @@ def _create_test_model_ragged_to_sparse(
|
|||
],
|
||||
outputs=['tokout0', 'tokout1'],
|
||||
name='SentencepieceTokenizeOpName',
|
||||
domain='ai.onnx.contrib',
|
||||
domain=domain,
|
||||
))
|
||||
inputs = [
|
||||
mkv('model', onnx_proto.TensorProto.UINT8, [None]),
|
||||
|
@ -194,7 +195,7 @@ def _create_test_model_ragged_to_dense(
|
|||
outputs=['tokout0', 'tokout1'],
|
||||
model=model_b64,
|
||||
name='SentencepieceTokenizeOpName',
|
||||
domain='ai.onnx.contrib',
|
||||
domain=domain,
|
||||
))
|
||||
inputs = [
|
||||
mkv('inputs', onnx_proto.TensorProto.STRING, [None]),
|
||||
|
|
|
@ -1,71 +1,78 @@
|
|||
# coding: utf-8
|
||||
import unittest
|
||||
import re
|
||||
import numpy as np
|
||||
from onnx import helper, onnx_pb as onnx_proto
|
||||
import onnxruntime as _ort
|
||||
from onnxruntime_extensions import (make_onnx_model,
|
||||
get_library_path as _get_library_path)
|
||||
from onnxruntime_extensions import (
|
||||
make_onnx_model,
|
||||
get_library_path as _get_library_path,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib',
|
||||
global_replace=True, ignore_case=False):
|
||||
nodes = []
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['text'], ['id1']))
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['pattern'], ['id2']))
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['rewrite'], ['id3']))
|
||||
nodes.append(
|
||||
def _create_test_model_string_replace(
|
||||
prefix, domain="ai.onnx.contrib", global_replace=True, ignore_case=False
|
||||
):
|
||||
nodes = [
|
||||
helper.make_node("Identity", ["text"], ["id1"]),
|
||||
helper.make_node("Identity", ["pattern"], ["id2"]),
|
||||
helper.make_node("Identity", ["rewrite"], ["id3"]),
|
||||
helper.make_node(
|
||||
'%sStringECMARegexReplace' % prefix, ['id1', 'id2', 'id3'],
|
||||
['customout'], domain=domain,
|
||||
"%sStringECMARegexReplace" % prefix,
|
||||
["id1", "id2", "id3"],
|
||||
["customout"],
|
||||
domain=domain,
|
||||
global_replace=1 if global_replace else 0,
|
||||
ignore_case=1 if ignore_case else 0))
|
||||
ignore_case=1 if ignore_case else 0,
|
||||
),
|
||||
]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'text', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
"text", onnx_proto.TensorProto.STRING, [None, 1]
|
||||
)
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'pattern', onnx_proto.TensorProto.STRING, [1])
|
||||
"pattern", onnx_proto.TensorProto.STRING, [1]
|
||||
)
|
||||
input2 = helper.make_tensor_value_info(
|
||||
'rewrite', onnx_proto.TensorProto.STRING, [1])
|
||||
"rewrite", onnx_proto.TensorProto.STRING, [1]
|
||||
)
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
"customout", onnx_proto.TensorProto.STRING, [None, 1]
|
||||
)
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes, 'test0', [input0, input1, input2], [output0])
|
||||
graph = helper.make_graph(nodes, "test0", [input0, input1, input2], [output0])
|
||||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
|
||||
def _create_test_model_string_regex_split(prefix, domain='ai.onnx.contrib'):
|
||||
def _create_test_model_string_regex_split(prefix, domain="ai.onnx.contrib"):
|
||||
nodes = []
|
||||
nodes.append(helper.make_node('Identity', ['input'], ['id1']))
|
||||
nodes.append(helper.make_node('Identity', ['pattern'], ['id2']))
|
||||
nodes.append(helper.make_node('Identity', ['keep_pattern'], ['id3']))
|
||||
nodes.append(helper.make_node("Identity", ["input"], ["id1"]))
|
||||
nodes.append(helper.make_node("Identity", ["pattern"], ["id2"]))
|
||||
nodes.append(helper.make_node("Identity", ["keep_pattern"], ["id3"]))
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
'%sStringECMARegexSplitWithOffsets' % prefix, ['id1', 'id2', 'id3'],
|
||||
['tokens', 'begins', 'ends', 'row_indices'], domain=domain))
|
||||
"%sStringECMARegexSplitWithOffsets" % prefix,
|
||||
["id1", "id2", "id3"],
|
||||
["tokens", "begins", "ends", "row_indices"],
|
||||
domain=domain,
|
||||
)
|
||||
)
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input', onnx_proto.TensorProto.STRING, [])
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'pattern', onnx_proto.TensorProto.STRING, [])
|
||||
input0 = helper.make_tensor_value_info("input", onnx_proto.TensorProto.STRING, [])
|
||||
input1 = helper.make_tensor_value_info("pattern", onnx_proto.TensorProto.STRING, [])
|
||||
input2 = helper.make_tensor_value_info(
|
||||
'keep_pattern', onnx_proto.TensorProto.STRING, [])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'tokens', onnx_proto.TensorProto.STRING, [])
|
||||
output1 = helper.make_tensor_value_info(
|
||||
'begins', onnx_proto.TensorProto.INT64, [])
|
||||
output2 = helper.make_tensor_value_info(
|
||||
'ends', onnx_proto.TensorProto.INT64, [])
|
||||
"keep_pattern", onnx_proto.TensorProto.STRING, []
|
||||
)
|
||||
output0 = helper.make_tensor_value_info("tokens", onnx_proto.TensorProto.STRING, [])
|
||||
output1 = helper.make_tensor_value_info("begins", onnx_proto.TensorProto.INT64, [])
|
||||
output2 = helper.make_tensor_value_info("ends", onnx_proto.TensorProto.INT64, [])
|
||||
output3 = helper.make_tensor_value_info(
|
||||
'row_indices', onnx_proto.TensorProto.INT64, [])
|
||||
"row_indices", onnx_proto.TensorProto.INT64, []
|
||||
)
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
|
||||
[output0, output1, output2, output3])
|
||||
graph = helper.make_graph(
|
||||
nodes, "test0", [input0, input1, input2], [output0, output1, output2, output3]
|
||||
)
|
||||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
|
@ -74,75 +81,82 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
def test_string_replace_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('')
|
||||
onnx_model = _create_test_model_string_replace("")
|
||||
self.assertIn('op_type: "StringECMARegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject* py_$1(void) {'])
|
||||
text = np.array([['def myfunc():'], ['def dummy():']])
|
||||
txout = sess.run(
|
||||
None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
|
||||
exp = [['static PyObject* py_myfunc(void) {'],
|
||||
['static PyObject* py_dummy(void) {']]
|
||||
pattern = np.array([r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):"])
|
||||
rewrite = np.array([r"static PyObject* py_$1(void) {"])
|
||||
text = np.array([["def myfunc():"], ["def dummy():"]])
|
||||
txout = sess.run(None, {"text": text, "pattern": pattern, "rewrite": rewrite})
|
||||
exp = [
|
||||
["static PyObject* py_myfunc(void) {"],
|
||||
["static PyObject* py_dummy(void) {"],
|
||||
]
|
||||
self.assertEqual(exp, txout[0].tolist())
|
||||
|
||||
def test_string_replace_cc_first(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace(
|
||||
'', global_replace=False)
|
||||
onnx_model = _create_test_model_string_replace("", global_replace=False)
|
||||
self.assertIn('op_type: "StringECMARegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject* py_$1(void) {'])
|
||||
text = np.array([['def myfunc():def myfunc():'],
|
||||
['def dummy():def dummy():']])
|
||||
txout = sess.run(
|
||||
None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
|
||||
exp = [['static PyObject* py_myfunc(void) {def myfunc():'],
|
||||
['static PyObject* py_dummy(void) {def dummy():']]
|
||||
pattern = np.array([r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):"])
|
||||
rewrite = np.array([r"static PyObject* py_$1(void) {"])
|
||||
text = np.array([["def myfunc():def myfunc():"], ["def dummy():def dummy():"]])
|
||||
txout = sess.run(None, {"text": text, "pattern": pattern, "rewrite": rewrite})
|
||||
exp = [
|
||||
["static PyObject* py_myfunc(void) {def myfunc():"],
|
||||
["static PyObject* py_dummy(void) {def dummy():"],
|
||||
]
|
||||
self.assertEqual(exp, txout[0].tolist())
|
||||
|
||||
def test_string_replace_cc_x2(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('')
|
||||
onnx_model = _create_test_model_string_replace("")
|
||||
self.assertIn('op_type: "StringECMARegexReplace"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
|
||||
rewrite = np.array([r'static PyObject* py_$1(void) {'])
|
||||
text = np.array([['def myfunc():'], ['def dummy():' * 2]])
|
||||
txout = sess.run(
|
||||
None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
|
||||
exp = [['static PyObject* py_myfunc(void) {'],
|
||||
['static PyObject* py_dummy(void) {' * 2]]
|
||||
pattern = np.array([r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):"])
|
||||
rewrite = np.array([r"static PyObject* py_$1(void) {"])
|
||||
text = np.array([["def myfunc():"], ["def dummy():" * 2]])
|
||||
txout = sess.run(None, {"text": text, "pattern": pattern, "rewrite": rewrite})
|
||||
exp = [
|
||||
["static PyObject* py_myfunc(void) {"],
|
||||
["static PyObject* py_dummy(void) {" * 2],
|
||||
]
|
||||
self.assertEqual(exp, txout[0].tolist())
|
||||
|
||||
def test_string_replace_uncased(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_replace('', 'ai.onnx.contrib', True, True)
|
||||
onnx_model = _create_test_model_string_replace(
|
||||
"", "ai.onnx.contrib", True, True
|
||||
)
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
|
||||
pattern = np.array(
|
||||
[r"([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)"])
|
||||
rewrite = np.array([r'EMAIL'])
|
||||
text = np.array([['The email is Micro$oft@microsoft.com, G00Gle@google.Internal.com'],
|
||||
['The email is wangli@51biz.com 1234556@qq.com']])
|
||||
[
|
||||
r"([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)"
|
||||
]
|
||||
)
|
||||
rewrite = np.array([r"EMAIL"])
|
||||
text = np.array(
|
||||
[
|
||||
["The email is Micro$oft@microsoft.com, G00Gle@google.Internal.com"],
|
||||
["The email is wangli@51biz.com 1234556@qq.com"],
|
||||
]
|
||||
)
|
||||
|
||||
txout = sess.run(
|
||||
None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
|
||||
txout = sess.run(None, {"text": text, "pattern": pattern, "rewrite": rewrite})
|
||||
|
||||
exp = [['The email is EMAIL, EMAIL'],
|
||||
['The email is EMAIL EMAIL']]
|
||||
exp = [["The email is EMAIL, EMAIL"], ["The email is EMAIL EMAIL"]]
|
||||
self.assertEqual(exp, txout[0].tolist())
|
||||
|
||||
def test_string_regex_split_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_regex_split('')
|
||||
self.assertIn('op_type: "StringECMARegexSplitWithOffsets"',
|
||||
str(onnx_model))
|
||||
onnx_model = _create_test_model_string_regex_split("")
|
||||
self.assertIn('op_type: "StringECMARegexSplitWithOffsets"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input = np.array(["hello there", "hello there"])
|
||||
pattern = np.array(["(\\s)"])
|
||||
|
@ -150,11 +164,10 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
# keep_pattern not empty
|
||||
keep_pattern = np.array(["\\s"])
|
||||
txout = sess.run(
|
||||
None, {'input': input, 'pattern': pattern,
|
||||
'keep_pattern': keep_pattern})
|
||||
None, {"input": input, "pattern": pattern, "keep_pattern": keep_pattern}
|
||||
)
|
||||
|
||||
exp_text = np.array(['hello', ' ', 'there',
|
||||
'hello', ' ', ' ', 'there'])
|
||||
exp_text = np.array(["hello", " ", "there", "hello", " ", " ", "there"])
|
||||
exp_begins = np.array([0, 5, 6, 0, 5, 6, 7])
|
||||
exp_ends = np.array([5, 6, 11, 5, 6, 7, 12])
|
||||
exp_rows = np.array([0, 3, 7])
|
||||
|
@ -165,15 +178,22 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
self.assertEqual(exp_rows.tolist(), txout[3].tolist())
|
||||
|
||||
try:
|
||||
from tensorflow_text.python.ops.regex_split_ops import gen_regex_split_ops as lib_gen_regex_split_ops
|
||||
from tensorflow_text.python.ops.regex_split_ops import (
|
||||
gen_regex_split_ops as lib_gen_regex_split_ops,
|
||||
)
|
||||
|
||||
use_tf = True
|
||||
except ImportError:
|
||||
use_tf = False
|
||||
|
||||
if use_tf:
|
||||
tf_tokens, tf_begins, tf_ends, tf_rows = lib_gen_regex_split_ops.regex_split_with_offsets(input, "(\\s)",
|
||||
"\\s")
|
||||
ltk = [s.decode('utf-8') for s in tf_tokens.numpy()]
|
||||
(
|
||||
tf_tokens,
|
||||
tf_begins,
|
||||
tf_ends,
|
||||
tf_rows,
|
||||
) = lib_gen_regex_split_ops.regex_split_with_offsets(input, "(\\s)", "\\s")
|
||||
ltk = [s.decode("utf-8") for s in tf_tokens.numpy()]
|
||||
self.assertEqual(ltk, txout[0].tolist())
|
||||
self.assertEqual(tf_begins.numpy().tolist(), txout[1].tolist())
|
||||
self.assertEqual(tf_ends.numpy().tolist(), txout[2].tolist())
|
||||
|
@ -182,9 +202,9 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
# keep_pattern empty
|
||||
keep_pattern = np.array([""])
|
||||
txout = sess.run(
|
||||
None, {'input': input, 'pattern': pattern,
|
||||
'keep_pattern': keep_pattern})
|
||||
exp_text = np.array(['hello', 'there', 'hello', 'there'])
|
||||
None, {"input": input, "pattern": pattern, "keep_pattern": keep_pattern}
|
||||
)
|
||||
exp_text = np.array(["hello", "there", "hello", "there"])
|
||||
exp_begins = np.array([0, 6, 0, 7])
|
||||
exp_ends = np.array([5, 11, 5, 12])
|
||||
exp_rows = np.array([0, 2, 4])
|
||||
|
@ -195,9 +215,13 @@ class TestStringECMARegex(unittest.TestCase):
|
|||
self.assertEqual(exp_rows.tolist(), txout[3].tolist())
|
||||
|
||||
if use_tf:
|
||||
tf_tokens, tf_begins, tf_ends, tf_rows = lib_gen_regex_split_ops.regex_split_with_offsets(input, "(\\s)",
|
||||
"")
|
||||
ltk = [s.decode('utf-8') for s in tf_tokens.numpy()]
|
||||
(
|
||||
tf_tokens,
|
||||
tf_begins,
|
||||
tf_ends,
|
||||
tf_rows,
|
||||
) = lib_gen_regex_split_ops.regex_split_with_offsets(input, "(\\s)", "")
|
||||
ltk = [s.decode("utf-8") for s in tf_tokens.numpy()]
|
||||
self.assertEqual(ltk, txout[0].tolist())
|
||||
self.assertEqual(tf_begins.numpy().tolist(), txout[1].tolist())
|
||||
self.assertEqual(tf_ends.numpy().tolist(), txout[2].tolist())
|
||||
|
|
|
@ -1,19 +1,9 @@
|
|||
from pathlib import Path
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import util
|
||||
from onnxruntime_extensions import PyOrtFunction, StringMapping
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
def read_file(path):
|
||||
with open(path) as file_content:
|
||||
return file_content.read()
|
||||
|
||||
|
||||
def _run_string_mapping(input, output, map):
|
||||
v2str = PyOrtFunction.from_customop(StringMapping, map=map)
|
||||
result = v2str(input)
|
||||
|
@ -30,7 +20,7 @@ class TestStringMapping(unittest.TestCase):
|
|||
def test_string_mapping_case2(self):
|
||||
_run_string_mapping(input=np.array(["a", "b", "c", "excel spreadsheet"]),
|
||||
output=np.array(["a", "b", "c", "excel"]),
|
||||
map=read_file(_get_test_data_file("data", "string_mapping.txt")))
|
||||
map=util.read_file(util.get_test_data_file("data", "string_mapping.txt")))
|
||||
|
||||
def test_string_mapping_case3(self):
|
||||
_run_string_mapping(
|
||||
|
@ -39,7 +29,7 @@ class TestStringMapping(unittest.TestCase):
|
|||
"powerpointpresentation"]),
|
||||
output=np.array(
|
||||
["a", "b", "c", "excel", "image", "imag", "ppt", "powerpointpresentation"]),
|
||||
map=read_file(_get_test_data_file("data", "string_mapping.txt")))
|
||||
map=util.read_file(util.get_test_data_file("data", "string_mapping.txt")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# coding: utf-8
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
|
@ -1168,6 +1167,7 @@ class TestPythonOpString(unittest.TestCase):
|
|||
split_unknown_characters=False))
|
||||
|
||||
ltk = [s.decode('utf-8') for s in tf_tokens.numpy()]
|
||||
txout = cc_txout
|
||||
check(ltk, txout[0])
|
||||
check(tf_rows.numpy(), txout[1])
|
||||
check(tf_begins.numpy(), txout[2])
|
||||
|
|
|
@ -10,24 +10,48 @@ def _run_string_to_vector(input, output, map, unk):
|
|||
|
||||
|
||||
class TestStringToVector(unittest.TestCase):
|
||||
|
||||
def test_string_to_vector1(self):
|
||||
_run_string_to_vector(input=np.array(["a", "b", "c", "unknown_word"]),
|
||||
output=np.array([[0], [2], [3], [-1]], dtype=np.int64),
|
||||
map={"a": [0], "b": [2], "c": [3]},
|
||||
unk=[-1])
|
||||
_run_string_to_vector(
|
||||
input=np.array(["a", "b", "c", "unknown_word"]),
|
||||
output=np.array([[0], [2], [3], [-1]], dtype=np.int64),
|
||||
map={"a": [0], "b": [2], "c": [3]},
|
||||
unk=[-1],
|
||||
)
|
||||
|
||||
def test_string_to_vector2(self):
|
||||
_run_string_to_vector(input=np.array(["a", "b", "c", "unknown_word"]),
|
||||
output=np.array([[0, 1, 2], [1, 2, 3], [2, 3, 4], [-1, -1, -1]], dtype=np.int64),
|
||||
map={"a": [0, 1, 2], "b": [1, 2, 3], "c": [2, 3, 4]},
|
||||
unk=[-1, -1, -1])
|
||||
_run_string_to_vector(
|
||||
input=np.array(["a", "b", "c", "unknown_word"]),
|
||||
output=np.array(
|
||||
[[0, 1, 2], [1, 2, 3], [2, 3, 4], [-1, -1, -1]], dtype=np.int64
|
||||
),
|
||||
map={"a": [0, 1, 2], "b": [1, 2, 3], "c": [2, 3, 4]},
|
||||
unk=[-1, -1, -1],
|
||||
)
|
||||
|
||||
def test_string_to_vector3(self):
|
||||
_run_string_to_vector(input=np.array(["a", "b", "c", "unknown_word", "你好", "下午", "测试"]),
|
||||
output=np.array([[0, 1, 2], [1, 2, 3], [2, 3, 4], [-1, -1, -1], [6, 6, 6], [7, 8, 9], [-1, -1, -1]], dtype=np.int64),
|
||||
map={"a": [0, 1, 2], "b": [1, 2, 3], "c": [2, 3, 4], "你好": [6, 6, 6], "下午": [7, 8, 9]},
|
||||
unk=[-1, -1, -1])
|
||||
_run_string_to_vector(
|
||||
input=np.array(["a", "b", "c", "unknown_word", "你好", "下午", "测试"]),
|
||||
output=np.array(
|
||||
[
|
||||
[0, 1, 2],
|
||||
[1, 2, 3],
|
||||
[2, 3, 4],
|
||||
[-1, -1, -1],
|
||||
[6, 6, 6],
|
||||
[7, 8, 9],
|
||||
[-1, -1, -1],
|
||||
],
|
||||
dtype=np.int64,
|
||||
),
|
||||
map={
|
||||
"a": [0, 1, 2],
|
||||
"b": [1, 2, 3],
|
||||
"c": [2, 3, 4],
|
||||
"你好": [6, 6, 6],
|
||||
"下午": [7, 8, 9],
|
||||
},
|
||||
unk=[-1, -1, -1],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -7,16 +7,14 @@ import torch
|
|||
import torchvision
|
||||
import onnxruntime as _ort
|
||||
|
||||
from onnx import load
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
from onnxruntime_extensions import (
|
||||
PyOp,
|
||||
onnx_op,
|
||||
PyOrtFunction,
|
||||
hook_model_op,
|
||||
get_library_path as _get_library_path)
|
||||
|
||||
from onnxruntime_extensions import PyOrtFunction
|
||||
|
||||
|
||||
def my_inverse(g, self):
|
||||
return g.op("ai.onnx.contrib::Inverse", self)
|
||||
|
@ -78,7 +76,7 @@ class TestPyTorchCustomOp(unittest.TestCase):
|
|||
# Export model to ONNX
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(CustomInverse(), (x0, x1), f, opset_version=12)
|
||||
onnx_model = load(io.BytesIO(f.getvalue()))
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
|
||||
|
||||
model = CustomInverse()
|
||||
|
|
Загрузка…
Ссылка в новой задаче