TFLite: Add fused_activation_function for ADD, SUB, MUL, DIV (#3372)
This commit is contained in:
Родитель
df6957a5ea
Коммит
5050ab5e16
|
@ -298,6 +298,12 @@ class OperatorConverter(object):
|
||||||
"""Generic method to Convert TFLite elemwise"""
|
"""Generic method to Convert TFLite elemwise"""
|
||||||
try:
|
try:
|
||||||
from tflite.Operator import Operator
|
from tflite.Operator import Operator
|
||||||
|
from tflite.AddOptions import AddOptions
|
||||||
|
from tflite.SubOptions import SubOptions
|
||||||
|
from tflite.MulOptions import MulOptions
|
||||||
|
from tflite.DivOptions import DivOptions
|
||||||
|
from tflite.BuiltinOptions import BuiltinOptions
|
||||||
|
from tflite.ActivationFunctionType import ActivationFunctionType
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The tflite package must be installed")
|
raise ImportError("The tflite package must be installed")
|
||||||
|
|
||||||
|
@ -320,6 +326,26 @@ class OperatorConverter(object):
|
||||||
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
|
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
|
||||||
dtype=rhs_type_str)
|
dtype=rhs_type_str)
|
||||||
out = relay_op(lhs_expr, rhs_expr)
|
out = relay_op(lhs_expr, rhs_expr)
|
||||||
|
|
||||||
|
# Options (fused_activation_function)
|
||||||
|
options = None
|
||||||
|
if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
|
||||||
|
options = AddOptions()
|
||||||
|
elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
|
||||||
|
options = SubOptions()
|
||||||
|
elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
|
||||||
|
options = MulOptions()
|
||||||
|
elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
|
||||||
|
options = DivOptions()
|
||||||
|
|
||||||
|
if options is not None:
|
||||||
|
op_options = op.BuiltinOptions()
|
||||||
|
options.Init(op_options.Bytes, op_options.Pos)
|
||||||
|
fused_activation_fn = options.FusedActivationFunction()
|
||||||
|
# if we have activation fn
|
||||||
|
if fused_activation_fn != ActivationFunctionType.NONE:
|
||||||
|
out = self.convert_fused_activation_function(out, fused_activation_fn)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def convert_add(self, op):
|
def convert_add(self, op):
|
||||||
|
|
|
@ -21,6 +21,7 @@ TFLite testcases
|
||||||
This article is a test script to test TFLite operator with Relay.
|
This article is a test script to test TFLite operator with Relay.
|
||||||
"""
|
"""
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tvm
|
import tvm
|
||||||
from tvm import relay
|
from tvm import relay
|
||||||
|
@ -146,6 +147,20 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
|
||||||
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
|
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def with_fused_activation_function(input_tensor, fn_name):
|
||||||
|
if fn_name is None or fn_name == "NONE":
|
||||||
|
return input_tensor
|
||||||
|
if fn_name == "RELU":
|
||||||
|
return nn_ops.relu(input_tensor)
|
||||||
|
if fn_name == "RELU6":
|
||||||
|
return nn_ops.relu6(input_tensor)
|
||||||
|
if fn_name == "RELU_N1_TO_1":
|
||||||
|
return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
|
||||||
|
if fn_name == "TANH":
|
||||||
|
return math_ops.tanh(input_tensor)
|
||||||
|
raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
|
||||||
|
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# Pooling
|
# Pooling
|
||||||
# -------
|
# -------
|
||||||
|
@ -313,7 +328,7 @@ def test_forward_concatenation():
|
||||||
# Element-wise
|
# Element-wise
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
def _test_elemwise(math_op, data):
|
def _test_elemwise(math_op, data, fused_activation_function=None):
|
||||||
""" One iteration of add """
|
""" One iteration of add """
|
||||||
|
|
||||||
assert len(data) == 2
|
assert len(data) == 2
|
||||||
|
@ -323,12 +338,14 @@ def _test_elemwise(math_op, data):
|
||||||
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
|
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
|
||||||
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
|
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
|
||||||
out = math_op(in_data[0], in_data[1])
|
out = math_op(in_data[0], in_data[1])
|
||||||
|
out = with_fused_activation_function(out, fused_activation_function)
|
||||||
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
|
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
|
||||||
|
|
||||||
# Test with tensor and constant
|
# Test with tensor and constant
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
|
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
|
||||||
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
|
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
|
||||||
|
out = with_fused_activation_function(out, fused_activation_function)
|
||||||
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
|
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
|
||||||
|
|
||||||
|
|
||||||
|
@ -336,31 +353,31 @@ def _test_elemwise(math_op, data):
|
||||||
# Add
|
# Add
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
def _test_add(data):
|
def _test_add(data, fused_activation_function=None):
|
||||||
""" One iteration of add """
|
""" One iteration of add """
|
||||||
return _test_elemwise(math_ops.add, data)
|
return _test_elemwise(math_ops.add, data, fused_activation_function)
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# Subtract
|
# Subtract
|
||||||
# --------
|
# --------
|
||||||
|
|
||||||
def _test_sub(data):
|
def _test_sub(data, fused_activation_function=None):
|
||||||
""" One iteration of subtract """
|
""" One iteration of subtract """
|
||||||
return _test_elemwise(math_ops.subtract, data)
|
return _test_elemwise(math_ops.subtract, data, fused_activation_function)
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# Mul
|
# Mul
|
||||||
# ---
|
# ---
|
||||||
def _test_mul(data):
|
def _test_mul(data, fused_activation_function=None):
|
||||||
""" One iteration of mul """
|
""" One iteration of mul """
|
||||||
return _test_elemwise(math_ops.multiply, data)
|
return _test_elemwise(math_ops.multiply, data, fused_activation_function)
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# Divide
|
# Divide
|
||||||
# ------
|
# ------
|
||||||
|
|
||||||
def _test_div(data):
|
def _test_div(data, fused_activation_function=None):
|
||||||
""" One iteration of divide """
|
""" One iteration of divide """
|
||||||
return _test_elemwise(math_ops.divide, data)
|
return _test_elemwise(math_ops.divide, data, fused_activation_function)
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# Power
|
# Power
|
||||||
# -----
|
# -----
|
||||||
|
@ -386,17 +403,25 @@ def _test_minimum(data):
|
||||||
def _test_forward_elemwise(testop):
|
def _test_forward_elemwise(testop):
|
||||||
""" Elewise"""
|
""" Elewise"""
|
||||||
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
|
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
|
||||||
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
|
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
|
||||||
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
|
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
|
||||||
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
|
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
|
||||||
testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
|
testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
|
||||||
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
|
np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])
|
||||||
|
|
||||||
def test_all_elemwise():
|
def test_all_elemwise():
|
||||||
_test_forward_elemwise(_test_add)
|
_test_forward_elemwise(_test_add)
|
||||||
|
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
|
||||||
|
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
|
||||||
_test_forward_elemwise(_test_sub)
|
_test_forward_elemwise(_test_sub)
|
||||||
|
_test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
|
||||||
|
_test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
|
||||||
_test_forward_elemwise(_test_mul)
|
_test_forward_elemwise(_test_mul)
|
||||||
|
_test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
|
||||||
|
_test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
|
||||||
_test_forward_elemwise(_test_div)
|
_test_forward_elemwise(_test_div)
|
||||||
|
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
|
||||||
|
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
|
||||||
_test_forward_elemwise(_test_pow)
|
_test_forward_elemwise(_test_pow)
|
||||||
_test_forward_elemwise(_test_maximum)
|
_test_forward_elemwise(_test_maximum)
|
||||||
_test_forward_elemwise(_test_minimum)
|
_test_forward_elemwise(_test_minimum)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче