[NNVM][TENSORFLOW] Mobilenet support. (#1335)
This commit is contained in:
Родитель
ca2ad6d427
Коммит
373a8caa79
|
@ -35,6 +35,11 @@ class AttrCvt(object):
|
|||
self._ignores.append('use_cudnn_on_gpu')
|
||||
self._ignores.append('_node_name')
|
||||
self._ignores.append('is_training')
|
||||
# Retain the names
|
||||
try:
|
||||
attrs['name'] = attrs['_node_name']
|
||||
except KeyError:
|
||||
pass
|
||||
return AttrConvert(self._op_name, self._transforms, self._excludes,
|
||||
self._disables, self._ignores, self._extras,
|
||||
self._custom_check)(inputs, attrs, *args)
|
||||
|
@ -405,13 +410,19 @@ def _concat():
|
|||
|
||||
def _reshape():
|
||||
def _impl(inputs, attr, params):
|
||||
pop_node = inputs.pop(1)
|
||||
shape_arg = params[pop_node.list_output_names()[0]]
|
||||
params.pop(pop_node.list_output_names()[0])
|
||||
return AttrCvt(
|
||||
op_name="reshape",
|
||||
extras={'shape':tuple(shape_arg.asnumpy())},
|
||||
ignores=['Tshape'])(inputs, attr)
|
||||
try:
|
||||
pop_node = inputs[1]
|
||||
shape_arg = params.pop(pop_node.list_output_names()[0])
|
||||
inputs.pop(1)
|
||||
|
||||
return AttrCvt(
|
||||
op_name="reshape",
|
||||
extras={'shape':tuple(shape_arg.asnumpy())},
|
||||
ignores=['Tshape'])(inputs, attr)
|
||||
except KeyError:
|
||||
return AttrCvt(
|
||||
op_name="reshape_like",
|
||||
ignores=['Tshape'])(inputs, attr)
|
||||
return _impl
|
||||
|
||||
def _bias_add():
|
||||
|
@ -427,6 +438,18 @@ def _squeeze():
|
|||
ignores=['T'])(inputs, attr)
|
||||
return _impl
|
||||
|
||||
def _fused_batch_norm():
|
||||
def _impl(inputs, attr, params):
|
||||
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
|
||||
# NNVM: (data, gamma, beta, moving_mean, moving_varience)
|
||||
return AttrCvt(
|
||||
op_name='batch_norm',
|
||||
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
|
||||
extras={'axis': 3}, # Fix axis
|
||||
ignores=['data_format'],
|
||||
disables=['momentum'])(inputs, attr)
|
||||
return _impl
|
||||
|
||||
def _batch_norm():
|
||||
def _impl(inputs, attr, params):
|
||||
# Rearrange inputs from
|
||||
|
@ -445,19 +468,14 @@ def _batch_norm():
|
|||
|
||||
def _relu6():
|
||||
def _impl(inputs, attr, params):
|
||||
return _sym.clip(inputs[0], a_min=0, a_max=6)
|
||||
return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name'])
|
||||
return _impl
|
||||
|
||||
def _shape():
|
||||
def _impl(inputs, attr, params):
|
||||
input_shapes = attr['_input_shapes'][inputs[0]]
|
||||
|
||||
# Fix the -1 dimensions to 1
|
||||
input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]]
|
||||
params[attr['_node_name']] = tvm.nd.array(input_shapes[0])
|
||||
|
||||
return _sym.Variable(name=attr['_node_name'],
|
||||
shape=params[attr['_node_name']].shape)
|
||||
# Result of this operator is prominently used by reshape operator.
|
||||
# Just pass the input as it is so that reshape_like can be used there.
|
||||
return inputs[0]
|
||||
return _impl
|
||||
|
||||
# compatible operators that do NOT require any conversion.
|
||||
|
@ -491,7 +509,7 @@ _convert_map = {
|
|||
'Add' : _elemwise('add'),
|
||||
'Rsqrt' : _rsqrt(),
|
||||
'Squeeze' : _squeeze(),
|
||||
'FusedBatchNorm' : _batch_norm(),
|
||||
'FusedBatchNorm' : _fused_batch_norm(),
|
||||
'Relu6' : _relu6(),
|
||||
'DepthwiseConv2dNative' : _depthwise_conv(),
|
||||
'Shape' : _shape(),
|
||||
|
|
|
@ -153,6 +153,35 @@ def read_normalized_tensor_from_image_file(file_name,
|
|||
np_array = normalized.eval()
|
||||
return np_array
|
||||
|
||||
def get_workload(model_path):
|
||||
""" Import workload from frozen protobuf
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path: str
|
||||
model_path on remote repository to download from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
graph_def: graphdef
|
||||
graph_def is the tensorflow workload for mobilenet.
|
||||
|
||||
"""
|
||||
|
||||
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
|
||||
model_name = os.path.basename(model_path)
|
||||
model_url = os.path.join(repo_base, model_path)
|
||||
|
||||
from mxnet.gluon.utils import download
|
||||
download(model_url, model_name)
|
||||
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
graph = tf.import_graph_def(graph_def, name='')
|
||||
return graph_def
|
||||
|
||||
def get_workload_inception_v3():
|
||||
""" Import Inception V3 workload from frozen protobuf
|
||||
|
||||
|
@ -168,23 +197,15 @@ def get_workload_inception_v3():
|
|||
"""
|
||||
|
||||
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
|
||||
model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
|
||||
model_url = os.path.join(repo_base, model_name)
|
||||
model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'
|
||||
|
||||
image_name = 'elephant-299.jpg'
|
||||
image_url = os.path.join(repo_base, image_name)
|
||||
|
||||
from mxnet.gluon.utils import download
|
||||
download(model_url, model_name)
|
||||
download(image_url, image_name)
|
||||
|
||||
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))
|
||||
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
graph = tf.import_graph_def(graph_def, name='')
|
||||
return (normalized, graph_def)
|
||||
return (normalized, get_workload(model_path))
|
||||
|
||||
def get_workload_inception_v1():
|
||||
""" Import Inception V1 workload from frozen protobuf
|
||||
|
@ -203,13 +224,11 @@ def get_workload_inception_v1():
|
|||
"""
|
||||
|
||||
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
|
||||
model_name = 'classify_image_graph_def-with_shapes.pb'
|
||||
model_url = os.path.join(repo_base, model_name)
|
||||
model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb'
|
||||
image_name = 'elephant-299.jpg'
|
||||
image_url = os.path.join(repo_base, image_name)
|
||||
|
||||
from mxnet.gluon.utils import download
|
||||
download(model_url, model_name)
|
||||
download(image_url, image_name)
|
||||
|
||||
if not tf.gfile.Exists(os.path.join("./", image_name)):
|
||||
|
@ -221,9 +240,20 @@ def get_workload_inception_v1():
|
|||
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
|
||||
tvm_data = np.array(tvm_data)
|
||||
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
graph = tf.import_graph_def(graph_def, name='')
|
||||
return (image_data, tvm_data, graph_def)
|
||||
return (image_data, tvm_data, get_workload(model_path))
|
||||
|
||||
def get_workload_mobilenet():
|
||||
""" Import mobilenet workload from frozen protobuf
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Nothing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
graph_def: graphdef
|
||||
graph_def is the tensorflow workload for mobilenet.
|
||||
|
||||
"""
|
||||
|
||||
return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
|
||||
|
|
|
@ -406,6 +406,29 @@ def test_forward_inception_v1():
|
|||
|
||||
np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
|
||||
|
||||
#######################################################################
|
||||
# Mobilenet
|
||||
# ---------
|
||||
def test_forward_mobilenet():
|
||||
'''test mobilenet model'''
|
||||
with tf.Graph().as_default():
|
||||
graph_def = nnvm.testing.tf.get_workload_mobilenet()
|
||||
# Call the utility to import the graph definition into default graph.
|
||||
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
|
||||
|
||||
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
|
||||
out_node = 'MobilenetV1/Predictions/Reshape_1'
|
||||
|
||||
with tf.Session() as sess:
|
||||
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
|
||||
|
||||
out_shape = tf_output.shape
|
||||
tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32')
|
||||
top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1]
|
||||
top_tf = np.squeeze(tf_output).argsort()[-10:][::-1]
|
||||
|
||||
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
|
||||
|
||||
#######################################################################
|
||||
# Main
|
||||
# ----
|
||||
|
@ -419,3 +442,4 @@ if __name__ == '__main__':
|
|||
test_forward_multi_input()
|
||||
test_forward_inception_v3()
|
||||
test_forward_inception_v1()
|
||||
test_forward_mobilenet()
|
||||
|
|
Загрузка…
Ссылка в новой задаче