зеркало из https://github.com/microsoft/MMdnn.git
TensorFlow -> MXNet mobilenet converts passed. correctness not tested.
This commit is contained in:
Родитель
18736e4e4d
Коммит
538f79252c
|
@ -25,7 +25,7 @@ class TestKit(object):
|
|||
'resnet' : [(22, 11.756789), (147, 8.5718527), (24, 6.1751032), (88, 4.3121386), (141, 4.1778097)],
|
||||
'resnet_v1_101' : [(21, 14.384739), (23, 14.262486), (144, 14.068737), (94, 12.17205), (134, 12.064575)],
|
||||
'inception_v3' : [(22, 9.4921198), (24, 4.0932288), (25, 3.700398), (23, 3.3715961), (147, 3.3620636)],
|
||||
'mobilenet' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
|
||||
'mobilenet_v1_1.0' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
|
||||
},
|
||||
'keras' : {
|
||||
'vgg16' : [(21, 0.81199354), (562, 0.019326132), (23, 0.018279659), (144, 0.012460723), (22, 0.012429929)],
|
||||
|
|
|
@ -383,12 +383,13 @@ def predict(model, labels, url):
|
|||
else:
|
||||
num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-1]
|
||||
|
||||
no_bias = not IR_node.IR_layer.attr["use_bias"].b
|
||||
if not no_bias and self.weight_loaded:
|
||||
use_bias = IR_node.get_attr('use_bias', False)
|
||||
if use_bias and self.weight_loaded:
|
||||
self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']
|
||||
|
||||
if pattern == "DepthwiseConv":
|
||||
num_group = num_filter
|
||||
num_group = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
|
||||
num_filter = num_filter * num_group
|
||||
pattern = "Convolution"
|
||||
else:
|
||||
num_group = IR_node.get_attr('group', 1)
|
||||
|
@ -404,6 +405,8 @@ def predict(model, labels, url):
|
|||
if self.weight_loaded:
|
||||
# if layout not in MXNetEmitter.channels_last:
|
||||
weights = MXNetEmitter.transpose(weights, dim)
|
||||
if num_group > 1:
|
||||
weights = np.swapaxes(weights, 0, 1)
|
||||
self.output_weights[IR_node.name + "_weight"] = weights
|
||||
|
||||
code = ""
|
||||
|
@ -418,7 +421,7 @@ def predict(model, labels, url):
|
|||
tuple(pad),
|
||||
num_filter,
|
||||
num_group,
|
||||
no_bias,
|
||||
not use_bias,
|
||||
layout,
|
||||
IR_node.name)
|
||||
else:
|
||||
|
@ -432,7 +435,7 @@ def predict(model, labels, url):
|
|||
dilate,
|
||||
num_filter,
|
||||
num_group,
|
||||
no_bias,
|
||||
not use_bias,
|
||||
layout,
|
||||
IR_node.name)
|
||||
|
||||
|
|
|
@ -323,7 +323,7 @@ class TestModels(CorrectnessTest):
|
|||
'resnet_v2_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'resnet_v2_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'mobilenet_v1_1.0' : [TensorflowEmit, KerasEmit],
|
||||
# 'inception_resnet_v2' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
# 'inception_resnet_v2' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], # TODO
|
||||
# 'nasnet-a_large' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO
|
||||
},
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче