The new pre/post processing API, replacing ONNXCompose (#205)

* traced processing module

* before debugging.

* updates

* temporary

* the trace mode pass

* code adjusting for ci pipeline.

* only torch 1.11 support prim:pythonop

* extending sequence processing module.
This commit is contained in:
Wenbing Li 2022-03-08 16:32:59 -08:00 коммит произвёл GitHub
Родитель 2e2ee11772
Коммит 4bb3a22c45
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 459 добавлений и 159 удалений

Просмотреть файл

@ -10,6 +10,7 @@ The entry point to onnxruntime custom op library
__author__ = "Microsoft"
import pathlib
import inspect
from ._version import __version__
from ._ocos import get_library_path # noqa
from ._ocos import Opdef, PyCustomOpDef # noqa
@ -19,22 +20,16 @@ from ._ocos import expand_onnx_inputs # noqa
from ._ocos import hook_model_op # noqa
from ._ocos import default_opset_domain # noqa
from ._cuops import * # noqa
from ._ortapi2 import OrtPyFunction as PyOrtFunction, optimize_model, make_onnx_model
from ._ortapi2 import OrtPyFunction as PyOrtFunction
from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model
onnx_op = Opdef.declare
PyOp = PyCustomOpDef
# ONNX-Compose depends PyTorch, which is optional for onnxruntime-extensions.
try:
import torch
except ImportError:
pass
else:
from .compose import ONNXCompose
def get_test_data_file(case_file, *sub_dirs):
# do a favour for the unit test.
def get_test_data_file(*sub_dirs):
case_file = inspect.currentframe().f_back.f_code.co_filename
test_dir = pathlib.Path(case_file).parent
return str(test_dir.joinpath(*sub_dirs))

Просмотреть файл

@ -35,6 +35,8 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(),
class OrtPyFunction:
__name__ = 'OrtPyFunction'
@classmethod
def get_ort_session_options(cls):
# ONNXRuntime has an issue to support reusing the SessionOptions object.

Просмотреть файл

@ -4,7 +4,7 @@ import torch
import numpy
from torch.onnx import TrainingMode, export as _export
from ._ortapi2 import OrtPyFunction
from .pnp import ONNXModelUtils, ProcessingModule, ProcessingScriptModule
from .pnp import ONNXModelUtils, _ProcessingModule, ProcessingScriptModule
def _is_numpy_object(x):
@ -44,9 +44,9 @@ class ONNXCompose:
"""
def __init__(self, models=None, preprocessors=None, postprocessors=None):
assert isinstance(preprocessors, ProcessingModule),\
assert isinstance(preprocessors, _ProcessingModule),\
'preprocessors must be subclassing from ProcessingModule'
assert postprocessors is None or isinstance(postprocessors, ProcessingModule),\
assert postprocessors is None or isinstance(postprocessors, _ProcessingModule),\
'postprocessors must be subclassing from ProcessingModule'
self.models = models
self.preprocessors = preprocessors
@ -93,7 +93,7 @@ class ONNXCompose:
model_l.append(post_m)
if output_file is not None:
# also output the pre/post-processing model for debugging
# also output the pre- / post-processing model for debugging
idx = 0
for _mdl in model_l:
if _mdl is self.models and isinstance(_mdl, onnx.ModelProto):

Просмотреть файл

@ -1,6 +1,13 @@
from ._utils import ONNXModelUtils
from ._base import ProcessingModule, ProcessingScriptModule, CustomFunction
from ._functions import * # noqa
# onnxruntime-extensions pre&post processing frontend depends on the PyTorch
try:
import torch
except ImportError as e:
print("No torch installation found, which is required by the pre&post scripting!")
raise e
from ._imagenet import PreMobileNet, PostMobileNet
from ._base import _ProcessingModule, ProcessingScriptModule, CustomFunction
from ._torchext import * # noqa
from ._unifier import export
from ._imagenet import * # noqa
from ._nlp import PreHuggingFaceGPT2

Просмотреть файл

@ -3,14 +3,55 @@ import onnx
import torch
from typing import Any
from onnx.onnx_pb import TensorProto
from torch.onnx import TrainingMode, export as _export
class ProcessingModule(torch.nn.Module):
def _export_f(model, args=None,
opset_version=None,
output_path=None,
output_seq=0,
export_params=True,
verbose=False,
input_names=None,
output_names=None,
operator_export_type=None,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None):
with io.BytesIO() as f:
_export(model, args, f,
export_params=export_params, verbose=verbose,
training=TrainingMode.EVAL, input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets)
mdl = onnx.load_model(io.BytesIO(f.getvalue()))
if output_path is not None:
if output_seq > 0:
output_path.replace('.onnx', '.{}.onnx'.format(output_seq))
onnx.save_model(mdl, output_path)
return mdl
class _ProcessingModule(torch.nn.Module):
def __init__(self):
super(_ProcessingModule, self).__init__()
_ProcessingModule.register_customops()
@staticmethod
@torch.jit.unused
def _argsort(g, x, dim, descending):
return g.op('ai.onnx.contrib::ArgSort', x, dim)
@classmethod
@torch.jit.unused
def register_customops(cls):
if hasattr(cls, 'loaded'):
return True
@ -21,30 +62,32 @@ class ProcessingModule(torch.nn.Module):
cls.loaded = True
return True
def export(self, opset_version, *args, **kwargs):
@torch.jit.unused
def export(self, *args, opset_version=None, script_mode=False, output_path=None, output_seq=0, **kwargs):
if opset_version is None:
raise RuntimeError('No opset_version found in the kwargs.')
mod = self
script_model = kwargs.pop('script_mode', False)
if script_model:
if script_mode and not isinstance(mod, torch.jit.ScriptModule):
mod = torch.jit.script(mod)
ofname = kwargs.pop('ofname', None)
with io.BytesIO() as f:
torch.onnx.export(mod, args, f,
training=torch.onnx.TrainingMode.EVAL,
opset_version=opset_version,
**kwargs)
mdl = onnx.load_model(io.BytesIO(f.getvalue()))
if ofname is not None:
ofname.replace('.onnx', '.1.onnx')
onnx.save_model(mdl, ofname)
return mdl
return _export_f(mod,
*args,
opset_version=opset_version,
output_path=output_path,
output_seq=output_seq, **kwargs)
class ProcessingScriptModule(ProcessingModule):
def export(self, opset_version, *args, **kwargs):
return super().export(opset_version, *args, script_mode=True, **kwargs)
class ProcessingTracedModule(_ProcessingModule):
pass
class ProcessingScriptModule(_ProcessingModule):
def __init__(self):
super(ProcessingScriptModule, self).__init__()
@torch.jit.unused
def export(self, *args, opset_version=None, **kwargs):
return super().export(*args, opset_version=opset_version, script_mode=True, **kwargs)
class CustomFunction(torch.autograd.Function):

Просмотреть файл

@ -1,7 +1,8 @@
import torch
from typing import Tuple
from torch.nn.functional import interpolate
from ._base import ProcessingModule
from ._functions import onnx_where, onnx_greater
from ._base import ProcessingTracedModule
from ._torchext import onnx_where, onnx_greater
def _resize_param(img, size):
@ -11,7 +12,7 @@ def _resize_param(img, size):
return onnx_where(onnx_greater(scale_x, scale_y), scale_x, scale_y)
class ImageNetPreProcessing(ProcessingModule):
class ImageNetPreProcessing(ProcessingTracedModule):
def __init__(self, size, resize_image=True):
super(ImageNetPreProcessing, self).__init__()
self.target_size = size
@ -47,17 +48,9 @@ class ImageNetPreProcessing(ProcessingModule):
x /= torch.reshape(torch.tensor(std), (3, 1, 1))
return x
def export(self, opset_version, *args, **kwargs):
name_i = 'image'
return super().export(opset_version, *args,
input_names=[name_i],
dynamic_axes={name_i: [0, 1]},
**kwargs)
class ImageNetPostProcessing(ProcessingModule):
def forward(self, scores):
ProcessingModule.register_customops()
class ImageNetPostProcessing(ProcessingTracedModule):
def forward(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
probabilities = torch.softmax(scores, dim=1)
top10_prob, top10_ids = probabilities.topk(k=10, dim=1, largest=True, sorted=True)
return top10_ids, top10_prob

Просмотреть файл

@ -1,7 +1,7 @@
import json
from ._base import ProcessingModule, tensor_data_type as _dt
from ._functions import create_op_function
from ._base import ProcessingTracedModule, tensor_data_type as _dt
from ._torchext import create_op_function
from ._onnx_ops import schema
from .._ocos import default_opset_domain
@ -44,7 +44,7 @@ def _get_bound_object(func):
return func.__self__
class PreHuggingFaceGPT2(ProcessingModule):
class PreHuggingFaceGPT2(ProcessingTracedModule):
def __init__(self, hf_tok=None, vocab_file=None, merges_file=None, padding_length=-1):
super(PreHuggingFaceGPT2, self).__init__()
if hf_tok is None:
@ -58,5 +58,5 @@ class PreHuggingFaceGPT2(ProcessingModule):
def forward(self, text):
return self.onnx_gpt2_tokenize(text)
def export(self, opset_version, *args, **kwargs):
def export(self, *args, opset_version=0, **kwargs):
return _get_bound_object(self.onnx_gpt2_tokenize).build_model(opset_version, *args)

Просмотреть файл

@ -4,8 +4,10 @@ import numpy as np
from typing import Any
from onnx import helper
from onnx import onnx_pb as onnx_proto
from distutils.version import LooseVersion
from torch.onnx import register_custom_op_symbolic
from ._base import CustomFunction
from ._base import CustomFunction, ProcessingTracedModule
from ._onnx_ops import ox as _ox, schema as _schema
from ._onnx_ops import ONNXElementContainer, make_model_ex
from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
@ -35,7 +37,7 @@ def _to_onnx_type(dtype):
return ty_dict.get(dtype, onnx_proto.TensorProto.STRING)
class ONNXOpFunction(CustomFunction):
class OnnxOpFunction(CustomFunction):
@classmethod
def get_next_id_name(cls, name_base):
name = 'cls' if name_base is None else name_base
@ -54,6 +56,8 @@ class ONNXOpFunction(CustomFunction):
@classmethod
def build_model(cls, opset_version, *args):
# build the one node graph
if isinstance(args[0], list):
args = [np.asarray(_i) for _i in args]
ec = ONNXElementContainer(get_opset_version_from_ort() if opset_version is None else opset_version)
attrs = cls.attrs
vi_inputs = [helper.make_tensor_value_info(
@ -101,7 +105,7 @@ class ONNXOpFunction(CustomFunction):
def create_op_function(op_type: str, func, **attrs):
if _ox.is_raw(func):
func = _schema(func.__func__)
cls = type(_ox.get_unique_operator_type_name(op_type), (ONNXOpFunction, ),
cls = type(_ox.get_unique_operator_type_name(op_type), (OnnxOpFunction, ),
dict(
op_type=op_type,
opb_func=func,
@ -118,19 +122,19 @@ onnx_greater = create_op_function('Greater', _ox.greater)
class PythonOpFunction:
"""
PythonOpFunction wraps a generic Python function which skips forward operation on torch.onnx.exporter
converting in the script mode, which cannot support the API from external package, like Numpy.
Autograd.Function cannot be used torch.jit.script.
converting in the script mode, since the exporter cannot support the APIs from external package, like Numpy.
BTW, Autograd.Function cannot be used torch.jit.script.
"""
id_func_map = {}
current_func_id = 0
@staticmethod
def get_next_id():
def _get_next_id():
PythonOpFunction.current_func_id += 1
return PythonOpFunction.current_func_id
@staticmethod
@torch.jit.ignore(drop_on_export=False)
@torch.jit.ignore
def _pass_through_call(*args, **kwargs):
func_id = args[0]
func = PythonOpFunction.id_func_map[func_id]
@ -143,7 +147,102 @@ class PythonOpFunction:
@classmethod
def get_id(cls):
if not hasattr(cls, 'func_id'):
_id = PythonOpFunction.get_next_id()
_id = PythonOpFunction._get_next_id()
setattr(cls, 'func_id', _id)
PythonOpFunction.id_func_map[_id] = cls
return cls.func_id
class _OnnxModelFunction:
id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type
str_model_function_id = '_model_function_id'
str_model_id = '_model_id'
str_model_attached = '_model_attached'
@torch.jit.ignore
def _invoke_onnx_model(model_id: int, *args, **kwargs):
model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
if model_or_path is None:
raise ValueError("cannot find id={} registered!".format(model_id))
func = OrtPyFunction.from_model(model_or_path)
return torch.from_numpy(
func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs))
@torch.jit.ignore
def invoke_onnx_model1(model_id: int, arg0):
return _invoke_onnx_model(model_id, arg0)
@torch.jit.ignore
def invoke_onnx_model2(model_id: int, arg0, arg1):
return _invoke_onnx_model(model_id, arg0, arg1)
@torch.jit.ignore
def invoke_onnx_model3(model_id: int, arg0, arg1, arg2):
return _invoke_onnx_model(model_id, arg0, arg1, arg2)
class _OnnxTracedFunction(CustomFunction):
@classmethod
def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
return _invoke_onnx_model(args[0].item(), *args[1:], **kwargs)
@classmethod
def symbolic(cls, g, *args):
return g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
def create_model_function(model_or_path):
_id = id(model_or_path)
assert _id != 0, "internal error: the id of a Python object is 0."
_OnnxModelFunction.id_object_map[_id] = model_or_path
return _id
def get_id_models():
return _OnnxModelFunction.id_object_map
class SequenceProcessingModule(ProcessingTracedModule):
def __init__(self, *models):
super(SequenceProcessingModule, self).__init__()
self.model_list = models
self.model_function_ids = []
for mdl_ in models:
if isinstance(mdl_, onnx.ModelProto):
self.model_function_ids.append(create_model_function(mdl_))
else:
self.model_function_ids.append(0)
def forward(self, *args):
outputs = args
for idx_, mdl_ in enumerate(self.model_list):
if not isinstance(outputs, tuple):
outputs = (outputs, )
if self.model_function_ids[idx_] != 0:
outputs = _OnnxTracedFunction.apply(torch.tensor(self.model_function_ids[idx_]), *outputs)
else:
outputs = self.model_list[idx_].forward(*outputs)
return outputs
def _symbolic_pythonop(g: torch._C.Graph, n: torch._C.Node, *args, **kwargs):
name = kwargs["name"]
if name.startswith(invoke_onnx_model1.__name__[:-1]):
# id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]).item()
ret = g.op("ai.onnx.contrib::_ModelFunctionCall", *args)
else:
# Logs a warning and returns None
import warnings
return warnings.warn("prim::PythonOp", "unknown node kind: " + name)
# Copy type and shape from original node.
ret.setType(args[-1].type())
return ret
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
register_custom_op_symbolic("prim::PythonOp", _symbolic_pythonop, 1)

Просмотреть файл

@ -0,0 +1,40 @@
import onnx
from .._ortapi2 import get_opset_version_from_ort
from ._utils import ONNXModelUtils
from ._torchext import get_id_models
def export(m, *args,
opset_version=0,
output_path=None,
export_params=True,
verbose=False,
input_names=None,
output_names=None,
operator_export_type=None,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
io_mapping=None):
"""
export all models and modules into a merged ONNX model.
"""
if opset_version == 0:
opset_version = get_opset_version_from_ort()
model = m.export(*args, opset_version=opset_version,
output_path=output_path,
export_params=export_params,
verbose=verbose,
input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets)
full_m = ONNXModelUtils.unfold_model(model, get_id_models(), io_mapping)
if output_path is not None:
onnx.save_model(full_m, output_path)
return full_m

Просмотреть файл

@ -1,8 +1,26 @@
import copy
import onnx
from onnx import numpy_helper
from collections import namedtuple
class _Container:
def __init__(self):
self.parent = None
self.initializer=[]
self.value_info=[]
self.nodes = []
self.node_domain_version_pair_sets = {}
def add_model(self, oxml):
self.initializer.extend(oxml.graph.initializer)
self.value_info.extend(oxml.graph.value_info)
self.nodes.extend(oxml.graph.node)
self.node_domain_version_pair_sets.update(
[(opset_.domain, opset_.version) for opset_ in oxml.opset_import])
return self
class ONNXModelUtils:
@staticmethod
def merge_name(prefix, name):
@ -57,24 +75,86 @@ class ONNXModelUtils:
node.attribute.extend(attr_list)
return node
@staticmethod
def get_model_name_abbr(node):
no = node.name.split('_')[-1]
return 'm_' + no
@staticmethod
def get_model_id_from_arg0(nodes, node):
arg0_name = node.input[0]
c_node = [n_ for n_ in nodes if
n_.op_type == 'Constant' and n_.output[0] == arg0_name]
assert len(c_node) == 1, 'internal error, multiple nodes with the same output.'
c_node = c_node[0]
tensor_value = onnx.helper.get_attribute_value(c_node.attribute[0])
_id = numpy_helper.to_array(tensor_value).item()
return _id
@classmethod
def unfold_model_node(cls, container):
top_containter = container
while top_containter.parent is not None: # only one opset_import in the model.
top_containter = top_containter.parent
def _unfold_model_node(cls, container, name, model, io_mapping=None):
top_container = container
while top_container.parent is not None: # only one opset_import in the model.
top_container = top_container.parent
nodes = container.nodes
model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')}
onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes]
renamed_nodes = cls._rename_graph(model.graph, name, container)
onnx_nodes = [cls._process_node_body(nd_, name) for nd_ in renamed_nodes]
for node in model_nodes.values():
renamed_nodes = cls._rename_graph(node.model.graph, node.name, container)
onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes)
top_containter.node_domain_version_pair_sets.update(
[(opset_.domain, opset_.version) for opset_ in node.model.opset_import])
top_container.node_domain_version_pair_sets.update(
[(opset_.domain, opset_.version) for opset_ in model.opset_import])
return onnx_nodes
@classmethod
def unfold_model(cls, oxml, id_to_model, io_mapping=None):
container = _Container().add_model(oxml)
nodes = []
for _nid, _node in enumerate(oxml.graph.node):
if _node.op_type != '_ModelFunctionCall':
nodes.append(_node)
else:
model_id = cls.get_model_id_from_arg0(list(oxml.graph.node), _node)
if model_id not in id_to_model:
raise RuntimeError("Cannot find the model id({}) in the table".format(model_id))
prefix = cls.get_model_name_abbr(_node)
nest_model = id_to_model[model_id]
input_mapping = []
output_mapping = []
for idx_, in_ in enumerate(nest_model.graph.input):
_renamed_in = "{}_{}".format(prefix, in_.name)
_nd = onnx.helper.make_node('Identity',
[_node.input[idx_ + 1]], # the first arg is model id, skip it.
[_renamed_in],
name='i_' + _renamed_in)
input_mapping.append(_nd)
nds = cls._unfold_model_node(container,
prefix,
nest_model,
io_mapping)
for idx_, out_ in enumerate(nest_model.graph.output):
_renamed_out = "{}_{}".format(prefix, out_.name)
_nd = onnx.helper.make_node('Identity',
[_renamed_out],
[_node.output[idx_]],
name='o_' + _renamed_out)
output_mapping.append(_nd)
if io_mapping is not None:
assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
input_mapping, output_mapping = io_mapping(input_mapping, output_mapping)
# attention: the order of the list operations is important, which avoids the topological sort.
nodes.extend(input_mapping)
nodes.extend(nds)
nodes.extend(output_mapping)
intlzs = cls._remove_unused_initializers(nodes, container.initializer)
oxml = copy.deepcopy(oxml)
del oxml.graph.node[:]
oxml.graph.node.extend(nodes)
del oxml.graph.initializer[:]
oxml.graph.initializer.extend(intlzs)
return oxml
@classmethod
def topological_sort(cls, container, nodes, inputs, outputs):
op_output_map = {}
@ -136,12 +216,14 @@ class ONNXModelUtils:
return sorted_nodes
@staticmethod
def _remove_unused_initializers(nodes, initializers, reversed_names):
def _remove_unused_initializers(nodes, initializers, reserved_names=None):
if reserved_names is None:
reserved_names = set()
nodes_input_set = set()
for nd_ in nodes:
nodes_input_set.update(n_ for n_ in nd_.input)
return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reversed_names]
return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reserved_names]
@classmethod
def join_models(cls, *models, io_mapping=None):
@ -171,11 +253,9 @@ class ONNXModelUtils:
port_mapping[iname] = oname
nodes = []
Container = namedtuple('Container', ['initializer', 'value_info'])
container = Container(initializer=[], value_info=[])
container = _Container()
for _idx, _m in enumerate(models):
container.initializer.extend(_m.graph.initializer)
container.value_info.extend(_m.graph.value_info)
container.initializer.add_model(_m)
nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container)
for _n in nodes:
@ -199,7 +279,7 @@ class ONNXModelUtils:
_opset.append(_ops)
name = name + '_' + _mdl.graph.name if name else _mdl.graph.name
inits = cls._remove_unused_initializers(nodes, container.initializer, set())
inits = cls._remove_unused_initializers(nodes, container.initializer)
helper = onnx.helper
g = helper.make_graph(nodes, name, inputs, outputs,
initializer=inits,

Просмотреть файл

@ -2,10 +2,10 @@ import onnx
import numpy
import torch
import unittest
from typing import List
from typing import List, Tuple
from PIL import Image
from distutils.version import LooseVersion
from onnxruntime_extensions import PyOrtFunction, ONNXCompose
from onnxruntime_extensions import OrtPyFunction
from onnxruntime_extensions import pnp, get_test_data_file
from transformers import GPT2Config, GPT2LMHeadModel
@ -13,7 +13,6 @@ from transformers import GPT2Config, GPT2LMHeadModel
class _GPT2LMHeadModel(GPT2LMHeadModel):
""" Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.
"""
def __init__(self, config):
super().__init__(config)
@ -25,34 +24,50 @@ class _GPT2LMHeadModel(GPT2LMHeadModel):
return result[0]
@torch.jit.script
def _broadcasting_add(input_list: List[torch.Tensor]) -> torch.Tensor:
return input_list[1] + input_list[0]
class _SequenceTensorModel(pnp.ProcessingScriptModule):
def forward(self, img_list: List[torch.Tensor]) -> List[torch.Tensor]:
return img_list[0], img_list[1]
def forward(self, img_list: List[torch.Tensor]) -> torch.Tensor:
return _broadcasting_add(img_list)
class _AddModel(torch.nn.Module):
def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
return input_list[1] + input_list[0] # test broadcasting.
class _MobileNetProcessingModule(pnp.ProcessingScriptModule):
def __init__(self, oxml):
super(_MobileNetProcessingModule, self).__init__()
self.model_function_id = pnp.create_model_function(oxml)
self.pre_proc = torch.jit.trace(pnp.PreMobileNet(224), torch.zeros(224, 224, 3, dtype=torch.float32))
self.post_proc = torch.jit.trace(pnp.ImageNetPostProcessing(), torch.zeros(1, 1000, dtype=torch.float32))
def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
proc_input = self.pre_proc(img)
return self.post_proc.forward(pnp.invoke_onnx_model1(self.model_function_id, proc_input))
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Only tested the lastest PyTorch')
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Not works with older PyTorch')
class TestPreprocessing(unittest.TestCase):
def test_imagenet_preprocessing(self):
mnv2 = onnx.load_model(get_test_data_file(__file__, 'data', 'mobilev2.onnx'))
mnv2 = onnx.load_model(get_test_data_file('data', 'mobilev2.onnx'))
# load an image
img = Image.open(get_test_data_file(__file__, 'data', 'pineapple.jpg'))
img = numpy.asarray(img.convert('RGB'))
img = Image.open(get_test_data_file('data', 'pineapple.jpg'))
img = torch.from_numpy(numpy.asarray(img.convert('RGB')))
full_models = ONNXCompose(
mnv2,
preprocessors=pnp.PreMobileNet(224),
postprocessors=pnp.PostMobileNet()
)
ids, probabilities = full_models.predict(torch.from_numpy(img))
full_model_func = PyOrtFunction.from_model(full_models.export(opset_version=11))
actual_ids, actual_result = full_model_func(img)
full_models = pnp.SequenceProcessingModule(pnp.PreMobileNet(224),
mnv2,
pnp.PostMobileNet())
ids, probabilities = full_models.forward(img)
name_i = 'image'
full_model_func = OrtPyFunction.from_model(
pnp.export(full_models,
img,
opset_version=11,
output_path='temp_imagenet.onnx',
input_names=[name_i],
dynamic_axes={name_i: [0, 1]}))
actual_ids, actual_result = full_model_func(img.numpy())
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
@ -60,33 +75,59 @@ class TestPreprocessing(unittest.TestCase):
cfg = GPT2Config(n_layer=3)
gpt2_m = _GPT2LMHeadModel(cfg)
gpt2_m.eval().to('cpu')
full_model = ONNXCompose(
gpt2_m,
preprocessors=pnp.PreHuggingFaceGPT2(vocab_file=get_test_data_file(__file__, 'data', 'gpt2.vocab'),
merges_file=get_test_data_file(__file__, 'data', 'gpt2.merges.txt')))
test_sentence = ["Test a sentence"]
expected = full_model.predict(test_sentence)
model = full_model.export(opset_version=12, do_constant_folding=False)
mfunc = PyOrtFunction.from_model(model)
actuals = mfunc(test_sentence)
# the random weight may generate a large diff in result, test the shape only.
self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
tok = pnp.PreHuggingFaceGPT2(vocab_file=get_test_data_file('data', 'gpt2.vocab'),
merges_file=get_test_data_file('data', 'gpt2.merges.txt'))
inputs = tok.forward(test_sentence)
pnp.export(tok, [test_sentence], opset_version=12, output_path='temp_tok2.onnx')
# TODO: the following test doesn't work due to GPT-2 exporting error.
# pnp.export(pnp.SequenceProcessingModule(gpt2_m), inputs, opset_version=12, do_constant_folding=False)
# full_model = pnp.SequenceProcessingModule(
# tok,
# gpt2_m)
# expected = full_model.forward(test_sentence)
# model = pnp.export(full_model, test_sentence, opset_version=12, do_constant_folding=False)
# mfunc = OrtPyFunction.from_model(model)
# actuals = mfunc(test_sentence)
# # the random weight may generate a large diff in result, test the shape only.
# self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
def test_sequence_tensor(self):
seq_m = ONNXCompose(torch.jit.script(_AddModel()), _SequenceTensorModel(), None)
test_input = [numpy.array([1]), numpy.array([3, 4]), numpy.array([5, 6])]
res = seq_m.predict(test_input)
seq_m = _SequenceTensorModel()
test_input = [torch.from_numpy(_i) for _i in [
numpy.array([1]), numpy.array([3, 4]), numpy.array([5, 6])]]
res = seq_m.forward(test_input)
numpy.testing.assert_allclose(res, numpy.array([4, 5]))
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
# The ONNX exporter fixing for sequence tensor only released in 1.11 and the above.
oxml = seq_m.export(12, output_file='temp_seqtest.onnx')
# The fixing for the sequence tensor support is only released in 1.11 and the above.
oxml = pnp.export(seq_m,
test_input,
opset_version=12,
output_path='temp_seqtest.onnx')
# TODO: ORT doesn't accept the default empty element type of a sequence type.
oxml.graph.input[0].type.sequence_type.elem_type.CopyFrom(
onnx.helper.make_tensor_type_proto(onnx.onnx_pb.TensorProto.INT32, []))
mfunc = PyOrtFunction.from_model(oxml)
mfunc = OrtPyFunction.from_model(oxml)
o_res = mfunc(test_input)
numpy.testing.assert_allclose(res, o_res)
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.11"),
'PythonOp bug fixing on Pytorch 1.11')
def test_functional_processing(self):
# load an image
img = Image.open(get_test_data_file('data', 'pineapple.jpg')).convert('RGB')
img = torch.from_numpy(numpy.asarray(img))
pipeline = _MobileNetProcessingModule(onnx.load_model(get_test_data_file('data', 'mobilev2.onnx')))
ids, probabilities = pipeline.forward(img)
full_model_func = OrtPyFunction.from_model(
pnp.export(pipeline, img, opset_version=11, output_path='temp_func.onnx'))
actual_ids, actual_result = full_model_func(img)
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -1,32 +1,33 @@
import os
import sys
OPMAP_TO_CMAKE_FLAGS = {'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE',
'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
'WordpieceTokenizer': 'OCOS_ENABLE_WORDPIECE_TOKENIZER',
'StringConcat': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexReplace': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexSplitWithOffsets': 'OCOS_ENABLE_TF_STRING',
'StringEqual': 'OCOS_ENABLE_TF_STRING',
'StringToHashBucket': 'OCOS_ENABLE_TF_STRING',
'StringToHashBucketFast': 'OCOS_ENABLE_TF_STRING',
'StringJoin': 'OCOS_ENABLE_TF_STRING',
'StringLength': 'OCOS_ENABLE_TF_STRING',
'StringLower': 'OCOS_ENABLE_TF_STRING',
'StringRegexReplace': 'OCOS_ENABLE_RE2_REGEX',
'StringRegexSplitWithOffsets': 'OCOS_ENABLE_RE2_REGEX',
'StringSplit': 'OCOS_ENABLE_TF_STRING',
'StringToVector': 'OCOS_ENABLE_TF_STRING',
'StringUpper': 'OCOS_ENABLE_TF_STRING',
'SegmentExtraction': 'OCOS_ENABLE_MATH',
'StringMapping' : 'OCOS_ENABLE_TF_STRING',
'VectorToString': 'OCOS_ENABLE_TF_STRING',
'MaskedFill': 'OCOS_ENABLE_TF_STRING',
'BertTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
'BasicTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
'BertTokenizerDecoder': 'OCOS_ENABLE_BERT_TOKENIZER',
'SentencepieceTokenizer': 'OCOS_ENABLE_SPM_TOKENIZER'
}
OPMAP_TO_CMAKE_FLAGS = {
'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE',
'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
'WordpieceTokenizer': 'OCOS_ENABLE_WORDPIECE_TOKENIZER',
'StringConcat': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexReplace': 'OCOS_ENABLE_TF_STRING',
'StringECMARegexSplitWithOffsets': 'OCOS_ENABLE_TF_STRING',
'StringEqual': 'OCOS_ENABLE_TF_STRING',
'StringToHashBucket': 'OCOS_ENABLE_TF_STRING',
'StringToHashBucketFast': 'OCOS_ENABLE_TF_STRING',
'StringJoin': 'OCOS_ENABLE_TF_STRING',
'StringLength': 'OCOS_ENABLE_TF_STRING',
'StringLower': 'OCOS_ENABLE_TF_STRING',
'StringRegexReplace': 'OCOS_ENABLE_RE2_REGEX',
'StringRegexSplitWithOffsets': 'OCOS_ENABLE_RE2_REGEX',
'StringSplit': 'OCOS_ENABLE_TF_STRING',
'StringToVector': 'OCOS_ENABLE_TF_STRING',
'StringUpper': 'OCOS_ENABLE_TF_STRING',
'SegmentExtraction': 'OCOS_ENABLE_MATH',
'StringMapping': 'OCOS_ENABLE_TF_STRING',
'VectorToString': 'OCOS_ENABLE_TF_STRING',
'MaskedFill': 'OCOS_ENABLE_TF_STRING',
'BertTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
'BasicTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
'BertTokenizerDecoder': 'OCOS_ENABLE_BERT_TOKENIZER',
'SentencepieceTokenizer': 'OCOS_ENABLE_SPM_TOKENIZER'
}
def gen_cmake_oplist(opconfig_file, oplist_cmake_file='_selectedoplist.cmake'):
@ -41,7 +42,7 @@ def gen_cmake_oplist(opconfig_file, oplist_cmake_file='_selectedoplist.cmake'):
ext_domain_cnt += 1
items = _ln.strip().split(';')
if len(items) < 3:
raise RuntimeError("The malformated operator config file.")
raise RuntimeError("The malformed operator config file.")
for _op in items[2].split(','):
if not _op:
continue # is None or ""

Просмотреть файл

@ -1,8 +1,8 @@
####################################################################################
###
### !!! This script is replaced by the latest onnxruntime contrib op solution, which is
### https://github.com/microsoft/onnxruntime/blob/ad9d2e2e891714e0911ccc3fa8b70f42025b4d56/onnxruntime/python/tools/transformers/convert_beam_search.py
###
#
# !!! This script is replaced by the latest onnxruntime contrib op solution, which is
# https://github.com/microsoft/onnxruntime/blob/ad9d2e2e891714e0911ccc3fa8b70f42025b4d56/onnxruntime/python/tools/transformers/convert_beam_search.py
#
###################################################################################
import os
@ -14,7 +14,6 @@ import onnxruntime_extensions as _ortex
from transformers import AutoConfig
from distutils.version import StrictVersion
if StrictVersion(_ort.__version__) < StrictVersion('1.8.1'):
raise RuntimeError('Full GPT-2 model is only available on onxruntime 1.8.1 and higher version.')
@ -154,12 +153,12 @@ def _beam_search(tokenizer, func_one_step,
factor = (~step.type(torch.bool)).type(torch.int64)
prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device)
attention_mask = torch.cat(
[
prev_attention_mask,
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
],
1,
).to(device)
[
prev_attention_mask,
torch.ones([batch_size * beam_size, 1], dtype=torch.float),
],
1,
).to(device)
beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device)
input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device)