diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index ab566467..9c9fac89 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -35,6 +35,7 @@ class AttrCvt(object): self._ignores.append('use_cudnn_on_gpu') self._ignores.append('_node_name') self._ignores.append('is_training') + self._ignores.append('_target_layout') # Retain the names try: attrs['name'] = attrs['_node_name'] @@ -121,6 +122,9 @@ def _pooling(name): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") + flip_layout = False + + input_shape = attr['_input_shapes'][inputs[0]][0] if attr['data_format'] == 'NHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) @@ -129,11 +133,17 @@ def _pooling(name): else: raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) + if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": + tmp_shape = attr['_input_shapes'][inputs[0]][0] + input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2)) + attr['data_format'] = "NCHW" + flip_layout = True + # Fix strides attr['strides'] = (attr['strides'][1], attr['strides'][2]) # Fix padding - input_shapes = attr['_input_shapes'][inputs[0]] attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': @@ -142,11 +152,11 @@ def _pooling(name): stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] if attr['data_format'] == 'NHWC': - in_h = input_shapes[0][1] - in_w = input_shapes[0][2] + in_h = input_shape[1] + in_w = input_shape[2] else: - in_h = input_shapes[0][2] - in_w = input_shapes[0][3] + in_h = input_shape[2] + in_w = input_shape[3] pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) @@ -158,7 +168,7 @@ def _pooling(name): if name == "avg_pool": attr['count_include_pad'] = False - return AttrCvt( + out = AttrCvt( op_name=_dimension_picker(name), transforms={ 'kernel_shape':'pool_size', @@ -166,33 +176,53 @@ def _pooling(name): ignores=['ksize'], extras={'ceil_mode': False}, custom_check=_dimension_constraint())(inputs, attr) + + if flip_layout: + out = _sym.transpose(out, axes=(0, 2, 3, 1)) + + return out return _impl def _conv(opname): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") - input_shapes = attr['_input_shapes'][inputs[0]] + flip_layout = False - # Extract kernel shape from params - conv_param_weights = params[inputs[1].list_output_names()[0]] + input_shape = attr['_input_shapes'][inputs[0]][0] + weights_shape = params[inputs[1].list_output_names()[0]].shape + + if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": + input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2)) + if opname == 'conv': + weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] + inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1)) + else: + weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)] + inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1)) + + attr['data_format'] = "NCHW" + flip_layout = True if attr['data_format'] == 'NHWC': - kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape - attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) + kernel_h, kernel_w, _, depth_mult = weights_shape + attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) if opname == 'conv': - attr['channels'] = conv_param_weights.shape[3] + attr['channels'] = weights_shape[3] else: - attr['channels'] = input_shapes[0][3] * depth_mult + attr['channels'] = input_shape[3] * depth_mult if 'dilations' in attr: attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) elif attr['data_format'] == 'NCHW': - depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape - attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) + depth_mult, _, kernel_h, kernel_w = weights_shape + attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) if opname == 'conv': - attr['channels'] = conv_param_weights.shape[1] + attr['channels'] = weights_shape[0] else: - attr['channels'] = input_shapes[0][1] * depth_mult + attr['channels'] = input_shape[0] * depth_mult + if attr['channels'] < 0: + attr['channels'] *= -1 if 'dilations' in attr: attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) @@ -215,11 +245,11 @@ def _conv(opname): stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] if attr['data_format'] == 'NHWC': - in_h = input_shapes[0][1] - in_w = input_shapes[0][2] + in_h = input_shape[1] + in_w = input_shape[2] else: - in_h = input_shapes[0][2] - in_w = input_shapes[0][3] + in_h = input_shape[2] + in_w = input_shape[3] pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) @@ -248,7 +278,7 @@ def _conv(opname): else: attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' - return AttrCvt( + out = AttrCvt( op_name=_dimension_picker('conv'), transforms={ 'kernel_shape': 'kernel_size', @@ -257,6 +287,11 @@ def _conv(opname): 'group': ('groups', 1)}, extras={'use_bias': len(inputs) == 3}, custom_check=_dimension_constraint())(inputs, attr) + + if flip_layout: + out = _sym.transpose(out, axes=(0, 2, 3, 1)) + + return out return _impl def _decode_image(): @@ -305,7 +340,7 @@ def _matmul(): def _impl(inputs, attr, params): channels = _infer_channels(inputs[1], params, not attr['transpose_b']) if attr['transpose_a']: - inputs[0] = _sym.transpose(inputs[0], axis(1, 0)) + inputs[0] = _sym.transpose(inputs[0], axes(1, 0)) if not attr['transpose_b']: inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) return AttrCvt(op_name="dense", @@ -948,7 +983,7 @@ class GraphProto(object): self._num_param = 0 self._num_rnn_layer = False - def from_tensorflow(self, graph): + def from_tensorflow(self, graph, layout="NHWC"): """Construct nnvm nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to NNVM. @@ -1036,6 +1071,9 @@ class GraphProto(object): # Pass the node name too in attr attr["_node_name"] = node.name + # Pass the target layout + attr["_target_layout"] = layout + #ToDo: Some of the tensorflow operators internaly maintain #execution layers and its output name will the layer number along with #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the @@ -1265,7 +1303,7 @@ class GraphProto(object): return inputs -def from_tensorflow(graph): +def from_tensorflow(graph, layout="NHWC"): """ Load tensorflow graph which is a python tensorflow graph object into nnvm graph. The companion parameters will be handled automatically. @@ -1283,5 +1321,5 @@ def from_tensorflow(graph): Dict of converted parameters stored in tvm.ndarray format """ g = GraphProto() - sym, params = g.from_tensorflow(graph) + sym, params = g.from_tensorflow(graph, layout) return sym, params diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 61625950..ad7f41a8 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -26,11 +26,15 @@ import nnvm.testing.tf ####################################################################### # Generic run functions for TVM & tensorflow # ------------------------------------------ -def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype): +def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'): """ Generic function to compile on nnvm and execute on tvm """ - sym, params = nnvm.frontend.from_tensorflow(graph_def) - target = 'llvm' + layout = None + if target == "cuda": + layout = "NCHW" + + sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout) + target_host = 'llvm' if isinstance(input_data, list): shape_dict = {} dtype_dict = {} @@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) shape_dict = {input_node: input_data.shape} dtype_dict = {input_node: input_data.dtype} - graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, + graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict, dtype=dtype_dict, params=params) - ctx = tvm.cpu(0) + ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False) ) tf_output = run_tf_graph(sess, in_data, in_name, out_name) - tvm_output = run_tvm_graph(final_graph_def, in_data, - in_node, tf_output.shape, tf_output.dtype) - np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + + for device in ["llvm", "cuda"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + + tvm_output = run_tvm_graph(final_graph_def, in_data, + in_node, tf_output.shape, tf_output.dtype, target=device) + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + sess.close() ####################################################################### diff --git a/tutorials/nnvm/from_tensorflow.py b/tutorials/nnvm/from_tensorflow.py index 033cdd8a..7cd7e784 100644 --- a/tutorials/nnvm/from_tensorflow.py +++ b/tutorials/nnvm/from_tensorflow.py @@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto) lable_map = 'imagenet_synset_to_human_label_map.txt' lable_map_url = os.path.join(repo_base, lable_map) +# Target settings +# Use these commented settings to build for cuda. +#target = 'cuda' +#target_host = 'llvm' +#layout = "NCHW" +#ctx = tvm.gpu(0) +target = 'llvm' +target_host = 'llvm' +layout = None +ctx = tvm.cpu(0) ###################################################################### # Download required files @@ -99,7 +109,7 @@ x = np.array(image) # Results: # sym: nnvm graph for given tensorflow protobuf. # params: params converted from tensorflow params (tensor protobuf). -sym, params = nnvm.frontend.from_tensorflow(graph_def) +sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout) print ("Tensorflow protobuf imported as nnvm graph") ###################################################################### @@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph") # lib: target library which can be deployed on target with tvm runtime. import nnvm.compiler -target = 'llvm' shape_dict = {'DecodeJpeg/contents': x.shape} dtype_dict = {'DecodeJpeg/contents': 'uint8'} -graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params) +graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params) ###################################################################### # Execute the portable graph on TVM # --------------------------------- -# Now we can try deploying the NNVM compiled model on cpu target. +# Now we can try deploying the NNVM compiled model on target. from tvm.contrib import graph_runtime -ctx = tvm.cpu(0) dtype = 'uint8' m = graph_runtime.create(graph, lib, ctx) # set inputs