ort-customops/test/test_torch_ops.py

126 строки
3.6 KiB
Python

import io
import onnx
import numpy
import unittest
import platform
import torch
import torchvision
import onnxruntime as _ort
from distutils.version import LooseVersion
from torch.onnx import register_custom_op_symbolic
from onnxruntime_extensions import (
PyOp,
onnx_op,
PyOrtFunction,
hook_model_op,
get_library_path as _get_library_path)
def my_inverse(g, self):
return g.op("ai.onnx.contrib::Inverse", self)
if LooseVersion(torch.__version__) >= LooseVersion("1.13"):
register_custom_op_symbolic('::linalg_inv', my_inverse, 1)
else:
register_custom_op_symbolic('::inverse', my_inverse, 1)
def my_all(g, self):
return g.op("ai.onnx.contrib::All", self)
register_custom_op_symbolic('::all', my_all, 1)
class CustomTorchOp(torch.autograd.Function):
@staticmethod
def symbolic(g, input):
return g.op("torchcustom::Add10", input)
@staticmethod
def forward(ctx, x):
return x + 10
class CustomInverse(torch.nn.Module):
def forward(self, x, y):
ress = CustomTorchOp.apply(torch.inverse(x))
return ress, torch.all(y)
class TestPyTorchCustomOp(unittest.TestCase):
_hooked = False
@classmethod
def setUpClass(cls):
@onnx_op(op_type="Inverse")
def inverse(x):
# the user custom op implementation here:
return numpy.linalg.inv(x)
@onnx_op(op_type='All', inputs=[PyOp.dt_bool], outputs=[PyOp.dt_bool])
def op_all(x):
return numpy.all(x)
@onnx_op(op_type='torchcustom::Add10')
def op_add10(x):
return x + 10
def test_custom_pythonop_pytorch(self):
# register_custom_op_symbolic(
# '<namespace>::inverse', my_inverse, <opset_version>)
x0, x1 = torch.randn(3, 3), torch.tensor([True, False])
# Export model to ONNX
f = io.BytesIO()
torch.onnx.export(CustomInverse(), (x0, x1), f, opset_version=12)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
model = CustomInverse()
onnx.save_model(onnx_model, 'temp_pytorchcustomop.onnx')
pt_outputs = model(x0, x1)
run_ort = PyOrtFunction.from_model(onnx_model)
ort_outputs = run_ort(x0.numpy(), x1.numpy())
# Validate PyTorch and ONNX Runtime results
numpy.testing.assert_allclose(pt_outputs[0].numpy(),
ort_outputs[0], rtol=1e-03, atol=1e-05)
@staticmethod
def on_hook(*x):
TestPyTorchCustomOp._hooked = True
return x
@unittest.skipIf(
(platform.system() == 'Darwin') or (LooseVersion(_ort.__version__) > LooseVersion("1.11")),
"pytorch.onnx crashed for this case! and test asserts with higher versions of ort"
)
def test_pyop_hooking(self): # type: () -> None
model = torchvision.models.mobilenet_v2(pretrained=False)
x = torch.rand(1, 3, 224, 224)
with io.BytesIO() as f:
torch.onnx.export(model, (x, ), f)
model = onnx.load_model_from_string(f.getvalue())
self.assertTrue(model.graph.node[5].op_type == 'Conv')
hkd_model = hook_model_op(model, model.graph.node[5].name, TestPyTorchCustomOp.on_hook, [PyOp.dt_float] * 3)
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(hkd_model.SerializeToString(), so)
TestPyTorchCustomOp._hooked = False
sess.run(None, {'input.1': x.numpy()})
self.assertTrue(TestPyTorchCustomOp._hooked)
if __name__ == "__main__":
unittest.main()