2020-10-23 11:11:11 +03:00
|
|
|
import io
|
2021-02-10 22:02:06 +03:00
|
|
|
import onnx
|
2020-10-23 11:11:11 +03:00
|
|
|
import numpy
|
2021-02-10 22:02:06 +03:00
|
|
|
import unittest
|
|
|
|
import platform
|
2021-04-28 20:56:30 +03:00
|
|
|
import torch
|
|
|
|
import torchvision
|
2021-02-10 22:02:06 +03:00
|
|
|
import onnxruntime as _ort
|
2020-10-23 11:11:11 +03:00
|
|
|
|
Update CI build workflow matrix
Upgraded the onnxruntime headers from v1.6 to v1.9
Update workflow matrix so it's consistent across the platforms and using
newer versions of the dependencies. Current supported matrix -
+------------+------------------------+-----------------------+-----------------------+----------------------+
|Python | 3.7 | 3.8 | 3.9 | 3.10 |
+------------+------------------------+-----------------------+-----------------------+----------------------+
|Onnxruntime | 1.9.0 (Sept 22, 2021) | 1.10.0 (Dec 7, 2021) | 1.11.0 (Mar 26, 2022) | 1.12.1 (Aug 4, 2022) |
|Torch | 1.9.1 (Sept 22, 2021) | 1.10.0 (Oct 21, 2021) | 1.11.0 (Mar 10, 2022) | 1.12.1 (Aug 5, 2022) |
|TorchVision | 0.10.1 (Jun 15, 2021) | 0.11.1 (Oct 21, 2021) | 0.12.0 (Mar 10, 2022) | 0.13.1 (Aug 5, 2022) |
|TorchAudio | 0.9.0 (Jun 15, 2021) | 0.10.0 (Oct 21, 2021) | 0.11.0 (Mar 10, 2022) | 0.12.1 (Aug 5, 2022) |
+------------+------------------------+-----------------------+-----------------------+----------------------+
Release versions strictly follow the convention of onnxruntime being one
release ahead of all its dependencies.
2022-09-02 04:03:27 +03:00
|
|
|
from distutils.version import LooseVersion
|
2020-10-23 11:11:11 +03:00
|
|
|
from torch.onnx import register_custom_op_symbolic
|
2021-05-12 22:02:57 +03:00
|
|
|
from onnxruntime_extensions import (
|
2021-03-12 21:39:21 +03:00
|
|
|
PyOp,
|
2020-10-23 11:11:11 +03:00
|
|
|
onnx_op,
|
2022-06-20 22:38:06 +03:00
|
|
|
PyOrtFunction,
|
2021-02-10 22:02:06 +03:00
|
|
|
hook_model_op,
|
2020-10-23 11:11:11 +03:00
|
|
|
get_library_path as _get_library_path)
|
|
|
|
|
|
|
|
|
|
|
|
def my_inverse(g, self):
|
|
|
|
return g.op("ai.onnx.contrib::Inverse", self)
|
|
|
|
|
|
|
|
|
2023-02-07 05:23:56 +03:00
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.13"):
|
|
|
|
register_custom_op_symbolic('::linalg_inv', my_inverse, 1)
|
|
|
|
else:
|
|
|
|
register_custom_op_symbolic('::inverse', my_inverse, 1)
|
2021-04-28 20:56:30 +03:00
|
|
|
|
|
|
|
|
|
|
|
def my_all(g, self):
|
|
|
|
return g.op("ai.onnx.contrib::All", self)
|
|
|
|
|
|
|
|
|
|
|
|
register_custom_op_symbolic('::all', my_all, 1)
|
|
|
|
|
|
|
|
|
2021-08-26 21:04:45 +03:00
|
|
|
class CustomTorchOp(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def symbolic(g, input):
|
|
|
|
return g.op("torchcustom::Add10", input)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, x):
|
|
|
|
return x + 10
|
|
|
|
|
|
|
|
|
2020-10-23 11:11:11 +03:00
|
|
|
class CustomInverse(torch.nn.Module):
|
2021-04-28 20:56:30 +03:00
|
|
|
def forward(self, x, y):
|
2021-08-26 21:04:45 +03:00
|
|
|
ress = CustomTorchOp.apply(torch.inverse(x))
|
2021-04-28 20:56:30 +03:00
|
|
|
return ress, torch.all(y)
|
2020-10-23 11:11:11 +03:00
|
|
|
|
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
class TestPyTorchCustomOp(unittest.TestCase):
|
|
|
|
|
2021-02-10 22:02:06 +03:00
|
|
|
_hooked = False
|
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
|
|
|
|
@onnx_op(op_type="Inverse")
|
|
|
|
def inverse(x):
|
|
|
|
# the user custom op implementation here:
|
|
|
|
return numpy.linalg.inv(x)
|
|
|
|
|
2021-04-28 20:56:30 +03:00
|
|
|
@onnx_op(op_type='All', inputs=[PyOp.dt_bool], outputs=[PyOp.dt_bool])
|
|
|
|
def op_all(x):
|
|
|
|
return numpy.all(x)
|
|
|
|
|
2021-08-26 21:04:45 +03:00
|
|
|
@onnx_op(op_type='torchcustom::Add10')
|
|
|
|
def op_add10(x):
|
|
|
|
return x + 10
|
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
def test_custom_pythonop_pytorch(self):
|
|
|
|
|
|
|
|
# register_custom_op_symbolic(
|
|
|
|
# '<namespace>::inverse', my_inverse, <opset_version>)
|
|
|
|
|
2021-04-28 20:56:30 +03:00
|
|
|
x0, x1 = torch.randn(3, 3), torch.tensor([True, False])
|
2020-10-30 13:20:18 +03:00
|
|
|
|
|
|
|
# Export model to ONNX
|
|
|
|
f = io.BytesIO()
|
2021-04-28 20:56:30 +03:00
|
|
|
torch.onnx.export(CustomInverse(), (x0, x1), f, opset_version=12)
|
2022-06-20 22:38:06 +03:00
|
|
|
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
2020-10-30 13:20:18 +03:00
|
|
|
self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
|
|
|
|
|
|
|
|
model = CustomInverse()
|
2021-06-10 19:57:49 +03:00
|
|
|
onnx.save_model(onnx_model, 'temp_pytorchcustomop.onnx')
|
2021-04-28 20:56:30 +03:00
|
|
|
pt_outputs = model(x0, x1)
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2021-07-09 01:36:11 +03:00
|
|
|
run_ort = PyOrtFunction.from_model(onnx_model)
|
2021-04-28 20:56:30 +03:00
|
|
|
ort_outputs = run_ort(x0.numpy(), x1.numpy())
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
# Validate PyTorch and ONNX Runtime results
|
2021-04-28 20:56:30 +03:00
|
|
|
numpy.testing.assert_allclose(pt_outputs[0].numpy(),
|
2020-10-30 13:20:18 +03:00
|
|
|
ort_outputs[0], rtol=1e-03, atol=1e-05)
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2021-02-10 22:02:06 +03:00
|
|
|
@staticmethod
|
|
|
|
def on_hook(*x):
|
|
|
|
TestPyTorchCustomOp._hooked = True
|
|
|
|
return x
|
|
|
|
|
Update CI build workflow matrix
Upgraded the onnxruntime headers from v1.6 to v1.9
Update workflow matrix so it's consistent across the platforms and using
newer versions of the dependencies. Current supported matrix -
+------------+------------------------+-----------------------+-----------------------+----------------------+
|Python | 3.7 | 3.8 | 3.9 | 3.10 |
+------------+------------------------+-----------------------+-----------------------+----------------------+
|Onnxruntime | 1.9.0 (Sept 22, 2021) | 1.10.0 (Dec 7, 2021) | 1.11.0 (Mar 26, 2022) | 1.12.1 (Aug 4, 2022) |
|Torch | 1.9.1 (Sept 22, 2021) | 1.10.0 (Oct 21, 2021) | 1.11.0 (Mar 10, 2022) | 1.12.1 (Aug 5, 2022) |
|TorchVision | 0.10.1 (Jun 15, 2021) | 0.11.1 (Oct 21, 2021) | 0.12.0 (Mar 10, 2022) | 0.13.1 (Aug 5, 2022) |
|TorchAudio | 0.9.0 (Jun 15, 2021) | 0.10.0 (Oct 21, 2021) | 0.11.0 (Mar 10, 2022) | 0.12.1 (Aug 5, 2022) |
+------------+------------------------+-----------------------+-----------------------+----------------------+
Release versions strictly follow the convention of onnxruntime being one
release ahead of all its dependencies.
2022-09-02 04:03:27 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
(platform.system() == 'Darwin') or (LooseVersion(_ort.__version__) > LooseVersion("1.11")),
|
|
|
|
"pytorch.onnx crashed for this case! and test asserts with higher versions of ort"
|
|
|
|
)
|
2021-02-10 22:02:06 +03:00
|
|
|
def test_pyop_hooking(self): # type: () -> None
|
|
|
|
model = torchvision.models.mobilenet_v2(pretrained=False)
|
|
|
|
x = torch.rand(1, 3, 224, 224)
|
|
|
|
with io.BytesIO() as f:
|
|
|
|
torch.onnx.export(model, (x, ), f)
|
|
|
|
model = onnx.load_model_from_string(f.getvalue())
|
|
|
|
|
2021-03-12 21:39:21 +03:00
|
|
|
self.assertTrue(model.graph.node[5].op_type == 'Conv')
|
|
|
|
hkd_model = hook_model_op(model, model.graph.node[5].name, TestPyTorchCustomOp.on_hook, [PyOp.dt_float] * 3)
|
2021-02-10 22:02:06 +03:00
|
|
|
|
|
|
|
so = _ort.SessionOptions()
|
|
|
|
so.register_custom_ops_library(_get_library_path())
|
|
|
|
sess = _ort.InferenceSession(hkd_model.SerializeToString(), so)
|
|
|
|
TestPyTorchCustomOp._hooked = False
|
|
|
|
sess.run(None, {'input.1': x.numpy()})
|
|
|
|
self.assertTrue(TestPyTorchCustomOp._hooked)
|
|
|
|
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|