This commit is contained in:
Ru ZHANG 2018-02-09 10:28:52 +08:00
Родитель 779a252690 4c794431b8
Коммит 2e10150d9e
1 изменённых файлов: 39 добавлений и 48 удалений

Просмотреть файл

@ -279,7 +279,7 @@ class TestModels(CorrectnessTest):
'imagenet1k-resnext-50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
},
'caffe' : {
'vgg19' : [KerasEmit],
# 'vgg19' : [KerasEmit],
# 'alexnet' : [KerasEmit],
# 'inception_v1' : [CntkEmit],
# 'resnet152' : [CntkEmit],
@ -317,68 +317,59 @@ class TestModels(CorrectnessTest):
print("Testing {} model all passed.".format(original_framework))
# def test_keras(self):
# # keras original
# ensure_dir(self.cachedir)
# ensure_dir(self.tmpdir)
# original_framework = 'keras'
def test_keras(self):
# keras original
ensure_dir(self.cachedir)
ensure_dir(self.tmpdir)
original_framework = 'keras'
# for network_name in self.test_table[original_framework].keys():
# print("Testing {} model {} start.".format(original_framework, network_name))
for network_name in self.test_table[original_framework].keys():
print("Testing {} model {} start.".format(original_framework, network_name))
# # get original model prediction result
# original_predict = self.KerasParse(network_name, self.image_path)
# get original model prediction result
original_predict = self.KerasParse(network_name, self.image_path)
# for emit in self.test_table[original_framework][network_name]:
# converted_predict = emit.__func__(
# original_framework,
# network_name,
# self.tmpdir + network_name + "_converted.pb",
# self.tmpdir + network_name + "_converted.npy",
# self.image_path)
for emit in self.test_table[original_framework][network_name]:
converted_predict = emit.__func__(
original_framework,
network_name,
self.tmpdir + network_name + "_converted.pb",
self.tmpdir + network_name + "_converted.npy",
self.image_path)
# self._compare_outputs(original_predict, converted_predict)
self._compare_outputs(original_predict, converted_predict)
# os.remove(self.tmpdir + network_name + "_converted.pb")
# os.remove(self.tmpdir + network_name + "_converted.npy")
# print("Testing {} model {} passed.".format(original_framework, network_name))
<<<<<<< HEAD
# print("Testing {} model all passed.".format(original_framework))
=======
os.remove(self.tmpdir + network_name + "_converted.json")
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))
>>>>>>> 4c794431b8c07d1be1dffc1a9c705cb57831b31e
# def test_mxnet(self):
# # mxnet original
# ensure_dir(self.cachedir)
# ensure_dir(self.tmpdir)
# original_framework = 'mxnet'
def test_mxnet(self):
# mxnet original
ensure_dir(self.cachedir)
ensure_dir(self.tmpdir)
original_framework = 'mxnet'
# for network_name in self.test_table[original_framework].keys():
# print("Testing {} model {} start.".format(original_framework, network_name))
for network_name in self.test_table[original_framework].keys():
print("Testing {} model {} start.".format(original_framework, network_name))
# # get original model prediction result
# original_predict = self.MXNetParse(network_name, self.image_path)
# get original model prediction result
original_predict = self.MXNetParse(network_name, self.image_path)
# for emit in self.test_table[original_framework][network_name]:
# converted_predict = emit.__func__(
# original_framework,
# network_name,
# self.tmpdir + network_name + "_converted.pb",
# self.tmpdir + network_name + "_converted.npy",
# self.image_path)
for emit in self.test_table[original_framework][network_name]:
converted_predict = emit.__func__(
original_framework,
network_name,
self.tmpdir + network_name + "_converted.pb",
self.tmpdir + network_name + "_converted.npy",
self.image_path)
# self._compare_outputs(original_predict, converted_predict)
self._compare_outputs(original_predict, converted_predict)
# os.remove(self.tmpdir + network_name + "_converted.pb")
# os.remove(self.tmpdir + network_name + "_converted.npy")
# print("Testing {} model {} passed.".format(original_framework, network_name))
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))
# print("Testing {} model all passed.".format(original_framework))
print("Testing {} model all passed.".format(original_framework))