This commit is contained in:
Kit 2018-02-06 14:35:13 +08:00
Родитель e7dbbf68af
Коммит e0e9ea05a2
4 изменённых файлов: 169 добавлений и 14 удалений

Просмотреть файл

@ -82,7 +82,7 @@ class TestKit(object):
'vgg16' : lambda path : TestKit.ZeroCenter(path, 224, True),
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, True),
'inception_v3' : lambda path : TestKit.Standard(path, 299),
'resnet' : lambda path : TestKit.ZeroCenter(path, 224, True),
'resnet50' : lambda path : TestKit.ZeroCenter(path, 224, True),
'xception' : lambda path : TestKit.Standard(path, 299),
'mobilenet' : lambda path : TestKit.Standard(path, 224),
'inception_resnet_v2' : lambda path : TestKit.Standard(path, 299),

Просмотреть файл

@ -6,6 +6,7 @@
from __future__ import absolute_import
import os
import keras
from keras import backend as K
from mmdnn.conversion.examples.imagenet_test import TestKit
from mmdnn.conversion.examples.extractor import base_extractor
@ -37,27 +38,20 @@ class keras_extractor(base_extractor):
}
@classmethod
def download(cls, architecture, path='test/model/'):
def download(cls, architecture, path="./"):
if cls.sanity_check(architecture):
if os.path.exists(path + 'imagenet_{}.h5'.format(architecture)) == False:
print("No model before")
output_filename = path + 'imagenet_{}.h5'.format(architecture)
if os.path.exists(output_filename) == False:
model = cls.architecture_map[architecture]()
output_filename = path + 'imagenet_{}.h5'.format(architecture)
model.save(output_filename)
print("Keras model {} is saved in [{}]".format(architecture, output_filename))
K.clear_session()
del model
return output_filename
# # save network structure as JSON
# json_string = model.to_json()
# with open("imagenet_{}.json".format(architecture), "w") as of:
# of.write(json_string)
# print("Network structure is saved as [imagenet_{}.json].".format(architecture))
# model.save_weights('imagenet_{}.h5'.format(architecture))
# print("Network weights are saved as [imagenet_{}.h5].".format(architecture))
else:
output_filename = path + 'imagenet_{}.h5'.format(architecture)
print("File [{}] existed, skip download.".format(output_filename))
return output_filename
else:
return None
@ -73,6 +67,8 @@ class keras_extractor(base_extractor):
img = np.expand_dims(img, axis=0)
predict = model.predict(img)
predict = np.squeeze(predict)
K.clear_session()
del model
return predict
else:

Просмотреть файл

@ -6,6 +6,7 @@
import os
from six import string_types as _string_types
import keras as _keras
from keras import backend as _K
from mmdnn.conversion.keras.keras2_graph import Keras2Graph
import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType
@ -122,6 +123,8 @@ class Keras2Parser(Parser):
print("KerasParser has not supported operator [%s]." % (node_type))
self.rename_UNKNOWN(current_node)
_K.clear_session()
@staticmethod
def _set_output_shape(source_node, IR_node):
@ -355,6 +358,8 @@ class Keras2Parser(Parser):
def rename_UNKNOWN(self, source_node):
print (source_node.layer.get_config())
# only for training
IR_node = self.IR_graph.node.add()
@ -599,6 +604,13 @@ class Keras2Parser(Parser):
def rename_Lambda(self, source_node):
# print (source_node.layer.function)
# import marshal
# raw_code = marshal.dumps(source_node.layer.function.__code__)
# print (raw_code)
# print (source_node.layer.get_config())
raise NotImplementedError("Lambda layer in keras is not supported yet.")
IR_node = self.IR_graph.node.add()
# name, op
@ -674,3 +686,14 @@ class Keras2Parser(Parser):
def custom_relu6(x):
return _keras.relu(x, max_value=6)
def rename_Cropping2D(self, source_node):
IR_node = self.IR_graph.node.add()
# name, op
Keras2Parser._copy_and_reop(source_node, IR_node)
# input edge
self.convert_inedge(source_node, IR_node)
assert False

