[FRONTEND][ONNX]LRN support for ONNX (#1518)
* LRN support for ONNX * [ONNX] Updated lrn testcases
This commit is contained in:
Родитель
a8574e7bb8
Коммит
0241fdc5d2
|
@ -499,6 +499,23 @@ class Gather(OnnxOpConverter):
|
|||
params[name] = indices
|
||||
return _sym.take(inputs[0], gather_indices, axis=axis)
|
||||
|
||||
class LRN(OnnxOpConverter):
|
||||
""" Operator converter for Local Response Normalization.
|
||||
"""
|
||||
@classmethod
|
||||
def _impl_v1(cls, inputs, attr, params):
|
||||
"""LRN support only NCHW format
|
||||
https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
|
||||
"""
|
||||
axis = 1
|
||||
alpha = attr.get('alpha', 0.0001)
|
||||
beta = attr.get('beta', 0.75)
|
||||
bias = attr.get('bias', 1.0)
|
||||
nsize = attr.get('size')
|
||||
return _sym.lrn(inputs[0], size=nsize, axis=axis,
|
||||
alpha=alpha, beta=beta, bias=bias)
|
||||
|
||||
|
||||
# compatible operators that do NOT require any conversion.
|
||||
_identity_list = []
|
||||
|
||||
|
@ -586,7 +603,7 @@ def _get_convert_map(opset):
|
|||
# 'LpNormalization'
|
||||
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
|
||||
'Flatten': Renamer('flatten'),
|
||||
# 'LRN'
|
||||
'LRN': LRN.get_converter(opset),
|
||||
|
||||
# defs/reduction
|
||||
'ReduceMax': AttrCvt('max', {'axes', 'axis'}),
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import math
|
||||
import nnvm
|
||||
import tvm
|
||||
from tvm.contrib import graph_runtime
|
||||
|
@ -312,6 +313,58 @@ def test_matmul():
|
|||
|
||||
np.testing.assert_allclose(np.matmul(a_array, b_array), tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
|
||||
in_array = np.random.uniform(size=shape).astype(dtype)
|
||||
|
||||
if alpha == None and beta == None and bias==None:
|
||||
alpha = 0.0001
|
||||
beta = 0.75
|
||||
bias = 1.0
|
||||
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize)
|
||||
else:
|
||||
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
|
||||
beta=beta, bias=bias, size=nsize)
|
||||
|
||||
graph = helper.make_graph([node],
|
||||
"lrn_test",
|
||||
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
|
||||
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
|
||||
model = helper.make_model(graph, producer_name='lrn_test')
|
||||
|
||||
def _get_python_lrn():
|
||||
square_sum = np.zeros(shape).astype(dtype)
|
||||
for n, c, h, w in np.ndindex(in_array.shape):
|
||||
square_sum[n, c, h, w] = sum(in_array[n,
|
||||
max(0, c - int(math.floor((nsize - 1) / 2))): \
|
||||
min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
|
||||
h,
|
||||
w] ** 2)
|
||||
py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
|
||||
return py_out
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
new_sym, params = nnvm.frontend.from_onnx(model)
|
||||
|
||||
input_name = model.graph.input[0].name
|
||||
shape_dict = {input_name: in_array.shape}
|
||||
dtype_dict = {input_name: dtype}
|
||||
graph, lib, params = nnvm.compiler.build(new_sym, target,
|
||||
shape_dict, dtype_dict, params=params)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
m.set_input(input_name, tvm.nd.array(in_array.astype(dtype)))
|
||||
m.set_input(**params)
|
||||
m.run()
|
||||
# get outputs
|
||||
tvm_out = m.get_output(0, tvm.nd.empty(shape, dtype))
|
||||
py_out = _get_python_lrn()
|
||||
np.testing.assert_allclose(py_out, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_lrn():
|
||||
verify_lrn((5, 5, 5, 5), 3, 'float32')
|
||||
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# verify_super_resolution_example()
|
||||
# verify_squeezenet1_1()
|
||||
|
@ -328,3 +381,4 @@ if __name__ == '__main__':
|
|||
test_clip()
|
||||
test_matmul()
|
||||
test_gather()
|
||||
test_lrn()
|
||||
|
|
Загрузка…
Ссылка в новой задаче