[NNVM][FRONTEND] Tensorflow frontend support (#1188)

This commit is contained in:
Siva 2018-06-13 11:07:35 +05:30 коммит произвёл Tianqi Chen
Родитель 7afeab0795
Коммит a81ebd9040
8 изменённых файлов: 1507 добавлений и 0 удалений

Просмотреть файл

@ -270,6 +270,10 @@ def build(graph, target=None, shape=None, dtype="float32",
# Apply optimization
with target:
graph = optimize(graph, shape, dtype, layout)
# Clear extra params without nodes.
_remove_noref_params(params, graph)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
@ -296,6 +300,24 @@ def build(graph, target=None, shape=None, dtype="float32",
params.update(init_var)
return graph, libmod, params
def _remove_noref_params(params, graph):
""" Helper to clear non referenced params
Parameters
----------
graph : Graph
The input graph
params: dict of str to ndarray
The parameter dictionary
"""
arg_list = set(graph.symbol.list_input_names())
if params:
param_keys = list(params.keys())
for key in param_keys:
if key not in arg_list:
params.pop(key)
def _run_graph(graph, params):
"""Helper utility to build and run and get outputs, only use cpu mode.

Просмотреть файл

@ -5,3 +5,4 @@ from .onnx import from_onnx
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow

Просмотреть файл

@ -0,0 +1,651 @@
# pylint: disable=import-self, invalid-name, unused-argument
"""TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs
from __future__ import print_function
# Numpy support
import numpy as np
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from .common import get_nnvm_op, AttrConverter as AttrConvert
__all__ = ['from_tensorflow']
class AttrCvt(object):
"""A Wrapper to handle some common jobs:
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
self._ignores.append('_output_shapes')
self._ignores.append('_input_shapes')
self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu')
return AttrConvert(self._op_name, self._transforms, self._excludes,
self._disables, self._ignores, self._extras,
self._custom_check)(inputs, attrs, *args)
def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
else:
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
return [pad_before, pad_after]
def _math_name_picker(surfix):
def _impl(attr):
return 'broadcast_' + surfix
return _impl
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _rsqrt():
def _impl(inputs, attr, *args):
return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
return _impl
def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
op_name = _math_name_picker(name)(attr)
axis = int(attr.get('axis', 0))
conv_ops = ["conv2d", "conv2d_transpose"]
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
# TODO: remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs)
return _impl
def _pooling(name):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
return AttrCvt(
op_name=_dimension_picker(name),
transforms={
'kernel_shape':'pool_size',
'data_format':'layout'},
ignores=['ksize'],
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
def _conv():
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]
if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
attr['channels'] = conv_param_weights.shape[3]
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
attr['channels'] = conv_param_weights.shape[1]
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
return AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
return inputs[0]
return _impl
def _cast():
def _impl(inputs, attr, params):
# Convert from tensorflow Dtype to str
attr['DstT'] = attr['DstT'].name
return AttrCvt(op_name='cast', transforms={'DstT': 'dtype'}, ignores=['SrcT'])(inputs, attr)
return _impl
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.list_output_names()[0]]
params.pop(dim_input.list_output_names()[0])
return AttrCvt(op_name="expand_dims", ignores=['Tdim'],
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _impl
def _resize_bilinear():
def _impl(inputs, attr, params):
# Change this when we have corresponding resize bilinear operation.
print("ResizeBilinear:Only NN (nearest neighbor) supported in symetric mode of dimensions")
print("Change this when we have corresponding resize bilinear operation")
# NHWC
input_shape = attr['_input_shapes'][inputs[0]][0]
in_hw = (input_shape[1], input_shape[2])
out_hw = params[inputs[1].list_output_names()[0]]
inputs.pop(1)
attr['layout'] = 'NHWC'
if in_hw[0] < 0 or in_hw[1] < 0:
scale = 1
else:
# Considering height alone for scale
scale = out_hw[0] / in_hw[0]
return AttrCvt(op_name="upsampling",
ignores=['Tdim', 'align_corners'],
extras={'scale': scale})(inputs, attr)
return _impl
def _check_numerics():
def _impl(inputs, attr, params):
# Making a copy node assuming no need to verify
return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
return _impl
def _matmul():
def _impl(inputs, attr, params):
channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
if attr['transpose_a']:
inputs[0] = _sym.transpose(inputs[0], axis(1, 0))
if not attr['transpose_b']:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return AttrCvt(op_name="dense",
extras={'use_bias': False, 'units': channels},
ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr)
return _impl
def _identity():
def _impl(inputs, attr, params):
return inputs[0]
return _impl
def _concatV2():
def _impl(inputs, attr, params):
pop_node = inputs.pop(len(inputs)-1)
axis = params[pop_node.list_output_names()[0]]
params.pop(pop_node.list_output_names()[0])
return AttrCvt(
op_name="concatenate", ignores=['T', 'N', 'Tidx'],
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _impl
def _concat():
def _impl(inputs, attr, params):
pop_node = inputs.pop(0)
axis = params[pop_node.list_output_names()[0]]
params.pop(pop_node.list_output_names()[0])
return AttrCvt(
op_name="concatenate", ignores=['N'],
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _impl
def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
shape_arg = params[pop_node.list_output_names()[0]]
params.pop(pop_node.list_output_names()[0])
return AttrCvt(
op_name="reshape",
extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
return _impl
def _bias_add():
def _impl(inputs, attr, params):
return _sym.broadcast_add(inputs[0], inputs[1])
return _impl
def _squeeze():
def _impl(inputs, attr, params):
return AttrCvt(
op_name="squeeze",
transforms={'squeeze_dims':'axis'},
ignores=['T'])(inputs, attr)
return _impl
def _batch_norm():
def _impl(inputs, attr, params):
# Rearrange inputs from
# (data, moving_mean, moving_variance, beta, gamma)
# to
# (data, gamma, beta, moving_mean, moving_var)
new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]]
return AttrCvt(
op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis
disables=['momentum'])(new_inputs, attr)
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'AvgPool' : _pooling('avg_pool'),
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(),
'Cast' : _cast(),
'CheckNumerics' : _check_numerics(),
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv(),
'DecodeJpeg' : _decode_image(),
'ExpandDims' : _expand_dims(),
'Identity' : _identity(),
'MatMul' : _matmul(),
'MaxPool' : _pooling('max_pool'),
'Mul' : _elemwise('mul'),
'Relu' : AttrCvt('relu'),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
'Sub' : _elemwise('sub'),
'Add' : _elemwise('add'),
'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(),
}
class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
"""
def __init__(self):
self._nodes = {}
self._params = {}
self._renames = {}
self._replacements = {}
self._output_shapes = {}
self._num_input = 0
self._num_param = 0
self._input_node = ''
def from_tensorflow(self, graph):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> First Const or Placeholder node will be considered as graph input.
-> Rest all Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
-> DecodeJpeg, ResizeBilinear: These are dummy operators.
Hence user should handle preprocessing outside.
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
Parameters
----------
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name)
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
elif node.op == "Const":
# Assuming first Const node as Graph Input node
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name)
else:
# Rest all nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
self._parse_param(key, value, node.name)
if node.name not in self._nodes:
raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name))
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
else:
attr = self._parse_attr(node.attr)
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) for shape in attr['_output_shapes']]
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]
try:
inputs = [self._nodes[i] for i in node.input]
input_shapes = {}
for i in node.input:
if i not in self._params:
input_shapes[self._nodes[i]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes
except KeyError:
# TODO: Need to find clean way to handle '^CheckNumerics'
print("Some Exception while inputs list:", node.input, " ignoring...")
inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr)
# Assuming only one output.
self._nodes[node.name] = op
node_output = op
# Assume the final node is the output node
out = node_output
return out, self._params
def _parse_param(self, key, value, name):
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
if key == 'value':
np_array = tensor_util.MakeNdarray(value.tensor)
array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array
self._params[name] = tvm.nd.array(new_array)
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = _sym.Variable(name=name,
shape=self._params[name].shape)
else:
if key != 'dtype' and key != '_output_shapes':
raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key))
def _get_attr(self, buf):
"""Returns the value of the attr of this buf with the given `name`.
Args:
buf: attrvalue protobuf.
Returns:
The value of the attr, as a Python object.
Raises:
ValueError: If this op does not have an attr with the given `name`.
"""
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
x = buf
ret = []
try:
from tensorflow.python.framework import dtypes
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
return ret
if x.HasField("list"):
for f in fields:
if getattr(x.list, f):
if f == "type":
ret = [dtypes.as_dtype(x) for x in list(getattr(x.list, f))]
else:
ret = list(getattr(x.list, f))
else:
for f in fields:
if x.HasField(f):
if f == "type":
ret = dtypes.as_dtype(getattr(x, f))
else:
ret = getattr(x, f)
return ret
def _parse_attr(self, attr_proto):
"""Convert a list of AttributeProto to a dict, with names as keys."""
attrs = {}
for key, value in attr_proto.items():
attrs[key] = self._get_attr(value)
return attrs
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_name : str
Operator name, such as Conv2D, AvgPool
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
identity_list : list
List of operators that don't require conversion
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
sym = get_nnvm_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym
def _fix_extranodes(self, op_name, attr, inputs):
if op_name == "Softmax":
# Require some times flatten of data before it goes to softmax
# Need to relook into this with latest softmax axis support.
op = AttrCvt(op_name='flatten')(inputs, {})
node_output = op.list_output_names()
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
inputs = [op]
return inputs
def from_tensorflow(graph):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
Parameters
----------
graph : GraphDef object
Tensorflow GraphDef
Returns
-------
sym : nnvm.Symbol
Compatible nnvm symbol
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph)
return sym, params

