[TFLite] Support depthwise convolution multiplier greater than 1 (#3922)
This commit is contained in:
Родитель
54dbcc2872
Коммит
968ffef62b
|
@ -623,8 +623,6 @@ class OperatorConverter(object):
|
|||
conv_options = DepthwiseConv2DOptions()
|
||||
conv_options.Init(op_options.Bytes, op_options.Pos)
|
||||
depth_multiplier = conv_options.DepthMultiplier()
|
||||
assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \
|
||||
"original value is set to 0.25, 0.5 or anything else"
|
||||
else:
|
||||
raise tvm.error.OpNotImplemented(
|
||||
'Operator {} is not supported for frontend TFLite.'.format(conv_type))
|
||||
|
@ -636,11 +634,13 @@ class OperatorConverter(object):
|
|||
padding = conv_options.Padding()
|
||||
fused_activation_fn = conv_options.FusedActivationFunction()
|
||||
|
||||
_, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy()
|
||||
_, input_h, input_w, input_c = input_tensor.tensor.ShapeAsNumpy()
|
||||
|
||||
if is_depthwise_conv:
|
||||
multiplier, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
|
||||
assert multiplier == depth_multiplier
|
||||
# TFLite depthwise convolution kernel layout is:
|
||||
# 1 KH KW C(input_c * depth_multiplier)
|
||||
_, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
|
||||
assert in_channels == input_c * depth_multiplier
|
||||
else:
|
||||
output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy()
|
||||
|
||||
|
@ -654,7 +654,7 @@ class OperatorConverter(object):
|
|||
'data_layout': 'NHWC'}
|
||||
|
||||
if is_depthwise_conv:
|
||||
params['channels'] = int(in_channels * multiplier)
|
||||
params['channels'] = int(in_channels)
|
||||
params['groups'] = int(in_channels)
|
||||
params['kernel_layout'] = 'HWOI'
|
||||
else:
|
||||
|
@ -669,9 +669,16 @@ class OperatorConverter(object):
|
|||
in_expr = self.get_expr(input_tensor_idx)
|
||||
weight_value = self.get_tensor_value(weight_tensor)
|
||||
|
||||
# TFLite is OC/M KH KW IC, we require KH KW IC OC/M
|
||||
# M means multiplier in depthwise convolution
|
||||
weight_value = weight_value.transpose((1, 2, 3, 0))
|
||||
# TFLite kernel layout:
|
||||
# convolution:
|
||||
# OC KH KW IC, we require KH KW IC OC (HWIO)
|
||||
# depthwise convolution:
|
||||
# 1 KH KW C(input_c * depth_multiplier), we require
|
||||
# KH KW IC M (depth_multiplier) (HWOI)
|
||||
if is_depthwise_conv:
|
||||
weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
|
||||
else:
|
||||
weight_value = weight_value.transpose((1, 2, 3, 0))
|
||||
|
||||
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
|
||||
|
||||
|
|
|
@ -356,6 +356,7 @@ def test_forward_convolution():
|
|||
_test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
|
||||
_test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
|
||||
_test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
|
||||
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)
|
||||
|
||||
|
||||
#######################################################################
|
||||
|
|
Загрузка…
Ссылка в новой задаче