[Relay][Frontend] Support TF Gather (#2935)

* [Relay][Frontend] Support TF Gather

* fix comments
This commit is contained in:
Yong Wu 2019-04-02 22:05:27 -07:00 коммит произвёл Siva
Родитель 4968279f87
Коммит 38151abd72
2 изменённых файлов: 52 добавлений и 9 удалений

Просмотреть файл

@ -673,10 +673,13 @@ def _square():
return _op.multiply(inputs[0], inputs[0])
return _impl
def _gather_v2():
"Tensorflow now support only gatherv2"
def _gather():
"GatherV2, Gather"
def _impl(inputs, attr, params):
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
axis = 0
if len(inputs) > 2:
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
@ -1013,7 +1016,8 @@ _convert_map = {
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'GatherV2' : _gather(),
'Gather' : _gather(),
'StridedSlice' : _stridedSlice(),
'LRN' : _lrn(),
'Pad' : _pad('Pad'),

Просмотреть файл

@ -19,8 +19,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2
from distutils.version import LooseVersion
import tvm.relay.testing.tf as tf_testing
#######################################################################
@ -473,11 +473,11 @@ def test_forward_stridedslice():
#######################################################################
# Gather
# ------
# Gather, GatherV2
# ----------------
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
""" One iteration of a Gather """
""" One iteration of a GatherV2 """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
@ -497,7 +497,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')
def test_forward_gather():
'''test gather layer'''
'''test GatherV2 layer'''
_test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1,4), (1,), [0], 0, 'int32')
@ -509,6 +509,44 @@ def test_forward_gather():
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
""" One iteration of a Gather"""
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
if isinstance(indice_value, int):
indices = np.array([indice_value], dtype='int32')
else:
indices = np.asarray(indice_value, dtype='int32')
return indices
np_indices = _fill_indices(indice_value)
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'Gather:0')
def test_forward_gather_v1():
'''test gather layer'''
if tf.__version__ < LooseVersion('1.7'):
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((4,), (1,), 1, 'int32')
_test_gather_v1((4,), (1,), 1, 'float32')
_test_gather_v1((1, 4), (1,), [0], 'int32')
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
_test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32')
#######################################################################
# Split
# -----
@ -1213,6 +1251,7 @@ if __name__ == '__main__':
test_forward_crop()
test_forward_pad()
test_forward_gather()
test_forward_gather_v1()
test_forward_stridedslice()
test_forward_split()
test_forward_unstack()