Просмотреть файл

@ -0,0 +1,229 @@
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
"""
Tensorflow Model Helpers
========================
Some helper definitions for tensorflow models.
"""
import re
import os.path
import numpy as np
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
######################################################################
# Some helper functions
# ---------------------
def ProcessGraphDefParam(graph_def):
"""Type-checks and possibly canonicalizes `graph_def`.
Parameters
----------
graph_def : Obj
tensorflow graph definition.
Returns
-------
graph_def : Obj
tensorflow graph devinition
"""
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self,
label_lookup_path=None,
uid_lookup_path=None):
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
"""Loads a human readable English name for each softmax node.
Parameters
----------
label_lookup_path: String
File containing String UID to integer node ID mapping .
uid_lookup_path: String
File containing String UID to human-readable string mapping.
Returns
-------
node_id_to_name: dict
dict from integer node ID to human-readable string.
"""
if not tf.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path)
if not tf.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def read_normalized_tensor_from_image_file(file_name,
input_height=299,
input_width=299,
input_mean=0,
input_std=255):
""" Preprocessing of image
Parameters
----------
file_name: String
Image filename.
input_height: int
model input height.
input_width: int
model input width
input_mean: int
Mean to be substracted in normalization.
input_std: int
Standard deviation used in normalization.
Returns
-------
np_array: Numpy array
Normalized image data as a numpy array.
"""
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
image_reader = tf.image.decode_jpeg(file_reader, channels=3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
tf.InteractiveSession()
np_array = normalized.eval()
return np_array
def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(normalized, graph_def) : Tuple
normalized is normalized input for graph testing.
graph_def is the tensorflow workload for Inception V3.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (normalized, graph_def)
def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(image_data, tvm_data, graph_def) : Tuple
image_data is raw encoded image data for TF input.
tvm_data is the decoded image data for TVM input.
graph_def is the tensorflow workload for Inception V1.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
if not tf.gfile.Exists(os.path.join("./", image_name)):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(os.path.join("./", image_name), 'rb').read()
# TVM doesn't handle decode, hence decode it.
from PIL import Image
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (image_data, tvm_data, graph_def)

Просмотреть файл

@ -52,6 +52,15 @@ reg.register_schedule("_assign", _fschedule_broadcast)
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)
# cast
@reg.register_compute("cast")
def compute_cast(attrs, inputs, _):
"""Compute definition of cast"""
dtype = attrs.get_string("dtype")
return topi.cast(inputs[0], dtype)
reg.register_pattern("cast", OpPattern.ELEMWISE)
reg.register_schedule("cast", _fschedule_broadcast)
# exp
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)

