[RELAY][FRONTEND]Onnx to relay frontend (#2302)
This commit is contained in:
Родитель
312802f341
Коммит
30a5a6007d
|
@ -35,7 +35,7 @@ Our goal is to build the shared libraries:
|
|||
.. code:: bash
|
||||
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev
|
||||
sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake
|
||||
|
||||
The minimal building requirements are
|
||||
|
||||
|
|
|
@ -910,7 +910,7 @@ def test_single_ops():
|
|||
model = helper.make_model(graph, producer_name='_test')
|
||||
for target, ctx in ctx_list():
|
||||
tvm_out = get_tvm_output(model, [x], target, ctx)
|
||||
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
|
||||
tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
|
||||
|
||||
x = np.random.uniform(size=in_shape).astype(dtype)
|
||||
verify_single_ops("Neg",x, -x)
|
||||
|
@ -918,13 +918,13 @@ def test_single_ops():
|
|||
verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Relu",x, np.maximum(x, 0))
|
||||
verify_single_ops("Exp",x, np.exp(x))
|
||||
verify_single_ops("Log",x, np.log(x))
|
||||
verify_single_ops("Log",x, np.log(x))
|
||||
verify_single_ops("Tanh",x, np.tanh(x))
|
||||
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
|
||||
verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
|
||||
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
|
||||
verify_single_ops("Exp",x, np.exp(x), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Tanh",x, np.tanh(x), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("Softsign",x, x / (1 + np.abs(x)), rtol=1e-5, atol=1e-5)
|
||||
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_leaky_relu():
|
||||
def leaky_relu_x(x, alpha):
|
||||
|
|
|
@ -465,6 +465,14 @@ def const(value, dtype=None):
|
|||
"""
|
||||
if isinstance(value, (_base.numeric_types, (bool, list))):
|
||||
value = _np.array(value, dtype=dtype)
|
||||
if not dtype:
|
||||
# when dtype is None: int maps to "int32", float maps to "float32"
|
||||
map_dtype = {
|
||||
_np.dtype('int64'): _np.int32,
|
||||
_np.dtype('float64'): _np.float32
|
||||
}.get(value.dtype, None)
|
||||
if map_dtype:
|
||||
value = value.astype(map_dtype)
|
||||
if isinstance(value, (_np.ndarray, _np.generic)):
|
||||
value = _nd.array(value)
|
||||
|
||||
|
|
|
@ -9,3 +9,4 @@ from __future__ import absolute_import
|
|||
|
||||
from .mxnet import from_mxnet
|
||||
from .keras import from_keras
|
||||
from .onnx import from_onnx
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
"""Common utilities"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import logging
|
||||
from topi.util import get_const_tuple
|
||||
from .. import expr as _expr
|
||||
from .. import expr as _expr
|
||||
from .. import ir_pass
|
||||
from .. import op as _op
|
||||
|
||||
|
||||
class RequiredAttr(object):
|
||||
|
@ -204,6 +209,30 @@ class StrAttrsDict(object):
|
|||
raise AttributeError("Required attribute {} not found.".format(key))
|
||||
return default
|
||||
|
||||
def get_relay_op(op_name):
|
||||
"""Get the callable function from Relay based on operator name.
|
||||
Parameters
|
||||
----------
|
||||
op_name : str
|
||||
The Relay operator name.
|
||||
"""
|
||||
if '.' in op_name:
|
||||
# explicit hierachical modules
|
||||
op = _op
|
||||
try:
|
||||
for opn in op_name.split('.'):
|
||||
op = getattr(op, opn)
|
||||
except AttributeError:
|
||||
op = None
|
||||
else:
|
||||
# try search op in various modules
|
||||
for candidate in (_op, _op.nn, _op.image):
|
||||
op = getattr(candidate, op_name, None)
|
||||
if op is not None:
|
||||
break
|
||||
if not op:
|
||||
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
|
||||
return op
|
||||
|
||||
class ExprTable(object):
|
||||
"""Table storing Relay expressions by names."""
|
||||
|
@ -227,3 +256,156 @@ class ExprTable(object):
|
|||
def set_expr(self, name, expr):
|
||||
assert isinstance(expr, _expr.Expr)
|
||||
self.exprs[name] = expr
|
||||
|
||||
|
||||
class AttrCvt(object):
|
||||
"""Common attribute conveter. An AttrConverter instance is a callable:
|
||||
```
|
||||
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
|
||||
new_op_name, new_attr = attr_converter(attrs)
|
||||
```
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_name : str or callable
|
||||
If set as str, returned operator name is the str.
|
||||
If set as callable, returned operator is the str returned by calling:
|
||||
`op_name = func(attr)`
|
||||
transforms : dict of `new_name, or (new_name, default_value, transform function)`
|
||||
If only a new_name is provided, it's like renaming the attribute name.
|
||||
If default_value if provded, then the attribute is considered as optional.
|
||||
If transform function is provided, the original attribute value is handled
|
||||
by transform function.
|
||||
excludes : list
|
||||
A list of excluded attributes that should `NOT` appear.
|
||||
Raise NotImplementedError if occured.
|
||||
disables : list
|
||||
A list of attributes that is disabled in relay. Log warnings.
|
||||
ignores : list
|
||||
A list of attributes that is ignored in relay. Debug level logging.
|
||||
extras : dict
|
||||
A series of additional attributes should be added anyway to the returned
|
||||
attribute dict.
|
||||
custom_check : callable
|
||||
A custom function takes attribute, and return True/False.
|
||||
Raise RuntimeError if not bool(True) returned.
|
||||
"""
|
||||
def __init__(self, op_name, transforms=None,
|
||||
excludes=None, disables=None, ignores=None,
|
||||
extras=None, custom_check=None):
|
||||
self._op_name = op_name
|
||||
self._transforms = transforms if transforms else {}
|
||||
self._excludes = excludes if excludes else []
|
||||
self._disables = disables if disables else []
|
||||
self._ignores = ignores if ignores else []
|
||||
self._extras = extras if extras else {}
|
||||
self._custom_check = custom_check
|
||||
|
||||
def __call__(self, inputs, attrs, *args):
|
||||
# apply custom check
|
||||
if self._custom_check:
|
||||
func, msg = self._custom_check
|
||||
if not func(attrs):
|
||||
raise RuntimeError("Check failed: {}".format(msg))
|
||||
# get new op_name
|
||||
if isinstance(self._op_name, str):
|
||||
op_name = self._op_name
|
||||
else:
|
||||
assert callable(self._op_name), "op_name can either be string or callable"
|
||||
op_name = self._op_name(attrs)
|
||||
# convert attributes
|
||||
new_attrs = {}
|
||||
for k in attrs.keys():
|
||||
if k in self._excludes:
|
||||
raise NotImplementedError("Attribute {} not supported yet.".format(k))
|
||||
elif k in self._disables:
|
||||
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
|
||||
elif k in self._ignores:
|
||||
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
|
||||
elif k in self._transforms:
|
||||
new_name, defaults, transform = self._parse_default(self._transforms[k])
|
||||
if defaults is None:
|
||||
new_attr = self._required_attr(attrs, k)
|
||||
else:
|
||||
new_attr = attrs.get(k, None)
|
||||
if new_attr is None:
|
||||
new_attrs[new_name] = defaults
|
||||
else:
|
||||
new_attrs[new_name] = transform(new_attr)
|
||||
else:
|
||||
# copy
|
||||
new_attrs[k] = attrs[k]
|
||||
# add extras
|
||||
new_attrs.update(self._extras)
|
||||
return get_relay_op(op_name)(*inputs, **new_attrs)
|
||||
|
||||
def _parse_default(self, target):
|
||||
"""Helper function to parse default values."""
|
||||
if not isinstance(target, (list, tuple)):
|
||||
k, v, t = target, None, lambda x: x
|
||||
elif len(target) == 1:
|
||||
k, v, t = target[0], None, lambda x: x
|
||||
elif len(target) == 2:
|
||||
k, v, t = target[0], target[1], lambda x: x
|
||||
elif len(target) > 2:
|
||||
k, v, t = target[0], target[1], target[2]
|
||||
else:
|
||||
k = None # should raise
|
||||
if not isinstance(k, str):
|
||||
msg = "{} is not a valid target, (name, default) expected.".format(target)
|
||||
raise ValueError(msg)
|
||||
return k, v, t
|
||||
|
||||
def _parse_bool(self, value):
|
||||
"""Helper function to parse default boolean values."""
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
|
||||
return bool(value)
|
||||
|
||||
def _required_attr(self, attr, key):
|
||||
"""Wrapper for getting required attributes."""
|
||||
assert isinstance(attr, dict)
|
||||
if key not in attr:
|
||||
raise AttributeError("Required attribute {} not found.".format(key))
|
||||
return attr[key]
|
||||
|
||||
def get_name(node):
|
||||
name = ''
|
||||
if hasattr(node, "name_hint"):
|
||||
name = node.name_hint
|
||||
return name
|
||||
|
||||
def infer_shape(inputs):
|
||||
"""A method to get the output shape of an intermediate node in the graph."""
|
||||
out_type = ir_pass.infer_type(inputs)
|
||||
out_shapes = get_const_tuple(out_type.checked_type.shape)
|
||||
return out_shapes
|
||||
|
||||
def infer_channels(inputs, transpose=False):
|
||||
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
|
||||
these attributes. We check the shape of weights provided to get the number.
|
||||
"""
|
||||
out_type = ir_pass.infer_type(inputs)
|
||||
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
|
||||
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
|
||||
return channels
|
||||
|
||||
def new_var(name_hint,
|
||||
type_annotation=None,
|
||||
shape=None,
|
||||
dtype="float32"):
|
||||
return _expr.var(name_hint, type_annotation, shape, dtype)
|
||||
|
||||
class Renamer(object):
|
||||
"""A simply renamer for operators.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_name : str
|
||||
The new name for the operator
|
||||
"""
|
||||
def __init__(self, new_name):
|
||||
self._new_name = new_name
|
||||
|
||||
def __call__(self, inputs, attrs, *args):
|
||||
return get_relay_op(self._new_name)(*inputs, **attrs)
|
||||
|
|
|
@ -4,15 +4,7 @@ from __future__ import absolute_import as _abs
|
|||
|
||||
from .. import expr as _expr
|
||||
from .. import op as _op
|
||||
|
||||
def _get_relay_op(op_name):
|
||||
op = _op
|
||||
for path in op_name.split("."):
|
||||
op = getattr(op, path)
|
||||
if not op:
|
||||
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
|
||||
return op
|
||||
|
||||
from .common import get_relay_op
|
||||
|
||||
def _warn_not_used(attr, op='nnvm'):
|
||||
import warnings
|
||||
|
@ -22,7 +14,7 @@ def _warn_not_used(attr, op='nnvm'):
|
|||
|
||||
def _rename(new_op):
|
||||
if isinstance(new_op, str):
|
||||
new_op = _get_relay_op(new_op)
|
||||
new_op = get_relay_op(new_op)
|
||||
# attrs are ignored.
|
||||
def impl(inputs, _, _dtype='float32'):
|
||||
return new_op(*inputs)
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -32,3 +32,6 @@ python3 -m nose -v tests/python/frontend/mxnet || exit -1
|
|||
|
||||
echo "Running relay Keras frontend test..."
|
||||
python3 -m nose -v tests/python/frontend/keras || exit -1
|
||||
|
||||
echo "Running relay ONNX frondend test..."
|
||||
python3 -m nose -v tests/python/frontend/onnx || exit -1
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Compile ONNX Models
|
||||
===================
|
||||
**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
|
||||
|
||||
This article is an introductory tutorial to deploy ONNX models with Relay.
|
||||
|
||||
For us to begin with, ONNX package must be installed.
|
||||
|
||||
A quick solution is to install protobuf compiler, and
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install onnx --user
|
||||
|
||||
or please refer to offical site.
|
||||
https://github.com/onnx/onnx
|
||||
"""
|
||||
import onnx
|
||||
import numpy as np
|
||||
import tvm
|
||||
import tvm.relay as relay
|
||||
|
||||
def download(url, path, overwrite=False):
|
||||
import os
|
||||
if os.path.isfile(path) and not overwrite:
|
||||
print('File {} existed, skip.'.format(path))
|
||||
return
|
||||
print('Downloading from url {} to {}'.format(url, path))
|
||||
try:
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(url, path)
|
||||
except:
|
||||
import urllib
|
||||
urllib.urlretrieve(url, path)
|
||||
|
||||
######################################################################
|
||||
# Load pretrained ONNX model
|
||||
# ---------------------------------------------
|
||||
# The example super resolution model used here is exactly the same model in onnx tutorial
|
||||
# http://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html
|
||||
# we skip the pytorch model construction part, and download the saved onnx model
|
||||
model_url = ''.join(['https://gist.github.com/zhreshold/',
|
||||
'bcda4716699ac97ea44f791c24310193/raw/',
|
||||
'93672b029103648953c4e5ad3ac3aadf346a4cdc/',
|
||||
'super_resolution_0.2.onnx'])
|
||||
download(model_url, 'super_resolution.onnx', False)
|
||||
# now you have super_resolution.onnx on disk
|
||||
onnx_model = onnx.load('super_resolution.onnx')
|
||||
|
||||
######################################################################
|
||||
# Load a test image
|
||||
# ---------------------------------------------
|
||||
# A single cat dominates the examples!
|
||||
from PIL import Image
|
||||
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
|
||||
download(img_url, 'cat.png')
|
||||
img = Image.open('cat.png').resize((224, 224))
|
||||
img_ycbcr = img.convert("YCbCr") # convert to YCbCr
|
||||
img_y, img_cb, img_cr = img_ycbcr.split()
|
||||
x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
|
||||
|
||||
######################################################################
|
||||
# Compile the model with relay
|
||||
# ---------------------------------------------
|
||||
target = 'llvm'
|
||||
|
||||
input_name = '1'
|
||||
shape_dict = {input_name: x.shape}
|
||||
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
|
||||
|
||||
with relay.build_config(opt_level=1):
|
||||
intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target)
|
||||
|
||||
######################################################################
|
||||
# Execute on TVM
|
||||
# ---------------------------------------------
|
||||
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
|
||||
|
||||
######################################################################
|
||||
# Display results
|
||||
# ---------------------------------------------
|
||||
# We put input and output image neck to neck
|
||||
from matplotlib import pyplot as plt
|
||||
out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L')
|
||||
out_cb = img_cb.resize(out_y.size, Image.BICUBIC)
|
||||
out_cr = img_cr.resize(out_y.size, Image.BICUBIC)
|
||||
result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB')
|
||||
canvas = np.full((672, 672*2, 3), 255)
|
||||
canvas[0:224, 0:224, :] = np.asarray(img)
|
||||
canvas[:, 672:, :] = np.asarray(result)
|
||||
plt.imshow(canvas.astype(np.uint8))
|
||||
plt.show()
|
Загрузка…
Ссылка в новой задаче