TensorFlow -> MXNet mobilenet converts passed. correctness not tested.

This commit is contained in:
Kit 2018-02-09 18:26:59 +08:00
Родитель 18736e4e4d
Коммит 538f79252c
3 изменённых файлов: 10 добавлений и 7 удалений

Просмотреть файл

@ -25,7 +25,7 @@ class TestKit(object):
'resnet' : [(22, 11.756789), (147, 8.5718527), (24, 6.1751032), (88, 4.3121386), (141, 4.1778097)],
'resnet_v1_101' : [(21, 14.384739), (23, 14.262486), (144, 14.068737), (94, 12.17205), (134, 12.064575)],
'inception_v3' : [(22, 9.4921198), (24, 4.0932288), (25, 3.700398), (23, 3.3715961), (147, 3.3620636)],
'mobilenet' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
'mobilenet_v1_1.0' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
},
'keras' : {
'vgg16' : [(21, 0.81199354), (562, 0.019326132), (23, 0.018279659), (144, 0.012460723), (22, 0.012429929)],

Просмотреть файл

@ -383,12 +383,13 @@ def predict(model, labels, url):
else:
num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-1]
no_bias = not IR_node.IR_layer.attr["use_bias"].b
if not no_bias and self.weight_loaded:
use_bias = IR_node.get_attr('use_bias', False)
if use_bias and self.weight_loaded:
self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']
if pattern == "DepthwiseConv":
num_group = num_filter
num_group = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
num_filter = num_filter * num_group
pattern = "Convolution"
else:
num_group = IR_node.get_attr('group', 1)
@ -404,6 +405,8 @@ def predict(model, labels, url):
if self.weight_loaded:
# if layout not in MXNetEmitter.channels_last:
weights = MXNetEmitter.transpose(weights, dim)
if num_group > 1:
weights = np.swapaxes(weights, 0, 1)
self.output_weights[IR_node.name + "_weight"] = weights
code = ""
@ -418,7 +421,7 @@ def predict(model, labels, url):
tuple(pad),
num_filter,
num_group,
no_bias,
not use_bias,
layout,
IR_node.name)
else:
@ -432,7 +435,7 @@ def predict(model, labels, url):
dilate,
num_filter,
num_group,
no_bias,
not use_bias,
layout,
IR_node.name)

Просмотреть файл

@ -323,7 +323,7 @@ class TestModels(CorrectnessTest):
'resnet_v2_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
'resnet_v2_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
'mobilenet_v1_1.0' : [TensorflowEmit, KerasEmit],
# 'inception_resnet_v2' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
# 'inception_resnet_v2' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], # TODO
# 'nasnet-a_large' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO
},
}