Просмотреть файл

@ -0,0 +1,136 @@
import os
import unittest
import numpy as np
from imp import reload
from mmdnn.conversion.examples.imagenet_test import TestKit
from mmdnn.conversion.examples.keras.extractor import keras_extractor
from mmdnn.conversion.keras.keras2_parser import Keras2Parser
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
def _compute_SNR(x,y):
noise = x - y
noise_var = np.sum(noise ** 2)/len(noise) + 1e-7
signal_energy = np.sum(y ** 2)/len(y)
max_signal_energy = np.amax(y ** 2)
SNR = 10 * np.log10(signal_energy/noise_var)
PSNR = 10 * np.log10(max_signal_energy/noise_var)
return SNR, PSNR
def _compute_max_relative_error(x,y):
rerror = 0
index = 0
for i in range(len(x)):
den = max(1.0, np.abs(x[i]), np.abs(y[i]))
if np.abs(x[i]/den - y[i]/den) > rerror:
rerror = np.abs(x[i]/den - y[i]/den)
index = i
return rerror, index
class CorrectnessTest(unittest.TestCase):
@classmethod
def setUpClass(self):
""" Set up the unit test by loading common utilities.
"""
self.err_thresh = 0.0015
self.snr_thresh = 12
self.psnr_thresh = 30
def _compare_outputs(self, original_predict, converted_predict):
self.assertEquals(len(original_predict), len(converted_predict))
error, ind = _compute_max_relative_error(converted_predict, original_predict)
SNR, PSNR = _compute_SNR(converted_predict, original_predict)
print("error:", error)
print("SNR:", SNR)
print("PSNR:", PSNR)
self.assertGreater(SNR, self.snr_thresh)
self.assertGreater(PSNR, self.psnr_thresh)
self.assertLess(error, self.err_thresh)
class TestModels(CorrectnessTest):
image_path = "mmdnn/conversion/examples/data/seagull.jpg"
cachedir = "tests/cache/"
tmpdir = "tests/tmp/"
@staticmethod
def KerasParse(architecture_name, image_path):
# get original model prediction result
original_predict = keras_extractor.inference(architecture_name, image_path)
# download model
model_filename = keras_extractor.download(architecture_name, TestModels.cachedir)
# original to IR
parser = Keras2Parser(model_filename)
parser.gen_IR()
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
del parser
return original_predict
@staticmethod
def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
print("Testing {} from {} to CNTK.".format(architecture_name, original_framework))
# IR to code
emitter = CntkEmitter((architecture_path, weight_path))
emitter.run("converted_model.py", None, 'test')
del emitter
# import converted model
import converted_model
reload (converted_model)
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
func = TestKit.preprocess_func[original_framework][architecture_name]
img = func(image_path)
predict = model_converted.eval({model_converted.arguments[0]:[img]})
converted_predict = np.squeeze(predict)
del model_converted
del converted_model
os.remove("converted_model.py")
return converted_predict
test_table = {
'keras': {
'vgg16' : [CntkEmit],
'vgg19' : [CntkEmit],
'inception_v3' : [CntkEmit],
'resnet50' : [CntkEmit],
'densenet' : [CntkEmit],
'xception' : [],
# 'nasnet' : [],
}
}
def test_keras(self):
# keras original
original_framework = 'keras'
for network_name in self.test_table[original_framework].keys():
print("Testing {} model {} start.".format(original_framework, network_name))
# get original model prediction result
original_predict = self.KerasParse(network_name, self.image_path)
for emit in self.test_table[original_framework][network_name]:
converted_predict = emit.__func__(
original_framework,
network_name,
self.tmpdir + network_name + "_converted.pb",
self.tmpdir + network_name + "_converted.npy",
self.image_path)
self._compare_outputs(original_predict, converted_predict)
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))
print("Testing {} model all passed.".format(original_framework))