[NNVM][TENSORFLOW]Local Response Normalization added for tensorflow (#1522)
This commit is contained in:
Родитель
32076df815
Коммит
a2870fefc8
|
@ -468,6 +468,20 @@ def _fill():
|
|||
ignores=['index_type', 'T'])(new_inputs, attr)
|
||||
return _impl
|
||||
|
||||
def _lrn():
|
||||
def _impl(inputs, attr, params):
|
||||
new_inputs = []
|
||||
attr_new = {}
|
||||
depth_radius = attr.get('depth_radius', 5)
|
||||
size = (depth_radius * 2) + 1
|
||||
attr_new['axis'] = 3 # Fix axis, NHWC format
|
||||
attr_new['size'] = size
|
||||
attr_new['bias'] = attr.get('bias', 1)
|
||||
attr_new['alpha'] = attr.get('alpha', 1) * size
|
||||
attr_new['beta'] = attr.get('beta', 0.5)
|
||||
return AttrCvt(op_name='lrn')(new_inputs, attr_new)
|
||||
return _impl
|
||||
|
||||
def _gather_v2():
|
||||
"Tensorflow now support only gatherv2"
|
||||
def _impl(inputs, attr, params):
|
||||
|
@ -680,6 +694,7 @@ _convert_map = {
|
|||
'Fill' : _fill(),
|
||||
'GatherV2' : _gather_v2(),
|
||||
'StridedSlice' : _stridedSlice(),
|
||||
'LRN' : _lrn(),
|
||||
}
|
||||
|
||||
# _convert_map_rnn defines maps of rnn operator name to
|
||||
|
|
|
@ -855,6 +855,40 @@ def test_forward_ptb():
|
|||
assert(tvm_sample_str == tf_sample_str)
|
||||
|
||||
#######################################################################
|
||||
# LRN (Local Response Normalization)
|
||||
# ----------------------------------
|
||||
|
||||
def _test_lrn(ishape, size, axis, bias, alpha, beta):
|
||||
""" testing local response normalization """
|
||||
lrn_depth_radius = size / 2
|
||||
|
||||
inp_array = np.random.uniform(size=ishape).astype(np.float32)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
|
||||
nn_ops.local_response_normalization(in1,
|
||||
name="lrn",
|
||||
depth_radius=lrn_depth_radius,
|
||||
bias=bias,
|
||||
alpha=alpha,
|
||||
beta=beta)
|
||||
|
||||
with tf.Session() as sess:
|
||||
graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
sess.graph.as_graph_def(add_shapes=True),
|
||||
['lrn'],)
|
||||
|
||||
tf_output = run_tf_graph(sess, inp_array, 'lrn0_data:0', 'lrn:0')
|
||||
tvm_output = run_tvm_graph(graph_def,
|
||||
inp_array,
|
||||
"lrn0_data", tf_output.shape, tf_output.dtype)
|
||||
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
|
||||
sess.close()
|
||||
|
||||
def test_forward_lrn():
|
||||
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
|
||||
|
||||
# Main
|
||||
# ----
|
||||
if __name__ == '__main__':
|
||||
|
@ -875,3 +909,4 @@ if __name__ == '__main__':
|
|||
test_forward_stridedslice()
|
||||
test_forward_gather()
|
||||
test_forward_ptb()
|
||||
test_forward_lrn()
|
||||
|
|
Загрузка…
Ссылка в новой задаче