[Relay][Frontend] Caffe2 Support (#2507)
* [Relay][Frontend] Add Caffe2 Support * [Relay][Frontend] Add Caffe2 Support (fix unsed import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix model install and reflect code reviews) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 frontend import) * [Relay][Frontend] Add Caffe2 Support (rename function name in test_forward) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Doc] Caffe2 frontend tutorial * [Doc] Caffe2 frontend tutorial * [Doc] Caffe2 frontend tutorial * [Relay][Frontend] Add Caffe2 Support (remove unsed file)
This commit is contained in:
Родитель
e012f819b1
Коммит
b3b3d28a18
|
@ -67,6 +67,9 @@ RUN bash /install/ubuntu_install_onnx.sh
|
|||
COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh
|
||||
RUN bash /install/ubuntu_install_tflite.sh
|
||||
|
||||
COPY install/ubuntu_install_caffe2.sh /install/ubuntu_install_caffe2.sh
|
||||
RUN bash /install/ubuntu_install_caffe2.sh
|
||||
|
||||
RUN pip3 install Pillow
|
||||
|
||||
COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
python3 -m caffe2.python.models.download -i -f squeezenet
|
||||
python3 -m caffe2.python.models.download -i -f resnet50
|
||||
python3 -m caffe2.python.models.download -i -f vgg19
|
|
@ -12,3 +12,4 @@ from .keras import from_keras
|
|||
from .onnx import from_onnx
|
||||
from .tflite import from_tflite
|
||||
from .coreml import from_coreml
|
||||
from .caffe2 import from_caffe2
|
||||
|
|
|
@ -0,0 +1,565 @@
|
|||
# pylint: disable=import-self, invalid-name, line-too-long, unused-argument
|
||||
"""Caffe2 frontend"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from .. import ir_pass
|
||||
from .. import expr as _expr
|
||||
from .. import op as _op
|
||||
from ... import nd as _nd
|
||||
from .common import AttrCvt, Renamer
|
||||
from .common import get_relay_op, new_var, infer_channels
|
||||
|
||||
__all__ = ['from_caffe2']
|
||||
|
||||
def dimension_picker(prefix, surfix=''):
|
||||
def _impl(attr):
|
||||
kernel = attr['kernel_shape']
|
||||
if len(kernel) == 2:
|
||||
return prefix + '2d' + surfix
|
||||
else:
|
||||
raise NotImplementedError("Only 2d kernel supported.")
|
||||
|
||||
return _impl
|
||||
|
||||
|
||||
def revert_caffe2_pad(pads):
|
||||
"""Caffe2 requires two times the normal padding."""
|
||||
if len(pads) == 4:
|
||||
pads = pads[:2]
|
||||
elif len(pads) == 2:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
|
||||
return pads
|
||||
|
||||
|
||||
def dimension_constraint():
|
||||
def _dim_check(args):
|
||||
if len(args['kernel_shape']) == 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
return _dim_check, "Only 2d kernel supported."
|
||||
|
||||
|
||||
def _clean_up_pool_args(args):
|
||||
""" A helper function to clean up common arguments in conv and pooling ops.
|
||||
"""
|
||||
assert isinstance(args, dict)
|
||||
|
||||
if 'stride_h' in args and 'stride_w' in args:
|
||||
assert 'stride' not in args and 'strides' not in args
|
||||
args['strides'] = [args['stride_h'], args['stride_w']]
|
||||
args.pop('stride_h')
|
||||
args.pop('stride_w')
|
||||
elif 'stride' in args:
|
||||
args['strides'] = [args['stride'], args['stride']]
|
||||
args.pop('stride')
|
||||
|
||||
# rename 'kernel', 'kernels', to 'kernel_shape'
|
||||
if 'kernel_h' in args and 'kernel_w' in args:
|
||||
assert 'kernel' not in args and 'kernels' not in args
|
||||
args['kernel_shape'] = [args['kernel_h'], args['kernel_w']]
|
||||
args.pop('kernel_h')
|
||||
args.pop('kernel_w')
|
||||
elif 'kernel' in args:
|
||||
args['kernel_shape'] = [args['kernel'], args['kernel']]
|
||||
args.pop('kernel')
|
||||
elif 'kernels' in args:
|
||||
args['kernel_shape'] = args['kernels']
|
||||
args.pop('kernels')
|
||||
|
||||
if 'pad_t' in args and 'pad_l' in args and 'pad_b' in args and 'pad_r' in args:
|
||||
assert 'pad' not in args and 'pads' not in args
|
||||
args['pads'] = [
|
||||
args['pad_t'], args['pad_l'], args['pad_b'], args['pad_r']
|
||||
]
|
||||
for pad in ['pad_t', 'pad_l', 'pad_b', 'pad_r']:
|
||||
args.pop(pad)
|
||||
elif 'pad' in args:
|
||||
args['pads'] = [args['pad'], args['pad']]
|
||||
args.pop('pad')
|
||||
|
||||
if 'dilation_h' in args and 'dilation_w' in args:
|
||||
assert 'dilation' not in args and 'dilations' not in args
|
||||
args['dilations'] = [args['dilation_h'], args['dilation_w']]
|
||||
args.pop('dilation_h')
|
||||
args.pop('dilation_w')
|
||||
elif 'dilation' in args:
|
||||
args['dilations'] = [args['dilation'], args['dilation']]
|
||||
args.pop('dilation')
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class Caffe2OpConverter(object):
|
||||
""" A helper class for holding Caffe2 op converters.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_converter(cls):
|
||||
""" Get converter.
|
||||
|
||||
:return: converter, which should be `_impl`.
|
||||
"""
|
||||
|
||||
if hasattr(cls, '_impl'):
|
||||
return getattr(cls, '_impl')
|
||||
else:
|
||||
raise NotImplementedError('{} not implemented'.format(
|
||||
cls.__name__))
|
||||
|
||||
|
||||
_caffe2_internal_args = [
|
||||
# nnpack args
|
||||
'algo',
|
||||
'convolution_transform_strategy',
|
||||
'float16_compute',
|
||||
'shared_buffer',
|
||||
|
||||
# training args
|
||||
'init_params',
|
||||
'cudnn_exhaustive_search',
|
||||
'exhaustive_search',
|
||||
|
||||
# training args
|
||||
'adj',
|
||||
'hwgq',
|
||||
|
||||
# args that we don't care
|
||||
'legacy_pad',
|
||||
]
|
||||
|
||||
|
||||
class Elemwise(Caffe2OpConverter):
|
||||
""" A helper class for elemwise op converters.
|
||||
"""
|
||||
name = ''
|
||||
@classmethod
|
||||
def _math_name_picker(cls, suffix):
|
||||
|
||||
def _impl(attr):
|
||||
if attr.get('broadcast', 0):
|
||||
return 'broadcast_' + suffix
|
||||
return 'elemwise_' + suffix
|
||||
|
||||
return _impl
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
|
||||
len(inputs))
|
||||
op_name = cls._math_name_picker(cls.name)(args)
|
||||
axis = int(args.get('axis', 0))
|
||||
conv_ops = ["conv2d", "conv2d_transpose"]
|
||||
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
|
||||
# TODO(zhreshold): remove hard coded infershape
|
||||
inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
|
||||
return get_relay_op(op_name)(*inputs)
|
||||
|
||||
|
||||
class Add(Elemwise):
|
||||
""" Operator converter for Add.
|
||||
"""
|
||||
name = 'add'
|
||||
|
||||
|
||||
class Pool(Caffe2OpConverter):
|
||||
""" A helper class for pool op converters.
|
||||
"""
|
||||
|
||||
name = ''
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
_clean_up_pool_args(args)
|
||||
if 'global_pooling' in args and args['global_pooling'] == 1:
|
||||
op_name = dimension_picker('global_' + cls.name)
|
||||
return get_relay_op(op_name(args))(*inputs)
|
||||
|
||||
return AttrCvt(
|
||||
op_name=dimension_picker(cls.name),
|
||||
transforms={
|
||||
'kernel_shape': 'pool_size',
|
||||
'pads': ('padding', (0, 0), revert_caffe2_pad),
|
||||
'strides': 'strides',
|
||||
},
|
||||
ignores=['dilations', 'order', 'legacy_pad', 'global_pooling'],
|
||||
extras={'ceil_mode': False},
|
||||
custom_check=dimension_constraint())(inputs, args, params)
|
||||
|
||||
|
||||
class AveragePool(Pool):
|
||||
name = 'avg_pool'
|
||||
|
||||
|
||||
class MaxPool(Pool):
|
||||
name = 'max_pool'
|
||||
|
||||
|
||||
class Conv(Caffe2OpConverter):
|
||||
""" Operator converter for Conv.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
# get number of channels
|
||||
channels = infer_channels(inputs[1])
|
||||
args['channels'] = channels
|
||||
_clean_up_pool_args(args)
|
||||
out = AttrCvt(
|
||||
op_name=dimension_picker('conv'),
|
||||
transforms={
|
||||
'group': ('groups', 1),
|
||||
'kernel_shape': 'kernel_size',
|
||||
'pads': ('padding', (0, 0), revert_caffe2_pad),
|
||||
'strides': 'strides',
|
||||
'dilations': ('dilation', (1, 1)),
|
||||
'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
|
||||
},
|
||||
excludes=[],
|
||||
ignores=[],
|
||||
custom_check=dimension_constraint())(inputs[:2], args, params)
|
||||
use_bias = len(inputs) == 3
|
||||
if use_bias:
|
||||
out = _op.nn.bias_add(out, inputs[2])
|
||||
return out
|
||||
|
||||
|
||||
class Concat(Caffe2OpConverter):
|
||||
""" Operator converter for Concat.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
def _get_axis_from_order_str(order):
|
||||
order = order if isinstance(order, str) else order.decode('UTF-8')
|
||||
if order == 'NCHW':
|
||||
return 1
|
||||
elif order == 'NHWC':
|
||||
return 3
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unsupported storage order: {} in caffe2".format(order))
|
||||
|
||||
return AttrCvt(
|
||||
op_name='concatenate',
|
||||
transforms={
|
||||
'order': ('axis', (1), _get_axis_from_order_str),
|
||||
},
|
||||
excludes=['add_axis'])((inputs,), args, params)
|
||||
|
||||
|
||||
class NormalizePlanarYUV(Caffe2OpConverter):
|
||||
""" Operator converter for NormalizePlanarYUV.
|
||||
caffe2 definition: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/norm_planar_yuv_op.cc
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
assert len(inputs) == 3
|
||||
mean = _op.expand_dims(inputs[1], axis=2, num_newaxis=2)
|
||||
std = _op.expand_dims(inputs[2], axis=2, num_newaxis=2)
|
||||
|
||||
return _op.broadcast_divide(_op.subtract(inputs[0], mean), std)
|
||||
|
||||
|
||||
class ResizeNearest(Caffe2OpConverter):
|
||||
""" Operator converter for Upsample (nearest mode).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
width_scale = args['width_scale'] if 'width_scale' in args else 1
|
||||
height_scale = args['height_scale'] if 'height_scale' in args else 1
|
||||
assert width_scale == height_scale
|
||||
|
||||
return _op.nn.upsampling(
|
||||
inputs[0], scale=int(width_scale), method="NEAREST_NEIGHBOR")
|
||||
|
||||
|
||||
class Sum(Caffe2OpConverter):
|
||||
""" Operator converter for Sum.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
# Sum Operator
|
||||
for in_index in range(len(inputs) - 1):
|
||||
inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1])
|
||||
|
||||
return inputs[len(inputs) - 1]
|
||||
|
||||
|
||||
class Softmax(Caffe2OpConverter):
|
||||
""" Operator converter for Softmax.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
# set default value when axis is not set in the model
|
||||
if 'axis' not in args:
|
||||
args['axis'] = 1
|
||||
return AttrCvt('softmax', transforms={'axis': ('axis', args['axis'])})(inputs, args, params)
|
||||
|
||||
|
||||
class FC(Caffe2OpConverter):
|
||||
""" Operator converter for FC.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
inputs[0] = _op.nn.batch_flatten(inputs[0])
|
||||
units = infer_channels(inputs[1])
|
||||
res = _op.nn.dense(inputs[0], inputs[1], units=units)
|
||||
use_bias = len(inputs) == 3
|
||||
if use_bias:
|
||||
res = _op.nn.bias_add(res, inputs[2])
|
||||
return res
|
||||
|
||||
|
||||
class SpatialBN(Caffe2OpConverter):
|
||||
""" Operator converter for SpatialBN.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _impl(cls, inputs, args, params):
|
||||
return AttrCvt(
|
||||
op_name='batch_norm',
|
||||
disables=['momentum'],
|
||||
ignores=[
|
||||
'order', 'spatial', 'is_test', 'consumed_inputs', 'num_batches'
|
||||
])(inputs, args, params)
|
||||
|
||||
|
||||
# compatible operators that do NOT require any conversion.
|
||||
_identity_list = []
|
||||
|
||||
# _convert_map defines maps of name to converter functor(callable)
|
||||
# for 1 to 1 mapping, use Renamer if nothing but name is different
|
||||
# use AttrCvt if attributes need to be converted
|
||||
# for 1 to N mapping(composed), use custom callable functions
|
||||
# for N to 1 mapping, currently not supported(?)
|
||||
|
||||
# Minimal set of ops for squeezenet and resnet50
|
||||
def _get_convert_map():
|
||||
return {
|
||||
# caffe2 common operators
|
||||
'Add': Add.get_converter(),
|
||||
'Sum': Sum.get_converter(),
|
||||
'Softmax': Softmax.get_converter(),
|
||||
|
||||
# nn
|
||||
'AveragePool': AveragePool.get_converter(),
|
||||
'MaxPool': MaxPool.get_converter(),
|
||||
'Conv': Conv.get_converter(),
|
||||
'Concat': Concat.get_converter(),
|
||||
'FC': FC.get_converter(),
|
||||
'SpatialBN': SpatialBN.get_converter(),
|
||||
'ResizeNearest': ResizeNearest.get_converter(),
|
||||
'Relu': AttrCvt('relu', {}, ignores=['order']),
|
||||
'Sigmoid': Renamer('sigmoid'),
|
||||
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
|
||||
|
||||
# c2 image preprocessing ops
|
||||
'NormalizePlanarYUV': NormalizePlanarYUV.get_converter(),
|
||||
}
|
||||
|
||||
|
||||
class Caffe2NetDef(object):
|
||||
"""A helper class for handling Relay expression copying from pb2.GraphProto.
|
||||
Definition: https://github.com/pytorch/pytorch/blob/master/caffe2/proto/caffe2.proto
|
||||
"""
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self._nodes = {}
|
||||
self._params = {}
|
||||
self._visited_nodes = set()
|
||||
self._ops = {}
|
||||
self._shape = shape
|
||||
self._dtype = dtype
|
||||
|
||||
def from_caffe2(self, init_net, predict_net):
|
||||
"""Construct Relay expression from caffe2 graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_net : protobuf object
|
||||
predict_net : protobuf object
|
||||
|
||||
Returns
|
||||
-------
|
||||
func : tvm.relay.expr.Function
|
||||
Compatible relay function
|
||||
params : dict
|
||||
A dict of name: tvm.nd.array pairs, used as pretrained weights
|
||||
"""
|
||||
from caffe2.python import workspace
|
||||
workspace.RunNetOnce(init_net)
|
||||
|
||||
# Input
|
||||
input_name = predict_net.op[0].input[0]
|
||||
|
||||
# Params
|
||||
self._params = {}
|
||||
used_blobs = set()
|
||||
for c2_op in predict_net.op:
|
||||
for i in c2_op.input:
|
||||
used_blobs.add(i)
|
||||
for blob in workspace.Blobs():
|
||||
if blob in used_blobs and blob != input_name:
|
||||
self._params[blob] = _nd.array(workspace.FetchBlob(blob))
|
||||
|
||||
# Variables
|
||||
self._nodes = {}
|
||||
for blob in predict_net.external_input:
|
||||
if blob in self._params:
|
||||
self._nodes[blob] = new_var(blob, shape=self._params[blob].shape, dtype=self._params[blob].dtype)
|
||||
else:
|
||||
shape = self._shape[blob] if blob in self._shape else ()
|
||||
if isinstance(self._dtype, dict) and blob in self._dtype:
|
||||
dtype = str(self._dtype[blob])
|
||||
elif isinstance(self._dtype, str):
|
||||
dtype = self._dtype
|
||||
else:
|
||||
dtype = "float32"
|
||||
self._nodes[blob] = new_var(blob, shape=shape, dtype=dtype)
|
||||
|
||||
# Ops
|
||||
for c2_op in predict_net.op:
|
||||
for blob in c2_op.output:
|
||||
self._ops[blob] = c2_op
|
||||
|
||||
for c2_op in predict_net.op:
|
||||
self._process_op(c2_op)
|
||||
|
||||
# Outputs
|
||||
out = []
|
||||
for blob in predict_net.external_output:
|
||||
out.append(self._nodes[blob])
|
||||
|
||||
if len(out) > 1:
|
||||
outputs = _expr.Tuple(out)
|
||||
else:
|
||||
outputs = out[0]
|
||||
|
||||
func = _expr.Function(ir_pass.free_vars(outputs), outputs)
|
||||
|
||||
return func, self._params
|
||||
|
||||
def _get_node(self, blob):
|
||||
"""Get the Symbol of blob and detect cyclic dependency in the graph."""
|
||||
if blob in self._nodes:
|
||||
return self._nodes[blob]
|
||||
|
||||
assert blob not in self._visited_nodes, 'Cyclic dependency in the graph (in {})'.format(
|
||||
blob)
|
||||
self._visited_nodes.add(blob)
|
||||
|
||||
self._process_op(self._ops[blob])
|
||||
return self._nodes[blob]
|
||||
|
||||
def _process_op(self, c2_op):
|
||||
op_type = c2_op.type
|
||||
args = self._parse_arg(c2_op.arg)
|
||||
inputs = [self._get_node(i) for i in c2_op.input]
|
||||
tvm_op = self._convert_operator(op_type, inputs, args)
|
||||
|
||||
if not isinstance(tvm_op, _expr.TupleWrapper):
|
||||
self._nodes[c2_op.output[0]] = tvm_op
|
||||
else:
|
||||
for k, i in zip(list(c2_op.output), range(len(tvm_op))):
|
||||
self._nodes[k] = tvm_op[i]
|
||||
|
||||
def _parse_arg(self, arg):
|
||||
"""Convert a list of Argument to a dict, with names as keys."""
|
||||
args = {}
|
||||
for a in arg:
|
||||
for f in ['f', 'i', 's']:
|
||||
if a.HasField(f):
|
||||
args[a.name] = getattr(a, f)
|
||||
for f in ['floats', 'ints', 'strings']:
|
||||
if list(getattr(a, f)):
|
||||
assert a.name not in args, "Only one type of attr is allowed"
|
||||
args[a.name] = tuple(getattr(a, f))
|
||||
for f in ['n']:
|
||||
if a.HasField(f):
|
||||
raise NotImplementedError(
|
||||
"Field {} is not supported in relay.".format(f))
|
||||
for f in ['nets']:
|
||||
if list(getattr(a, f)):
|
||||
raise NotImplementedError(
|
||||
"Field {} is not supported in relay.".format(f))
|
||||
if a.name not in args:
|
||||
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
|
||||
return args
|
||||
|
||||
def _convert_operator(self,
|
||||
op_type,
|
||||
inputs,
|
||||
args,
|
||||
identity_list=None,
|
||||
convert_map=None):
|
||||
"""Convert from Caffe2 operator to Relay operator.
|
||||
The converter must specify conversions explicity for incompatible name, and
|
||||
apply handlers to operator attributes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_type : str
|
||||
Operator name, such as Convolution, FullyConnected
|
||||
inputs : list of tvm.relay.expr.Function
|
||||
List of input inputs.
|
||||
args : dict
|
||||
Dict of operator attributes
|
||||
identity_list : list
|
||||
List of operators that don't require conversion
|
||||
convert_map : dict
|
||||
Dict of name : callable, where name is the op's name that
|
||||
require conversion to relay, callable are functions which
|
||||
take args and return (new_op_type, new_args)
|
||||
|
||||
Returns
|
||||
-------
|
||||
func : tvm.relay.expr.Function
|
||||
Converted relay function
|
||||
"""
|
||||
identity_list = identity_list if identity_list else _identity_list
|
||||
convert_map = convert_map if convert_map else _get_convert_map()
|
||||
if op_type in identity_list:
|
||||
func = get_relay_op(op_type)(*inputs, **args)
|
||||
elif op_type in convert_map:
|
||||
# Add a sanitizing step to convert all byte strings in args to strings
|
||||
func = convert_map[op_type](inputs, args, self._params)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Operator {} not implemented.".format(op_type))
|
||||
return func
|
||||
|
||||
|
||||
def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
|
||||
"""Load caffe2 graph which contains init_net and predict_net into Relay Function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_net : protobuf object
|
||||
Caffe2 NetDef containing the weights
|
||||
|
||||
predict_net : protobuf object
|
||||
Caffe2 NetDef containing the graph
|
||||
|
||||
shape : dict of str to tuple
|
||||
The input shape to the graph
|
||||
|
||||
dtype : str or dict of str to str
|
||||
The input types to the graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
sym : tvm.relay.expr.Function
|
||||
Compatible relay function
|
||||
|
||||
params : dict of str to tvm.ndarray
|
||||
Dict of converted parameters stored in tvm.ndarray format
|
||||
"""
|
||||
|
||||
caffe2 = Caffe2NetDef(shape, dtype)
|
||||
return caffe2.from_caffe2(init_net, predict_net)
|
|
@ -0,0 +1,29 @@
|
|||
"""Store for caffe2 examples and common models."""
|
||||
from __future__ import absolute_import as _abs
|
||||
import os
|
||||
import sys
|
||||
import importlib
|
||||
from . import squeezenet
|
||||
from caffe2.python.models.download import ModelDownloader
|
||||
|
||||
models = [
|
||||
'squeezenet',
|
||||
'resnet50',
|
||||
'vgg19',
|
||||
]
|
||||
|
||||
mf = ModelDownloader()
|
||||
|
||||
class Model:
|
||||
def __init__(self, model_name):
|
||||
self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name)
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model)
|
||||
except ImportError:
|
||||
locals()['c2_' + model] = Model(model)
|
||||
|
||||
# squeezenet
|
||||
def relay_squeezenet():
|
||||
return squeezenet.get_workload()
|
|
@ -0,0 +1,132 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# coding: utf-8
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
"""
|
||||
Symbol of SqueezeNet
|
||||
|
||||
Reference:
|
||||
Iandola, Forrest N., et al.
|
||||
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
|
||||
"""
|
||||
|
||||
from tvm import relay
|
||||
from tvm.relay.testing import create_workload
|
||||
|
||||
# Helpers
|
||||
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, prefix=""):
|
||||
net = _make_fire_conv(net, squeeze_channels, 1, 0, "%s/squeeze1x1" % prefix)
|
||||
|
||||
left = _make_fire_conv(net, expand1x1_channels, 1, 0, "%s/expand1x1" % prefix)
|
||||
right = _make_fire_conv(net, expand3x3_channels, 3, 1, "%s/expand3x3" % prefix)
|
||||
# NOTE : Assume NCHW layout here
|
||||
net = relay.concatenate((left, right), axis=1)
|
||||
return net
|
||||
|
||||
|
||||
def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""):
|
||||
net = relay.nn.conv2d(net, relay.var("%s_weight" % prefix),
|
||||
channels=channels,
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
padding=(padding, padding))
|
||||
net = relay.nn.bias_add(net, relay.var("%s_bias" % prefix))
|
||||
net = relay.nn.relu(net)
|
||||
return net
|
||||
|
||||
|
||||
# Net
|
||||
def get_net(batch_size, image_shape, num_classes, dtype):
|
||||
"""Get symbol of SqueezeNet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size : int
|
||||
The batch size used in the model
|
||||
|
||||
image_shape : tuple
|
||||
The input image shape
|
||||
|
||||
num_classes: int
|
||||
The number of classification results
|
||||
|
||||
dtype : str
|
||||
The data type
|
||||
|
||||
"""
|
||||
data_shape = (batch_size,) + image_shape
|
||||
net = relay.var("data", shape=data_shape, dtype=dtype)
|
||||
net = relay.nn.conv2d(net, relay.var("conv1_weight"),
|
||||
channels=64,
|
||||
kernel_size=(3, 3),
|
||||
strides=(2, 2),
|
||||
padding=(0, 0))
|
||||
net = relay.nn.bias_add(net, relay.var("conv1_bias"))
|
||||
net = relay.nn.relu(net)
|
||||
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
|
||||
net = _make_fire(net, 16, 64, 64, 'fire2')
|
||||
net = _make_fire(net, 16, 64, 64, "fire3")
|
||||
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
|
||||
net = _make_fire(net, 32, 128, 128, "fire4")
|
||||
net = _make_fire(net, 32, 128, 128, "fire5")
|
||||
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
|
||||
net = _make_fire(net, 48, 192, 192, "fire6")
|
||||
net = _make_fire(net, 48, 192, 192, "fire7")
|
||||
net = _make_fire(net, 64, 256, 256, "fire8")
|
||||
net = _make_fire(net, 64, 256, 256, "fire9")
|
||||
net = relay.nn.dropout(net, rate=0.5)
|
||||
net = relay.nn.conv2d(net, relay.var('conv10_weight'), channels=num_classes, kernel_size=(1, 1))
|
||||
net = relay.nn.bias_add(net, relay.var("conv10_bias"))
|
||||
net = relay.nn.relu(net)
|
||||
net = relay.nn.global_avg_pool2d(net)
|
||||
net = relay.nn.softmax(net, axis=1)
|
||||
args = relay.ir_pass.free_vars(net)
|
||||
return relay.Function(args, net)
|
||||
|
||||
|
||||
def get_workload(batch_size=1,
|
||||
image_shape=(3, 224, 224),
|
||||
num_classes=1000,
|
||||
dtype="float32"):
|
||||
"""Get benchmark workload for SqueezeNet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size : int, optional
|
||||
The batch size used in the model
|
||||
|
||||
num_classes : int, optional
|
||||
Number of classes
|
||||
|
||||
image_shape : tuple, optional
|
||||
The input image shape
|
||||
|
||||
dtype : str, optional
|
||||
The data type
|
||||
|
||||
Returns
|
||||
-------
|
||||
net : relay.Function
|
||||
The computational graph
|
||||
|
||||
params : dict of str to NDArray
|
||||
The parameters.
|
||||
"""
|
||||
|
||||
net = get_net(batch_size, image_shape, num_classes, dtype)
|
||||
return create_workload(net)
|
|
@ -0,0 +1,87 @@
|
|||
import numpy as np
|
||||
import tvm
|
||||
from tvm.contrib import graph_runtime
|
||||
from tvm.relay.testing.config import ctx_list
|
||||
from tvm import relay
|
||||
from model_zoo import c2_squeezenet, c2_resnet50, c2_vgg19
|
||||
from caffe2.python import workspace
|
||||
|
||||
|
||||
def get_tvm_output(model,
|
||||
input_data,
|
||||
target,
|
||||
ctx,
|
||||
output_shape,
|
||||
output_dtype='float32'):
|
||||
""" Generic function to execute and get tvm output"""
|
||||
# supporting multiple inputs in caffe2 in a bit tricky,
|
||||
# because the input names can appear at the beginning or end of model.predict_net.external_input
|
||||
assert isinstance(input_data, np.ndarray)
|
||||
|
||||
# here we use the first input blob to the first op to get the input name
|
||||
input_names = model.predict_net.op[0].input[0]
|
||||
shape_dict = {input_names: input_data.shape}
|
||||
dtype_dict = {input_names: input_data.dtype}
|
||||
func, params = relay.frontend.from_caffe2(model.init_net, model.predict_net, shape_dict, dtype_dict)
|
||||
with relay.build_config(opt_level=3):
|
||||
graph, lib, params = relay.build(func, target, params=params)
|
||||
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
|
||||
# set inputs
|
||||
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
|
||||
m.set_input(**params)
|
||||
|
||||
# execute
|
||||
m.run()
|
||||
|
||||
# get outputs
|
||||
if isinstance(output_shape, list) and isinstance(output_dtype, list):
|
||||
tvm_output_list = []
|
||||
for i, s in enumerate(output_shape):
|
||||
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
|
||||
tvm_output_list.append(tvm_output.asnumpy())
|
||||
return tvm_output_list
|
||||
else:
|
||||
tvm_output = m.get_output(0, tvm.nd.empty((output_shape),
|
||||
output_dtype))
|
||||
return tvm_output.asnumpy()
|
||||
|
||||
|
||||
def get_caffe2_output(model, x, dtype='float32'):
|
||||
workspace.RunNetOnce(model.init_net)
|
||||
|
||||
input_blob = model.predict_net.op[0].input[0]
|
||||
workspace.FeedBlob(input_blob, x.astype(dtype))
|
||||
workspace.RunNetOnce(model.predict_net)
|
||||
|
||||
output_blob = model.predict_net.external_output[0]
|
||||
c2_output = workspace.FetchBlob(output_blob)
|
||||
return c2_output
|
||||
|
||||
|
||||
def verify_caffe2_forward_impl(model, data_shape, out_shape):
|
||||
dtype = 'float32'
|
||||
data = np.random.uniform(size=data_shape).astype(dtype)
|
||||
c2_out = get_caffe2_output(model, data, dtype)
|
||||
for target, ctx in ctx_list():
|
||||
tvm_out = get_tvm_output(model, data, target, ctx, out_shape, dtype)
|
||||
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
def test_forward_squeezenet1_1():
|
||||
verify_caffe2_forward_impl(c2_squeezenet, (1, 3, 224, 224), (1, 1000, 1, 1))
|
||||
|
||||
|
||||
def test_forward_resnet50():
|
||||
verify_caffe2_forward_impl(c2_resnet50, (1, 3, 224, 224), (1, 1000))
|
||||
|
||||
|
||||
def test_forward_vgg19():
|
||||
verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_forward_squeezenet1_1()
|
||||
test_forward_resnet50()
|
||||
test_forward_vgg19()
|
|
@ -0,0 +1,21 @@
|
|||
"""Test graph equality of caffe2 models."""
|
||||
from tvm import relay
|
||||
from model_zoo import c2_squeezenet, relay_squeezenet
|
||||
|
||||
|
||||
def compare_graph(f1, f2):
|
||||
f1 = relay.ir_pass.infer_type(f1)
|
||||
f2 = relay.ir_pass.infer_type(f2)
|
||||
assert relay.ir_pass.alpha_equal(f1, f2)
|
||||
|
||||
|
||||
def test_squeeze_net():
|
||||
shape_dict = {'data': (1, 3, 224, 224)}
|
||||
dtype_dict = {'data': 'float32'}
|
||||
from_c2_func, _ = relay.frontend.from_caffe2(c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
|
||||
relay_func, _ = relay_squeezenet()
|
||||
compare_graph(from_c2_func, relay_func)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_squeeze_net()
|
|
@ -47,3 +47,7 @@ python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1
|
|||
|
||||
echo "Running relay TFLite frontend test..."
|
||||
python3 -m nose -v tests/python/frontend/tflite || exit -1
|
||||
|
||||
echo "Running relay caffe2 frondend test..."
|
||||
python3 -m nose -v tests/python/frontend/caffe2 || exit -1
|
||||
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
"""
|
||||
Compile Caffe2 Models
|
||||
=====================
|
||||
**Author**: `Hiroyuki Makino <https://makihiro.github.io/>`_
|
||||
|
||||
This article is an introductory tutorial to deploy Caffe2 models with Relay.
|
||||
|
||||
For us to begin with, Caffe2 should be installed.
|
||||
|
||||
A quick solution is to install via conda
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# for cpu
|
||||
conda install pytorch-nightly-cpu -c pytorch
|
||||
# for gpu with CUDA 8
|
||||
conda install pytorch-nightly cuda80 -c pytorch
|
||||
|
||||
or please refer to official site
|
||||
https://caffe2.ai/docs/getting-started.html
|
||||
"""
|
||||
######################################################################
|
||||
# Utils for downloading files
|
||||
# ----------------------------
|
||||
def download(url, path, overwrite=False):
|
||||
import os
|
||||
if os.path.isfile(path) and not overwrite:
|
||||
print('File {} exists, skip.'.format(path))
|
||||
return
|
||||
print('Downloading from url {} to {}'.format(url, path))
|
||||
try:
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(url, path)
|
||||
except:
|
||||
import urllib
|
||||
urllib.urlretrieve(url, path)
|
||||
|
||||
######################################################################
|
||||
# Load pretrained Caffe2 model
|
||||
# ----------------------------
|
||||
# We load a pretrained resnet50 classification model provided by Caffe2.
|
||||
from caffe2.python.models.download import ModelDownloader
|
||||
mf = ModelDownloader()
|
||||
|
||||
class Model:
|
||||
def __init__(self, model_name):
|
||||
self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name)
|
||||
|
||||
resnet50 = Model('resnet50')
|
||||
|
||||
######################################################################
|
||||
# Load a test image
|
||||
# ------------------
|
||||
# A single cat dominates the examples!
|
||||
from PIL import Image
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
|
||||
download(img_url, 'cat.png')
|
||||
img = Image.open('cat.png').resize((224, 224))
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
# input preprocess
|
||||
def transform_image(image):
|
||||
image = np.array(image) - np.array([123., 117., 104.])
|
||||
image /= np.array([58.395, 57.12, 57.375])
|
||||
image = image.transpose((2, 0, 1))
|
||||
image = image[np.newaxis, :].astype('float32')
|
||||
return image
|
||||
|
||||
data = transform_image(img)
|
||||
|
||||
######################################################################
|
||||
# Compile the model on Relay
|
||||
# --------------------------
|
||||
|
||||
# Caffe2 input tensor name, shape and type
|
||||
input_name = resnet50.predict_net.op[0].input[0]
|
||||
shape_dict = {input_name: data.shape}
|
||||
dtype_dict = {input_name: data.dtype}
|
||||
|
||||
# parse Caffe2 model and convert into Relay computation graph
|
||||
from tvm import relay
|
||||
func, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict)
|
||||
|
||||
# compile the model
|
||||
# target x86 cpu
|
||||
target = 'llvm'
|
||||
with relay.build_config(opt_level=3):
|
||||
graph, lib, params = relay.build(func, target, params=params)
|
||||
|
||||
######################################################################
|
||||
# Execute on TVM
|
||||
# ---------------
|
||||
# The process is no different from other examples.
|
||||
import tvm
|
||||
from tvm.contrib import graph_runtime
|
||||
# context x86 cpu, use tvm.gpu(0) if you run on GPU
|
||||
ctx = tvm.cpu(0)
|
||||
# create a runtime executor module
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
m.set_input(input_name, tvm.nd.array(data.astype('float32')))
|
||||
# set related params
|
||||
m.set_input(**params)
|
||||
# execute
|
||||
m.run()
|
||||
# get outputs
|
||||
tvm_out = m.get_output(0)
|
||||
top1_tvm = np.argmax(tvm_out.asnumpy()[0])
|
||||
|
||||
#####################################################################
|
||||
# Look up synset name
|
||||
# -------------------
|
||||
# Look up prediction top 1 index in 1000 class synset.
|
||||
from caffe2.python import workspace
|
||||
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
|
||||
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
|
||||
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
|
||||
'imagenet1000_clsid_to_human.txt'])
|
||||
synset_name = 'synset.txt'
|
||||
download(synset_url, synset_name)
|
||||
with open(synset_name) as f:
|
||||
synset = eval(f.read())
|
||||
print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm]))
|
||||
# confirm correctness with caffe2 output
|
||||
p = workspace.Predictor(resnet50.init_net, resnet50.predict_net)
|
||||
caffe2_out = p.run({input_name: data})
|
||||
top1_caffe2 = np.argmax(caffe2_out)
|
||||
print('Caffe2 top-1 id: {}, class name: {}'.format(top1_caffe2, synset[top1_caffe2]))
|
Загрузка…
Ссылка в новой задаче