Просмотреть файл

@ -0,0 +1,421 @@
# pylint: disable=import-self, invalid-name, unused-argument
"""
Tensorflow testcases
====================
This article is a test script to test tensorflow operator with NNVM.
"""
from __future__ import print_function
import numpy as np
import nnvm.compiler
import tvm
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype):
""" Generic function to compile on nnvm and execute on tvm """
sym, params = nnvm.frontend.from_tensorflow(graph_def)
target = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype
else:
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
dtype=dtype_dict, params=params)
ctx = tvm.cpu(0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
if isinstance(input_data, list):
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
else:
m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
return tvm_output.asnumpy()
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """
tensor = sess.graph.get_tensor_by_name(output_node)
if isinstance(input_data, list):
input_dict = {}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
else:
input_dict = {input_node: input_data}
output_data = sess.run(tensor, input_dict)
return output_data
#######################################################################
# Pooling
# -------
def _test_pooling(input_shape, **kwargs):
""" One iteration of pool operation with given shapes and attributes """
x = -np.arange(
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
with tf.Graph().as_default():
in_data = constant_op.constant(x, shape=input_shape, dtype='float32')
# pylint: disable=unused-variable
pool = nn_ops.pool(in_data, **kwargs)
# pylint: enable=unused-variable
if kwargs['pooling_type'] == 'MAX':
out_node = 'max_pool'
out_name = 'max_pool:0'
else:
out_node = 'avg_pool'
out_name = 'avg_pool:0'
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
)
tf_output = run_tf_graph(sess, x, 'Const:0', out_name)
tvm_output = run_tvm_graph(graph_def, x.astype('float32'),
"Const", tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
sess.close()
def test_forward_pooling():
""" Pooling """
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 1])
#######################################################################
# Convolution
# -----------
def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format):
""" One iteration of convolution with given shapes and attributes """
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
with tf.Graph().as_default():
in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
# pylint: disable=unused-variable
conv = nn_ops.conv2d(in_data,
in_filter,
strides=strides,
padding=padding,
data_format=data_format)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Conv2D'],
)
tf_output = run_tf_graph(sess, np.reshape(data_array, tensor_in_sizes),
'Const:0', 'Conv2D:0')
tvm_output = run_tvm_graph(graph_def,
np.reshape(data_array, tensor_in_sizes).astype('float32'),
"Const", tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
sess.close()
def test_forward_convolution():
_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
#######################################################################
# Reshape
# -------
def _test_reshape(data, out_shape):
""" One iteration of reshape operation with given data and out shape """
with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
# pylint: disable=unused-variable
reshape_out = array_ops.reshape(in_data, out_shape)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Reshape'],
)
tf_output = run_tf_graph(sess, data,
'Const:0', 'Reshape:0')
tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
def test_forward_reshape():
_test_reshape(np.arange(6.0), [2, 3])
_test_reshape(np.arange(6), [-1, 2])
_test_reshape(np.arange(6), [3, -1])
_test_reshape(np.arange(6), [-1])
#######################################################################
# Squeeze
# -------
def _test_squeeze(data, squeeze_dims=None):
""" One iteration of squeeze """
if squeeze_dims is None:
squeeze_dims = []
with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
# pylint: disable=unused-variable
if squeeze_dims:
squeeze_out = array_ops.squeeze(in_data, squeeze_dims)
else:
squeeze_out = array_ops.squeeze(in_data)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Squeeze'],
)
tf_output = run_tf_graph(sess, data,
'Const:0', 'Squeeze:0')
tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
def test_forward_squeeze():
""" Squeeze """
# Nothing to squeeze.
_test_squeeze(np.arange(2).reshape((2)))
_test_squeeze(np.arange(6).reshape((2, 3)))
# Squeeze the middle element away.
_test_squeeze(np.arange(4).reshape((2, 1, 2)))
# Squeeze on both ends.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)))
# Positive squeeze dim index.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [2, 4])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0, 4, 2])
# Negative squeeze dim index.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-1])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
#######################################################################
# ConcatV2
# --------
def _test_concat_v2(data, dim):
""" One iteration of ConcatV2 """
with tf.Graph().as_default():
# pylint: disable=unused-variable
concat_out = gen_array_ops._concat_v2(data, dim)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['ConcatV2'],
)
tf_output = run_tf_graph(sess, data,
['ConcatV2/values_0:0', 'ConcatV2/values_1:0'], 'ConcatV2:0')
tvm_output = run_tvm_graph(graph_def,
data,
["ConcatV2/values_0", 'ConcatV2/values_1'],
tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
def _test_forward_concat_v2():
t1 = np.array([])
t2 = np.array([])
test_concat_v2([t1, t2], 0)
t1 = np.array([[1, 2, 3], [4, 5, 6]])
t2 = np.array([[7, 8, 9], [10, 11, 12]])
_test_concat_v2([t1, t2], 1)
#######################################################################
# Multi Input to graph
# --------------------
def test_forward_multi_input():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
out1 = tf.add(in1, in2, name='out1')
out2 = tf.subtract(in3, in4, name='out2')
out = tf.multiply(out1, out2, name='out')
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['out'],
)
in_data = np.arange(9, dtype='int32').reshape([3, 3])
tf_output = run_tf_graph(sess, [in_data, in_data, in_data, in_data ],
['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
tvm_output = run_tvm_graph(graph_def,
[in_data, in_data, in_data, in_data ],
['in1', 'in2', 'in3', 'in4'],
tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
(data, graph_def) = nnvm.testing.tf.get_workload_inception_v3()
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
tvm_output = run_tvm_graph(graph_def, data, 'input', (1, 1001), 'float32')
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-3:][::-1]
# TVM implementation of SAME padding some times make a slight deviation.
# Hence check for top predictions.
top_tvm = np.sort(top_tvm)
top_tf = np.sort(top_tf)
np.testing.assert_allclose(top_tf, top_tvm)
#######################################################################
# Inception V1
# ------------
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
(data, tvm_data, graph_def) = nnvm.testing.tf.get_workload_inception_v1()
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', (1, 1008), 'float32')
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
#######################################################################
# Main
# ----
if __name__ == '__main__':
test_forward_convolution()
test_forward_pooling()
test_forward_reshape()
test_forward_squeeze()
if tf.__version__ == '1.4.1':
_test_forward_concat_v2()
test_forward_multi_input()
test_forward_inception_v3()
test_forward_inception_v1()

Просмотреть файл

@ -18,3 +18,6 @@ python3 -m nose -v nnvm/tests/python/frontend/mxnet || exit -1
echo "Running Keras frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/keras || exit -1
echo "Running Tensorflow frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/tensorflow || exit -1

Просмотреть файл

@ -0,0 +1,171 @@
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with NNVM.
For us to begin with, tensorflow module is required to be installed.
A quick solution is to install tensorlfow from
https://www.tensorflow.org/install/install_sources
"""
import nnvm
import tvm
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
import nnvm.testing.tf
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
######################################################################
# Download processed tensorflow model
# -----------------------------------
# In this section, we download a pretrained Tensorflow model and classify an image.
from mxnet.gluon.utils import download
download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Creates graph from saved graph_def.pb.
# --------------------------------------
with tf.gfile.FastGFile(os.path.join(
"./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
######################################################################
# Decode image
# ------------
from PIL import Image
image = Image.open(img_name).resize((299, 299))
def transform_image(image):
image = np.array(image)
return image
x = transform_image(image)
######################################################################
# Import the graph to NNVM
# ------------------------
sym, params = nnvm.frontend.from_tensorflow(graph_def)
######################################################################
# Now compile the graph through NNVM
import nnvm.compiler
target = 'llvm'
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now, we would like to reproduce the same forward computation using TVM.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
######################################################################
# Process the output to human readable
# ------------------------------------
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
######################################################################
# Run the same graph with tensorflow and dump output.
# ---------------------------------------------------
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image (img_name)