This commit is contained in:
Ru ZHANG 2018-02-09 10:32:19 +08:00
Родитель 2e10150d9e
Коммит faae5ff444
2 изменённых файлов: 7 добавлений и 9 удалений

Просмотреть файл

@ -64,7 +64,6 @@ def _main():
# net = caffe.Net(arch_fn, weight_fn, caffe.TEST)
func = TestKit.preprocess_func['caffe'][args.network]
img = func(args.image)
print(img.size)
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, 0)
net.blobs['data'].data[...] = img

Просмотреть файл

@ -275,15 +275,15 @@ class TestModels(CorrectnessTest):
'imagenet1k-inception-bn' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
'imagenet1k-resnet-152' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
'squeezenet_v1.1' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
'imagenet1k-resnext-101-64x4d' : [TensorflowEmit, PytorchEmit], # TODO: CntkEmit
'imagenet1k-resnext-50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
'imagenet1k-resnext-101-64x4d' : [CntkEmit, TensorflowEmit, PytorchEmit], # Keras is too slow
'imagenet1k-resnext-50' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
},
'caffe' : {
# 'vgg19' : [KerasEmit],
'vgg19' : [KerasEmit],
# 'alexnet' : [KerasEmit],
# 'inception_v1' : [CntkEmit],
# 'resnet152' : [CntkEmit],
# 'squeezenet' : [CntkEmit]
'inception_v1' : [CntkEmit],
'resnet152' : [CntkEmit],
'squeezenet' : [CntkEmit]
}
}
@ -339,7 +339,6 @@ class TestModels(CorrectnessTest):
self._compare_outputs(original_predict, converted_predict)
os.remove(self.tmpdir + network_name + "_converted.json")
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))
@ -367,7 +366,7 @@ class TestModels(CorrectnessTest):
self._compare_outputs(original_predict, converted_predict)
os.remove(self.tmpdir + network_name + "_converted.json")
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))