[FRONTEND][TENSORFLOW] GPU support for tensorflow models. (#1718)
This commit is contained in:
Родитель
ae5a28dba4
Коммит
fdf795a076
|
@ -35,6 +35,7 @@ class AttrCvt(object):
|
|||
self._ignores.append('use_cudnn_on_gpu')
|
||||
self._ignores.append('_node_name')
|
||||
self._ignores.append('is_training')
|
||||
self._ignores.append('_target_layout')
|
||||
# Retain the names
|
||||
try:
|
||||
attrs['name'] = attrs['_node_name']
|
||||
|
@ -121,6 +122,9 @@ def _pooling(name):
|
|||
def _impl(inputs, attr, params):
|
||||
|
||||
attr['data_format'] = attr['data_format'].decode("utf-8")
|
||||
flip_layout = False
|
||||
|
||||
input_shape = attr['_input_shapes'][inputs[0]][0]
|
||||
|
||||
if attr['data_format'] == 'NHWC':
|
||||
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
|
||||
|
@ -129,11 +133,17 @@ def _pooling(name):
|
|||
else:
|
||||
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
|
||||
|
||||
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
|
||||
tmp_shape = attr['_input_shapes'][inputs[0]][0]
|
||||
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
|
||||
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
|
||||
attr['data_format'] = "NCHW"
|
||||
flip_layout = True
|
||||
|
||||
# Fix strides
|
||||
attr['strides'] = (attr['strides'][1], attr['strides'][2])
|
||||
|
||||
# Fix padding
|
||||
input_shapes = attr['_input_shapes'][inputs[0]]
|
||||
attr['padding'] = attr['padding'].decode("utf-8")
|
||||
|
||||
if attr['padding'] == 'VALID':
|
||||
|
@ -142,11 +152,11 @@ def _pooling(name):
|
|||
stride_h, stride_w = attr['strides']
|
||||
kernel_h, kernel_w = attr['kernel_shape']
|
||||
if attr['data_format'] == 'NHWC':
|
||||
in_h = input_shapes[0][1]
|
||||
in_w = input_shapes[0][2]
|
||||
in_h = input_shape[1]
|
||||
in_w = input_shape[2]
|
||||
else:
|
||||
in_h = input_shapes[0][2]
|
||||
in_w = input_shapes[0][3]
|
||||
in_h = input_shape[2]
|
||||
in_w = input_shape[3]
|
||||
|
||||
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
|
||||
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
|
||||
|
@ -158,7 +168,7 @@ def _pooling(name):
|
|||
if name == "avg_pool":
|
||||
attr['count_include_pad'] = False
|
||||
|
||||
return AttrCvt(
|
||||
out = AttrCvt(
|
||||
op_name=_dimension_picker(name),
|
||||
transforms={
|
||||
'kernel_shape':'pool_size',
|
||||
|
@ -166,33 +176,53 @@ def _pooling(name):
|
|||
ignores=['ksize'],
|
||||
extras={'ceil_mode': False},
|
||||
custom_check=_dimension_constraint())(inputs, attr)
|
||||
|
||||
if flip_layout:
|
||||
out = _sym.transpose(out, axes=(0, 2, 3, 1))
|
||||
|
||||
return out
|
||||
return _impl
|
||||
|
||||
def _conv(opname):
|
||||
def _impl(inputs, attr, params):
|
||||
attr['data_format'] = attr['data_format'].decode("utf-8")
|
||||
input_shapes = attr['_input_shapes'][inputs[0]]
|
||||
flip_layout = False
|
||||
|
||||
# Extract kernel shape from params
|
||||
conv_param_weights = params[inputs[1].list_output_names()[0]]
|
||||
input_shape = attr['_input_shapes'][inputs[0]][0]
|
||||
weights_shape = params[inputs[1].list_output_names()[0]].shape
|
||||
|
||||
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
|
||||
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
|
||||
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
|
||||
if opname == 'conv':
|
||||
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
|
||||
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
|
||||
else:
|
||||
weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
|
||||
inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))
|
||||
|
||||
attr['data_format'] = "NCHW"
|
||||
flip_layout = True
|
||||
|
||||
if attr['data_format'] == 'NHWC':
|
||||
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
|
||||
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
|
||||
kernel_h, kernel_w, _, depth_mult = weights_shape
|
||||
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
|
||||
if opname == 'conv':
|
||||
attr['channels'] = conv_param_weights.shape[3]
|
||||
attr['channels'] = weights_shape[3]
|
||||
else:
|
||||
attr['channels'] = input_shapes[0][3] * depth_mult
|
||||
attr['channels'] = input_shape[3] * depth_mult
|
||||
|
||||
if 'dilations' in attr:
|
||||
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
|
||||
elif attr['data_format'] == 'NCHW':
|
||||
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
|
||||
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
|
||||
depth_mult, _, kernel_h, kernel_w = weights_shape
|
||||
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
|
||||
if opname == 'conv':
|
||||
attr['channels'] = conv_param_weights.shape[1]
|
||||
attr['channels'] = weights_shape[0]
|
||||
else:
|
||||
attr['channels'] = input_shapes[0][1] * depth_mult
|
||||
attr['channels'] = input_shape[0] * depth_mult
|
||||
if attr['channels'] < 0:
|
||||
attr['channels'] *= -1
|
||||
|
||||
if 'dilations' in attr:
|
||||
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
|
||||
|
@ -215,11 +245,11 @@ def _conv(opname):
|
|||
stride_h, stride_w = attr['strides']
|
||||
kernel_h, kernel_w = attr['kernel_shape']
|
||||
if attr['data_format'] == 'NHWC':
|
||||
in_h = input_shapes[0][1]
|
||||
in_w = input_shapes[0][2]
|
||||
in_h = input_shape[1]
|
||||
in_w = input_shape[2]
|
||||
else:
|
||||
in_h = input_shapes[0][2]
|
||||
in_w = input_shapes[0][3]
|
||||
in_h = input_shape[2]
|
||||
in_w = input_shape[3]
|
||||
|
||||
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
|
||||
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
|
||||
|
@ -248,7 +278,7 @@ def _conv(opname):
|
|||
else:
|
||||
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
|
||||
|
||||
return AttrCvt(
|
||||
out = AttrCvt(
|
||||
op_name=_dimension_picker('conv'),
|
||||
transforms={
|
||||
'kernel_shape': 'kernel_size',
|
||||
|
@ -257,6 +287,11 @@ def _conv(opname):
|
|||
'group': ('groups', 1)},
|
||||
extras={'use_bias': len(inputs) == 3},
|
||||
custom_check=_dimension_constraint())(inputs, attr)
|
||||
|
||||
if flip_layout:
|
||||
out = _sym.transpose(out, axes=(0, 2, 3, 1))
|
||||
|
||||
return out
|
||||
return _impl
|
||||
|
||||
def _decode_image():
|
||||
|
@ -305,7 +340,7 @@ def _matmul():
|
|||
def _impl(inputs, attr, params):
|
||||
channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
|
||||
if attr['transpose_a']:
|
||||
inputs[0] = _sym.transpose(inputs[0], axis(1, 0))
|
||||
inputs[0] = _sym.transpose(inputs[0], axes(1, 0))
|
||||
if not attr['transpose_b']:
|
||||
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
|
||||
return AttrCvt(op_name="dense",
|
||||
|
@ -948,7 +983,7 @@ class GraphProto(object):
|
|||
self._num_param = 0
|
||||
self._num_rnn_layer = False
|
||||
|
||||
def from_tensorflow(self, graph):
|
||||
def from_tensorflow(self, graph, layout="NHWC"):
|
||||
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
|
||||
|
||||
Follow the tensorflow graph definition to parse and convert it to NNVM.
|
||||
|
@ -1036,6 +1071,9 @@ class GraphProto(object):
|
|||
# Pass the node name too in attr
|
||||
attr["_node_name"] = node.name
|
||||
|
||||
# Pass the target layout
|
||||
attr["_target_layout"] = layout
|
||||
|
||||
#ToDo: Some of the tensorflow operators internaly maintain
|
||||
#execution layers and its output name will the layer number along with
|
||||
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
|
||||
|
@ -1265,7 +1303,7 @@ class GraphProto(object):
|
|||
|
||||
return inputs
|
||||
|
||||
def from_tensorflow(graph):
|
||||
def from_tensorflow(graph, layout="NHWC"):
|
||||
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
|
||||
The companion parameters will be handled automatically.
|
||||
|
||||
|
@ -1283,5 +1321,5 @@ def from_tensorflow(graph):
|
|||
Dict of converted parameters stored in tvm.ndarray format
|
||||
"""
|
||||
g = GraphProto()
|
||||
sym, params = g.from_tensorflow(graph)
|
||||
sym, params = g.from_tensorflow(graph, layout)
|
||||
return sym, params
|
||||
|
|
|
@ -26,11 +26,15 @@ import nnvm.testing.tf
|
|||
#######################################################################
|
||||
# Generic run functions for TVM & tensorflow
|
||||
# ------------------------------------------
|
||||
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype):
|
||||
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'):
|
||||
""" Generic function to compile on nnvm and execute on tvm """
|
||||
|
||||
sym, params = nnvm.frontend.from_tensorflow(graph_def)
|
||||
target = 'llvm'
|
||||
layout = None
|
||||
if target == "cuda":
|
||||
layout = "NCHW"
|
||||
|
||||
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
|
||||
target_host = 'llvm'
|
||||
if isinstance(input_data, list):
|
||||
shape_dict = {}
|
||||
dtype_dict = {}
|
||||
|
@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
|
|||
shape_dict = {input_node: input_data.shape}
|
||||
dtype_dict = {input_node: input_data.dtype}
|
||||
|
||||
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
|
||||
graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
|
||||
dtype=dtype_dict, params=params)
|
||||
|
||||
ctx = tvm.cpu(0)
|
||||
ctx = tvm.context(target, 0)
|
||||
from tvm.contrib import graph_runtime
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
|
@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
|
|||
)
|
||||
|
||||
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
|
||||
tvm_output = run_tvm_graph(final_graph_def, in_data,
|
||||
in_node, tf_output.shape, tf_output.dtype)
|
||||
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
|
||||
|
||||
for device in ["llvm", "cuda"]:
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
continue
|
||||
|
||||
tvm_output = run_tvm_graph(final_graph_def, in_data,
|
||||
in_node, tf_output.shape, tf_output.dtype, target=device)
|
||||
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
|
||||
|
||||
sess.close()
|
||||
|
||||
#######################################################################
|
||||
|
|
|
@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto)
|
|||
lable_map = 'imagenet_synset_to_human_label_map.txt'
|
||||
lable_map_url = os.path.join(repo_base, lable_map)
|
||||
|
||||
# Target settings
|
||||
# Use these commented settings to build for cuda.
|
||||
#target = 'cuda'
|
||||
#target_host = 'llvm'
|
||||
#layout = "NCHW"
|
||||
#ctx = tvm.gpu(0)
|
||||
target = 'llvm'
|
||||
target_host = 'llvm'
|
||||
layout = None
|
||||
ctx = tvm.cpu(0)
|
||||
|
||||
######################################################################
|
||||
# Download required files
|
||||
|
@ -99,7 +109,7 @@ x = np.array(image)
|
|||
# Results:
|
||||
# sym: nnvm graph for given tensorflow protobuf.
|
||||
# params: params converted from tensorflow params (tensor protobuf).
|
||||
sym, params = nnvm.frontend.from_tensorflow(graph_def)
|
||||
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
|
||||
|
||||
print ("Tensorflow protobuf imported as nnvm graph")
|
||||
######################################################################
|
||||
|
@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph")
|
|||
# lib: target library which can be deployed on target with tvm runtime.
|
||||
|
||||
import nnvm.compiler
|
||||
target = 'llvm'
|
||||
shape_dict = {'DecodeJpeg/contents': x.shape}
|
||||
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
|
||||
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
|
||||
graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params)
|
||||
|
||||
######################################################################
|
||||
# Execute the portable graph on TVM
|
||||
# ---------------------------------
|
||||
# Now we can try deploying the NNVM compiled model on cpu target.
|
||||
# Now we can try deploying the NNVM compiled model on target.
|
||||
|
||||
from tvm.contrib import graph_runtime
|
||||
ctx = tvm.cpu(0)
|
||||
dtype = 'uint8'
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
|
|
Загрузка…
Ссылка в новой задаче