TFLite: Add fused_activation_function for ADD, SUB, MUL, DIV (#3372)
This commit is contained in:
Родитель
df6957a5ea
Коммит
5050ab5e16
|
@ -298,6 +298,12 @@ class OperatorConverter(object):
|
|||
"""Generic method to Convert TFLite elemwise"""
|
||||
try:
|
||||
from tflite.Operator import Operator
|
||||
from tflite.AddOptions import AddOptions
|
||||
from tflite.SubOptions import SubOptions
|
||||
from tflite.MulOptions import MulOptions
|
||||
from tflite.DivOptions import DivOptions
|
||||
from tflite.BuiltinOptions import BuiltinOptions
|
||||
from tflite.ActivationFunctionType import ActivationFunctionType
|
||||
except ImportError:
|
||||
raise ImportError("The tflite package must be installed")
|
||||
|
||||
|
@ -320,6 +326,26 @@ class OperatorConverter(object):
|
|||
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
|
||||
dtype=rhs_type_str)
|
||||
out = relay_op(lhs_expr, rhs_expr)
|
||||
|
||||
# Options (fused_activation_function)
|
||||
options = None
|
||||
if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
|
||||
options = AddOptions()
|
||||
elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
|
||||
options = SubOptions()
|
||||
elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
|
||||
options = MulOptions()
|
||||
elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
|
||||
options = DivOptions()
|
||||
|
||||
if options is not None:
|
||||
op_options = op.BuiltinOptions()
|
||||
options.Init(op_options.Bytes, op_options.Pos)
|
||||
fused_activation_fn = options.FusedActivationFunction()
|
||||
# if we have activation fn
|
||||
if fused_activation_fn != ActivationFunctionType.NONE:
|
||||
out = self.convert_fused_activation_function(out, fused_activation_fn)
|
||||
|
||||
return out
|
||||
|
||||
def convert_add(self, op):
|
||||
|
|
|
@ -21,6 +21,7 @@ TFLite testcases
|
|||
This article is a test script to test TFLite operator with Relay.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
|
@ -146,6 +147,20 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
|
|||
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def with_fused_activation_function(input_tensor, fn_name):
|
||||
if fn_name is None or fn_name == "NONE":
|
||||
return input_tensor
|
||||
if fn_name == "RELU":
|
||||
return nn_ops.relu(input_tensor)
|
||||
if fn_name == "RELU6":
|
||||
return nn_ops.relu6(input_tensor)
|
||||
if fn_name == "RELU_N1_TO_1":
|
||||
return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
|
||||
if fn_name == "TANH":
|
||||
return math_ops.tanh(input_tensor)
|
||||
raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Pooling
|
||||
# -------
|
||||
|
@ -313,7 +328,7 @@ def test_forward_concatenation():
|
|||
# Element-wise
|
||||
# ---
|
||||
|
||||
def _test_elemwise(math_op, data):
|
||||
def _test_elemwise(math_op, data, fused_activation_function=None):
|
||||
""" One iteration of add """
|
||||
|
||||
assert len(data) == 2
|
||||
|
@ -323,12 +338,14 @@ def _test_elemwise(math_op, data):
|
|||
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
|
||||
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
|
||||
out = math_op(in_data[0], in_data[1])
|
||||
out = with_fused_activation_function(out, fused_activation_function)
|
||||
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
|
||||
|
||||
# Test with tensor and constant
|
||||
with tf.Graph().as_default():
|
||||
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
|
||||
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
|
||||
out = with_fused_activation_function(out, fused_activation_function)
|
||||
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
|
||||
|
||||
|
||||
|
@ -336,31 +353,31 @@ def _test_elemwise(math_op, data):
|
|||
# Add
|
||||
# ---
|
||||
|
||||
def _test_add(data):
|
||||
def _test_add(data, fused_activation_function=None):
|
||||
""" One iteration of add """
|
||||
return _test_elemwise(math_ops.add, data)
|
||||
return _test_elemwise(math_ops.add, data, fused_activation_function)
|
||||
|
||||
#######################################################################
|
||||
# Subtract
|
||||
# --------
|
||||
|
||||
def _test_sub(data):
|
||||
def _test_sub(data, fused_activation_function=None):
|
||||
""" One iteration of subtract """
|
||||
return _test_elemwise(math_ops.subtract, data)
|
||||
return _test_elemwise(math_ops.subtract, data, fused_activation_function)
|
||||
#######################################################################
|
||||
# Mul
|
||||
# ---
|
||||
def _test_mul(data):
|
||||
def _test_mul(data, fused_activation_function=None):
|
||||
""" One iteration of mul """
|
||||
return _test_elemwise(math_ops.multiply, data)
|
||||
return _test_elemwise(math_ops.multiply, data, fused_activation_function)
|
||||
|
||||
#######################################################################
|
||||
# Divide
|
||||
# ------
|
||||
|
||||
def _test_div(data):
|
||||
def _test_div(data, fused_activation_function=None):
|
||||
""" One iteration of divide """
|
||||
return _test_elemwise(math_ops.divide, data)
|
||||
return _test_elemwise(math_ops.divide, data, fused_activation_function)
|
||||
#######################################################################
|
||||
# Power
|
||||
# -----
|
||||
|
@ -386,17 +403,25 @@ def _test_minimum(data):
|
|||
def _test_forward_elemwise(testop):
|
||||
""" Elewise"""
|
||||
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
|
||||
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
|
||||
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
|
||||
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
|
||||
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
|
||||
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
|
||||
testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
|
||||
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
|
||||
np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])
|
||||
|
||||
def test_all_elemwise():
|
||||
_test_forward_elemwise(_test_add)
|
||||
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
|
||||
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
|
||||
_test_forward_elemwise(_test_sub)
|
||||
_test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
|
||||
_test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
|
||||
_test_forward_elemwise(_test_mul)
|
||||
_test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
|
||||
_test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
|
||||
_test_forward_elemwise(_test_div)
|
||||
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
|
||||
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
|
||||
_test_forward_elemwise(_test_pow)
|
||||
_test_forward_elemwise(_test_maximum)
|
||||
_test_forward_elemwise(_test_minimum)
|
||||
|
|
Загрузка…
Ссылка в новой задаче