support the sequence tensor in pnp.ProcessingModule. (#197)
* support the sequence tensor in ProcessingModule. * version check * version check 2
This commit is contained in:
Родитель
459c4f7d61
Коммит
d0ff193eec
|
@ -4,7 +4,7 @@ import torch
|
|||
import numpy
|
||||
from torch.onnx import TrainingMode, export as _export
|
||||
from ._ortapi2 import OrtPyFunction
|
||||
from .pnp import ONNXModelUtils, ProcessingModule
|
||||
from .pnp import ONNXModelUtils, ProcessingModule, ProcessingScriptModule
|
||||
|
||||
|
||||
def _is_numpy_object(x):
|
||||
|
@ -92,6 +92,15 @@ class ONNXCompose:
|
|||
if post_m is not None:
|
||||
model_l.append(post_m)
|
||||
|
||||
if output_file is not None:
|
||||
# also output the pre/post-processing model for debugging
|
||||
idx = 0
|
||||
for _mdl in model_l:
|
||||
if _mdl is self.models and isinstance(_mdl, onnx.ModelProto):
|
||||
continue
|
||||
onnx.save_model(_mdl, "{}_sub{}.onnx".format(output_file[:-5], idx))
|
||||
idx += 1
|
||||
|
||||
full_m = ONNXModelUtils.join_models(*model_l, io_mapping=io_mapping)
|
||||
if output_file is not None:
|
||||
onnx.save_model(full_m, output_file)
|
||||
|
@ -115,16 +124,26 @@ class ONNXCompose:
|
|||
return all(_is_array(_x) for _x in x)
|
||||
return _is_numpy_object(x) and (not _is_numpy_string_type(x))
|
||||
|
||||
def _from_numpy(x):
|
||||
if isinstance(x, list):
|
||||
return [torch.from_numpy(_x) for _x in x]
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# convert the raw value, and special handling for string.
|
||||
n_args = [numpy.array(_arg) if not _is_tensor(_arg) else _arg for _arg in args]
|
||||
n_args = [torch.from_numpy(_arg) if
|
||||
n_args = [numpy.array(_arg) if
|
||||
not (_is_tensor(_arg) or _is_array(_arg)) else _arg for _arg in args]
|
||||
n_args = [_from_numpy(_arg) if
|
||||
_is_array(_arg) else _arg for _arg in n_args]
|
||||
|
||||
self.pre_args = n_args
|
||||
inputs = [self.preprocessors.forward(*n_args)]
|
||||
flatten_inputs = []
|
||||
for _i in inputs:
|
||||
flatten_inputs += list(_i) if isinstance(_i, tuple) else [_i]
|
||||
if isinstance(self.preprocessors, ProcessingScriptModule):
|
||||
flatten_inputs = inputs
|
||||
else:
|
||||
flatten_inputs = []
|
||||
for _i in inputs:
|
||||
flatten_inputs += list(_i) if isinstance(_i, tuple) else [_i]
|
||||
|
||||
self.models_args = flatten_inputs
|
||||
if isinstance(self.models, torch.nn.Module):
|
||||
outputs = self.models.forward(*flatten_inputs)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from ._utils import ONNXModelUtils
|
||||
from ._base import ProcessingModule, CustomFunction
|
||||
from ._base import ProcessingModule, ProcessingScriptModule, CustomFunction
|
||||
from ._functions import * # noqa
|
||||
|
||||
from ._imagenet import PreMobileNet, PostMobileNet
|
||||
|
|
|
@ -42,6 +42,11 @@ class ProcessingModule(torch.nn.Module):
|
|||
return mdl
|
||||
|
||||
|
||||
class ProcessingScriptModule(ProcessingModule):
|
||||
def export(self, opset_version, *args, **kwargs):
|
||||
return super().export(opset_version, *args, script_mode=True, **kwargs)
|
||||
|
||||
|
||||
class CustomFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
|
||||
|
|
|
@ -2,6 +2,7 @@ import onnx
|
|||
import numpy
|
||||
import torch
|
||||
import unittest
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
from distutils.version import LooseVersion
|
||||
from onnxruntime_extensions import PyOrtFunction, ONNXCompose
|
||||
|
@ -24,6 +25,16 @@ class _GPT2LMHeadModel(GPT2LMHeadModel):
|
|||
return result[0]
|
||||
|
||||
|
||||
class _SequenceTensorModel(pnp.ProcessingScriptModule):
|
||||
def forward(self, img_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
return img_list[0], img_list[1]
|
||||
|
||||
|
||||
class _AddModel(torch.nn.Module):
|
||||
def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
return input_list[1] + input_list[0] # test broadcasting.
|
||||
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Only tested the lastest PyTorch')
|
||||
class TestPreprocessing(unittest.TestCase):
|
||||
def test_imagenet_preprocessing(self):
|
||||
|
@ -61,6 +72,21 @@ class TestPreprocessing(unittest.TestCase):
|
|||
# the random weight may generate a large diff in result, test the shape only.
|
||||
self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
|
||||
|
||||
def test_sequence_tensor(self):
|
||||
seq_m = ONNXCompose(torch.jit.script(_AddModel()), _SequenceTensorModel(), None)
|
||||
test_input = [numpy.array([1]), numpy.array([3, 4]), numpy.array([5, 6])]
|
||||
res = seq_m.predict(test_input)
|
||||
numpy.testing.assert_allclose(res, numpy.array([4, 5]))
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
|
||||
# The ONNX exporter fixing for sequence tensor only released in 1.11 and the above.
|
||||
oxml = seq_m.export(12, output_file='temp_seqtest.onnx')
|
||||
# TODO: ORT doesn't accept the default empty element type of a sequence type.
|
||||
oxml.graph.input[0].type.sequence_type.elem_type.CopyFrom(
|
||||
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT32, []))
|
||||
mfunc = PyOrtFunction.from_model(oxml)
|
||||
o_res = mfunc(test_input)
|
||||
numpy.testing.assert_allclose(res, o_res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче