The new pre/post processing API, replacing ONNXCompose (#205)
* traced processing module * before debugging. * updates * temporary * the trace mode pass * code adjusting for ci pipeline. * only torch 1.11 support prim:pythonop * extending sequence processing module.
This commit is contained in:
Родитель
2e2ee11772
Коммит
4bb3a22c45
|
@ -10,6 +10,7 @@ The entry point to onnxruntime custom op library
|
|||
__author__ = "Microsoft"
|
||||
|
||||
import pathlib
|
||||
import inspect
|
||||
from ._version import __version__
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
||||
|
@ -19,22 +20,16 @@ from ._ocos import expand_onnx_inputs # noqa
|
|||
from ._ocos import hook_model_op # noqa
|
||||
from ._ocos import default_opset_domain # noqa
|
||||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction, optimize_model, make_onnx_model
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction
|
||||
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model
|
||||
|
||||
|
||||
onnx_op = Opdef.declare
|
||||
PyOp = PyCustomOpDef
|
||||
|
||||
|
||||
# ONNX-Compose depends PyTorch, which is optional for onnxruntime-extensions.
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
from .compose import ONNXCompose
|
||||
|
||||
|
||||
def get_test_data_file(case_file, *sub_dirs):
|
||||
# do a favour for the unit test.
|
||||
def get_test_data_file(*sub_dirs):
|
||||
case_file = inspect.currentframe().f_back.f_code.co_filename
|
||||
test_dir = pathlib.Path(case_file).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
|
|
@ -35,6 +35,8 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
|
|||
|
||||
class OrtPyFunction:
|
||||
|
||||
__name__ = 'OrtPyFunction'
|
||||
|
||||
@classmethod
|
||||
def get_ort_session_options(cls):
|
||||
# ONNXRuntime has an issue to support reusing the SessionOptions object.
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import numpy
|
||||
from torch.onnx import TrainingMode, export as _export
|
||||
from ._ortapi2 import OrtPyFunction
|
||||
from .pnp import ONNXModelUtils, ProcessingModule, ProcessingScriptModule
|
||||
from .pnp import ONNXModelUtils, _ProcessingModule, ProcessingScriptModule
|
||||
|
||||
|
||||
def _is_numpy_object(x):
|
||||
|
@ -44,9 +44,9 @@ class ONNXCompose:
|
|||
"""
|
||||
def __init__(self, models=None, preprocessors=None, postprocessors=None):
|
||||
|
||||
assert isinstance(preprocessors, ProcessingModule),\
|
||||
assert isinstance(preprocessors, _ProcessingModule),\
|
||||
'preprocessors must be subclassing from ProcessingModule'
|
||||
assert postprocessors is None or isinstance(postprocessors, ProcessingModule),\
|
||||
assert postprocessors is None or isinstance(postprocessors, _ProcessingModule),\
|
||||
'postprocessors must be subclassing from ProcessingModule'
|
||||
self.models = models
|
||||
self.preprocessors = preprocessors
|
||||
|
@ -93,7 +93,7 @@ class ONNXCompose:
|
|||
model_l.append(post_m)
|
||||
|
||||
if output_file is not None:
|
||||
# also output the pre/post-processing model for debugging
|
||||
# also output the pre- / post-processing model for debugging
|
||||
idx = 0
|
||||
for _mdl in model_l:
|
||||
if _mdl is self.models and isinstance(_mdl, onnx.ModelProto):
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
from ._utils import ONNXModelUtils
|
||||
from ._base import ProcessingModule, ProcessingScriptModule, CustomFunction
|
||||
from ._functions import * # noqa
|
||||
# onnxruntime-extensions pre&post processing frontend depends on the PyTorch
|
||||
try:
|
||||
import torch
|
||||
except ImportError as e:
|
||||
print("No torch installation found, which is required by the pre&post scripting!")
|
||||
raise e
|
||||
|
||||
from ._imagenet import PreMobileNet, PostMobileNet
|
||||
from ._base import _ProcessingModule, ProcessingScriptModule, CustomFunction
|
||||
from ._torchext import * # noqa
|
||||
from ._unifier import export
|
||||
|
||||
from ._imagenet import * # noqa
|
||||
from ._nlp import PreHuggingFaceGPT2
|
||||
|
|
|
@ -3,14 +3,55 @@ import onnx
|
|||
import torch
|
||||
from typing import Any
|
||||
from onnx.onnx_pb import TensorProto
|
||||
from torch.onnx import TrainingMode, export as _export
|
||||
|
||||
|
||||
class ProcessingModule(torch.nn.Module):
|
||||
def _export_f(model, args=None,
|
||||
opset_version=None,
|
||||
output_path=None,
|
||||
output_seq=0,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
input_names=None,
|
||||
output_names=None,
|
||||
operator_export_type=None,
|
||||
do_constant_folding=True,
|
||||
dynamic_axes=None,
|
||||
keep_initializers_as_inputs=None,
|
||||
custom_opsets=None):
|
||||
|
||||
with io.BytesIO() as f:
|
||||
_export(model, args, f,
|
||||
export_params=export_params, verbose=verbose,
|
||||
training=TrainingMode.EVAL, input_names=input_names,
|
||||
output_names=output_names,
|
||||
operator_export_type=operator_export_type, opset_version=opset_version,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets)
|
||||
|
||||
mdl = onnx.load_model(io.BytesIO(f.getvalue()))
|
||||
if output_path is not None:
|
||||
if output_seq > 0:
|
||||
output_path.replace('.onnx', '.{}.onnx'.format(output_seq))
|
||||
onnx.save_model(mdl, output_path)
|
||||
return mdl
|
||||
|
||||
|
||||
class _ProcessingModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(_ProcessingModule, self).__init__()
|
||||
_ProcessingModule.register_customops()
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.unused
|
||||
def _argsort(g, x, dim, descending):
|
||||
return g.op('ai.onnx.contrib::ArgSort', x, dim)
|
||||
|
||||
@classmethod
|
||||
@torch.jit.unused
|
||||
def register_customops(cls):
|
||||
if hasattr(cls, 'loaded'):
|
||||
return True
|
||||
|
@ -21,30 +62,32 @@ class ProcessingModule(torch.nn.Module):
|
|||
cls.loaded = True
|
||||
return True
|
||||
|
||||
def export(self, opset_version, *args, **kwargs):
|
||||
@torch.jit.unused
|
||||
def export(self, *args, opset_version=None, script_mode=False, output_path=None, output_seq=0, **kwargs):
|
||||
if opset_version is None:
|
||||
raise RuntimeError('No opset_version found in the kwargs.')
|
||||
mod = self
|
||||
script_model = kwargs.pop('script_mode', False)
|
||||
if script_model:
|
||||
if script_mode and not isinstance(mod, torch.jit.ScriptModule):
|
||||
mod = torch.jit.script(mod)
|
||||
|
||||
ofname = kwargs.pop('ofname', None)
|
||||
|
||||
with io.BytesIO() as f:
|
||||
torch.onnx.export(mod, args, f,
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
opset_version=opset_version,
|
||||
**kwargs)
|
||||
|
||||
mdl = onnx.load_model(io.BytesIO(f.getvalue()))
|
||||
if ofname is not None:
|
||||
ofname.replace('.onnx', '.1.onnx')
|
||||
onnx.save_model(mdl, ofname)
|
||||
return mdl
|
||||
return _export_f(mod,
|
||||
*args,
|
||||
opset_version=opset_version,
|
||||
output_path=output_path,
|
||||
output_seq=output_seq, **kwargs)
|
||||
|
||||
|
||||
class ProcessingScriptModule(ProcessingModule):
|
||||
def export(self, opset_version, *args, **kwargs):
|
||||
return super().export(opset_version, *args, script_mode=True, **kwargs)
|
||||
class ProcessingTracedModule(_ProcessingModule):
|
||||
pass
|
||||
|
||||
|
||||
class ProcessingScriptModule(_ProcessingModule):
|
||||
def __init__(self):
|
||||
super(ProcessingScriptModule, self).__init__()
|
||||
|
||||
@torch.jit.unused
|
||||
def export(self, *args, opset_version=None, **kwargs):
|
||||
return super().export(*args, opset_version=opset_version, script_mode=True, **kwargs)
|
||||
|
||||
|
||||
class CustomFunction(torch.autograd.Function):
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from typing import Tuple
|
||||
from torch.nn.functional import interpolate
|
||||
from ._base import ProcessingModule
|
||||
from ._functions import onnx_where, onnx_greater
|
||||
from ._base import ProcessingTracedModule
|
||||
from ._torchext import onnx_where, onnx_greater
|
||||
|
||||
|
||||
def _resize_param(img, size):
|
||||
|
@ -11,7 +12,7 @@ def _resize_param(img, size):
|
|||
return onnx_where(onnx_greater(scale_x, scale_y), scale_x, scale_y)
|
||||
|
||||
|
||||
class ImageNetPreProcessing(ProcessingModule):
|
||||
class ImageNetPreProcessing(ProcessingTracedModule):
|
||||
def __init__(self, size, resize_image=True):
|
||||
super(ImageNetPreProcessing, self).__init__()
|
||||
self.target_size = size
|
||||
|
@ -47,17 +48,9 @@ class ImageNetPreProcessing(ProcessingModule):
|
|||
x /= torch.reshape(torch.tensor(std), (3, 1, 1))
|
||||
return x
|
||||
|
||||
def export(self, opset_version, *args, **kwargs):
|
||||
name_i = 'image'
|
||||
return super().export(opset_version, *args,
|
||||
input_names=[name_i],
|
||||
dynamic_axes={name_i: [0, 1]},
|
||||
**kwargs)
|
||||
|
||||
|
||||
class ImageNetPostProcessing(ProcessingModule):
|
||||
def forward(self, scores):
|
||||
ProcessingModule.register_customops()
|
||||
class ImageNetPostProcessing(ProcessingTracedModule):
|
||||
def forward(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probabilities = torch.softmax(scores, dim=1)
|
||||
top10_prob, top10_ids = probabilities.topk(k=10, dim=1, largest=True, sorted=True)
|
||||
return top10_ids, top10_prob
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
|
||||
from ._base import ProcessingModule, tensor_data_type as _dt
|
||||
from ._functions import create_op_function
|
||||
from ._base import ProcessingTracedModule, tensor_data_type as _dt
|
||||
from ._torchext import create_op_function
|
||||
from ._onnx_ops import schema
|
||||
from .._ocos import default_opset_domain
|
||||
|
||||
|
@ -44,7 +44,7 @@ def _get_bound_object(func):
|
|||
return func.__self__
|
||||
|
||||
|
||||
class PreHuggingFaceGPT2(ProcessingModule):
|
||||
class PreHuggingFaceGPT2(ProcessingTracedModule):
|
||||
def __init__(self, hf_tok=None, vocab_file=None, merges_file=None, padding_length=-1):
|
||||
super(PreHuggingFaceGPT2, self).__init__()
|
||||
if hf_tok is None:
|
||||
|
@ -58,5 +58,5 @@ class PreHuggingFaceGPT2(ProcessingModule):
|
|||
def forward(self, text):
|
||||
return self.onnx_gpt2_tokenize(text)
|
||||
|
||||
def export(self, opset_version, *args, **kwargs):
|
||||
def export(self, *args, opset_version=0, **kwargs):
|
||||
return _get_bound_object(self.onnx_gpt2_tokenize).build_model(opset_version, *args)
|
||||
|
|
|
@ -4,8 +4,10 @@ import numpy as np
|
|||
from typing import Any
|
||||
from onnx import helper
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
from distutils.version import LooseVersion
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
|
||||
from ._base import CustomFunction
|
||||
from ._base import CustomFunction, ProcessingTracedModule
|
||||
from ._onnx_ops import ox as _ox, schema as _schema
|
||||
from ._onnx_ops import ONNXElementContainer, make_model_ex
|
||||
from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
|
||||
|
@ -35,7 +37,7 @@ def _to_onnx_type(dtype):
|
|||
return ty_dict.get(dtype, onnx_proto.TensorProto.STRING)
|
||||
|
||||
|
||||
class ONNXOpFunction(CustomFunction):
|
||||
class OnnxOpFunction(CustomFunction):
|
||||
@classmethod
|
||||
def get_next_id_name(cls, name_base):
|
||||
name = 'cls' if name_base is None else name_base
|
||||
|
@ -54,6 +56,8 @@ class ONNXOpFunction(CustomFunction):
|
|||
@classmethod
|
||||
def build_model(cls, opset_version, *args):
|
||||
# build the one node graph
|
||||
if isinstance(args[0], list):
|
||||
args = [np.asarray(_i) for _i in args]
|
||||
ec = ONNXElementContainer(get_opset_version_from_ort() if opset_version is None else opset_version)
|
||||
attrs = cls.attrs
|
||||
vi_inputs = [helper.make_tensor_value_info(
|
||||
|
@ -101,7 +105,7 @@ class ONNXOpFunction(CustomFunction):
|
|||
def create_op_function(op_type: str, func, **attrs):
|
||||
if _ox.is_raw(func):
|
||||
func = _schema(func.__func__)
|
||||
cls = type(_ox.get_unique_operator_type_name(op_type), (ONNXOpFunction, ),
|
||||
cls = type(_ox.get_unique_operator_type_name(op_type), (OnnxOpFunction, ),
|
||||
dict(
|
||||
op_type=op_type,
|
||||
opb_func=func,
|
||||
|
@ -118,19 +122,19 @@ onnx_greater = create_op_function('Greater', _ox.greater)
|
|||
class PythonOpFunction:
|
||||
"""
|
||||
PythonOpFunction wraps a generic Python function which skips forward operation on torch.onnx.exporter
|
||||
converting in the script mode, which cannot support the API from external package, like Numpy.
|
||||
Autograd.Function cannot be used torch.jit.script.
|
||||
converting in the script mode, since the exporter cannot support the APIs from external package, like Numpy.
|
||||
BTW, Autograd.Function cannot be used torch.jit.script.
|
||||
"""
|
||||
id_func_map = {}
|
||||
current_func_id = 0
|
||||
|
||||
@staticmethod
|
||||
def get_next_id():
|
||||
def _get_next_id():
|
||||
PythonOpFunction.current_func_id += 1
|
||||
return PythonOpFunction.current_func_id
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.ignore(drop_on_export=False)
|
||||
@torch.jit.ignore
|
||||
def _pass_through_call(*args, **kwargs):
|
||||
func_id = args[0]
|
||||
func = PythonOpFunction.id_func_map[func_id]
|
||||
|
@ -143,7 +147,102 @@ class PythonOpFunction:
|
|||
@classmethod
|
||||
def get_id(cls):
|
||||
if not hasattr(cls, 'func_id'):
|
||||
_id = PythonOpFunction.get_next_id()
|
||||
_id = PythonOpFunction._get_next_id()
|
||||
setattr(cls, 'func_id', _id)
|
||||
PythonOpFunction.id_func_map[_id] = cls
|
||||
return cls.func_id
|
||||
|
||||
|
||||
class _OnnxModelFunction:
|
||||
id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type
|
||||
str_model_function_id = '_model_function_id'
|
||||
str_model_id = '_model_id'
|
||||
str_model_attached = '_model_attached'
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _invoke_onnx_model(model_id: int, *args, **kwargs):
|
||||
model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
|
||||
if model_or_path is None:
|
||||
raise ValueError("cannot find id={} registered!".format(model_id))
|
||||
func = OrtPyFunction.from_model(model_or_path)
|
||||
return torch.from_numpy(
|
||||
func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def invoke_onnx_model1(model_id: int, arg0):
|
||||
return _invoke_onnx_model(model_id, arg0)
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def invoke_onnx_model2(model_id: int, arg0, arg1):
|
||||
return _invoke_onnx_model(model_id, arg0, arg1)
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def invoke_onnx_model3(model_id: int, arg0, arg1, arg2):
|
||||
return _invoke_onnx_model(model_id, arg0, arg1, arg2)
|
||||
|
||||
|
||||
class _OnnxTracedFunction(CustomFunction):
|
||||
@classmethod
|
||||
def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
return _invoke_onnx_model(args[0].item(), *args[1:], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def symbolic(cls, g, *args):
|
||||
return g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
|
||||
|
||||
|
||||
def create_model_function(model_or_path):
|
||||
_id = id(model_or_path)
|
||||
assert _id != 0, "internal error: the id of a Python object is 0."
|
||||
_OnnxModelFunction.id_object_map[_id] = model_or_path
|
||||
return _id
|
||||
|
||||
|
||||
def get_id_models():
|
||||
return _OnnxModelFunction.id_object_map
|
||||
|
||||
|
||||
class SequenceProcessingModule(ProcessingTracedModule):
|
||||
def __init__(self, *models):
|
||||
super(SequenceProcessingModule, self).__init__()
|
||||
self.model_list = models
|
||||
self.model_function_ids = []
|
||||
for mdl_ in models:
|
||||
if isinstance(mdl_, onnx.ModelProto):
|
||||
self.model_function_ids.append(create_model_function(mdl_))
|
||||
else:
|
||||
self.model_function_ids.append(0)
|
||||
|
||||
def forward(self, *args):
|
||||
outputs = args
|
||||
for idx_, mdl_ in enumerate(self.model_list):
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs, )
|
||||
if self.model_function_ids[idx_] != 0:
|
||||
outputs = _OnnxTracedFunction.apply(torch.tensor(self.model_function_ids[idx_]), *outputs)
|
||||
else:
|
||||
outputs = self.model_list[idx_].forward(*outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def _symbolic_pythonop(g: torch._C.Graph, n: torch._C.Node, *args, **kwargs):
|
||||
name = kwargs["name"]
|
||||
if name.startswith(invoke_onnx_model1.__name__[:-1]):
|
||||
# id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]).item()
|
||||
ret = g.op("ai.onnx.contrib::_ModelFunctionCall", *args)
|
||||
else:
|
||||
# Logs a warning and returns None
|
||||
import warnings
|
||||
return warnings.warn("prim::PythonOp", "unknown node kind: " + name)
|
||||
# Copy type and shape from original node.
|
||||
ret.setType(args[-1].type())
|
||||
return ret
|
||||
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
|
||||
register_custom_op_symbolic("prim::PythonOp", _symbolic_pythonop, 1)
|
|
@ -0,0 +1,40 @@
|
|||
import onnx
|
||||
|
||||
from .._ortapi2 import get_opset_version_from_ort
|
||||
from ._utils import ONNXModelUtils
|
||||
from ._torchext import get_id_models
|
||||
|
||||
|
||||
def export(m, *args,
|
||||
opset_version=0,
|
||||
output_path=None,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
input_names=None,
|
||||
output_names=None,
|
||||
operator_export_type=None,
|
||||
do_constant_folding=True,
|
||||
dynamic_axes=None,
|
||||
keep_initializers_as_inputs=None,
|
||||
custom_opsets=None,
|
||||
io_mapping=None):
|
||||
"""
|
||||
export all models and modules into a merged ONNX model.
|
||||
"""
|
||||
if opset_version == 0:
|
||||
opset_version = get_opset_version_from_ort()
|
||||
model = m.export(*args, opset_version=opset_version,
|
||||
output_path=output_path,
|
||||
export_params=export_params,
|
||||
verbose=verbose,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
operator_export_type=operator_export_type,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets)
|
||||
full_m = ONNXModelUtils.unfold_model(model, get_id_models(), io_mapping)
|
||||
if output_path is not None:
|
||||
onnx.save_model(full_m, output_path)
|
||||
return full_m
|
|
@ -1,8 +1,26 @@
|
|||
import copy
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class _Container:
|
||||
def __init__(self):
|
||||
self.parent = None
|
||||
self.initializer=[]
|
||||
self.value_info=[]
|
||||
self.nodes = []
|
||||
self.node_domain_version_pair_sets = {}
|
||||
|
||||
def add_model(self, oxml):
|
||||
self.initializer.extend(oxml.graph.initializer)
|
||||
self.value_info.extend(oxml.graph.value_info)
|
||||
self.nodes.extend(oxml.graph.node)
|
||||
self.node_domain_version_pair_sets.update(
|
||||
[(opset_.domain, opset_.version) for opset_ in oxml.opset_import])
|
||||
return self
|
||||
|
||||
|
||||
class ONNXModelUtils:
|
||||
@staticmethod
|
||||
def merge_name(prefix, name):
|
||||
|
@ -57,24 +75,86 @@ class ONNXModelUtils:
|
|||
node.attribute.extend(attr_list)
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_model_name_abbr(node):
|
||||
no = node.name.split('_')[-1]
|
||||
return 'm_' + no
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_from_arg0(nodes, node):
|
||||
arg0_name = node.input[0]
|
||||
c_node = [n_ for n_ in nodes if
|
||||
n_.op_type == 'Constant' and n_.output[0] == arg0_name]
|
||||
assert len(c_node) == 1, 'internal error, multiple nodes with the same output.'
|
||||
c_node = c_node[0]
|
||||
tensor_value = onnx.helper.get_attribute_value(c_node.attribute[0])
|
||||
_id = numpy_helper.to_array(tensor_value).item()
|
||||
return _id
|
||||
|
||||
@classmethod
|
||||
def unfold_model_node(cls, container):
|
||||
top_containter = container
|
||||
while top_containter.parent is not None: # only one opset_import in the model.
|
||||
top_containter = top_containter.parent
|
||||
def _unfold_model_node(cls, container, name, model, io_mapping=None):
|
||||
top_container = container
|
||||
while top_container.parent is not None: # only one opset_import in the model.
|
||||
top_container = top_container.parent
|
||||
|
||||
nodes = container.nodes
|
||||
model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')}
|
||||
onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes]
|
||||
renamed_nodes = cls._rename_graph(model.graph, name, container)
|
||||
onnx_nodes = [cls._process_node_body(nd_, name) for nd_ in renamed_nodes]
|
||||
|
||||
for node in model_nodes.values():
|
||||
renamed_nodes = cls._rename_graph(node.model.graph, node.name, container)
|
||||
onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes)
|
||||
|
||||
top_containter.node_domain_version_pair_sets.update(
|
||||
[(opset_.domain, opset_.version) for opset_ in node.model.opset_import])
|
||||
top_container.node_domain_version_pair_sets.update(
|
||||
[(opset_.domain, opset_.version) for opset_ in model.opset_import])
|
||||
return onnx_nodes
|
||||
|
||||
@classmethod
|
||||
def unfold_model(cls, oxml, id_to_model, io_mapping=None):
|
||||
container = _Container().add_model(oxml)
|
||||
nodes = []
|
||||
for _nid, _node in enumerate(oxml.graph.node):
|
||||
if _node.op_type != '_ModelFunctionCall':
|
||||
nodes.append(_node)
|
||||
else:
|
||||
model_id = cls.get_model_id_from_arg0(list(oxml.graph.node), _node)
|
||||
if model_id not in id_to_model:
|
||||
raise RuntimeError("Cannot find the model id({}) in the table".format(model_id))
|
||||
|
||||
prefix = cls.get_model_name_abbr(_node)
|
||||
nest_model = id_to_model[model_id]
|
||||
|
||||
input_mapping = []
|
||||
output_mapping = []
|
||||
for idx_, in_ in enumerate(nest_model.graph.input):
|
||||
_renamed_in = "{}_{}".format(prefix, in_.name)
|
||||
_nd = onnx.helper.make_node('Identity',
|
||||
[_node.input[idx_ + 1]], # the first arg is model id, skip it.
|
||||
[_renamed_in],
|
||||
name='i_' + _renamed_in)
|
||||
input_mapping.append(_nd)
|
||||
nds = cls._unfold_model_node(container,
|
||||
prefix,
|
||||
nest_model,
|
||||
io_mapping)
|
||||
for idx_, out_ in enumerate(nest_model.graph.output):
|
||||
_renamed_out = "{}_{}".format(prefix, out_.name)
|
||||
_nd = onnx.helper.make_node('Identity',
|
||||
[_renamed_out],
|
||||
[_node.output[idx_]],
|
||||
name='o_' + _renamed_out)
|
||||
output_mapping.append(_nd)
|
||||
if io_mapping is not None:
|
||||
assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
|
||||
input_mapping, output_mapping = io_mapping(input_mapping, output_mapping)
|
||||
# attention: the order of the list operations is important, which avoids the topological sort.
|
||||
nodes.extend(input_mapping)
|
||||
nodes.extend(nds)
|
||||
nodes.extend(output_mapping)
|
||||
|
||||
intlzs = cls._remove_unused_initializers(nodes, container.initializer)
|
||||
oxml = copy.deepcopy(oxml)
|
||||
del oxml.graph.node[:]
|
||||
oxml.graph.node.extend(nodes)
|
||||
del oxml.graph.initializer[:]
|
||||
oxml.graph.initializer.extend(intlzs)
|
||||
return oxml
|
||||
|
||||
@classmethod
|
||||
def topological_sort(cls, container, nodes, inputs, outputs):
|
||||
op_output_map = {}
|
||||
|
@ -136,12 +216,14 @@ class ONNXModelUtils:
|
|||
return sorted_nodes
|
||||
|
||||
@staticmethod
|
||||
def _remove_unused_initializers(nodes, initializers, reversed_names):
|
||||
def _remove_unused_initializers(nodes, initializers, reserved_names=None):
|
||||
if reserved_names is None:
|
||||
reserved_names = set()
|
||||
nodes_input_set = set()
|
||||
for nd_ in nodes:
|
||||
nodes_input_set.update(n_ for n_ in nd_.input)
|
||||
|
||||
return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reversed_names]
|
||||
return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reserved_names]
|
||||
|
||||
@classmethod
|
||||
def join_models(cls, *models, io_mapping=None):
|
||||
|
@ -171,11 +253,9 @@ class ONNXModelUtils:
|
|||
port_mapping[iname] = oname
|
||||
|
||||
nodes = []
|
||||
Container = namedtuple('Container', ['initializer', 'value_info'])
|
||||
container = Container(initializer=[], value_info=[])
|
||||
container = _Container()
|
||||
for _idx, _m in enumerate(models):
|
||||
container.initializer.extend(_m.graph.initializer)
|
||||
container.value_info.extend(_m.graph.value_info)
|
||||
container.initializer.add_model(_m)
|
||||
nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container)
|
||||
|
||||
for _n in nodes:
|
||||
|
@ -199,7 +279,7 @@ class ONNXModelUtils:
|
|||
_opset.append(_ops)
|
||||
name = name + '_' + _mdl.graph.name if name else _mdl.graph.name
|
||||
|
||||
inits = cls._remove_unused_initializers(nodes, container.initializer, set())
|
||||
inits = cls._remove_unused_initializers(nodes, container.initializer)
|
||||
helper = onnx.helper
|
||||
g = helper.make_graph(nodes, name, inputs, outputs,
|
||||
initializer=inits,
|
||||
|
|
|
@ -2,10 +2,10 @@ import onnx
|
|||
import numpy
|
||||
import torch
|
||||
import unittest
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
from PIL import Image
|
||||
from distutils.version import LooseVersion
|
||||
from onnxruntime_extensions import PyOrtFunction, ONNXCompose
|
||||
from onnxruntime_extensions import OrtPyFunction
|
||||
from onnxruntime_extensions import pnp, get_test_data_file
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
@ -13,7 +13,6 @@ from transformers import GPT2Config, GPT2LMHeadModel
|
|||
class _GPT2LMHeadModel(GPT2LMHeadModel):
|
||||
""" Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -25,34 +24,50 @@ class _GPT2LMHeadModel(GPT2LMHeadModel):
|
|||
return result[0]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _broadcasting_add(input_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
return input_list[1] + input_list[0]
|
||||
|
||||
|
||||
class _SequenceTensorModel(pnp.ProcessingScriptModule):
|
||||
def forward(self, img_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
return img_list[0], img_list[1]
|
||||
def forward(self, img_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
return _broadcasting_add(img_list)
|
||||
|
||||
|
||||
class _AddModel(torch.nn.Module):
|
||||
def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
return input_list[1] + input_list[0] # test broadcasting.
|
||||
class _MobileNetProcessingModule(pnp.ProcessingScriptModule):
|
||||
def __init__(self, oxml):
|
||||
super(_MobileNetProcessingModule, self).__init__()
|
||||
self.model_function_id = pnp.create_model_function(oxml)
|
||||
self.pre_proc = torch.jit.trace(pnp.PreMobileNet(224), torch.zeros(224, 224, 3, dtype=torch.float32))
|
||||
self.post_proc = torch.jit.trace(pnp.ImageNetPostProcessing(), torch.zeros(1, 1000, dtype=torch.float32))
|
||||
|
||||
def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
proc_input = self.pre_proc(img)
|
||||
return self.post_proc.forward(pnp.invoke_onnx_model1(self.model_function_id, proc_input))
|
||||
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Only tested the lastest PyTorch')
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Not works with older PyTorch')
|
||||
class TestPreprocessing(unittest.TestCase):
|
||||
def test_imagenet_preprocessing(self):
|
||||
mnv2 = onnx.load_model(get_test_data_file(__file__, 'data', 'mobilev2.onnx'))
|
||||
mnv2 = onnx.load_model(get_test_data_file('data', 'mobilev2.onnx'))
|
||||
|
||||
# load an image
|
||||
img = Image.open(get_test_data_file(__file__, 'data', 'pineapple.jpg'))
|
||||
img = numpy.asarray(img.convert('RGB'))
|
||||
img = Image.open(get_test_data_file('data', 'pineapple.jpg'))
|
||||
img = torch.from_numpy(numpy.asarray(img.convert('RGB')))
|
||||
|
||||
full_models = ONNXCompose(
|
||||
mnv2,
|
||||
preprocessors=pnp.PreMobileNet(224),
|
||||
postprocessors=pnp.PostMobileNet()
|
||||
)
|
||||
|
||||
ids, probabilities = full_models.predict(torch.from_numpy(img))
|
||||
full_model_func = PyOrtFunction.from_model(full_models.export(opset_version=11))
|
||||
actual_ids, actual_result = full_model_func(img)
|
||||
full_models = pnp.SequenceProcessingModule(pnp.PreMobileNet(224),
|
||||
mnv2,
|
||||
pnp.PostMobileNet())
|
||||
ids, probabilities = full_models.forward(img)
|
||||
name_i = 'image'
|
||||
full_model_func = OrtPyFunction.from_model(
|
||||
pnp.export(full_models,
|
||||
img,
|
||||
opset_version=11,
|
||||
output_path='temp_imagenet.onnx',
|
||||
input_names=[name_i],
|
||||
dynamic_axes={name_i: [0, 1]}))
|
||||
actual_ids, actual_result = full_model_func(img.numpy())
|
||||
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
|
||||
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
|
||||
|
||||
|
@ -60,33 +75,59 @@ class TestPreprocessing(unittest.TestCase):
|
|||
cfg = GPT2Config(n_layer=3)
|
||||
gpt2_m = _GPT2LMHeadModel(cfg)
|
||||
gpt2_m.eval().to('cpu')
|
||||
full_model = ONNXCompose(
|
||||
gpt2_m,
|
||||
preprocessors=pnp.PreHuggingFaceGPT2(vocab_file=get_test_data_file(__file__, 'data', 'gpt2.vocab'),
|
||||
merges_file=get_test_data_file(__file__, 'data', 'gpt2.merges.txt')))
|
||||
|
||||
test_sentence = ["Test a sentence"]
|
||||
expected = full_model.predict(test_sentence)
|
||||
model = full_model.export(opset_version=12, do_constant_folding=False)
|
||||
mfunc = PyOrtFunction.from_model(model)
|
||||
actuals = mfunc(test_sentence)
|
||||
# the random weight may generate a large diff in result, test the shape only.
|
||||
self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
|
||||
tok = pnp.PreHuggingFaceGPT2(vocab_file=get_test_data_file('data', 'gpt2.vocab'),
|
||||
merges_file=get_test_data_file('data', 'gpt2.merges.txt'))
|
||||
inputs = tok.forward(test_sentence)
|
||||
pnp.export(tok, [test_sentence], opset_version=12, output_path='temp_tok2.onnx')
|
||||
# TODO: the following test doesn't work due to GPT-2 exporting error.
|
||||
# pnp.export(pnp.SequenceProcessingModule(gpt2_m), inputs, opset_version=12, do_constant_folding=False)
|
||||
# full_model = pnp.SequenceProcessingModule(
|
||||
# tok,
|
||||
# gpt2_m)
|
||||
# expected = full_model.forward(test_sentence)
|
||||
# model = pnp.export(full_model, test_sentence, opset_version=12, do_constant_folding=False)
|
||||
# mfunc = OrtPyFunction.from_model(model)
|
||||
# actuals = mfunc(test_sentence)
|
||||
# # the random weight may generate a large diff in result, test the shape only.
|
||||
# self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
|
||||
|
||||
def test_sequence_tensor(self):
|
||||
seq_m = ONNXCompose(torch.jit.script(_AddModel()), _SequenceTensorModel(), None)
|
||||
test_input = [numpy.array([1]), numpy.array([3, 4]), numpy.array([5, 6])]
|
||||
res = seq_m.predict(test_input)
|
||||
seq_m = _SequenceTensorModel()
|
||||
test_input = [torch.from_numpy(_i) for _i in [
|
||||
numpy.array([1]), numpy.array([3, 4]), numpy.array([5, 6])]]
|
||||
res = seq_m.forward(test_input)
|
||||
numpy.testing.assert_allclose(res, numpy.array([4, 5]))
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
|
||||
# The ONNX exporter fixing for sequence tensor only released in 1.11 and the above.
|
||||
oxml = seq_m.export(12, output_file='temp_seqtest.onnx')
|
||||
# The fixing for the sequence tensor support is only released in 1.11 and the above.
|
||||
oxml = pnp.export(seq_m,
|
||||
test_input,
|
||||
opset_version=12,
|
||||
output_path='temp_seqtest.onnx')
|
||||
# TODO: ORT doesn't accept the default empty element type of a sequence type.
|
||||
oxml.graph.input[0].type.sequence_type.elem_type.CopyFrom(
|
||||
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT32, []))
|
||||
mfunc = PyOrtFunction.from_model(oxml)
|
||||
mfunc = OrtPyFunction.from_model(oxml)
|
||||
o_res = mfunc(test_input)
|
||||
numpy.testing.assert_allclose(res, o_res)
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.11"),
|
||||
'PythonOp bug fixing on Pytorch 1.11')
|
||||
def test_functional_processing(self):
|
||||
# load an image
|
||||
img = Image.open(get_test_data_file('data', 'pineapple.jpg')).convert('RGB')
|
||||
img = torch.from_numpy(numpy.asarray(img))
|
||||
|
||||
pipeline = _MobileNetProcessingModule(onnx.load_model(get_test_data_file('data', 'mobilev2.onnx')))
|
||||
ids, probabilities = pipeline.forward(img)
|
||||
|
||||
full_model_func = OrtPyFunction.from_model(
|
||||
pnp.export(pipeline, img, opset_version=11, output_path='temp_func.onnx'))
|
||||
actual_ids, actual_result = full_model_func(img)
|
||||
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
|
||||
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,32 +1,33 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
OPMAP_TO_CMAKE_FLAGS = {'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE',
|
||||
'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
|
||||
'WordpieceTokenizer': 'OCOS_ENABLE_WORDPIECE_TOKENIZER',
|
||||
'StringConcat': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringECMARegexReplace': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringECMARegexSplitWithOffsets': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringEqual': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringToHashBucket': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringToHashBucketFast': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringJoin': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringLength': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringLower': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringRegexReplace': 'OCOS_ENABLE_RE2_REGEX',
|
||||
'StringRegexSplitWithOffsets': 'OCOS_ENABLE_RE2_REGEX',
|
||||
'StringSplit': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringToVector': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringUpper': 'OCOS_ENABLE_TF_STRING',
|
||||
'SegmentExtraction': 'OCOS_ENABLE_MATH',
|
||||
'StringMapping' : 'OCOS_ENABLE_TF_STRING',
|
||||
'VectorToString': 'OCOS_ENABLE_TF_STRING',
|
||||
'MaskedFill': 'OCOS_ENABLE_TF_STRING',
|
||||
'BertTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'BasicTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'BertTokenizerDecoder': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'SentencepieceTokenizer': 'OCOS_ENABLE_SPM_TOKENIZER'
|
||||
}
|
||||
OPMAP_TO_CMAKE_FLAGS = {
|
||||
'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE',
|
||||
'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
|
||||
'WordpieceTokenizer': 'OCOS_ENABLE_WORDPIECE_TOKENIZER',
|
||||
'StringConcat': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringECMARegexReplace': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringECMARegexSplitWithOffsets': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringEqual': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringToHashBucket': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringToHashBucketFast': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringJoin': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringLength': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringLower': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringRegexReplace': 'OCOS_ENABLE_RE2_REGEX',
|
||||
'StringRegexSplitWithOffsets': 'OCOS_ENABLE_RE2_REGEX',
|
||||
'StringSplit': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringToVector': 'OCOS_ENABLE_TF_STRING',
|
||||
'StringUpper': 'OCOS_ENABLE_TF_STRING',
|
||||
'SegmentExtraction': 'OCOS_ENABLE_MATH',
|
||||
'StringMapping': 'OCOS_ENABLE_TF_STRING',
|
||||
'VectorToString': 'OCOS_ENABLE_TF_STRING',
|
||||
'MaskedFill': 'OCOS_ENABLE_TF_STRING',
|
||||
'BertTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'BasicTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'BertTokenizerDecoder': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'SentencepieceTokenizer': 'OCOS_ENABLE_SPM_TOKENIZER'
|
||||
}
|
||||
|
||||
|
||||
def gen_cmake_oplist(opconfig_file, oplist_cmake_file='_selectedoplist.cmake'):
|
||||
|
@ -41,7 +42,7 @@ def gen_cmake_oplist(opconfig_file, oplist_cmake_file='_selectedoplist.cmake'):
|
|||
ext_domain_cnt += 1
|
||||
items = _ln.strip().split(';')
|
||||
if len(items) < 3:
|
||||
raise RuntimeError("The malformated operator config file.")
|
||||
raise RuntimeError("The malformed operator config file.")
|
||||
for _op in items[2].split(','):
|
||||
if not _op:
|
||||
continue # is None or ""
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
####################################################################################
|
||||
###
|
||||
### !!! This script is replaced by the latest onnxruntime contrib op solution, which is
|
||||
### https://github.com/microsoft/onnxruntime/blob/ad9d2e2e891714e0911ccc3fa8b70f42025b4d56/onnxruntime/python/tools/transformers/convert_beam_search.py
|
||||
###
|
||||
#
|
||||
# !!! This script is replaced by the latest onnxruntime contrib op solution, which is
|
||||
# https://github.com/microsoft/onnxruntime/blob/ad9d2e2e891714e0911ccc3fa8b70f42025b4d56/onnxruntime/python/tools/transformers/convert_beam_search.py
|
||||
#
|
||||
###################################################################################
|
||||
|
||||
import os
|
||||
|
@ -14,7 +14,6 @@ import onnxruntime_extensions as _ortex
|
|||
from transformers import AutoConfig
|
||||
from distutils.version import StrictVersion
|
||||
|
||||
|
||||
if StrictVersion(_ort.__version__) < StrictVersion('1.8.1'):
|
||||
raise RuntimeError('Full GPT-2 model is only available on onxruntime 1.8.1 and higher version.')
|
||||
|
||||
|
@ -154,12 +153,12 @@ def _beam_search(tokenizer, func_one_step,
|
|||
factor = (~step.type(torch.bool)).type(torch.int64)
|
||||
prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
prev_attention_mask,
|
||||
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
|
||||
],
|
||||
1,
|
||||
).to(device)
|
||||
[
|
||||
prev_attention_mask,
|
||||
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
|
||||
],
|
||||
1,
|
||||
).to(device)
|
||||
|
||||
beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device)
|
||||
input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device)
|
||||
|
|
Загрузка…
Ссылка в новой задаче