[Relay][Frontend] Support TF Gather (#2935)
* [Relay][Frontend] Support TF Gather * fix comments
This commit is contained in:
Родитель
4968279f87
Коммит
38151abd72
|
@ -673,10 +673,13 @@ def _square():
|
|||
return _op.multiply(inputs[0], inputs[0])
|
||||
return _impl
|
||||
|
||||
def _gather_v2():
|
||||
"Tensorflow now support only gatherv2"
|
||||
def _gather():
|
||||
"GatherV2, Gather"
|
||||
def _impl(inputs, attr, params):
|
||||
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
|
||||
|
||||
axis = 0
|
||||
if len(inputs) > 2:
|
||||
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
|
||||
new_input = []
|
||||
new_input.append(inputs.pop(0))
|
||||
new_input.append(inputs.pop(0))
|
||||
|
@ -1013,7 +1016,8 @@ _convert_map = {
|
|||
'Shape' : _shape(),
|
||||
'Sigmoid' : AttrCvt('sigmoid'),
|
||||
'Fill' : _fill(),
|
||||
'GatherV2' : _gather_v2(),
|
||||
'GatherV2' : _gather(),
|
||||
'Gather' : _gather(),
|
||||
'StridedSlice' : _stridedSlice(),
|
||||
'LRN' : _lrn(),
|
||||
'Pad' : _pad('Pad'),
|
||||
|
|
|
@ -19,8 +19,8 @@ from tensorflow.python.ops import math_ops
|
|||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import tvm.relay.testing.tf as tf_testing
|
||||
|
||||
#######################################################################
|
||||
|
@ -473,11 +473,11 @@ def test_forward_stridedslice():
|
|||
|
||||
|
||||
#######################################################################
|
||||
# Gather
|
||||
# ------
|
||||
# Gather, GatherV2
|
||||
# ----------------
|
||||
|
||||
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
|
||||
""" One iteration of a Gather """
|
||||
""" One iteration of a GatherV2 """
|
||||
|
||||
tf.reset_default_graph()
|
||||
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
|
||||
|
@ -497,7 +497,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
|
|||
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')
|
||||
|
||||
def test_forward_gather():
|
||||
'''test gather layer'''
|
||||
'''test GatherV2 layer'''
|
||||
_test_gather((4,), (1,), 1, 0, 'int32')
|
||||
_test_gather((4,), (1,), 1, 0, 'float32')
|
||||
_test_gather((1,4), (1,), [0], 0, 'int32')
|
||||
|
@ -509,6 +509,44 @@ def test_forward_gather():
|
|||
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
|
||||
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
|
||||
|
||||
|
||||
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
|
||||
""" One iteration of a Gather"""
|
||||
tf.reset_default_graph()
|
||||
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
|
||||
indices = tf.placeholder("int32", indice_shape, name="indices")
|
||||
tf.gather(in_data, indices)
|
||||
np_data = np.random.uniform(size=ip_shape).astype(dtype)
|
||||
|
||||
def _fill_indices(indice_value):
|
||||
indices = np.array(ip_shape, dtype=dtype)
|
||||
if isinstance(indice_value, int):
|
||||
indices = np.array([indice_value], dtype='int32')
|
||||
else:
|
||||
indices = np.asarray(indice_value, dtype='int32')
|
||||
return indices
|
||||
np_indices = _fill_indices(indice_value)
|
||||
|
||||
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'Gather:0')
|
||||
|
||||
|
||||
def test_forward_gather_v1():
|
||||
'''test gather layer'''
|
||||
|
||||
if tf.__version__ < LooseVersion('1.7'):
|
||||
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
|
||||
_test_gather_v1((4,), (1,), 1, 'int32')
|
||||
_test_gather_v1((4,), (1,), 1, 'float32')
|
||||
_test_gather_v1((1, 4), (1,), [0], 'int32')
|
||||
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
|
||||
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
|
||||
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
|
||||
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
|
||||
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
|
||||
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
|
||||
_test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32')
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Split
|
||||
# -----
|
||||
|
@ -1213,6 +1251,7 @@ if __name__ == '__main__':
|
|||
test_forward_crop()
|
||||
test_forward_pad()
|
||||
test_forward_gather()
|
||||
test_forward_gather_v1()
|
||||
test_forward_stridedslice()
|
||||
test_forward_split()
|
||||
test_forward_unstack()
|
||||
|
|
Загрузка…
Ссылка в новой задаче