зеркало из https://github.com/microsoft/MMdnn.git
Refactor pytest.
This commit is contained in:
Родитель
e7dbbf68af
Коммит
e0e9ea05a2
|
@ -82,7 +82,7 @@ class TestKit(object):
|
|||
'vgg16' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'inception_v3' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'resnet50' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'xception' : lambda path : TestKit.Standard(path, 299),
|
||||
'mobilenet' : lambda path : TestKit.Standard(path, 224),
|
||||
'inception_resnet_v2' : lambda path : TestKit.Standard(path, 299),
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
from __future__ import absolute_import
|
||||
import os
|
||||
import keras
|
||||
from keras import backend as K
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
from mmdnn.conversion.examples.extractor import base_extractor
|
||||
|
||||
|
@ -37,27 +38,20 @@ class keras_extractor(base_extractor):
|
|||
}
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def download(cls, architecture, path='test/model/'):
|
||||
def download(cls, architecture, path="./"):
|
||||
if cls.sanity_check(architecture):
|
||||
if os.path.exists(path + 'imagenet_{}.h5'.format(architecture)) == False:
|
||||
print("No model before")
|
||||
output_filename = path + 'imagenet_{}.h5'.format(architecture)
|
||||
if os.path.exists(output_filename) == False:
|
||||
model = cls.architecture_map[architecture]()
|
||||
output_filename = path + 'imagenet_{}.h5'.format(architecture)
|
||||
model.save(output_filename)
|
||||
print("Keras model {} is saved in [{}]".format(architecture, output_filename))
|
||||
K.clear_session()
|
||||
del model
|
||||
return output_filename
|
||||
# # save network structure as JSON
|
||||
# json_string = model.to_json()
|
||||
# with open("imagenet_{}.json".format(architecture), "w") as of:
|
||||
# of.write(json_string)
|
||||
# print("Network structure is saved as [imagenet_{}.json].".format(architecture))
|
||||
|
||||
# model.save_weights('imagenet_{}.h5'.format(architecture))
|
||||
# print("Network weights are saved as [imagenet_{}.h5].".format(architecture))
|
||||
else:
|
||||
output_filename = path + 'imagenet_{}.h5'.format(architecture)
|
||||
print("File [{}] existed, skip download.".format(output_filename))
|
||||
return output_filename
|
||||
else:
|
||||
return None
|
||||
|
@ -73,6 +67,8 @@ class keras_extractor(base_extractor):
|
|||
img = np.expand_dims(img, axis=0)
|
||||
predict = model.predict(img)
|
||||
predict = np.squeeze(predict)
|
||||
K.clear_session()
|
||||
del model
|
||||
return predict
|
||||
|
||||
else:
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import os
|
||||
from six import string_types as _string_types
|
||||
import keras as _keras
|
||||
from keras import backend as _K
|
||||
from mmdnn.conversion.keras.keras2_graph import Keras2Graph
|
||||
import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
|
||||
from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType
|
||||
|
@ -122,6 +123,8 @@ class Keras2Parser(Parser):
|
|||
print("KerasParser has not supported operator [%s]." % (node_type))
|
||||
self.rename_UNKNOWN(current_node)
|
||||
|
||||
_K.clear_session()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _set_output_shape(source_node, IR_node):
|
||||
|
@ -355,6 +358,8 @@ class Keras2Parser(Parser):
|
|||
|
||||
|
||||
def rename_UNKNOWN(self, source_node):
|
||||
print (source_node.layer.get_config())
|
||||
|
||||
# only for training
|
||||
IR_node = self.IR_graph.node.add()
|
||||
|
||||
|
@ -599,6 +604,13 @@ class Keras2Parser(Parser):
|
|||
|
||||
|
||||
def rename_Lambda(self, source_node):
|
||||
# print (source_node.layer.function)
|
||||
# import marshal
|
||||
# raw_code = marshal.dumps(source_node.layer.function.__code__)
|
||||
# print (raw_code)
|
||||
# print (source_node.layer.get_config())
|
||||
raise NotImplementedError("Lambda layer in keras is not supported yet.")
|
||||
|
||||
IR_node = self.IR_graph.node.add()
|
||||
|
||||
# name, op
|
||||
|
@ -674,3 +686,14 @@ class Keras2Parser(Parser):
|
|||
|
||||
def custom_relu6(x):
|
||||
return _keras.relu(x, max_value=6)
|
||||
|
||||
|
||||
def rename_Cropping2D(self, source_node):
|
||||
IR_node = self.IR_graph.node.add()
|
||||
|
||||
# name, op
|
||||
Keras2Parser._copy_and_reop(source_node, IR_node)
|
||||
|
||||
# input edge
|
||||
self.convert_inedge(source_node, IR_node)
|
||||
assert False
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from imp import reload
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
from mmdnn.conversion.examples.keras.extractor import keras_extractor
|
||||
from mmdnn.conversion.keras.keras2_parser import Keras2Parser
|
||||
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
|
||||
|
||||
def _compute_SNR(x,y):
|
||||
noise = x - y
|
||||
noise_var = np.sum(noise ** 2)/len(noise) + 1e-7
|
||||
signal_energy = np.sum(y ** 2)/len(y)
|
||||
max_signal_energy = np.amax(y ** 2)
|
||||
SNR = 10 * np.log10(signal_energy/noise_var)
|
||||
PSNR = 10 * np.log10(max_signal_energy/noise_var)
|
||||
return SNR, PSNR
|
||||
|
||||
|
||||
def _compute_max_relative_error(x,y):
|
||||
rerror = 0
|
||||
index = 0
|
||||
for i in range(len(x)):
|
||||
den = max(1.0, np.abs(x[i]), np.abs(y[i]))
|
||||
if np.abs(x[i]/den - y[i]/den) > rerror:
|
||||
rerror = np.abs(x[i]/den - y[i]/den)
|
||||
index = i
|
||||
return rerror, index
|
||||
|
||||
|
||||
class CorrectnessTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
""" Set up the unit test by loading common utilities.
|
||||
"""
|
||||
self.err_thresh = 0.0015
|
||||
self.snr_thresh = 12
|
||||
self.psnr_thresh = 30
|
||||
|
||||
def _compare_outputs(self, original_predict, converted_predict):
|
||||
self.assertEquals(len(original_predict), len(converted_predict))
|
||||
error, ind = _compute_max_relative_error(converted_predict, original_predict)
|
||||
SNR, PSNR = _compute_SNR(converted_predict, original_predict)
|
||||
print("error:", error)
|
||||
print("SNR:", SNR)
|
||||
print("PSNR:", PSNR)
|
||||
self.assertGreater(SNR, self.snr_thresh)
|
||||
self.assertGreater(PSNR, self.psnr_thresh)
|
||||
self.assertLess(error, self.err_thresh)
|
||||
|
||||
|
||||
class TestModels(CorrectnessTest):
|
||||
|
||||
image_path = "mmdnn/conversion/examples/data/seagull.jpg"
|
||||
cachedir = "tests/cache/"
|
||||
tmpdir = "tests/tmp/"
|
||||
|
||||
@staticmethod
|
||||
def KerasParse(architecture_name, image_path):
|
||||
# get original model prediction result
|
||||
original_predict = keras_extractor.inference(architecture_name, image_path)
|
||||
|
||||
# download model
|
||||
model_filename = keras_extractor.download(architecture_name, TestModels.cachedir)
|
||||
|
||||
# original to IR
|
||||
parser = Keras2Parser(model_filename)
|
||||
parser.gen_IR()
|
||||
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
|
||||
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
del parser
|
||||
return original_predict
|
||||
|
||||
|
||||
@staticmethod
|
||||
def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to CNTK.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
emitter = CntkEmitter((architecture_path, weight_path))
|
||||
emitter.run("converted_model.py", None, 'test')
|
||||
del emitter
|
||||
|
||||
# import converted model
|
||||
import converted_model
|
||||
reload (converted_model)
|
||||
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
|
||||
func = TestKit.preprocess_func[original_framework][architecture_name]
|
||||
img = func(image_path)
|
||||
predict = model_converted.eval({model_converted.arguments[0]:[img]})
|
||||
converted_predict = np.squeeze(predict)
|
||||
del model_converted
|
||||
del converted_model
|
||||
os.remove("converted_model.py")
|
||||
return converted_predict
|
||||
|
||||
|
||||
test_table = {
|
||||
'keras': {
|
||||
'vgg16' : [CntkEmit],
|
||||
'vgg19' : [CntkEmit],
|
||||
'inception_v3' : [CntkEmit],
|
||||
'resnet50' : [CntkEmit],
|
||||
'densenet' : [CntkEmit],
|
||||
'xception' : [],
|
||||
# 'nasnet' : [],
|
||||
}
|
||||
}
|
||||
|
||||
def test_keras(self):
|
||||
# keras original
|
||||
original_framework = 'keras'
|
||||
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = self.KerasParse(network_name, self.image_path)
|
||||
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
self.tmpdir + network_name + "_converted.pb",
|
||||
self.tmpdir + network_name + "_converted.npy",
|
||||
self.image_path)
|
||||
|
||||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
print("Testing {} model all passed.".format(original_framework))
|
Загрузка…
Ссылка в новой задаче