зеркало из https://github.com/microsoft/MMdnn.git
fix conflict in tensorflow
This commit is contained in:
Коммит
2e10150d9e
|
@ -279,7 +279,7 @@ class TestModels(CorrectnessTest):
|
|||
'imagenet1k-resnext-50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
},
|
||||
'caffe' : {
|
||||
'vgg19' : [KerasEmit],
|
||||
# 'vgg19' : [KerasEmit],
|
||||
# 'alexnet' : [KerasEmit],
|
||||
# 'inception_v1' : [CntkEmit],
|
||||
# 'resnet152' : [CntkEmit],
|
||||
|
@ -317,68 +317,59 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
print("Testing {} model all passed.".format(original_framework))
|
||||
|
||||
# def test_keras(self):
|
||||
# # keras original
|
||||
# ensure_dir(self.cachedir)
|
||||
# ensure_dir(self.tmpdir)
|
||||
# original_framework = 'keras'
|
||||
def test_keras(self):
|
||||
# keras original
|
||||
ensure_dir(self.cachedir)
|
||||
ensure_dir(self.tmpdir)
|
||||
original_framework = 'keras'
|
||||
|
||||
# for network_name in self.test_table[original_framework].keys():
|
||||
# print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# # get original model prediction result
|
||||
# original_predict = self.KerasParse(network_name, self.image_path)
|
||||
# get original model prediction result
|
||||
original_predict = self.KerasParse(network_name, self.image_path)
|
||||
|
||||
# for emit in self.test_table[original_framework][network_name]:
|
||||
# converted_predict = emit.__func__(
|
||||
# original_framework,
|
||||
# network_name,
|
||||
# self.tmpdir + network_name + "_converted.pb",
|
||||
# self.tmpdir + network_name + "_converted.npy",
|
||||
# self.image_path)
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
self.tmpdir + network_name + "_converted.pb",
|
||||
self.tmpdir + network_name + "_converted.npy",
|
||||
self.image_path)
|
||||
|
||||
# self._compare_outputs(original_predict, converted_predict)
|
||||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
# os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
# os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
# print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
<<<<<<< HEAD
|
||||
# print("Testing {} model all passed.".format(original_framework))
|
||||
=======
|
||||
os.remove(self.tmpdir + network_name + "_converted.json")
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
>>>>>>> 4c794431b8c07d1be1dffc1a9c705cb57831b31e
|
||||
|
||||
|
||||
# def test_mxnet(self):
|
||||
# # mxnet original
|
||||
# ensure_dir(self.cachedir)
|
||||
# ensure_dir(self.tmpdir)
|
||||
# original_framework = 'mxnet'
|
||||
def test_mxnet(self):
|
||||
# mxnet original
|
||||
ensure_dir(self.cachedir)
|
||||
ensure_dir(self.tmpdir)
|
||||
original_framework = 'mxnet'
|
||||
|
||||
# for network_name in self.test_table[original_framework].keys():
|
||||
# print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# # get original model prediction result
|
||||
# original_predict = self.MXNetParse(network_name, self.image_path)
|
||||
# get original model prediction result
|
||||
original_predict = self.MXNetParse(network_name, self.image_path)
|
||||
|
||||
# for emit in self.test_table[original_framework][network_name]:
|
||||
# converted_predict = emit.__func__(
|
||||
# original_framework,
|
||||
# network_name,
|
||||
# self.tmpdir + network_name + "_converted.pb",
|
||||
# self.tmpdir + network_name + "_converted.npy",
|
||||
# self.image_path)
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
self.tmpdir + network_name + "_converted.pb",
|
||||
self.tmpdir + network_name + "_converted.npy",
|
||||
self.image_path)
|
||||
|
||||
# self._compare_outputs(original_predict, converted_predict)
|
||||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
# os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
# os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
# print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
# print("Testing {} model all passed.".format(original_framework))
|
||||
print("Testing {} model all passed.".format(original_framework))
|
||||
|
|
Загрузка…
Ссылка в новой задаче