зеркало из https://github.com/microsoft/MMdnn.git
Implementing TensorFlow pytest.
This commit is contained in:
Родитель
09798b678a
Коммит
4c794431b8
|
@ -28,30 +28,6 @@ def _convert(args):
|
|||
|
||||
elif args.srcFramework == 'caffe2':
|
||||
raise NotImplementedError("Caffe2 is not supported yet.")
|
||||
'''
|
||||
assert args.inputShape != None
|
||||
from dlconv.caffe2.conversion.transformer import Caffe2Transformer
|
||||
transformer = Caffe2Transformer(args.network, args.weights, args.inputShape, 'tensorflow')
|
||||
|
||||
graph = transformer.transform_graph()
|
||||
data = transformer.transform_data()
|
||||
|
||||
from dlconv.common.writer import JsonFormatter, ModelSaver, PyWriter
|
||||
JsonFormatter(graph).dump(args.dstPath + ".json")
|
||||
print ("IR saved as [{}.json].".format(args.dstPath))
|
||||
|
||||
prototxt = graph.as_graph_def().SerializeToString()
|
||||
with open(args.dstPath + ".pb", 'wb') as of:
|
||||
of.write(prototxt)
|
||||
print ("IR saved as [{}.pb].".format(args.dstPath))
|
||||
|
||||
import numpy as np
|
||||
with open(args.dstPath + ".npy", 'wb') as of:
|
||||
np.save(of, data)
|
||||
print ("IR weights saved as [{}.npy].".format(args.dstPath))
|
||||
|
||||
return 0
|
||||
'''
|
||||
|
||||
elif args.srcFramework == 'keras':
|
||||
if args.network != None:
|
||||
|
@ -66,14 +42,10 @@ def _convert(args):
|
|||
if args.dstNodeName is None:
|
||||
raise ValueError("Need to provide the output node of Tensorflow model.")
|
||||
|
||||
if args.weights is None:
|
||||
# only convert network structure
|
||||
model = args.network
|
||||
else:
|
||||
model = (args.network, args.weights)
|
||||
assert args.network or args.frozen_pb
|
||||
|
||||
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
|
||||
parser = TensorflowParser(model, args.dstNodeName)
|
||||
parser = TensorflowParser(args.network, args.weights, args.frozen_pb, args.dstNodeName)
|
||||
|
||||
elif args.srcFramework == 'mxnet':
|
||||
assert args.inputShape != None
|
||||
|
@ -97,10 +69,7 @@ def _convert(args):
|
|||
else:
|
||||
raise ValueError("Unknown framework [{}].".format(args.srcFramework))
|
||||
|
||||
parser.gen_IR()
|
||||
parser.save_to_json(args.dstPath + ".json")
|
||||
parser.save_to_proto(args.dstPath + ".pb")
|
||||
parser.save_weights(args.dstPath + ".npy")
|
||||
parser.run(args.dstPath)
|
||||
|
||||
return 0
|
||||
|
||||
|
@ -140,6 +109,13 @@ def _main():
|
|||
default=None,
|
||||
help="[Tensorflow] Output nodes' name of the graph.")
|
||||
|
||||
parser.add_argument(
|
||||
'--frozen_pb',
|
||||
type=_text_type,
|
||||
default=None,
|
||||
help="[Tensorflow] frozen model file.")
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--inputShape',
|
||||
nargs='+',
|
||||
|
|
|
@ -61,7 +61,8 @@ def extract_model(args):
|
|||
extractor = keras_extractor()
|
||||
|
||||
elif args.framework == 'tensorflow' or args.framework == 'tf':
|
||||
pass
|
||||
from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
|
||||
extractor = tensorflow_extractor()
|
||||
|
||||
elif args.framework == 'mxnet':
|
||||
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
|
||||
|
@ -72,7 +73,7 @@ def extract_model(args):
|
|||
else:
|
||||
raise ValueError("Unknown framework [{}].".format(args.framework))
|
||||
|
||||
files = extractor.download(args.network,args.path)
|
||||
files = extractor.download(args.network, args.path)
|
||||
|
||||
if files and args.image:
|
||||
predict = extractor.inference(args.network, args.path, args.image)
|
||||
|
|
|
@ -18,6 +18,13 @@ class Parser(object):
|
|||
self.weights = dict()
|
||||
|
||||
|
||||
def run(self, dest_path):
|
||||
self.gen_IR()
|
||||
self.save_to_json(dest_path + ".json")
|
||||
self.save_to_proto(dest_path + ".pb")
|
||||
self.save_weights(dest_path + ".npy")
|
||||
|
||||
|
||||
@property
|
||||
def src_graph(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -177,16 +177,15 @@ class TensorflowParser(Parser):
|
|||
output_node.real_name = source_node.name
|
||||
|
||||
|
||||
def __init__(self, input_args, dest_nodes = None):
|
||||
def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None):
|
||||
super(TensorflowParser, self).__init__()
|
||||
|
||||
# load model files into Keras graph
|
||||
from six import string_types as _string_types
|
||||
if isinstance(input_args, _string_types):
|
||||
model = TensorflowParser._load_meta(input_args)
|
||||
elif isinstance(input_args, tuple):
|
||||
model = TensorflowParser._load_meta(input_args[0])
|
||||
self.ckpt_data = TensorflowParser._load_weights(input_args[1])
|
||||
# load model files into TensorFlow graph
|
||||
if meta_file:
|
||||
model = TensorflowParser._load_meta(meta_file)
|
||||
|
||||
if checkpoint_file:
|
||||
self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
|
||||
self.weight_loaded = True
|
||||
|
||||
if dest_nodes != None:
|
||||
|
@ -194,7 +193,7 @@ class TensorflowParser(Parser):
|
|||
model = extract_sub_graph(model, dest_nodes.split(','))
|
||||
|
||||
# Build network graph
|
||||
self.tf_graph = TensorflowGraph(model)
|
||||
self.tf_graph = TensorflowGraph(model)
|
||||
self.tf_graph.build()
|
||||
|
||||
|
||||
|
|
|
@ -83,9 +83,7 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
# original to IR
|
||||
parser = Keras2Parser(model_filename)
|
||||
parser.gen_IR()
|
||||
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
|
||||
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
parser.run(TestModels.tmpdir + architecture_name + "_converted")
|
||||
del parser
|
||||
return original_predict
|
||||
|
||||
|
@ -106,9 +104,7 @@ class TestModels(CorrectnessTest):
|
|||
model = (architecture_file, prefix, epoch, [3, 224, 224])
|
||||
|
||||
parser = MXNetParser(model)
|
||||
parser.gen_IR()
|
||||
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
|
||||
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
parser.run(TestModels.tmpdir + architecture_name + "_converted")
|
||||
del parser
|
||||
|
||||
return original_predict
|
||||
|
@ -306,6 +302,7 @@ class TestModels(CorrectnessTest):
|
|||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
os.remove(self.tmpdir + network_name + "_converted.json")
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
|
Загрузка…
Ссылка в новой задаче