Implementing TensorFlow pytest.

This commit is contained in:
Kit 2018-02-08 18:59:48 +08:00
Родитель 09798b678a
Коммит 4c794431b8
5 изменённых файлов: 31 добавлений и 51 удалений

Просмотреть файл

@ -28,30 +28,6 @@ def _convert(args):
elif args.srcFramework == 'caffe2':
raise NotImplementedError("Caffe2 is not supported yet.")
'''
assert args.inputShape != None
from dlconv.caffe2.conversion.transformer import Caffe2Transformer
transformer = Caffe2Transformer(args.network, args.weights, args.inputShape, 'tensorflow')
graph = transformer.transform_graph()
data = transformer.transform_data()
from dlconv.common.writer import JsonFormatter, ModelSaver, PyWriter
JsonFormatter(graph).dump(args.dstPath + ".json")
print ("IR saved as [{}.json].".format(args.dstPath))
prototxt = graph.as_graph_def().SerializeToString()
with open(args.dstPath + ".pb", 'wb') as of:
of.write(prototxt)
print ("IR saved as [{}.pb].".format(args.dstPath))
import numpy as np
with open(args.dstPath + ".npy", 'wb') as of:
np.save(of, data)
print ("IR weights saved as [{}.npy].".format(args.dstPath))
return 0
'''
elif args.srcFramework == 'keras':
if args.network != None:
@ -66,14 +42,10 @@ def _convert(args):
if args.dstNodeName is None:
raise ValueError("Need to provide the output node of Tensorflow model.")
if args.weights is None:
# only convert network structure
model = args.network
else:
model = (args.network, args.weights)
assert args.network or args.frozen_pb
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
parser = TensorflowParser(model, args.dstNodeName)
parser = TensorflowParser(args.network, args.weights, args.frozen_pb, args.dstNodeName)
elif args.srcFramework == 'mxnet':
assert args.inputShape != None
@ -97,10 +69,7 @@ def _convert(args):
else:
raise ValueError("Unknown framework [{}].".format(args.srcFramework))
parser.gen_IR()
parser.save_to_json(args.dstPath + ".json")
parser.save_to_proto(args.dstPath + ".pb")
parser.save_weights(args.dstPath + ".npy")
parser.run(args.dstPath)
return 0
@ -140,6 +109,13 @@ def _main():
default=None,
help="[Tensorflow] Output nodes' name of the graph.")
parser.add_argument(
'--frozen_pb',
type=_text_type,
default=None,
help="[Tensorflow] frozen model file.")
parser.add_argument(
'--inputShape',
nargs='+',

Просмотреть файл

@ -61,7 +61,8 @@ def extract_model(args):
extractor = keras_extractor()
elif args.framework == 'tensorflow' or args.framework == 'tf':
pass
from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
extractor = tensorflow_extractor()
elif args.framework == 'mxnet':
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
@ -72,7 +73,7 @@ def extract_model(args):
else:
raise ValueError("Unknown framework [{}].".format(args.framework))
files = extractor.download(args.network,args.path)
files = extractor.download(args.network, args.path)
if files and args.image:
predict = extractor.inference(args.network, args.path, args.image)

Просмотреть файл

@ -18,6 +18,13 @@ class Parser(object):
self.weights = dict()
def run(self, dest_path):
self.gen_IR()
self.save_to_json(dest_path + ".json")
self.save_to_proto(dest_path + ".pb")
self.save_weights(dest_path + ".npy")
@property
def src_graph(self):
raise NotImplementedError

Просмотреть файл

@ -177,16 +177,15 @@ class TensorflowParser(Parser):
output_node.real_name = source_node.name
def __init__(self, input_args, dest_nodes = None):
def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None):
super(TensorflowParser, self).__init__()
# load model files into Keras graph
from six import string_types as _string_types
if isinstance(input_args, _string_types):
model = TensorflowParser._load_meta(input_args)
elif isinstance(input_args, tuple):
model = TensorflowParser._load_meta(input_args[0])
self.ckpt_data = TensorflowParser._load_weights(input_args[1])
# load model files into TensorFlow graph
if meta_file:
model = TensorflowParser._load_meta(meta_file)
if checkpoint_file:
self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
self.weight_loaded = True
if dest_nodes != None:
@ -194,7 +193,7 @@ class TensorflowParser(Parser):
model = extract_sub_graph(model, dest_nodes.split(','))
# Build network graph
self.tf_graph = TensorflowGraph(model)
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()

Просмотреть файл

@ -83,9 +83,7 @@ class TestModels(CorrectnessTest):
# original to IR
parser = Keras2Parser(model_filename)
parser.gen_IR()
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
parser.run(TestModels.tmpdir + architecture_name + "_converted")
del parser
return original_predict
@ -106,9 +104,7 @@ class TestModels(CorrectnessTest):
model = (architecture_file, prefix, epoch, [3, 224, 224])
parser = MXNetParser(model)
parser.gen_IR()
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
parser.run(TestModels.tmpdir + architecture_name + "_converted")
del parser
return original_predict
@ -306,6 +302,7 @@ class TestModels(CorrectnessTest):
self._compare_outputs(original_predict, converted_predict)
os.remove(self.tmpdir + network_name + "_converted.json")
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))