This commit is contained in:
Kit 2018-02-09 10:59:00 +08:00
Родитель 444ab6a55c
Коммит c5b608fe08
3 изменённых файлов: 144 добавлений и 24 удалений

Просмотреть файл

@ -31,30 +31,6 @@ def extract_model(args):
elif args.framework == 'caffe2':
raise NotImplementedError("Caffe2 is not supported yet.")
'''
assert args.inputShape != None
from dlconv.caffe2.conversion.transformer import Caffe2Transformer
transformer = Caffe2Transformer(args.network, args.weights, args.inputShape, 'tensorflow')
graph = transformer.transform_graph()
data = transformer.transform_data()
from dlconv.common.writer import JsonFormatter, ModelSaver, PyWriter
JsonFormatter(graph).dump(args.dstPath + ".json")
print ("IR saved as [{}.json].".format(args.dstPath))
prototxt = graph.as_graph_def().SerializeToString()
with open(args.dstPath + ".pb", 'wb') as of:
of.write(prototxt)
print ("IR saved as [{}.pb].".format(args.dstPath))
import numpy as np
with open(args.dstPath + ".npy", 'wb') as of:
np.save(of, data)
print ("IR weights saved as [{}.npy].".format(args.dstPath))
return 0
'''
elif args.framework == 'keras':
from mmdnn.conversion.examples.keras.extractor import keras_extractor

Просмотреть файл

@ -0,0 +1,123 @@
#----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
#----------------------------------------------------------------------------------------------
from __future__ import absolute_import
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import vgg
from tensorflow.contrib.slim.python.slim.nets import inception
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
from tensorflow.contrib.slim.python.slim.nets import resnet_v2
slim = tf.contrib.slim
from mmdnn.conversion.examples.imagenet_test import TestKit
from mmdnn.conversion.examples.extractor import base_extractor
from mmdnn.conversion.common.utils import download_file
class tensorflow_extractor(base_extractor):
architecture_map = {
'inception_v1' : {
'url' : 'http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz',
'filename' : 'inception_v1.ckpt',
'builder' : lambda : inception.inception_v1,
'arg_scope' : inception.inception_v3_arg_scope,
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
'num_classes' : 1001,
},
'inception_v3' : {
'url' : 'http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz',
'filename' : 'inception_v3.ckpt',
'builder' : lambda : inception.inception_v3,
'arg_scope' : inception.inception_v3_arg_scope,
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
'num_classes' : 1001,
}
}
@classmethod
def handle_checkpoint(cls, architecture, path):
with slim.arg_scope(cls.architecture_map[architecture]['arg_scope']()):
data_input = cls.architecture_map[architecture]['input']()
logits, endpoints = cls.architecture_map[architecture]['builder']()(
data_input,
num_classes=cls.architecture_map[architecture]['num_classes'],
is_training=False)
labels = tf.squeeze(logits, name='MMdnn_Output')
init = tf.global_variables_initializer()
with tf.Session() as sess:
writer = tf.summary.FileWriter('./graphs', sess.graph)
writer.close()
sess.run(init)
saver = tf.train.Saver()
saver.restore(sess, path + cls.architecture_map[architecture]['filename'])
save_path = saver.save(sess, path + "imagenet_{}.ckpt".format(architecture))
print("Model saved in file: %s" % save_path)
import tensorflow.contrib.keras as keras
keras.backend.clear_session()
@classmethod
def handle_frozen_graph(cls, architecture, path):
raise NotImplementedError()
@classmethod
def download(cls, architecture, path="./"):
if cls.sanity_check(architecture):
architecture_file = download_file(cls.architecture_map[architecture]['url'], directory=path, auto_unzip=True)
if not architecture_file:
return None
if cls.architecture_map[architecture]['filename'].endswith('ckpt'):
cls.handle_checkpoint(architecture, path)
elif cls.architecture_map[architecture]['filename'].endswith('pb'):
cls.handle_frozen_graph(architecture, path)
else:
raise ValueError("Unknown file name [{}].".format(cls.architecture_map[architecture]['filename']))
return architecture_file
else:
return None
@classmethod
def inference(cls, architecture, path, image_path):
if cls.download(architecture, path):
import numpy as np
func = TestKit.preprocess_func['tensorflow'][architecture]
img = func(image_path)
img = np.expand_dims(img, axis=0)
with slim.arg_scope(cls.architecture_map[architecture]['arg_scope']()):
data_input = cls.architecture_map[architecture]['input']()
logits, endpoints = cls.architecture_map[architecture]['builder']()(
data_input,
num_classes=cls.architecture_map[architecture]['num_classes'],
is_training=False)
labels = tf.squeeze(logits)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver = tf.train.Saver()
saver.restore(sess, path + cls.architecture_map[architecture]['filename'])
predict = sess.run(logits, feed_dict = {data_input : img})
import tensorflow.contrib.keras as keras
keras.backend.clear_session()
predict = np.squeeze(predict)
return predict
else:
return None

Просмотреть файл

@ -74,6 +74,27 @@ class TestModels(CorrectnessTest):
cachedir = "tests/cache/"
tmpdir = "tests/tmp/"
@staticmethod
def TensorFlowParse(architecture_name, image_path):
from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
# get original model prediction result
original_predict = tensorflow_extractor.inference(architecture_name, TestModels.cachedir, image_path)
# original to IR
parser = TensorflowParser(
TestModels.cachedir + "imagenet_" + architecture_name + ".ckpt.meta",
TestModels.cachedir + "imagenet_" + architecture_name + ".ckpt",
None,
"MMdnn_Output")
parser.run(TestModels.tmpdir + architecture_name + "_converted")
del parser
del TensorflowParser
del tensorflow_extractor
return original_predict
@staticmethod
def KerasParse(architecture_name, image_path):
# get original model prediction result