[TFLite] Support depthwise convolution multiplier greater than 1 (#3922)

This commit is contained in:
Zhao Wu 2019-09-11 12:09:25 +08:00 коммит произвёл Thierry Moreau
Родитель 54dbcc2872
Коммит 968ffef62b
2 изменённых файлов: 17 добавлений и 9 удалений

Просмотреть файл

@ -623,8 +623,6 @@ class OperatorConverter(object):
conv_options = DepthwiseConv2DOptions()
conv_options.Init(op_options.Bytes, op_options.Pos)
depth_multiplier = conv_options.DepthMultiplier()
assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \
"original value is set to 0.25, 0.5 or anything else"
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(conv_type))
@ -636,11 +634,13 @@ class OperatorConverter(object):
padding = conv_options.Padding()
fused_activation_fn = conv_options.FusedActivationFunction()
_, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy()
_, input_h, input_w, input_c = input_tensor.tensor.ShapeAsNumpy()
if is_depthwise_conv:
multiplier, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
assert multiplier == depth_multiplier
# TFLite depthwise convolution kernel layout is:
# 1 KH KW C(input_c * depth_multiplier)
_, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
assert in_channels == input_c * depth_multiplier
else:
output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy()
@ -654,7 +654,7 @@ class OperatorConverter(object):
'data_layout': 'NHWC'}
if is_depthwise_conv:
params['channels'] = int(in_channels * multiplier)
params['channels'] = int(in_channels)
params['groups'] = int(in_channels)
params['kernel_layout'] = 'HWOI'
else:
@ -669,9 +669,16 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor)
# TFLite is OC/M KH KW IC, we require KH KW IC OC/M
# M means multiplier in depthwise convolution
weight_value = weight_value.transpose((1, 2, 3, 0))
# TFLite kernel layout:
# convolution:
# OC KH KW IC, we require KH KW IC OC (HWIO)
# depthwise convolution:
# 1 KH KW C(input_c * depth_multiplier), we require
# KH KW IC M (depth_multiplier) (HWOI)
if is_depthwise_conv:
weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
else:
weight_value = weight_value.transpose((1, 2, 3, 0))
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

Просмотреть файл

@ -356,6 +356,7 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
_test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)
#######################################################################