From 0241fdc5d23a410cff63a0ffeab7c548fadfaa98 Mon Sep 17 00:00:00 2001 From: Albin Joy Date: Tue, 7 Aug 2018 08:06:37 +0530 Subject: [PATCH] [FRONTEND][ONNX]LRN support for ONNX (#1518) * LRN support for ONNX * [ONNX] Updated lrn testcases --- nnvm/python/nnvm/frontend/onnx.py | 19 ++++++- .../python/frontend/onnx/test_forward.py | 54 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index cfef11d6..f4062c10 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -499,6 +499,23 @@ class Gather(OnnxOpConverter): params[name] = indices return _sym.take(inputs[0], gather_indices, axis=axis) +class LRN(OnnxOpConverter): + """ Operator converter for Local Response Normalization. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + """LRN support only NCHW format + https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN + """ + axis = 1 + alpha = attr.get('alpha', 0.0001) + beta = attr.get('beta', 0.75) + bias = attr.get('bias', 1.0) + nsize = attr.get('size') + return _sym.lrn(inputs[0], size=nsize, axis=axis, + alpha=alpha, beta=beta, bias=bias) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -586,7 +603,7 @@ def _get_convert_map(opset): # 'LpNormalization' 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), 'Flatten': Renamer('flatten'), - # 'LRN' + 'LRN': LRN.get_converter(opset), # defs/reduction 'ReduceMax': AttrCvt('max', {'axes', 'axis'}), diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index bddf4a87..f4dc3559 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -1,4 +1,5 @@ import numpy as np +import math import nnvm import tvm from tvm.contrib import graph_runtime @@ -312,6 +313,58 @@ def test_matmul(): np.testing.assert_allclose(np.matmul(a_array, b_array), tvm_out.asnumpy(), rtol=1e-5, atol=1e-5) +def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): + in_array = np.random.uniform(size=shape).astype(dtype) + + if alpha == None and beta == None and bias==None: + alpha = 0.0001 + beta = 0.75 + bias = 1.0 + node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize) + else: + node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha, + beta=beta, bias=bias, size=nsize) + + graph = helper.make_graph([node], + "lrn_test", + inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], + outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))]) + model = helper.make_model(graph, producer_name='lrn_test') + + def _get_python_lrn(): + square_sum = np.zeros(shape).astype(dtype) + for n, c, h, w in np.ndindex(in_array.shape): + square_sum[n, c, h, w] = sum(in_array[n, + max(0, c - int(math.floor((nsize - 1) / 2))): \ + min(5, c + int(math.ceil((nsize - 1) / 2)) + 1), + h, + w] ** 2) + py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta) + return py_out + + for target, ctx in ctx_list(): + new_sym, params = nnvm.frontend.from_onnx(model) + + input_name = model.graph.input[0].name + shape_dict = {input_name: in_array.shape} + dtype_dict = {input_name: dtype} + graph, lib, params = nnvm.compiler.build(new_sym, target, + shape_dict, dtype_dict, params=params) + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input(input_name, tvm.nd.array(in_array.astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + tvm_out = m.get_output(0, tvm.nd.empty(shape, dtype)) + py_out = _get_python_lrn() + np.testing.assert_allclose(py_out, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5) + +def test_lrn(): + verify_lrn((5, 5, 5, 5), 3, 'float32') + verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0) + + if __name__ == '__main__': # verify_super_resolution_example() # verify_squeezenet1_1() @@ -328,3 +381,4 @@ if __name__ == '__main__': test_clip() test_matmul() test_gather() + test_lrn()