Onnx opset support (#416)
This commit is contained in:
Родитель
f4789db696
Коммит
e722dbcbf7
|
@ -4,87 +4,119 @@ from __future__ import absolute_import as _abs
|
|||
import tvm
|
||||
from .. import symbol as _sym
|
||||
from .. import graph as _graph
|
||||
from .. compiler import graph_util
|
||||
from ..compiler import graph_util
|
||||
from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
|
||||
|
||||
__all__ = ['from_onnx']
|
||||
|
||||
def _revert_caffe2_pad(attr):
|
||||
"""Caffe2 require two times the normal padding."""
|
||||
if len(attr) == 4:
|
||||
attr = attr[:2]
|
||||
elif len(attr) == 2:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
|
||||
return attr
|
||||
|
||||
def _math_name_picker(surfix):
|
||||
def _impl(attr):
|
||||
if attr.get('broadcast', 0):
|
||||
return 'broadcast_' + surfix
|
||||
return 'elemwise_' + surfix
|
||||
return _impl
|
||||
|
||||
def _broadcast_constraint():
|
||||
def _broadcast_check(attrs):
|
||||
if attrs.get('axis', None):
|
||||
return False
|
||||
return True
|
||||
return _broadcast_check, "Specifying broadcast axis not allowed."
|
||||
|
||||
def _dimension_picker(prefix, surfix=''):
|
||||
def _impl(attr):
|
||||
kernel = attr['kernel_shape']
|
||||
if len(kernel) == 2:
|
||||
return prefix + '2d' + surfix
|
||||
else:
|
||||
raise NotImplementedError("Only 2d kernel supported.")
|
||||
return _impl
|
||||
|
||||
def _dimension_constraint():
|
||||
def _dim_check(attrs):
|
||||
if len(attrs['kernel_shape']) == 2:
|
||||
return True
|
||||
return False
|
||||
return _dim_check, "Only 2d kernel supported."
|
||||
|
||||
def _infer_channels(inputs, params, transpose=False):
|
||||
"""A hack for getting 'channles' or 'units' since onnx don't provide
|
||||
these attributes. We check the shape of weights provided to get the number.
|
||||
class OnnxOpConverter(object):
|
||||
""" A helper class for holding onnx op converters.
|
||||
"""
|
||||
g = _graph.create(inputs)
|
||||
shape_dict = {k: v.shape for k, v in params.items()}
|
||||
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
|
||||
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
|
||||
return channels
|
||||
|
||||
def _elemwise(name):
|
||||
def _impl(inputs, attr, *args):
|
||||
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
|
||||
op_name = _math_name_picker(name)(attr)
|
||||
@classmethod
|
||||
def get_converter(cls, opset):
|
||||
""" Get converter matches given opset.
|
||||
|
||||
:param opset: opset from model.
|
||||
:return: converter, which should be `_impl_vx`. Number x is the biggest
|
||||
number smaller than or equal to opset belongs to all support versions.
|
||||
"""
|
||||
versions = [
|
||||
int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
|
||||
]
|
||||
versions = sorted(versions + [opset])
|
||||
version = versions[
|
||||
max([i for i, v in enumerate(versions) if v == opset]) - 1]
|
||||
if hasattr(cls, '_impl_v{}'.format(version)):
|
||||
return getattr(cls, '_impl_v{}'.format(version))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'opset version {} of {} not implemented'.format(
|
||||
version, cls.__name__))
|
||||
|
||||
|
||||
class Elemwise(OnnxOpConverter):
|
||||
""" A helper class for elemwise op converters.
|
||||
"""
|
||||
|
||||
name = ''
|
||||
|
||||
@classmethod
|
||||
def _math_name_picker(cls, suffix):
|
||||
|
||||
def _impl(attr):
|
||||
if attr.get('broadcast', 0):
|
||||
return 'broadcast_' + suffix
|
||||
return 'elemwise_' + suffix
|
||||
|
||||
return _impl
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
|
||||
len(inputs))
|
||||
op_name = cls._math_name_picker(cls.name)(attr)
|
||||
axis = int(attr.get('axis', 0))
|
||||
conv_ops = ["conv2d", "conv2d_transpose"]
|
||||
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
|
||||
# TODO(zhreshold): remove hard coded infershape
|
||||
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
|
||||
return get_nnvm_op(op_name)(*inputs)
|
||||
return _impl
|
||||
|
||||
def _pooling(name):
|
||||
return AttrCvt(
|
||||
op_name=_dimension_picker(name),
|
||||
transforms={
|
||||
'kernel_shape': 'pool_size',
|
||||
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
|
||||
# very weird attributes here in onnx, force check
|
||||
ignores=['dilations'],
|
||||
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
|
||||
extras={'ceil_mode': False},
|
||||
custom_check=_dimension_constraint())
|
||||
|
||||
def _conv():
|
||||
def _impl(inputs, attr, params):
|
||||
class Pool(OnnxOpConverter):
|
||||
""" A helper class for pool op converters.
|
||||
"""
|
||||
|
||||
name = ''
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
return AttrCvt(
|
||||
op_name=_dimension_picker(cls.name),
|
||||
transforms={
|
||||
'kernel_shape': 'pool_size',
|
||||
'pads': ('padding', (0, 0), _revert_caffe2_pad)
|
||||
},
|
||||
# very weird attributes here in onnx, force check
|
||||
ignores=['dilations'],
|
||||
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
|
||||
extras={'ceil_mode': False},
|
||||
custom_check=_dimension_constraint())(inputs, attr, params)
|
||||
|
||||
|
||||
class Absolute(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
|
||||
|
||||
|
||||
class Add(Elemwise):
|
||||
name = 'add'
|
||||
|
||||
|
||||
class AveragePool(Pool):
|
||||
name = 'avg_pool'
|
||||
|
||||
|
||||
class BatchNorm(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
# TODO(zhreshold): 'spatial' is not properly handled here.
|
||||
return AttrCvt(
|
||||
op_name='batch_norm',
|
||||
disables=['momentum'],
|
||||
ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr,
|
||||
params)
|
||||
|
||||
|
||||
class Conv(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
# get number of channels
|
||||
channels = _infer_channels(inputs[1], params)
|
||||
attr['channels'] = channels
|
||||
|
@ -94,13 +126,16 @@ def _conv():
|
|||
'kernel_shape': 'kernel_size',
|
||||
'dilations': ('dilation', (0, 0)),
|
||||
'pads': ('padding', (0, 0), _revert_caffe2_pad),
|
||||
'group': ('groups', 1)},
|
||||
'group': ('groups', 1)
|
||||
},
|
||||
extras={'use_bias': len(inputs) == 3},
|
||||
custom_check=_dimension_constraint())(inputs, attr)
|
||||
return _impl
|
||||
custom_check=_dimension_constraint())(inputs, attr, params)
|
||||
|
||||
def _conv_transpose():
|
||||
def _impl(inputs, attr, params):
|
||||
|
||||
class ConvTranspose(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
# get number of channels
|
||||
channels = _infer_channels(inputs[1], params, True)
|
||||
attr['channels'] = channels
|
||||
|
@ -111,31 +146,34 @@ def _conv_transpose():
|
|||
transforms={
|
||||
'kernel_shape': 'kernel_size',
|
||||
'dilations': ('dilation', (0, 0)),
|
||||
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
|
||||
'pads': ('padding', (0, 0), _revert_caffe2_pad)
|
||||
},
|
||||
disables=['output_shape'],
|
||||
extras={'use_bias': len(inputs) == 3},
|
||||
custom_check=_dimension_constraint())(inputs, attr)
|
||||
return _impl
|
||||
|
||||
def _fully_connected():
|
||||
def _impl(inputs, attr, params):
|
||||
# get number of channels
|
||||
channels = _infer_channels(inputs[1], params)
|
||||
attr['units'] = channels
|
||||
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
|
||||
return _impl
|
||||
|
||||
def _batch_norm():
|
||||
# TODO(zhreshold): 'spatial' is not properly handled here.
|
||||
return AttrCvt(
|
||||
op_name='batch_norm',
|
||||
disables=['momentum'],
|
||||
ignores=['spatial', 'is_test', 'consumed_inputs'])
|
||||
custom_check=_dimension_constraint())(inputs, attr, params)
|
||||
|
||||
|
||||
def _gemm():
|
||||
def _impl(inputs, attr, params):
|
||||
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs))
|
||||
class Div(Elemwise):
|
||||
name = 'div'
|
||||
|
||||
|
||||
class Elu(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(
|
||||
inputs[0])
|
||||
|
||||
|
||||
class Gemm(OnnxOpConverter):
|
||||
""" Operator converter for Gemm.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
|
||||
len(inputs))
|
||||
# Y = alpha * A * B + beta * C
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
beta = float(attr.get('beta', 1.0))
|
||||
|
@ -147,81 +185,22 @@ def _gemm():
|
|||
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
|
||||
if not transB:
|
||||
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
|
||||
return _sym.dense(alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
|
||||
return _impl
|
||||
return _sym.dense(
|
||||
alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
|
||||
|
||||
def _thresholded_relu():
|
||||
def _impl(inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 0.0))
|
||||
return _sym.relu(inputs[0] - alpha)
|
||||
return _impl
|
||||
|
||||
def _scaled_tanh():
|
||||
def _impl(inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
beta = float(attr.get('beta', 1.0))
|
||||
return _sym.tanh(beta * inputs[0]) * alpha
|
||||
return _impl
|
||||
class MaxPool(Pool):
|
||||
name = 'max_pool'
|
||||
|
||||
def parametric_soft_plus():
|
||||
def _impl(inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
beta = float(attr.get('beta', 1.0))
|
||||
return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
|
||||
return _impl
|
||||
|
||||
def _scale():
|
||||
def _impl(inputs, attr, params):
|
||||
scale = float(attr.get('scale', 1.0))
|
||||
return inputs[0] * scale
|
||||
return _impl
|
||||
class Mul(Elemwise):
|
||||
name = 'mul'
|
||||
|
||||
def _absolute():
|
||||
"""This is a workaround."""
|
||||
def _impl(inputs, attr, params):
|
||||
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
|
||||
return _impl
|
||||
|
||||
def _reciprocal():
|
||||
def _impl(inputs, attr, params):
|
||||
return 1.0 / inputs[0]
|
||||
return _impl
|
||||
class Pad(OnnxOpConverter):
|
||||
|
||||
def _selu():
|
||||
def _impl(inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.6732))
|
||||
gamma = float(attr.get('gamma', 1.0507))
|
||||
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0]))
|
||||
+ _sym.relu(inputs[0]))
|
||||
return _impl
|
||||
|
||||
def _elu():
|
||||
def _impl(inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])
|
||||
return _impl
|
||||
|
||||
def _prelu():
|
||||
def _impl(inputs, attr, params):
|
||||
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
|
||||
channels = _infer_channels(inputs[1], params, False)
|
||||
if channels == 1:
|
||||
return inputs[0] * inputs[1]
|
||||
return _sym.broadcast_mul(inputs[0], inputs[1])
|
||||
return _impl
|
||||
|
||||
def _softsign():
|
||||
def _impl(inputs, attr, params):
|
||||
return inputs[0] / (1 + _absolute()(inputs, attr, params))
|
||||
return _impl
|
||||
|
||||
def _softplus():
|
||||
def _impl(inputs, attr, params):
|
||||
return _sym.log(_sym.exp(inputs[0]) + 1)
|
||||
return _impl
|
||||
|
||||
def _pad():
|
||||
def _impl(inputs, attr, params):
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
# get number of channels
|
||||
channels = _infer_channels(inputs[1], params, True)
|
||||
attr['channels'] = channels
|
||||
|
@ -231,133 +210,300 @@ def _pad():
|
|||
op_name='pad',
|
||||
transforms={
|
||||
'value': 'pad_value',
|
||||
'pads': 'pad_width'},
|
||||
custom_check=lambda attrs: attrs.get('mode') == 'constant')(inputs, attr)
|
||||
return _impl
|
||||
'pads': 'pad_width'
|
||||
},
|
||||
custom_check=lambda attrs: attrs.get('mode') == 'constant')(
|
||||
inputs, attr, params)
|
||||
|
||||
def _sum():
|
||||
def _impl(inputs, attr, params):
|
||||
|
||||
class ParametricSoftPlus(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
beta = float(attr.get('beta', 1.0))
|
||||
return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
|
||||
|
||||
|
||||
class Prelu(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
|
||||
len(inputs))
|
||||
channels = _infer_channels(inputs[1], params, False)
|
||||
if channels == 1:
|
||||
return inputs[0] * inputs[1]
|
||||
return _sym.broadcast_mul(inputs[0], inputs[1])
|
||||
|
||||
|
||||
class Reciprocal(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
return 1.0 / inputs[0]
|
||||
|
||||
|
||||
class Reshape(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
return _sym.reshape(inputs[0], shape=attr['shape'])
|
||||
|
||||
@classmethod
|
||||
def _impl_v5(cls, inputs, attr, params):
|
||||
return _sym.reshape(
|
||||
inputs[0],
|
||||
shape=tuple(params[inputs[1].list_output_names()[0]].asnumpy()))
|
||||
|
||||
|
||||
class Scale(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
scale = float(attr.get('scale', 1.0))
|
||||
return inputs[0] * scale
|
||||
|
||||
|
||||
class Selu(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.6732))
|
||||
gamma = float(attr.get('gamma', 1.0507))
|
||||
return gamma * (
|
||||
-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
|
||||
|
||||
|
||||
class ScaledTanh(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 1.0))
|
||||
beta = float(attr.get('beta', 1.0))
|
||||
return _sym.tanh(beta * inputs[0]) * alpha
|
||||
|
||||
|
||||
class SoftPlus(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
return _sym.log(_sym.exp(inputs[0]) + 1)
|
||||
|
||||
|
||||
class Softsign(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
return inputs[0] / (1 + Absolute.get_converter(1)(inputs, attr, params))
|
||||
|
||||
|
||||
class Sub(Elemwise):
|
||||
name = 'sub'
|
||||
|
||||
|
||||
class Sum(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
# Onnx Sum Operator
|
||||
for in_index in range(len(inputs)-1):
|
||||
inputs[in_index+1] = _sym.broadcast_add(inputs[in_index], inputs[in_index+1])
|
||||
for in_index in range(len(inputs) - 1):
|
||||
inputs[in_index + 1] = _sym.broadcast_add(inputs[in_index],
|
||||
inputs[in_index + 1])
|
||||
|
||||
return inputs[len(inputs) - 1]
|
||||
|
||||
|
||||
class ThresholdedRelu(OnnxOpConverter):
|
||||
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
alpha = float(attr.get('alpha', 0.0))
|
||||
return _sym.relu(inputs[0] - alpha)
|
||||
|
||||
|
||||
def _revert_caffe2_pad(attr):
|
||||
"""Caffe2 require two times the normal padding."""
|
||||
if len(attr) == 4:
|
||||
attr = attr[:2]
|
||||
elif len(attr) == 2:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
|
||||
return attr
|
||||
|
||||
|
||||
def _broadcast_constraint():
|
||||
|
||||
def _broadcast_check(attrs):
|
||||
if attrs.get('axis', None):
|
||||
return False
|
||||
return True
|
||||
|
||||
return _broadcast_check, "Specifying broadcast axis not allowed."
|
||||
|
||||
|
||||
def _dimension_picker(prefix, surfix=''):
|
||||
|
||||
def _impl(attr):
|
||||
kernel = attr['kernel_shape']
|
||||
if len(kernel) == 2:
|
||||
return prefix + '2d' + surfix
|
||||
else:
|
||||
raise NotImplementedError("Only 2d kernel supported.")
|
||||
|
||||
return inputs[len(inputs)-1]
|
||||
return _impl
|
||||
|
||||
|
||||
def _dimension_constraint():
|
||||
|
||||
def _dim_check(attrs):
|
||||
if len(attrs['kernel_shape']) == 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
return _dim_check, "Only 2d kernel supported."
|
||||
|
||||
|
||||
def _infer_channels(inputs, params, transpose=False):
|
||||
"""A hack for getting 'channles' or 'units' since onnx don't provide
|
||||
these attributes. We check the shape of weights provided to get the number.
|
||||
"""
|
||||
g = _graph.create(inputs)
|
||||
shape_dict = {k: v.shape for k, v in params.items()}
|
||||
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
|
||||
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
|
||||
return channels
|
||||
|
||||
|
||||
def _fully_connected(opset):
|
||||
|
||||
def _impl(inputs, attr, params):
|
||||
# get number of channels
|
||||
channels = _infer_channels(inputs[1], params)
|
||||
attr['units'] = channels
|
||||
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
|
||||
|
||||
return _impl
|
||||
|
||||
|
||||
# compatible operators that do NOT require any conversion.
|
||||
_identity_list = []
|
||||
|
||||
|
||||
# _convert_map defines maps of name to converter functor(callable)
|
||||
# for 1 to 1 mapping, use Renamer if nothing but name is different
|
||||
# use AttrCvt if attributes need to be converted
|
||||
# for 1 to N mapping(composed), use custom callable functions
|
||||
# for N to 1 mapping, currently not supported(?)
|
||||
_convert_map = {
|
||||
# defs/experimental
|
||||
'Identity' : Renamer('copy'),
|
||||
# 'Affine'
|
||||
'ThresholdedRelu': _thresholded_relu(),
|
||||
'ScaledTanh' : _scaled_tanh(),
|
||||
'ParametricSoftplus': parametric_soft_plus(),
|
||||
# 'ConstantFill'
|
||||
# 'GivenTensorFill'
|
||||
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
|
||||
'Scale' : _scale(),
|
||||
# 'GRUUnit'
|
||||
# 'ATen'
|
||||
# 'ImageScaler'
|
||||
# 'MeanVarianceNormalization'
|
||||
# 'Crop'
|
||||
# 'Embedding'
|
||||
# 'Upsample'
|
||||
'SpatialBN' : _batch_norm(),
|
||||
def _get_convert_map(opset):
|
||||
return {
|
||||
# defs/experimental
|
||||
'Identity': Renamer('copy'),
|
||||
# 'Affine'
|
||||
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
|
||||
'ScaledTanh': ScaledTanh.get_converter(opset),
|
||||
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
|
||||
# 'ConstantFill'
|
||||
# 'GivenTensorFill'
|
||||
'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
|
||||
'Scale': Scale.get_converter(opset),
|
||||
# 'GRUUnit'
|
||||
# 'ATen'
|
||||
# 'ImageScaler'
|
||||
# 'MeanVarianceNormalization'
|
||||
# 'Crop'
|
||||
# 'Embedding'
|
||||
# 'Upsample'
|
||||
'SpatialBN': BatchNorm.get_converter(opset),
|
||||
|
||||
# defs/generator
|
||||
# 'Constant'
|
||||
# 'RandomUniform'
|
||||
# 'RandomNormal'
|
||||
# 'RandomUniformLike'
|
||||
# 'RandomNormalLike'
|
||||
# defs/generator
|
||||
# 'Constant'
|
||||
# 'RandomUniform'
|
||||
# 'RandomNormal'
|
||||
# 'RandomUniformLike'
|
||||
# 'RandomNormalLike'
|
||||
|
||||
# defs/logical
|
||||
# defs/logical
|
||||
|
||||
# defs/math
|
||||
'Add' : _elemwise('add'),
|
||||
'Sub' : _elemwise('sub'),
|
||||
'Mul' : _elemwise('mul'),
|
||||
'Div' : _elemwise('div'),
|
||||
'Neg' : Renamer('negative'),
|
||||
'Abs' : _absolute(),
|
||||
'Reciprocal' : _reciprocal(),
|
||||
# 'Floor'
|
||||
# 'Ceil'
|
||||
'Sqrt' : Renamer('sqrt'),
|
||||
'Relu' : Renamer('relu'),
|
||||
'LeakyRelu' : Renamer('leaky_relu'),
|
||||
'Selu' : _selu(),
|
||||
'Elu' : _elu(),
|
||||
'Exp' : Renamer('exp'),
|
||||
'Log' : Renamer('log'),
|
||||
'Tanh' : Renamer('tanh'),
|
||||
# 'Pow'
|
||||
'PRelu' : _prelu(),
|
||||
'Sigmoid' : Renamer('sigmoid'),
|
||||
# 'HardSigmoid'
|
||||
# 'Max' : this is the elemwise maximum
|
||||
# 'Min' : this is the elemwise minimum
|
||||
'Sum' : _sum(),
|
||||
# 'Mean'
|
||||
# 'Clip'
|
||||
# softmax default axis is different in onnx
|
||||
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
|
||||
'LogSoftmax' : AttrCvt('log_softmax', {'axis': ('axis', 1)}),
|
||||
# 'Hardmax'
|
||||
'Softsign' : _softsign(),
|
||||
'SoftPlus' : _softplus(),
|
||||
'Gemm' : _gemm(),
|
||||
# 'MatMul' batch stacked dot operation
|
||||
# defs/math
|
||||
'Add': Add.get_converter(opset),
|
||||
'Sub': Sub.get_converter(opset),
|
||||
'Mul': Mul.get_converter(opset),
|
||||
'Div': Div.get_converter(opset),
|
||||
'Neg': Renamer('negative'),
|
||||
'Abs': Absolute.get_converter(opset),
|
||||
'Reciprocal': Reciprocal.get_converter(opset),
|
||||
# 'Floor'
|
||||
# 'Ceil'
|
||||
'Sqrt': Renamer('sqrt'),
|
||||
'Relu': Renamer('relu'),
|
||||
'LeakyRelu': Renamer('leaky_relu'),
|
||||
'Selu': Selu.get_converter(opset),
|
||||
'Elu': Elu.get_converter(opset),
|
||||
'Exp': Renamer('exp'),
|
||||
'Log': Renamer('log'),
|
||||
'Tanh': Renamer('tanh'),
|
||||
# 'Pow'
|
||||
'PRelu': Prelu.get_converter(opset),
|
||||
'Sigmoid': Renamer('sigmoid'),
|
||||
# 'HardSigmoid'
|
||||
# 'Max' : this is the elemwise maximum
|
||||
# 'Min' : this is the elemwise minimum
|
||||
'Sum': Sum.get_converter(opset),
|
||||
# 'Mean'
|
||||
# 'Clip'
|
||||
# softmax default axis is different in onnx
|
||||
'Softmax': AttrCvt('softmax', {'axis': ('axis', 1)}),
|
||||
'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
|
||||
# 'Hardmax'
|
||||
'Softsign': Softsign.get_converter(opset),
|
||||
'SoftPlus': SoftPlus.get_converter(opset),
|
||||
'Gemm': Gemm.get_converter(opset),
|
||||
# 'MatMul' batch stacked dot operation
|
||||
|
||||
# defs/nn
|
||||
'AveragePool' : _pooling('avg_pool'),
|
||||
'MaxPool' : _pooling('max_pool'),
|
||||
'Conv' : _conv(),
|
||||
'ConvTranspose' : _conv_transpose(),
|
||||
'GlobalAveragePool': Renamer('global_avg_pool2d'),
|
||||
'GlobalMaxPool' : Renamer('global_max_pool2d'),
|
||||
'BatchNormalization': _batch_norm(),
|
||||
# 'InstanceNormalization'
|
||||
# 'LpNormalization'
|
||||
'Dropout' : AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
|
||||
'Flatten' : Renamer('flatten'),
|
||||
# 'LRN'
|
||||
# defs/nn
|
||||
'AveragePool': AveragePool.get_converter(opset),
|
||||
'MaxPool': MaxPool.get_converter(opset),
|
||||
'Conv': Conv.get_converter(opset),
|
||||
'ConvTranspose': ConvTranspose.get_converter(opset),
|
||||
'GlobalAveragePool': Renamer('global_avg_pool2d'),
|
||||
'GlobalMaxPool': Renamer('global_max_pool2d'),
|
||||
'BatchNormalization': BatchNorm.get_converter(opset),
|
||||
# 'InstanceNormalization'
|
||||
# 'LpNormalization'
|
||||
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
|
||||
'Flatten': Renamer('flatten'),
|
||||
# 'LRN'
|
||||
|
||||
# defs/reduction
|
||||
'ReduceMax' : AttrCvt('max', {'axes', 'axis'}),
|
||||
'ReduceMin' : AttrCvt('min', {'axes', 'axis'}),
|
||||
'ReduceSum' : AttrCvt('sum', {'axes', 'axis'}),
|
||||
# 'ReduceMean'
|
||||
# 'ReduceProd'
|
||||
# 'ReduceLogSumExp'
|
||||
# 'ArgMax'
|
||||
# 'ArgMin'
|
||||
# defs/reduction
|
||||
'ReduceMax': AttrCvt('max', {'axes', 'axis'}),
|
||||
'ReduceMin': AttrCvt('min', {'axes', 'axis'}),
|
||||
'ReduceSum': AttrCvt('sum', {'axes', 'axis'}),
|
||||
# 'ReduceMean'
|
||||
# 'ReduceProd'
|
||||
# 'ReduceLogSumExp'
|
||||
# 'ArgMax'
|
||||
# 'ArgMin'
|
||||
|
||||
# defs/tensor
|
||||
'Cast' : AttrCvt('cast', {'to': 'dtype'}),
|
||||
'Reshape' : Renamer('reshape'),
|
||||
'Concat' : Renamer('concatenate'),
|
||||
'Split' : AttrCvt('split', {'split': 'indices_or_sections'}),
|
||||
# 'Slice'
|
||||
'Transpose' : AttrCvt('transpose', {'perm': 'axes'}),
|
||||
# 'Gather'
|
||||
# 'Squeeze'
|
||||
'Pad' : _pad(),
|
||||
}
|
||||
# defs/tensor
|
||||
'Cast': AttrCvt('cast', {'to': 'dtype'}),
|
||||
'Reshape': Reshape.get_converter(opset),
|
||||
'Concat': Renamer('concatenate'),
|
||||
'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
|
||||
# 'Slice'
|
||||
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
|
||||
# 'Gather'
|
||||
# 'Squeeze'
|
||||
'Pad': Pad.get_converter(opset),
|
||||
}
|
||||
|
||||
|
||||
class GraphProto(object):
|
||||
"""A helper class for handling nnvm graph copying from pb2.GraphProto.
|
||||
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._nodes = {}
|
||||
self._params = {}
|
||||
|
@ -365,7 +511,7 @@ class GraphProto(object):
|
|||
self._num_input = 0
|
||||
self._num_param = 0
|
||||
|
||||
def from_onnx(self, graph):
|
||||
def from_onnx(self, graph, opset):
|
||||
"""Construct nnvm nodes from onnx graph.
|
||||
The inputs from onnx graph is vague, only providing "1", "2"...
|
||||
For convenience, we rename the `real` input names to "input_0",
|
||||
|
@ -375,6 +521,7 @@ class GraphProto(object):
|
|||
----------
|
||||
graph : onnx protobuf object
|
||||
The loaded onnx graph
|
||||
opset : opset version
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -410,7 +557,7 @@ class GraphProto(object):
|
|||
op_name = node.op_type
|
||||
attr = self._parse_attr(node.attribute)
|
||||
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
|
||||
op = self._convert_operator(op_name, inputs, attr)
|
||||
op = self._convert_operator(op_name, inputs, attr, opset)
|
||||
node_output = self._fix_outputs(op_name, node.output)
|
||||
assert len(node_output) == len(op.list_output_names()), (
|
||||
"Number of output mismatch {} vs {} in {}.".format(
|
||||
|
@ -438,7 +585,8 @@ class GraphProto(object):
|
|||
try:
|
||||
from onnx.numpy_helper import to_array
|
||||
except ImportError as e:
|
||||
raise ImportError("Unable to import onnx which is required {}".format(e))
|
||||
raise ImportError(
|
||||
"Unable to import onnx which is required {}".format(e))
|
||||
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
|
||||
return tvm.nd.array(np_array)
|
||||
|
||||
|
@ -455,15 +603,23 @@ class GraphProto(object):
|
|||
attrs[a.name] = tuple(getattr(a, f))
|
||||
for f in ['t', 'g']:
|
||||
if a.HasField(f):
|
||||
raise NotImplementedError("Filed {} is not supported in nnvm.".format(f))
|
||||
raise NotImplementedError(
|
||||
"Filed {} is not supported in nnvm.".format(f))
|
||||
for f in ['tensors', 'graphs']:
|
||||
if list(getattr(a, f)):
|
||||
raise NotImplementedError("Filed {} is not supported in nnvm.".format(f))
|
||||
raise NotImplementedError(
|
||||
"Filed {} is not supported in nnvm.".format(f))
|
||||
if a.name not in attrs:
|
||||
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
|
||||
return attrs
|
||||
|
||||
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None):
|
||||
def _convert_operator(self,
|
||||
op_name,
|
||||
inputs,
|
||||
attrs,
|
||||
opset,
|
||||
identity_list=None,
|
||||
convert_map=None):
|
||||
"""Convert from onnx operator to nnvm operator.
|
||||
The converter must specify conversions explicity for incompatible name, and
|
||||
apply handlers to operator attributes.
|
||||
|
@ -476,6 +632,8 @@ class GraphProto(object):
|
|||
List of input symbols.
|
||||
attrs : dict
|
||||
Dict of operator attributes
|
||||
opset : int
|
||||
Opset version
|
||||
identity_list : list
|
||||
List of operators that don't require conversion
|
||||
convert_map : dict
|
||||
|
@ -489,13 +647,14 @@ class GraphProto(object):
|
|||
Converted nnvm Symbol
|
||||
"""
|
||||
identity_list = identity_list if identity_list else _identity_list
|
||||
convert_map = convert_map if convert_map else _convert_map
|
||||
convert_map = convert_map if convert_map else _get_convert_map(opset)
|
||||
if op_name in identity_list:
|
||||
sym = get_nnvm_op(op_name)(*inputs, **attrs)
|
||||
elif op_name in convert_map:
|
||||
sym = convert_map[op_name](inputs, attrs, self._params)
|
||||
else:
|
||||
raise NotImplementedError("Operator {} not implemented.".format(op_name))
|
||||
raise NotImplementedError(
|
||||
"Operator {} not implemented.".format(op_name))
|
||||
return sym
|
||||
|
||||
def _fix_outputs(self, op_name, outputs):
|
||||
|
@ -510,7 +669,7 @@ class GraphProto(object):
|
|||
return outputs
|
||||
|
||||
|
||||
def from_onnx(graph):
|
||||
def from_onnx(model):
|
||||
"""Load onnx graph which is a python protobuf object into nnvm graph.
|
||||
The companion parameters will be handled automatically.
|
||||
The inputs from onnx graph is vague, only providing "1", "2"...
|
||||
|
@ -519,8 +678,8 @@ def from_onnx(graph):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
graph : protobuf object
|
||||
ONNX GraphProto, or ONNX ModelProto after ONNX v0.2
|
||||
model : protobuf object
|
||||
ONNX ModelProto after ONNX v1.1.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -531,8 +690,7 @@ def from_onnx(graph):
|
|||
Dict of converted parameters stored in tvm.ndarray format
|
||||
"""
|
||||
g = GraphProto()
|
||||
if hasattr(graph, 'graph'):
|
||||
# it's a ModelProto wrapper
|
||||
graph = graph.graph
|
||||
sym, params = g.from_onnx(graph)
|
||||
graph = model.graph
|
||||
opset = model.opset_import[0].version if model.opset_import else 1
|
||||
sym, params = g.from_onnx(graph, opset)
|
||||
return sym, params
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
pip2 install onnx>=0.2.0
|
||||
pip3 install onnx>=0.2.0
|
||||
pip2 install onnx>=1.1.0
|
||||
pip3 install onnx>=1.1.0
|
||||
|
||||
pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
|
||||
pip2 install torchvision
|
||||
|
|
|
@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
|
|||
c2_out = prepared_backend.run(W)[0]
|
||||
return c2_out
|
||||
|
||||
def get_tvm_output(graph, x, target, ctx, dtype='float32'):
|
||||
new_sym, params = nnvm.frontend.from_onnx(graph)
|
||||
def get_tvm_output(model, x, target, ctx, dtype='float32'):
|
||||
new_sym, params = nnvm.frontend.from_onnx(model)
|
||||
shape_dict = {'input_0': x.shape}
|
||||
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
|
|
|
@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr
|
|||
from model_zoo import super_resolution, super_resolution_sym
|
||||
|
||||
def compare_graph(onnx_file, nnvm_sym, ishape):
|
||||
onnx_graph = onnx.load(onnx_file)
|
||||
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
onnx_sym, params = nnvm.frontend.from_onnx(onnx_model)
|
||||
g1 = nnvm.graph.create(onnx_sym)
|
||||
g2 = nnvm.graph.create(nnvm_sym)
|
||||
ishapes = {'input_0': ishape}
|
||||
|
|
|
@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/',
|
|||
'super_resolution_0.2.onnx'])
|
||||
download(model_url, 'super_resolution.onnx', True)
|
||||
# now you have super_resolution.onnx on disk
|
||||
onnx_graph = onnx.load('super_resolution.onnx')
|
||||
onnx_model = onnx.load('super_resolution.onnx')
|
||||
# we can load the graph as NNVM compatible model
|
||||
sym, params = nnvm.frontend.from_onnx(onnx_graph)
|
||||
sym, params = nnvm.frontend.from_onnx(onnx_model)
|
||||
|
||||
######################################################################
|
||||
# Load a test image
|
||||
|
|
Загрузка…
Ссылка в новой задаче