зеркало из https://github.com/microsoft/MMdnn.git
Add tensorflow extractor
This commit is contained in:
Родитель
444ab6a55c
Коммит
c5b608fe08
|
@ -31,30 +31,6 @@ def extract_model(args):
|
|||
|
||||
elif args.framework == 'caffe2':
|
||||
raise NotImplementedError("Caffe2 is not supported yet.")
|
||||
'''
|
||||
assert args.inputShape != None
|
||||
from dlconv.caffe2.conversion.transformer import Caffe2Transformer
|
||||
transformer = Caffe2Transformer(args.network, args.weights, args.inputShape, 'tensorflow')
|
||||
|
||||
graph = transformer.transform_graph()
|
||||
data = transformer.transform_data()
|
||||
|
||||
from dlconv.common.writer import JsonFormatter, ModelSaver, PyWriter
|
||||
JsonFormatter(graph).dump(args.dstPath + ".json")
|
||||
print ("IR saved as [{}.json].".format(args.dstPath))
|
||||
|
||||
prototxt = graph.as_graph_def().SerializeToString()
|
||||
with open(args.dstPath + ".pb", 'wb') as of:
|
||||
of.write(prototxt)
|
||||
print ("IR saved as [{}.pb].".format(args.dstPath))
|
||||
|
||||
import numpy as np
|
||||
with open(args.dstPath + ".npy", 'wb') as of:
|
||||
np.save(of, data)
|
||||
print ("IR weights saved as [{}.npy].".format(args.dstPath))
|
||||
|
||||
return 0
|
||||
'''
|
||||
|
||||
elif args.framework == 'keras':
|
||||
from mmdnn.conversion.examples.keras.extractor import keras_extractor
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
#----------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
#----------------------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.slim.python.slim.nets import vgg
|
||||
from tensorflow.contrib.slim.python.slim.nets import inception
|
||||
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
|
||||
from tensorflow.contrib.slim.python.slim.nets import resnet_v2
|
||||
slim = tf.contrib.slim
|
||||
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
from mmdnn.conversion.examples.extractor import base_extractor
|
||||
from mmdnn.conversion.common.utils import download_file
|
||||
|
||||
|
||||
class tensorflow_extractor(base_extractor):
|
||||
|
||||
architecture_map = {
|
||||
'inception_v1' : {
|
||||
'url' : 'http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz',
|
||||
'filename' : 'inception_v1.ckpt',
|
||||
'builder' : lambda : inception.inception_v1,
|
||||
'arg_scope' : inception.inception_v3_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'inception_v3' : {
|
||||
'url' : 'http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz',
|
||||
'filename' : 'inception_v3.ckpt',
|
||||
'builder' : lambda : inception.inception_v3,
|
||||
'arg_scope' : inception.inception_v3_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def handle_checkpoint(cls, architecture, path):
|
||||
with slim.arg_scope(cls.architecture_map[architecture]['arg_scope']()):
|
||||
data_input = cls.architecture_map[architecture]['input']()
|
||||
logits, endpoints = cls.architecture_map[architecture]['builder']()(
|
||||
data_input,
|
||||
num_classes=cls.architecture_map[architecture]['num_classes'],
|
||||
is_training=False)
|
||||
labels = tf.squeeze(logits, name='MMdnn_Output')
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
with tf.Session() as sess:
|
||||
writer = tf.summary.FileWriter('./graphs', sess.graph)
|
||||
writer.close()
|
||||
sess.run(init)
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, path + cls.architecture_map[architecture]['filename'])
|
||||
save_path = saver.save(sess, path + "imagenet_{}.ckpt".format(architecture))
|
||||
print("Model saved in file: %s" % save_path)
|
||||
|
||||
import tensorflow.contrib.keras as keras
|
||||
keras.backend.clear_session()
|
||||
|
||||
|
||||
@classmethod
|
||||
def handle_frozen_graph(cls, architecture, path):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@classmethod
|
||||
def download(cls, architecture, path="./"):
|
||||
if cls.sanity_check(architecture):
|
||||
architecture_file = download_file(cls.architecture_map[architecture]['url'], directory=path, auto_unzip=True)
|
||||
if not architecture_file:
|
||||
return None
|
||||
|
||||
if cls.architecture_map[architecture]['filename'].endswith('ckpt'):
|
||||
cls.handle_checkpoint(architecture, path)
|
||||
|
||||
elif cls.architecture_map[architecture]['filename'].endswith('pb'):
|
||||
cls.handle_frozen_graph(architecture, path)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown file name [{}].".format(cls.architecture_map[architecture]['filename']))
|
||||
|
||||
return architecture_file
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def inference(cls, architecture, path, image_path):
|
||||
if cls.download(architecture, path):
|
||||
import numpy as np
|
||||
func = TestKit.preprocess_func['tensorflow'][architecture]
|
||||
img = func(image_path)
|
||||
img = np.expand_dims(img, axis=0)
|
||||
|
||||
with slim.arg_scope(cls.architecture_map[architecture]['arg_scope']()):
|
||||
data_input = cls.architecture_map[architecture]['input']()
|
||||
logits, endpoints = cls.architecture_map[architecture]['builder']()(
|
||||
data_input,
|
||||
num_classes=cls.architecture_map[architecture]['num_classes'],
|
||||
is_training=False)
|
||||
labels = tf.squeeze(logits)
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, path + cls.architecture_map[architecture]['filename'])
|
||||
predict = sess.run(logits, feed_dict = {data_input : img})
|
||||
|
||||
import tensorflow.contrib.keras as keras
|
||||
keras.backend.clear_session()
|
||||
|
||||
predict = np.squeeze(predict)
|
||||
return predict
|
||||
|
||||
else:
|
||||
return None
|
|
@ -74,6 +74,27 @@ class TestModels(CorrectnessTest):
|
|||
cachedir = "tests/cache/"
|
||||
tmpdir = "tests/tmp/"
|
||||
|
||||
@staticmethod
|
||||
def TensorFlowParse(architecture_name, image_path):
|
||||
from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
|
||||
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = tensorflow_extractor.inference(architecture_name, TestModels.cachedir, image_path)
|
||||
|
||||
# original to IR
|
||||
parser = TensorflowParser(
|
||||
TestModels.cachedir + "imagenet_" + architecture_name + ".ckpt.meta",
|
||||
TestModels.cachedir + "imagenet_" + architecture_name + ".ckpt",
|
||||
None,
|
||||
"MMdnn_Output")
|
||||
parser.run(TestModels.tmpdir + architecture_name + "_converted")
|
||||
del parser
|
||||
del TensorflowParser
|
||||
del tensorflow_extractor
|
||||
return original_predict
|
||||
|
||||
|
||||
@staticmethod
|
||||
def KerasParse(architecture_name, image_path):
|
||||
# get original model prediction result
|
||||
|
|
Загрузка…
Ссылка в новой задаче