ort-customops/test/test_onnxprocess.py

90 строки
3.2 KiB
Python

import io
import onnx
import unittest
import platform
import torchvision
import numpy as np
from onnxruntime_extensions import PyOrtFunction, hook_model_op, PyOp
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model
@unittest.skipIf(platform.python_version_tuple()[0:2] == (
'3', '7'), 'Windows CI pipeline failed on the version temporarily.')
class TestTorchE2E(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mobilenet = torchvision.models.mobilenet_v2(pretrained=True)
cls.argmax_input = None
@staticmethod
def on_hook(*x):
TestTorchE2E.argmax_input = x[0]
return x
def test_range(self):
num = 10
f = io.BytesIO()
with trace_for_onnx(num, names=['count']) as tc_sess:
num_in = tc_sess.get_inputs()[0]
done = torch.tensor(True)
st_0 = torch.tensor(0)
cfg = torch.control_flow()
for _ in cfg.loop(num_in, done, st_0):
iter_num, *v = _
cfg.flow_output(done, st_0, iter_num + 0)
*_, rout = cfg.finalize()
tc_sess.save_as_onnx(f, rout)
m = onnx.load_model_from_string(f.getvalue())
onnx.save_model(m, 'temp_range.onnx')
fu_m = PyOrtFunction.from_model(m)
result = fu_m(num)
np.testing.assert_array_equal(result, np.array(range(num)))
def test_sequence(self):
input_text = ['test sentence', 'sentence 2']
f = io.BytesIO()
with trace_for_onnx(input_text, names=['in_text']) as tc_sess:
tc_inputs = tc_sess.get_inputs()[0]
batchsize = tc_inputs.size()[0]
shape = [batchsize, 2]
fuse_output = torch.zeros(*shape).size()
tc_sess.save_as_onnx(f, fuse_output)
m = onnx.load_model_from_string(f.getvalue())
onnx.save_model(m, 'temp_test00.onnx')
fu_m = PyOrtFunction.from_model(m)
result = fu_m(input_text)
np.testing.assert_array_equal(result, [2, 2])
def test_imagenet_postprocess(self):
mb_core_path = 'temp_mobilev2.onnx'
mb_full_path = 'temp_mobilev2_full.onnx'
dummy_input = torch.randn(10, 3, 224, 224)
np_input = dummy_input.numpy()
torch.onnx.export(self.mobilenet, dummy_input, mb_core_path, opset_version=11)
mbnet2 = pyfunc_from_model(mb_core_path)
with trace_for_onnx(dummy_input, names=['b10_input']) as tc_sess:
scores = mbnet2(*tc_sess.get_inputs())
probabilities = torch.softmax(scores, dim=1)
batch_top1 = probabilities.argmax(dim=1)
np_argmax = probabilities.numpy() # for the result comparison
np_output = batch_top1.numpy()
tc_sess.save_as_onnx(mb_full_path, batch_top1)
hkdmdl = hook_model_op(onnx.load_model(mb_full_path), 'argmax', self.on_hook, [PyOp.dt_float])
mbnet2_full = PyOrtFunction.from_model(hkdmdl)
batch_top1_2 = mbnet2_full(np_input)
np.testing.assert_allclose(np_argmax, self.argmax_input, rtol=1e-5)
np.testing.assert_array_equal(batch_top1_2, np_output)
if __name__ == "__main__":
unittest.main()