initial checkins for onnxcompose (#185)
* initial checkins for onnxcompose * update ci pipeline for the test. * add the missing quotes * Switch to looseVersion for torch version. * testif * padding_length * skip the gpt2 * add onnxruntime 1.9 test package. * fix a memory bug on pyop.
This commit is contained in:
Родитель
dddd85397d
Коммит
9bd453e0f1
|
@ -20,6 +20,7 @@ gen
|
|||
.DS_Store
|
||||
*~
|
||||
.vs
|
||||
*-checkpoint.ipynb
|
||||
.ipynb_checkpoints/
|
||||
tutorials/*.onnx
|
||||
Testing/
|
||||
|
|
|
@ -12,10 +12,12 @@ jobs:
|
|||
matrix:
|
||||
py39-170:
|
||||
python.version: '3.9'
|
||||
torch.version: 'torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html'
|
||||
ort.version: '1.7.0'
|
||||
ortlib.version: '38443267'
|
||||
py37-160:
|
||||
python.version: '3.7'
|
||||
torch.version: 'torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html'
|
||||
ort.version: '1.6.0'
|
||||
ortlib.version: '34858191'
|
||||
maxParallel: 2
|
||||
|
@ -63,7 +65,7 @@ jobs:
|
|||
python setup.py develop
|
||||
displayName: Build the library and tests
|
||||
|
||||
- script: python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- script: python -m pip install $(torch.version)
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: python -m pip install -r requirements-dev.txt
|
||||
|
|
|
@ -9,6 +9,7 @@ The entry point to onnxruntime custom op library
|
|||
|
||||
__author__ = "Microsoft"
|
||||
|
||||
import pathlib
|
||||
from ._version import __version__
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef # noqa
|
||||
|
@ -18,8 +19,22 @@ from ._ocos import expand_onnx_inputs # noqa
|
|||
from ._ocos import hook_model_op # noqa
|
||||
from ._ocos import default_opset_domain # noqa
|
||||
from ._cuops import * # noqa
|
||||
from ._ortapi2 import EagerOp as PyOrtFunction, optimize_model, make_onnx_model
|
||||
from ._ortapi2 import OrtPyFunction as PyOrtFunction, optimize_model, make_onnx_model
|
||||
|
||||
|
||||
onnx_op = Opdef.declare
|
||||
PyOp = PyCustomOpDef
|
||||
|
||||
|
||||
# ONNX-Compose depends PyTorch, which is optional for onnxruntime-extensions.
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
from .compose import ONNXCompose
|
||||
|
||||
|
||||
def get_test_data_file(case_file, *sub_dirs):
|
||||
test_dir = pathlib.Path(case_file).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
###############################################################################
|
||||
|
||||
import onnx
|
||||
import numpy
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
from ._ocos import default_opset_domain
|
||||
from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
|
||||
|
||||
|
||||
class CustomOp:
|
||||
|
@ -269,3 +270,14 @@ class SingleOpGraph:
|
|||
@staticmethod
|
||||
def get_op_class(op_type):
|
||||
return globals()[op_type]
|
||||
|
||||
|
||||
# TODO: have a C++ impl.
|
||||
def _argsort_op(x, dim):
|
||||
d = numpy.argsort(x, dim)
|
||||
return d[:, ::-1]
|
||||
|
||||
|
||||
Opdef.create(_argsort_op, op_type='ArgSort',
|
||||
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
|
||||
outputs=[PyCustomOpDef.dt_int64])
|
||||
|
|
|
@ -9,7 +9,7 @@ from ._ocos import default_opset_domain, get_library_path # noqa
|
|||
from ._cuops import * # noqa
|
||||
|
||||
|
||||
def _get_opset_version_from_ort():
|
||||
def get_opset_version_from_ort():
|
||||
_ORT_OPSET_SUPPORT_TABLE = {
|
||||
"1.5": 11,
|
||||
"1.6": 12,
|
||||
|
@ -24,16 +24,16 @@ def _get_opset_version_from_ort():
|
|||
|
||||
def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
|
||||
if opset_version == 0:
|
||||
opset_version = _get_opset_version_from_ort()
|
||||
opset_version = get_opset_version_from_ort()
|
||||
fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version'
|
||||
) else onnx.helper.make_model
|
||||
) else onnx.helper.make_model
|
||||
model = fn_mm(graph, opset_imports=[
|
||||
onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
|
||||
model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
|
||||
return model
|
||||
|
||||
|
||||
class EagerOp:
|
||||
class OrtPyFunction:
|
||||
|
||||
@classmethod
|
||||
def get_ort_session_options(cls):
|
||||
|
@ -117,11 +117,11 @@ class EagerOp:
|
|||
def __call__(self, *args, **kwargs):
|
||||
self._ensure_ort_session()
|
||||
outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
|
||||
return outputs[0] if len(outputs) == 1 else outputs
|
||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
|
||||
def optimize_model(model_or_file, output_file):
|
||||
sess_options = EagerOp.get_ort_session_options()
|
||||
sess_options = OrtPyFunction.get_ort_session_options()
|
||||
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
sess_options.optimized_model_filepath = output_file
|
||||
_ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
|
||||
|
|
|
@ -4,7 +4,7 @@ import onnx
|
|||
import numpy
|
||||
|
||||
from onnx import onnx_pb, save_tensor, numpy_helper
|
||||
from ._ortapi2 import EagerOp
|
||||
from ._ortapi2 import OrtPyFunction
|
||||
|
||||
|
||||
class ORTExtCommands:
|
||||
|
@ -16,7 +16,7 @@ class ORTExtCommands:
|
|||
"""
|
||||
Run an onnx model with the arguments as its inputs
|
||||
"""
|
||||
op_func = EagerOp.from_model(self._model)
|
||||
op_func = OrtPyFunction.from_model(self._model)
|
||||
np_args = [numpy.asarray(_x) for _x in args]
|
||||
for _idx, _sch in enumerate(op_func.inputs):
|
||||
if _sch.type.tensor_type.elem_type == onnx_pb.TensorProto.FLOAT:
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
import io
|
||||
import onnx
|
||||
import torch
|
||||
import numpy
|
||||
from torch.onnx import TrainingMode, export as _export
|
||||
from .pnp import ONNXModelUtils, ProcessingModule
|
||||
from ._ortapi2 import OrtPyFunction as ONNXPyFunction
|
||||
|
||||
|
||||
def _is_numpy_object(x):
|
||||
return isinstance(x, (numpy.ndarray, numpy.generic))
|
||||
|
||||
|
||||
def _is_numpy_string_type(arr):
|
||||
return arr.dtype.kind in {'U', 'S'}
|
||||
|
||||
|
||||
def _export_f(model, args=None,
|
||||
export_params=True, verbose=False,
|
||||
input_names=None, output_names=None,
|
||||
operator_export_type=None, opset_version=None,
|
||||
do_constant_folding=True,
|
||||
dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None):
|
||||
if isinstance(model, ProcessingModule):
|
||||
# try to call ProcessingModule export
|
||||
m = model.export(opset_version, *args)
|
||||
if m is not None:
|
||||
return m
|
||||
|
||||
with io.BytesIO() as f:
|
||||
_export(model, args, f,
|
||||
export_params=export_params, verbose=verbose,
|
||||
training=TrainingMode.EVAL, input_names=input_names,
|
||||
output_names=output_names,
|
||||
operator_export_type=operator_export_type, opset_version=opset_version,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets)
|
||||
return onnx.load_model(io.BytesIO(f.getvalue()))
|
||||
|
||||
|
||||
class ONNXCompose:
|
||||
"""
|
||||
Merge the pre/post processing Pytorch subclassing modules with the core model.
|
||||
:arg models the core model, can be an ONNX model or a PyTorch ONNX-exportable models
|
||||
:arg preprocessors the preprocessing module
|
||||
:arg postprocessors the postprocessing module
|
||||
"""
|
||||
def __init__(self, models=None, preprocessors=None, postprocessors=None):
|
||||
self.models = models
|
||||
self.preprocessors = preprocessors
|
||||
self.postprocessors = postprocessors
|
||||
|
||||
self.pre_args = None
|
||||
self.models_args = None
|
||||
self.post_args = None
|
||||
|
||||
def export(self, opset_version, output_file=None,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
input_names=None,
|
||||
output_names=None,
|
||||
operator_export_type=None,
|
||||
do_constant_folding=True,
|
||||
dynamic_axes=None,
|
||||
keep_initializers_as_inputs=None,
|
||||
custom_opsets=None,
|
||||
io_mapping=None):
|
||||
"""
|
||||
export all models and modules into a merged ONNX model.
|
||||
"""
|
||||
post_m = None
|
||||
pre_m = _export_f(self.preprocessors, tuple(self.pre_args),
|
||||
export_params=export_params, verbose=verbose, opset_version=opset_version)
|
||||
|
||||
if isinstance(self.models, torch.nn.Module):
|
||||
core = _export_f(self.models, tuple(self.models_args),
|
||||
export_params=export_params, verbose=verbose, input_names=input_names,
|
||||
output_names=output_names,
|
||||
operator_export_type=operator_export_type, opset_version=opset_version,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets)
|
||||
else:
|
||||
core = self.models
|
||||
|
||||
if self.postprocessors is not None:
|
||||
post_m = _export_f(self.postprocessors, tuple(self.post_args),
|
||||
export_params=export_params, verbose=verbose, opset_version=opset_version)
|
||||
model_l = [core]
|
||||
if pre_m is not None:
|
||||
model_l.insert(0, pre_m)
|
||||
if post_m is not None:
|
||||
model_l.append(post_m)
|
||||
|
||||
full_m = ONNXModelUtils.join_models(*model_l, io_mapping=io_mapping)
|
||||
if output_file is not None:
|
||||
onnx.save_model(full_m, output_file)
|
||||
return full_m
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
"""
|
||||
Predict the result through all modules/models
|
||||
:param args: the input arguments for the first preprocessing module.
|
||||
:param kwargs: ignored
|
||||
:return: the result from the last postprocessing module or
|
||||
from the core model if there is no postprocessing module.
|
||||
"""
|
||||
# convert the raw value, and special handling for string.
|
||||
n_args = [numpy.array(_arg) if not isinstance(_arg, torch.Tensor) else _arg for _arg in args]
|
||||
n_args = [torch.from_numpy(_arg) if
|
||||
_is_numpy_object(_arg) and (not _is_numpy_string_type(_arg)) else _arg for _arg in n_args]
|
||||
|
||||
self.pre_args = n_args
|
||||
inputs = [self.preprocessors.forward(*n_args)]
|
||||
flatten_inputs = []
|
||||
for _i in inputs:
|
||||
flatten_inputs += list(_i) if isinstance(_i, tuple) else [_i]
|
||||
self.models_args = flatten_inputs
|
||||
if isinstance(self.models, torch.nn.Module):
|
||||
outputs = self.models.forward(*flatten_inputs)
|
||||
else:
|
||||
f = ONNXPyFunction.from_model(self.models)
|
||||
outputs = [torch.from_numpy(f(*[_i.numpy() for _i in flatten_inputs]))]
|
||||
self.post_args = outputs
|
||||
if self.postprocessors is None:
|
||||
return outputs
|
||||
|
||||
return self.postprocessors.forward(*outputs)
|
|
@ -5,7 +5,7 @@ import warnings
|
|||
import numpy as np
|
||||
from onnx import helper, mapping
|
||||
from collections import namedtuple
|
||||
from .._ortapi2 import EagerOp
|
||||
from .._ortapi2 import OrtPyFunction
|
||||
from ._builder import is_path as _is_path
|
||||
from ._onnx_ops import ONNXElementContainer, make_model_ex
|
||||
from ._tensor import tensor_from_onnx, tensor_from_torch, tensor_set_session
|
||||
|
@ -264,7 +264,7 @@ class ONNXTraceSession:
|
|||
vi_output)
|
||||
result = None
|
||||
try:
|
||||
oxfunc = EagerOp.from_model(oxml)
|
||||
oxfunc = OrtPyFunction.from_model(oxml)
|
||||
result = oxfunc(*[ts_.numpy() for ts_ in ts_from])
|
||||
finally:
|
||||
if result is None:
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme,
|
|||
from torch import strided, memory_format, contiguous_format, StringType # noqa
|
||||
|
||||
from ._onnx_ops import ox as _ox
|
||||
from .._ortapi2 import EagerOp
|
||||
from .._ortapi2 import OrtPyFunction
|
||||
|
||||
|
||||
class _EagerTensor:
|
||||
|
@ -592,7 +592,7 @@ def control_flow():
|
|||
return _ControlFlowContext()
|
||||
|
||||
|
||||
class _TracingEagerOp(EagerOp):
|
||||
class _TracingEagerOp(OrtPyFunction):
|
||||
def __call__(self, *args, **kwargs):
|
||||
np_args = [ts_.numpy() if isinstance(ts_, _EagerTensor) else ts_ for ts_ in args]
|
||||
outseq = super().__call__(*np_args, **kwargs)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from ._utils import ONNXModelUtils
|
||||
from ._functions import * # noqa
|
||||
|
||||
from ._imagenet import PreMobileNet, PostMobileNet
|
||||
from ._nlp import PreHuggingFaceGPT2
|
||||
from ._base import ProcessingModule
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
from onnx.onnx_pb import TensorProto
|
||||
|
||||
|
||||
class ProcessingModule(torch.nn.Module):
|
||||
@staticmethod
|
||||
def _argsort(g, x, dim, descending):
|
||||
return g.op('ai.onnx.contrib::ArgSort', x, dim)
|
||||
|
||||
@classmethod
|
||||
def register_customops(cls):
|
||||
if hasattr(cls, 'loaded'):
|
||||
return True
|
||||
|
||||
torch.onnx.register_custom_op_symbolic('::argsort', cls._argsort, 1)
|
||||
# ... more
|
||||
|
||||
cls.loaded = True
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def export(self, opset_version, *args):
|
||||
return None
|
||||
|
||||
|
||||
tensor_data_type = TensorProto
|
|
@ -0,0 +1,114 @@
|
|||
import onnx
|
||||
import torch
|
||||
import numpy as np
|
||||
from onnx import helper
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
from typing import Any
|
||||
from torch.autograd import Function as _Function
|
||||
from ._onnx_ops import ox as _ox, schema as _schema
|
||||
from ._onnx_ops import ONNXElementContainer, make_model_ex
|
||||
from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
|
||||
|
||||
|
||||
def _is_numpy_object(x):
|
||||
return isinstance(x, (np.ndarray, np.generic))
|
||||
|
||||
|
||||
def _is_numpy_string_type(arr):
|
||||
return arr.dtype.kind in {'U', 'S'}
|
||||
|
||||
|
||||
def _is_string_type(x):
|
||||
if not _is_numpy_object(x):
|
||||
x = np.array(x)
|
||||
return _is_numpy_string_type(x)
|
||||
|
||||
|
||||
def _to_onnx_type(dtype):
|
||||
ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL,
|
||||
torch.float32: onnx_proto.TensorProto.FLOAT,
|
||||
torch.float64: onnx_proto.TensorProto.DOUBLE,
|
||||
torch.long: onnx_proto.TensorProto.INT64,
|
||||
torch.int32: onnx_proto.TensorProto.INT32}
|
||||
# ...
|
||||
return ty_dict.get(dtype, onnx_proto.TensorProto.STRING)
|
||||
|
||||
|
||||
class ONNXOpFunction(_Function):
|
||||
@classmethod
|
||||
def get_next_id_name(cls, name_base):
|
||||
name = 'cls' if name_base is None else name_base
|
||||
_cid = getattr(cls, '_cid', 1)
|
||||
cls._cid = _cid + 1
|
||||
return "{}_{}".format(name, _cid)
|
||||
|
||||
@staticmethod
|
||||
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
return grad_outputs
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, opset_version, *args):
|
||||
# build the one node graph
|
||||
ec = ONNXElementContainer(get_opset_version_from_ort() if opset_version is None else opset_version)
|
||||
attrs = cls.attrs
|
||||
vi_inputs = [helper.make_tensor_value_info(
|
||||
'it_' + str(id(_arg)), _to_onnx_type(_arg.dtype), list(_arg.shape))
|
||||
for _arg in args]
|
||||
inputs = [_vi.name for _vi in vi_inputs]
|
||||
if hasattr(cls.opb_func, 'outputs') and len(cls.opb_func.outputs) > 0:
|
||||
vi_outputs = [helper.make_tensor_value_info(
|
||||
cls.get_next_id_name('ot'), *_schm) for _schm in cls.opb_func.outputs]
|
||||
else:
|
||||
vi_outputs = [helper.make_tensor_value_info(
|
||||
cls.get_next_id_name('ot'), onnx_proto.TensorProto.FLOAT, []
|
||||
)]
|
||||
outputs = [_vi.name for _vi in vi_outputs]
|
||||
# build the node
|
||||
opfunc = cls.opb_func
|
||||
opfunc(inputs, outputs, ec, None, **attrs)
|
||||
g = helper.make_graph(ec.nodes, cls.get_next_id_name('g'), vi_inputs, vi_outputs)
|
||||
m = make_model_ex(g, ec.node_domain_version_pair_sets, ec.target_opset)
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
@torch.jit.unused
|
||||
def _onnx_call(cls, ctx, *args) -> Any:
|
||||
m = cls.build_model(None, *args)
|
||||
try:
|
||||
f = OrtPyFunction.from_model(m)
|
||||
result = f(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args))
|
||||
except Exception as e:
|
||||
onnx.save_model(m, 'smoking.onnx')
|
||||
raise e
|
||||
|
||||
results = result if isinstance(result, tuple) else [result]
|
||||
return tuple([torch.from_numpy(_o) for _o in results]) if len(results) > 1 else torch.from_numpy(results[0])
|
||||
|
||||
@classmethod
|
||||
def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
return cls._onnx_call(ctx, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def symbolic(cls, g, *args):
|
||||
return g.op(cls.op_type, *args)
|
||||
|
||||
|
||||
def create_op_function(op_type: str, func, **attrs):
|
||||
if _ox.is_raw(func):
|
||||
func = _schema(func.__func__)
|
||||
cls = type(_ox.get_unique_operator_type_name(op_type), (ONNXOpFunction, ),
|
||||
dict(
|
||||
op_type=op_type,
|
||||
opb_func=func,
|
||||
attrs=attrs
|
||||
))
|
||||
return cls.apply # noqa
|
||||
|
||||
|
||||
onnx_pad = create_op_function('Pad', _ox.pad)
|
||||
onnx_where = create_op_function('Where', _ox.where)
|
||||
onnx_greater = create_op_function('Greater', _ox.greater)
|
|
@ -0,0 +1,110 @@
|
|||
import io
|
||||
import onnx
|
||||
import torch
|
||||
from torch.nn.functional import interpolate
|
||||
from torch.onnx import TrainingMode, export as _export
|
||||
from ._base import ProcessingModule
|
||||
from ._functions import onnx_where, onnx_greater
|
||||
|
||||
|
||||
def _resize_param(img, size):
|
||||
y, x = tuple(img.shape[-2:])
|
||||
scale_y = size / y
|
||||
scale_x = size / x
|
||||
return onnx_where(onnx_greater(scale_x, scale_y), scale_x, scale_y)
|
||||
|
||||
|
||||
class ImagenetPreProcessingLite(ProcessingModule):
|
||||
def __init__(self, size):
|
||||
super(ImagenetPreProcessingLite, self).__init__()
|
||||
self.target_size = size
|
||||
|
||||
def forward(self, img):
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.tensor(img)
|
||||
img = torch.permute(img, (2, 0, 1))
|
||||
x = img.to(torch.float32).unsqueeze(0)
|
||||
# T.CenterCrop(224),
|
||||
width, height = tuple(self.target_size)
|
||||
img_h, img_w = x.shape[-2:]
|
||||
s_h = torch.div((img_h - height), 2, rounding_mode='trunc')
|
||||
s_w = torch.div((img_w - width), 2, rounding_mode='trunc')
|
||||
x = x[:, :, s_h:s_h + height, s_w:s_w + width]
|
||||
# T.ToTensor(),
|
||||
x /= 255. # ToTensor
|
||||
# T.Normalize(
|
||||
# mean=[0.485, 0.456, 0.406],
|
||||
# std=[0.229, 0.224, 0.225]
|
||||
# )
|
||||
mean = torch.tensor([0.485, 0.456, 0.406])
|
||||
std = torch.tensor([0.229, 0.224, 0.225])
|
||||
x -= torch.reshape(torch.tensor(mean), (3, 1, 1))
|
||||
x /= torch.reshape(torch.tensor(std), (3, 1, 1))
|
||||
return x
|
||||
|
||||
|
||||
class ImagenetPreProcessing(ProcessingModule):
|
||||
def __init__(self, size):
|
||||
super(ImagenetPreProcessing, self).__init__()
|
||||
self.target_size = size
|
||||
|
||||
def forward(self, img):
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.tensor(img)
|
||||
img = torch.permute(img, (2, 0, 1))
|
||||
# T.Resize(256),
|
||||
img = img.to(torch.float32).unsqueeze(0)
|
||||
scale = _resize_param(img, torch.tensor(256))
|
||||
x = interpolate(img, scale_factor=scale,
|
||||
recompute_scale_factor=True,
|
||||
mode="bilinear", align_corners=False)
|
||||
# T.CenterCrop(224),
|
||||
width, height = self.target_size, self.target_size
|
||||
img_h, img_w = x.shape[-2:]
|
||||
s_h = torch.div((img_h - height), 2, rounding_mode='trunc')
|
||||
s_w = torch.div((img_w - width), 2, rounding_mode='trunc')
|
||||
x = x[:, :, s_h:s_h + height, s_w:s_w + width]
|
||||
# T.ToTensor(),
|
||||
x /= 255. # ToTensor
|
||||
# T.Normalize(
|
||||
# mean=[0.485, 0.456, 0.406],
|
||||
# std=[0.229, 0.224, 0.225]
|
||||
# )
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
x -= torch.reshape(torch.tensor(mean), (3, 1, 1))
|
||||
x /= torch.reshape(torch.tensor(std), (3, 1, 1))
|
||||
# x[:, 0, :, :] -= mean[0]
|
||||
# x[:, 1, :, :] -= mean[1]
|
||||
# x[:, 2, :, :] -= mean[2]
|
||||
# x[:, 0, :, :] /= std[0]
|
||||
# x[:, 1, :, :] /= std[1]
|
||||
# x[:, 2, :, :] /= std[2]
|
||||
return x
|
||||
|
||||
|
||||
class ImagePostProcessing(ProcessingModule):
|
||||
def forward(self, scores):
|
||||
ProcessingModule.register_customops()
|
||||
probabilities = torch.softmax(scores, dim=1)
|
||||
ids = probabilities.argsort(dim=1, descending=True)
|
||||
return ids, probabilities
|
||||
|
||||
def export(self, opset_version, *args):
|
||||
with io.BytesIO() as f:
|
||||
name_i = 'image'
|
||||
_export(self, args, f,
|
||||
training=TrainingMode.EVAL,
|
||||
opset_version=opset_version,
|
||||
input_names=[name_i],
|
||||
dynamic_axes={name_i: [0, 1]})
|
||||
return onnx.load_model(io.BytesIO(f.getvalue()))
|
||||
|
||||
|
||||
class PreMobileNet(ImagenetPreProcessing):
|
||||
def __init__(self, size=None):
|
||||
super(PreMobileNet, self).__init__(224 if size is None else size)
|
||||
|
||||
|
||||
class PostMobileNet(ImagePostProcessing):
|
||||
pass
|
|
@ -0,0 +1,62 @@
|
|||
import json
|
||||
|
||||
from ._base import ProcessingModule, tensor_data_type as _dt
|
||||
from ._functions import create_op_function
|
||||
from ._onnx_ops import schema
|
||||
from .._ocos import default_opset_domain
|
||||
|
||||
|
||||
def make_custom_op(ctx, op_type, input_names, output_names, container, operator_name=None, **kwargs):
|
||||
op_name = container.get_unique_operator_name(op_type) if operator_name is None else operator_name
|
||||
container.add_node(op_type, input_names, output_names,
|
||||
op_version=1, name=op_name, op_domain=default_opset_domain(), **kwargs)
|
||||
|
||||
|
||||
@schema(inputs=((_dt.STRING, []),),
|
||||
outputs=((_dt.INT64, []), (_dt.INT64, [])))
|
||||
def gpt2_tokenize(ctx, input_names, output_names, container, operator_name=None, **kwargs):
|
||||
if 'hf_tok' in kwargs:
|
||||
hf_gpt2_tokenizer = kwargs['hf_tok']
|
||||
attrs = {'vocab': json.dumps(hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
||||
sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
||||
attrs['merges'] = '\n'.join("{} {}".format(*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||
elif 'vocab' in kwargs:
|
||||
attrs = dict(
|
||||
vocab=kwargs['vocab'],
|
||||
merges=kwargs['merges'])
|
||||
else:
|
||||
raise RuntimeError("Need hf_tok/vocab parameter to build the tokenizer")
|
||||
padding_len = -1
|
||||
if 'padding_length' in kwargs:
|
||||
padding_len = kwargs['padding_length']
|
||||
attrs['padding_length'] = padding_len
|
||||
|
||||
return make_custom_op(ctx, 'GPT2Tokenizer', input_names,
|
||||
output_names, container, operator_name=operator_name, **attrs)
|
||||
|
||||
|
||||
def _get_file_content(path):
|
||||
with open(path, "rb") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def _get_bound_object(func):
|
||||
return func.__self__
|
||||
|
||||
|
||||
class PreHuggingFaceGPT2(ProcessingModule):
|
||||
def __init__(self, hf_tok=None, vocab_file=None, merges_file=None, padding_length=-1):
|
||||
super(PreHuggingFaceGPT2, self).__init__()
|
||||
if hf_tok is None:
|
||||
self.onnx_gpt2_tokenize = create_op_function('GPT2Tokenizer', gpt2_tokenize,
|
||||
vocab=_get_file_content(vocab_file),
|
||||
merges=_get_file_content(merges_file),
|
||||
padding_length=padding_length)
|
||||
else:
|
||||
self.onnx_gpt2_tokenize = create_op_function('GPT2Tokenizer', gpt2_tokenize, hf_tok=self.hf_tok)
|
||||
|
||||
def forward(self, text):
|
||||
return self.onnx_gpt2_tokenize(text)
|
||||
|
||||
def export(self, opset_version, *args):
|
||||
return _get_bound_object(self.onnx_gpt2_tokenize).build_model(opset_version, *args)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,201 @@
|
|||
import copy
|
||||
import onnx
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class ONNXModelUtils:
|
||||
@staticmethod
|
||||
def merge_name(prefix, name):
|
||||
return "{}_{}".format(prefix, name)
|
||||
|
||||
@staticmethod
|
||||
def _rename_iter(iterables, prefix_name, inplace=False):
|
||||
new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
|
||||
for iz_ in new_iz:
|
||||
iz_.name = ONNXModelUtils.merge_name(prefix_name, iz_.name)
|
||||
return new_iz
|
||||
|
||||
@classmethod
|
||||
def _rename_graph(cls, graph, prefix, graph_or_container):
|
||||
def io_rename(node, prefix_name, idx):
|
||||
new_node = copy.deepcopy(node)
|
||||
if not node.name:
|
||||
new_node.name = cls.merge_name(prefix_name, "op{}".format(idx))
|
||||
else:
|
||||
new_node.name = cls.merge_name(prefix_name, node.name)
|
||||
|
||||
del new_node.input[:]
|
||||
new_node.input.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
|
||||
del new_node.output[:]
|
||||
new_node.output.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
|
||||
return new_node
|
||||
|
||||
assert prefix is not None, 'The graph prefix could not be None'
|
||||
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
|
||||
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
|
||||
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
|
||||
|
||||
@classmethod
|
||||
def _process_node_body(cls, node, prefix):
|
||||
if all(attr.name != 'body' for attr in node.attribute):
|
||||
return node
|
||||
|
||||
def _process_attr(attr, prefix_name):
|
||||
if attr.name == 'body':
|
||||
new_attr = copy.deepcopy(attr)
|
||||
del new_attr.g.value_info[:]
|
||||
del new_attr.g.node[:]
|
||||
new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
|
||||
cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
|
||||
cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
|
||||
return new_attr
|
||||
else:
|
||||
return attr
|
||||
|
||||
attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
|
||||
del node.attribute[:]
|
||||
node.attribute.extend(attr_list)
|
||||
return node
|
||||
|
||||
@classmethod
|
||||
def unfold_model_node(cls, container):
|
||||
top_containter = container
|
||||
while top_containter.parent is not None: # only one opset_import in the model.
|
||||
top_containter = top_containter.parent
|
||||
|
||||
nodes = container.nodes
|
||||
model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')}
|
||||
onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes]
|
||||
|
||||
for node in model_nodes.values():
|
||||
renamed_nodes = cls._rename_graph(node.model.graph, node.name, container)
|
||||
onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes)
|
||||
|
||||
top_containter.node_domain_version_pair_sets.update(
|
||||
[(opset_.domain, opset_.version) for opset_ in node.model.opset_import])
|
||||
return onnx_nodes
|
||||
|
||||
@classmethod
|
||||
def topological_sort(cls, container, nodes, inputs, outputs):
|
||||
op_output_map = {}
|
||||
DynNode = namedtuple('DynNode', ['name', 'output'])
|
||||
input_nodes = [DynNode(name='placeholder',
|
||||
output=[nm_.name for nm_ in inputs] +
|
||||
[it_.name for it_ in container.initializers])] + \
|
||||
[nd_ for nd_ in nodes if nd_.op_type == 'Constant']
|
||||
|
||||
for nd_ in nodes + input_nodes:
|
||||
for ky_ in nd_.output:
|
||||
op_output_map[ky_] = nd_
|
||||
|
||||
edges = {}
|
||||
for op in nodes:
|
||||
for x in op.input:
|
||||
if x == '':
|
||||
continue
|
||||
try:
|
||||
predecessor = op_output_map[x]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
|
||||
|
||||
val = edges.get(predecessor.name, [])
|
||||
val.append(op)
|
||||
edges[predecessor.name] = val
|
||||
|
||||
for y_ in outputs:
|
||||
op = op_output_map[y_.name].name
|
||||
if op not in edges:
|
||||
edges[op] = []
|
||||
|
||||
visited = set()
|
||||
sorted_nodes = []
|
||||
unfinished_nodes = set()
|
||||
|
||||
def recursive_helper(node):
|
||||
if node.name in visited:
|
||||
return
|
||||
|
||||
if node.name in unfinished_nodes:
|
||||
raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
|
||||
|
||||
unfinished_nodes.add(node.name)
|
||||
if node.name in edges: # if the node's output is not in the Graph output.
|
||||
assert node.name != '', 'this topological-sort depends on the unique node name.'
|
||||
for successor in edges[node.name]:
|
||||
recursive_helper(successor)
|
||||
|
||||
unfinished_nodes.remove(node.name)
|
||||
visited.add(node.name)
|
||||
if node is not input_nodes[0]:
|
||||
sorted_nodes.insert(0, node)
|
||||
|
||||
for nd_ in input_nodes:
|
||||
recursive_helper(nd_)
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
@staticmethod
|
||||
def _remove_unused_initializers(nodes, initializers, reversed_names):
|
||||
nodes_input_set = set()
|
||||
for nd_ in nodes:
|
||||
nodes_input_set.update(n_ for n_ in nd_.input)
|
||||
|
||||
return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reversed_names]
|
||||
|
||||
@classmethod
|
||||
def join_models(cls, *models, io_mapping=None):
|
||||
# generate the prefix id for the embedding graph to avoid the name conflict
|
||||
mdl_prefix = []
|
||||
for _i in range(len(models)):
|
||||
mdl_prefix.append("g{}".format(_i + 1))
|
||||
|
||||
inputs = cls._rename_iter(models[0].graph.input, mdl_prefix[0])
|
||||
outputs = cls._rename_iter(models[-1].graph.output, mdl_prefix[-1])
|
||||
|
||||
port_mapping = {}
|
||||
for _idx in range(len(models) - 1):
|
||||
for _i, _x in enumerate(models[_idx + 1].graph.input):
|
||||
iname = cls.merge_name(mdl_prefix[_idx + 1], _x.name)
|
||||
oname = cls.merge_name(mdl_prefix[_idx], models[_idx].graph.output[_i].name)
|
||||
port_mapping[iname] = oname
|
||||
if io_mapping:
|
||||
# TODO: support the prefix of in the argument.
|
||||
port_mapping.update(io_mapping)
|
||||
|
||||
nodes = []
|
||||
Container = namedtuple('Container', ['initializer', 'value_info'])
|
||||
container = Container(initializer=[], value_info=[])
|
||||
for _idx, _m in enumerate(models):
|
||||
container.initializer.extend(_m.graph.initializer)
|
||||
container.value_info.extend(_m.graph.value_info)
|
||||
nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container)
|
||||
|
||||
for _n in nodes:
|
||||
replaceable = False
|
||||
for _i in _n.input:
|
||||
if _i in port_mapping:
|
||||
replaceable = True
|
||||
break
|
||||
if replaceable:
|
||||
new_input = copy.deepcopy(_n.input)
|
||||
del _n.input[:]
|
||||
_n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input])
|
||||
|
||||
name = ''
|
||||
domains = set()
|
||||
_opset = []
|
||||
for _mdl in models:
|
||||
for _ops in _mdl.opset_import:
|
||||
if _ops.domain not in domains:
|
||||
domains.update([_ops.domain])
|
||||
_opset.append(_ops)
|
||||
name = name + '_' + _mdl.graph.name if name else _mdl.graph.name
|
||||
|
||||
inits = cls._remove_unused_initializers(nodes, container.initializer, set())
|
||||
helper = onnx.helper
|
||||
g = helper.make_graph(nodes, name, inputs, outputs,
|
||||
initializer=inits,
|
||||
value_info=container.value_info)
|
||||
m = helper.make_model(g, opset_imports=_opset)
|
||||
return m
|
|
@ -492,4 +492,8 @@ PYBIND11_MODULE(_ortcustomops, m) {
|
|||
init_numpy();
|
||||
AddGlobalMethods(m);
|
||||
AddObjectMethods(m);
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() {
|
||||
PyCustomOpDefImpl::op_invoker.reset();
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
[build]
|
||||
build-base = .setuptools-cmake-build
|
||||
# debug = 1
|
||||
debug = 1
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 285 KiB |
|
@ -8,6 +8,7 @@ from onnxruntime_extensions import PyOrtFunction, hook_model_op, PyOp
|
|||
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
|
||||
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model
|
||||
|
||||
|
||||
@unittest.skipIf(platform.python_version_tuple()[0:2] == (
|
||||
'3', '7'), 'Windows CI pipeline failed on the version temporarily.')
|
||||
class TestTorchE2E(unittest.TestCase):
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
import onnx
|
||||
import numpy
|
||||
import torch
|
||||
import unittest
|
||||
from PIL import Image
|
||||
from distutils.version import LooseVersion
|
||||
from onnxruntime_extensions import PyOrtFunction, ONNXCompose
|
||||
from onnxruntime_extensions import pnp, get_test_data_file
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
class _GPT2LMHeadModel(GPT2LMHeadModel):
|
||||
""" Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = super(_GPT2LMHeadModel, self).forward(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False)
|
||||
# drop the past states
|
||||
return result[0]
|
||||
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.9"), 'Only tested the lastest PyTorch')
|
||||
class TestPreprocessing(unittest.TestCase):
|
||||
def test_imagenet_preprocessing(self):
|
||||
mnv2 = onnx.load_model(get_test_data_file(__file__, 'data', 'mobilev2.onnx'))
|
||||
|
||||
# load an image
|
||||
img = Image.open(get_test_data_file(__file__, 'data', 'pineapple.jpg'))
|
||||
img = numpy.asarray(img.convert('RGB'))
|
||||
|
||||
full_models = ONNXCompose(
|
||||
mnv2,
|
||||
preprocessors=pnp.PreMobileNet(224),
|
||||
postprocessors=pnp.PostMobileNet()
|
||||
)
|
||||
|
||||
ids, probabilities = full_models.predict(torch.from_numpy(img))
|
||||
full_model_func = PyOrtFunction.from_model(full_models.export(opset_version=11))
|
||||
actual_ids, actual_result = full_model_func(img)
|
||||
numpy.testing.assert_allclose(probabilities.numpy(), actual_result, rtol=1e-3)
|
||||
self.assertEqual(ids[0, 0].item(), 953) # 953 is pineapple class id in the imagenet dataset
|
||||
|
||||
@unittest.skip("TODO, fix it on the CI environment.")
|
||||
def test_gpt2_preprocessing(self):
|
||||
cfg = GPT2Config(n_layer=3)
|
||||
gpt2_m = _GPT2LMHeadModel(cfg)
|
||||
gpt2_m.eval().to('cpu')
|
||||
full_model = ONNXCompose(
|
||||
gpt2_m,
|
||||
preprocessors=pnp.PreHuggingFaceGPT2(vocab_file=get_test_data_file(__file__, 'data', 'gpt2.vocab'),
|
||||
merges_file=get_test_data_file(__file__, 'data', 'gpt2.merges.txt')))
|
||||
test_sentence = ["Test a sentence"]
|
||||
expected = full_model.predict(test_sentence)
|
||||
model = full_model.export(opset_version=12, do_constant_folding=False)
|
||||
mfunc = PyOrtFunction.from_model(model)
|
||||
actuals = mfunc(test_sentence)
|
||||
# the random weight may generate a large diff in result, test the shape only.
|
||||
self.assertTrue(numpy.allclose(expected.size(), actuals.shape))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Загрузка…
Ссылка в новой задаче