support the sequence tensor in pnp.ProcessingModule. (#197)

* support the sequence tensor in ProcessingModule.

* version check

* version check 2
This commit is contained in:
Wenbing Li 2022-02-04 14:48:47 -08:00 коммит произвёл GitHub
Родитель 459c4f7d61
Коммит d0ff193eec
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 57 добавлений и 7 удалений

Просмотреть файл

@ -4,7 +4,7 @@ import torch
import numpy
from torch.onnx import TrainingMode, export as _export
from ._ortapi2 import OrtPyFunction
from .pnp import ONNXModelUtils, ProcessingModule
from .pnp import ONNXModelUtils, ProcessingModule, ProcessingScriptModule
def _is_numpy_object(x):
@ -92,6 +92,15 @@ class ONNXCompose:
if post_m is not None:
model_l.append(post_m)
if output_file is not None:
# also output the pre/post-processing model for debugging
idx = 0
for _mdl in model_l:
if _mdl is self.models and isinstance(_mdl, onnx.ModelProto):
continue
onnx.save_model(_mdl, "{}_sub{}.onnx".format(output_file[:-5], idx))
idx += 1
full_m = ONNXModelUtils.join_models(*model_l, io_mapping=io_mapping)
if output_file is not None:
onnx.save_model(full_m, output_file)
@ -115,16 +124,26 @@ class ONNXCompose:
return all(_is_array(_x) for _x in x)
return _is_numpy_object(x) and (not _is_numpy_string_type(x))
def _from_numpy(x):
if isinstance(x, list):
return [torch.from_numpy(_x) for _x in x]
return torch.from_numpy(x)
# convert the raw value, and special handling for string.
n_args = [numpy.array(_arg) if not _is_tensor(_arg) else _arg for _arg in args]
n_args = [torch.from_numpy(_arg) if
n_args = [numpy.array(_arg) if
not (_is_tensor(_arg) or _is_array(_arg)) else _arg for _arg in args]
n_args = [_from_numpy(_arg) if
_is_array(_arg) else _arg for _arg in n_args]
self.pre_args = n_args
inputs = [self.preprocessors.forward(*n_args)]
flatten_inputs = []
for _i in inputs:
flatten_inputs += list(_i) if isinstance(_i, tuple) else [_i]
if isinstance(self.preprocessors, ProcessingScriptModule):
flatten_inputs = inputs
else:
flatten_inputs = []
for _i in inputs:
flatten_inputs += list(_i) if isinstance(_i, tuple) else [_i]
self.models_args = flatten_inputs
if isinstance(self.models, torch.nn.Module):
outputs = self.models.forward(*flatten_inputs)

Просмотреть файл

@ -1,5 +1,5 @@
from ._utils import ONNXModelUtils
from ._base import ProcessingModule, CustomFunction
from ._base import ProcessingModule, ProcessingScriptModule, CustomFunction
from ._functions import * # noqa
from ._imagenet import PreMobileNet, PostMobileNet

Просмотреть файл

@ -42,6 +42,11 @@ class ProcessingModule(torch.nn.Module):
return mdl
class ProcessingScriptModule(ProcessingModule):
def export(self, opset_version, *args, **kwargs):
return super().export(opset_version, *args, script_mode=True, **kwargs)
class CustomFunction(torch.autograd.Function):
@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:

Просмотреть файл

@ -2,6 +2,7 @@ import onnx
import numpy
import torch
import unittest
from typing import List
from PIL import Image
from distutils.version import LooseVersion
from onnxruntime_extensions import PyOrtFunction, ONNXCompose
@ -24,6 +25,16 @@ class _GPT2LMHeadModel(GPT2LMHeadModel):
return result[0]
class _SequenceTensorModel(pnp.ProcessingScriptModule):
def forward(self, img_list: List[torch.Tensor]) -> List[torch.Tensor]:
return img_list[0], img_list[1]
class _AddModel(torch.nn.Module):
def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
return input_list[1] + input_list[0] # test broadcasting.
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Only tested the lastest PyTorch')
class TestPreprocessing(unittest.TestCase):
def test_imagenet_preprocessing(self):
@ -61,6 +72,21 @@ class TestPreprocessing(unittest.TestCase):
# the random weight may generate a large diff in result, test the shape only.
self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
def test_sequence_tensor(self):
seq_m = ONNXCompose(torch.jit.script(_AddModel()), _SequenceTensorModel(), None)
test_input = [numpy.array([1]), numpy.array([3, 4]), numpy.array([5, 6])]
res = seq_m.predict(test_input)
numpy.testing.assert_allclose(res, numpy.array([4, 5]))
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
# The ONNX exporter fixing for sequence tensor only released in 1.11 and the above.
oxml = seq_m.export(12, output_file='temp_seqtest.onnx')
# TODO: ORT doesn't accept the default empty element type of a sequence type.
oxml.graph.input[0].type.sequence_type.elem_type.CopyFrom(
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT32, []))
mfunc = PyOrtFunction.from_model(oxml)
o_res = mfunc(test_input)
numpy.testing.assert_allclose(res, o_res)
if __name__ == "__main__":
unittest.main()