From 373a8caa79293bcfb9a026fa73e9f25885240dd0 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 26 Jun 2018 22:45:24 +0530 Subject: [PATCH] [NNVM][TENSORFLOW] Mobilenet support. (#1335) --- nnvm/python/nnvm/frontend/tensorflow.py | 52 +++++++++----- nnvm/python/nnvm/testing/tf.py | 70 +++++++++++++------ .../frontend/tensorflow/test_forward.py | 24 +++++++ 3 files changed, 109 insertions(+), 37 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 3ab99761..71536517 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -35,6 +35,11 @@ class AttrCvt(object): self._ignores.append('use_cudnn_on_gpu') self._ignores.append('_node_name') self._ignores.append('is_training') + # Retain the names + try: + attrs['name'] = attrs['_node_name'] + except KeyError: + pass return AttrConvert(self._op_name, self._transforms, self._excludes, self._disables, self._ignores, self._extras, self._custom_check)(inputs, attrs, *args) @@ -405,13 +410,19 @@ def _concat(): def _reshape(): def _impl(inputs, attr, params): - pop_node = inputs.pop(1) - shape_arg = params[pop_node.list_output_names()[0]] - params.pop(pop_node.list_output_names()[0]) - return AttrCvt( - op_name="reshape", - extras={'shape':tuple(shape_arg.asnumpy())}, - ignores=['Tshape'])(inputs, attr) + try: + pop_node = inputs[1] + shape_arg = params.pop(pop_node.list_output_names()[0]) + inputs.pop(1) + + return AttrCvt( + op_name="reshape", + extras={'shape':tuple(shape_arg.asnumpy())}, + ignores=['Tshape'])(inputs, attr) + except KeyError: + return AttrCvt( + op_name="reshape_like", + ignores=['Tshape'])(inputs, attr) return _impl def _bias_add(): @@ -427,6 +438,18 @@ def _squeeze(): ignores=['T'])(inputs, attr) return _impl +def _fused_batch_norm(): + def _impl(inputs, attr, params): + # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) + # NNVM: (data, gamma, beta, moving_mean, moving_varience) + return AttrCvt( + op_name='batch_norm', + transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, + extras={'axis': 3}, # Fix axis + ignores=['data_format'], + disables=['momentum'])(inputs, attr) + return _impl + def _batch_norm(): def _impl(inputs, attr, params): # Rearrange inputs from @@ -445,19 +468,14 @@ def _batch_norm(): def _relu6(): def _impl(inputs, attr, params): - return _sym.clip(inputs[0], a_min=0, a_max=6) + return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name']) return _impl def _shape(): def _impl(inputs, attr, params): - input_shapes = attr['_input_shapes'][inputs[0]] - - # Fix the -1 dimensions to 1 - input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]] - params[attr['_node_name']] = tvm.nd.array(input_shapes[0]) - - return _sym.Variable(name=attr['_node_name'], - shape=params[attr['_node_name']].shape) + # Result of this operator is prominently used by reshape operator. + # Just pass the input as it is so that reshape_like can be used there. + return inputs[0] return _impl # compatible operators that do NOT require any conversion. @@ -491,7 +509,7 @@ _convert_map = { 'Add' : _elemwise('add'), 'Rsqrt' : _rsqrt(), 'Squeeze' : _squeeze(), - 'FusedBatchNorm' : _batch_norm(), + 'FusedBatchNorm' : _fused_batch_norm(), 'Relu6' : _relu6(), 'DepthwiseConv2dNative' : _depthwise_conv(), 'Shape' : _shape(), diff --git a/nnvm/python/nnvm/testing/tf.py b/nnvm/python/nnvm/testing/tf.py index 3421573e..1762ce56 100644 --- a/nnvm/python/nnvm/testing/tf.py +++ b/nnvm/python/nnvm/testing/tf.py @@ -153,6 +153,35 @@ def read_normalized_tensor_from_image_file(file_name, np_array = normalized.eval() return np_array +def get_workload(model_path): + """ Import workload from frozen protobuf + + Parameters + ---------- + model_path: str + model_path on remote repository to download from. + + Returns + ------- + graph_def: graphdef + graph_def is the tensorflow workload for mobilenet. + + """ + + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/' + model_name = os.path.basename(model_path) + model_url = os.path.join(repo_base, model_path) + + from mxnet.gluon.utils import download + download(model_url, model_name) + + # Creates graph from saved graph_def.pb. + with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + graph = tf.import_graph_def(graph_def, name='') + return graph_def + def get_workload_inception_v3(): """ Import Inception V3 workload from frozen protobuf @@ -168,23 +197,15 @@ def get_workload_inception_v3(): """ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/' - model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb' - model_url = os.path.join(repo_base, model_name) + model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb' + image_name = 'elephant-299.jpg' image_url = os.path.join(repo_base, image_name) - from mxnet.gluon.utils import download - download(model_url, model_name) download(image_url, image_name) - normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name)) - # Creates graph from saved graph_def.pb. - with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - graph = tf.import_graph_def(graph_def, name='') - return (normalized, graph_def) + return (normalized, get_workload(model_path)) def get_workload_inception_v1(): """ Import Inception V1 workload from frozen protobuf @@ -203,13 +224,11 @@ def get_workload_inception_v1(): """ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' - model_name = 'classify_image_graph_def-with_shapes.pb' - model_url = os.path.join(repo_base, model_name) + model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb' image_name = 'elephant-299.jpg' image_url = os.path.join(repo_base, image_name) from mxnet.gluon.utils import download - download(model_url, model_name) download(image_url, image_name) if not tf.gfile.Exists(os.path.join("./", image_name)): @@ -221,9 +240,20 @@ def get_workload_inception_v1(): tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299)) tvm_data = np.array(tvm_data) - # Creates graph from saved graph_def.pb. - with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - graph = tf.import_graph_def(graph_def, name='') - return (image_data, tvm_data, graph_def) + return (image_data, tvm_data, get_workload(model_path)) + +def get_workload_mobilenet(): + """ Import mobilenet workload from frozen protobuf + + Parameters + ---------- + Nothing. + + Returns + ------- + graph_def: graphdef + graph_def is the tensorflow workload for mobilenet. + + """ + + return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb") diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 4e742a4a..6dc8cfab 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -406,6 +406,29 @@ def test_forward_inception_v1(): np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2) +####################################################################### +# Mobilenet +# --------- +def test_forward_mobilenet(): + '''test mobilenet model''' + with tf.Graph().as_default(): + graph_def = nnvm.testing.tf.get_workload_mobilenet() + # Call the utility to import the graph definition into default graph. + graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + out_node = 'MobilenetV1/Predictions/Reshape_1' + + with tf.Session() as sess: + tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') + + out_shape = tf_output.shape + tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32') + top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1] + top_tf = np.squeeze(tf_output).argsort()[-10:][::-1] + + np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5) + ####################################################################### # Main # ---- @@ -419,3 +442,4 @@ if __name__ == '__main__': test_forward_multi_input() test_forward_inception_v3() test_forward_inception_v1() + test_forward_mobilenet()