TFLite: Add fused_activation_function for ADD, SUB, MUL, DIV (#3372)

This commit is contained in:
Alexander Pivovarov 2019-06-17 12:36:31 -07:00 коммит произвёл Yao Wang
Родитель df6957a5ea
Коммит 5050ab5e16
2 изменённых файлов: 63 добавлений и 12 удалений

Просмотреть файл

@ -298,6 +298,12 @@ class OperatorConverter(object):
"""Generic method to Convert TFLite elemwise""" """Generic method to Convert TFLite elemwise"""
try: try:
from tflite.Operator import Operator from tflite.Operator import Operator
from tflite.AddOptions import AddOptions
from tflite.SubOptions import SubOptions
from tflite.MulOptions import MulOptions
from tflite.DivOptions import DivOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
@ -320,6 +326,26 @@ class OperatorConverter(object):
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str) dtype=rhs_type_str)
out = relay_op(lhs_expr, rhs_expr) out = relay_op(lhs_expr, rhs_expr)
# Options (fused_activation_function)
options = None
if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
options = AddOptions()
elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
options = SubOptions()
elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
options = MulOptions()
elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
options = DivOptions()
if options is not None:
op_options = op.BuiltinOptions()
options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = options.FusedActivationFunction()
# if we have activation fn
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out return out
def convert_add(self, op): def convert_add(self, op):

Просмотреть файл

@ -21,6 +21,7 @@ TFLite testcases
This article is a test script to test TFLite operator with Relay. This article is a test script to test TFLite operator with Relay.
""" """
from __future__ import print_function from __future__ import print_function
from functools import partial
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
@ -146,6 +147,20 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
def with_fused_activation_function(input_tensor, fn_name):
if fn_name is None or fn_name == "NONE":
return input_tensor
if fn_name == "RELU":
return nn_ops.relu(input_tensor)
if fn_name == "RELU6":
return nn_ops.relu6(input_tensor)
if fn_name == "RELU_N1_TO_1":
return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
if fn_name == "TANH":
return math_ops.tanh(input_tensor)
raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
####################################################################### #######################################################################
# Pooling # Pooling
# ------- # -------
@ -313,7 +328,7 @@ def test_forward_concatenation():
# Element-wise # Element-wise
# --- # ---
def _test_elemwise(math_op, data): def _test_elemwise(math_op, data, fused_activation_function=None):
""" One iteration of add """ """ One iteration of add """
assert len(data) == 2 assert len(data) == 2
@ -323,12 +338,14 @@ def _test_elemwise(math_op, data):
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
out = math_op(in_data[0], in_data[1]) out = math_op(in_data[0], in_data[1])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
# Test with tensor and constant # Test with tensor and constant
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
@ -336,31 +353,31 @@ def _test_elemwise(math_op, data):
# Add # Add
# --- # ---
def _test_add(data): def _test_add(data, fused_activation_function=None):
""" One iteration of add """ """ One iteration of add """
return _test_elemwise(math_ops.add, data) return _test_elemwise(math_ops.add, data, fused_activation_function)
####################################################################### #######################################################################
# Subtract # Subtract
# -------- # --------
def _test_sub(data): def _test_sub(data, fused_activation_function=None):
""" One iteration of subtract """ """ One iteration of subtract """
return _test_elemwise(math_ops.subtract, data) return _test_elemwise(math_ops.subtract, data, fused_activation_function)
####################################################################### #######################################################################
# Mul # Mul
# --- # ---
def _test_mul(data): def _test_mul(data, fused_activation_function=None):
""" One iteration of mul """ """ One iteration of mul """
return _test_elemwise(math_ops.multiply, data) return _test_elemwise(math_ops.multiply, data, fused_activation_function)
####################################################################### #######################################################################
# Divide # Divide
# ------ # ------
def _test_div(data): def _test_div(data, fused_activation_function=None):
""" One iteration of divide """ """ One iteration of divide """
return _test_elemwise(math_ops.divide, data) return _test_elemwise(math_ops.divide, data, fused_activation_function)
####################################################################### #######################################################################
# Power # Power
# ----- # -----
@ -386,17 +403,25 @@ def _test_minimum(data):
def _test_forward_elemwise(testop): def _test_forward_elemwise(testop):
""" Elewise""" """ Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)), testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))]) np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])
def test_all_elemwise(): def test_all_elemwise():
_test_forward_elemwise(_test_add) _test_forward_elemwise(_test_add)
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
_test_forward_elemwise(_test_sub) _test_forward_elemwise(_test_sub)
_test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
_test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
_test_forward_elemwise(_test_mul) _test_forward_elemwise(_test_mul)
_test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
_test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
_test_forward_elemwise(_test_div) _test_forward_elemwise(_test_div)
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
_test_forward_elemwise(_test_pow) _test_forward_elemwise(_test_pow)
_test_forward_elemwise(_test_maximum) _test_forward_elemwise(_test_maximum)
_test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_minimum)