2020-10-30 13:20:18 +03:00
|
|
|
import unittest
|
|
|
|
from onnx import load
|
2020-10-23 11:11:11 +03:00
|
|
|
import torch
|
|
|
|
import onnxruntime as _ort
|
|
|
|
import io
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
from torch.onnx import register_custom_op_symbolic
|
2020-12-23 07:16:30 +03:00
|
|
|
from onnxruntime_customops import (
|
2020-10-23 11:11:11 +03:00
|
|
|
onnx_op,
|
|
|
|
get_library_path as _get_library_path)
|
|
|
|
|
|
|
|
|
|
|
|
def my_inverse(g, self):
|
|
|
|
return g.op("ai.onnx.contrib::Inverse", self)
|
|
|
|
|
|
|
|
|
|
|
|
class CustomInverse(torch.nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.inverse(x) + x
|
|
|
|
|
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
class TestPyTorchCustomOp(unittest.TestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
|
|
|
|
@onnx_op(op_type="Inverse")
|
|
|
|
def inverse(x):
|
|
|
|
# the user custom op implementation here:
|
|
|
|
return numpy.linalg.inv(x)
|
|
|
|
|
|
|
|
def test_custom_pythonop_pytorch(self):
|
|
|
|
|
|
|
|
# register_custom_op_symbolic(
|
|
|
|
# '<namespace>::inverse', my_inverse, <opset_version>)
|
|
|
|
register_custom_op_symbolic('::inverse', my_inverse, 1)
|
|
|
|
|
|
|
|
x = torch.randn(3, 3)
|
|
|
|
|
|
|
|
# Export model to ONNX
|
|
|
|
f = io.BytesIO()
|
|
|
|
torch.onnx.export(CustomInverse(), (x,), f)
|
|
|
|
onnx_model = load(io.BytesIO(f.getvalue()))
|
|
|
|
self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
|
|
|
|
|
|
|
|
model = CustomInverse()
|
|
|
|
pt_outputs = model(x)
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
so = _ort.SessionOptions()
|
|
|
|
so.register_custom_ops_library(_get_library_path())
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
# Run the exported model with ONNX Runtime
|
|
|
|
ort_sess = _ort.InferenceSession(f.getvalue(), so)
|
|
|
|
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy())
|
|
|
|
for i, input in enumerate((x,)))
|
|
|
|
ort_outputs = ort_sess.run(None, ort_inputs)
|
2020-10-23 11:11:11 +03:00
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
# Validate PyTorch and ONNX Runtime results
|
|
|
|
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(),
|
|
|
|
ort_outputs[0], rtol=1e-03, atol=1e-05)
|
2020-10-23 11:11:11 +03:00
|
|
|
|
|
|
|
|
2020-10-30 13:20:18 +03:00
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|