diff --git a/mmdnn/conversion/_script/convertToIR.py b/mmdnn/conversion/_script/convertToIR.py index 444dffc..ef12f0f 100644 --- a/mmdnn/conversion/_script/convertToIR.py +++ b/mmdnn/conversion/_script/convertToIR.py @@ -5,12 +5,15 @@ from six import text_type as _text_type def _convert(args): if args.inputShape != None: - inputshape = [int(x) for x in args.inputShape] + inputshape = [] + for x in args.inputShape: + shape = x.split(',') + inputshape.append([int(x) for x in shape]) else: inputshape = None if args.srcFramework == 'caffe': from mmdnn.conversion.caffe.transformer import CaffeTransformer - transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape, phase = args.caffePhase) + transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape[0], phase = args.caffePhase) graph = transformer.transform_graph() data = transformer.transform_data() @@ -48,25 +51,27 @@ def _convert(args): # assert args.network or args.frozen_pb if args.frozen_pb: + if args.inNodeName is None: + raise ValueError("Need to provide the input node of Tensorflow model.") from mmdnn.conversion.tensorflow.tensorflow_frozenparser import TensorflowParser2 parser = TensorflowParser2(args.frozen_pb, inputshape, args.inNodeName, args.dstNodeName) else: from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser - if args.inNodeName and inputshape: - parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape, args.inNodeName) + if args.inNodeName and inputshape[0]: + parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape[0], args.inNodeName) else: parser = TensorflowParser(args.network, args.weights, args.dstNodeName) elif args.srcFramework == 'mxnet': assert inputshape != None if args.weights == None: - model = (args.network, inputshape) + model = (args.network, inputshape[0]) else: import re if re.search('.', args.weights): args.weights = args.weights[:-7] prefix, epoch = args.weights.rsplit('-', 1) - model = (args.network, prefix, epoch, inputshape) + model = (args.network, prefix, epoch, inputshape[0]) from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser parser = MXNetParser(model) @@ -79,13 +84,13 @@ def _convert(args): elif args.srcFramework == 'pytorch': assert inputshape != None from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser - parser = PytorchParser(args.network, inputshape) + parser = PytorchParser(args.network, inputshape[0]) elif args.srcFramework == 'torch' or args.srcFramework == 'torch7': from mmdnn.conversion.torch.torch_parser import TorchParser model = args.network or args.weights assert model != None - parser = TorchParser(model, inputshape) + parser = TorchParser(model, inputshape[0]) elif args.srcFramework == 'onnx': from mmdnn.conversion.onnx.onnx_parser import ONNXParser @@ -138,12 +143,14 @@ def _get_parser(): parser.add_argument( '--inNodeName', '-inode', + nargs='+', type=_text_type, - default='input', + default=None, help="[Tensorflow] Input nodes' name of the graph.") parser.add_argument( '--dstNodeName', '-node', + nargs='+', type=_text_type, default=None, help="[Tensorflow] Output nodes' name of the graph.") @@ -160,7 +167,7 @@ def _get_parser(): nargs='+', type=_text_type, default=None, - help='[MXNet/Caffe2/Torch7] Input shape of model (channel, height, width)') + help='[Tensorflow/MXNet/Caffe2/Torch7] Input shape of model (channel, height, width)') # Caffe diff --git a/mmdnn/conversion/caffe/README.md b/mmdnn/conversion/caffe/README.md index 9b3e81d..64df74e 100644 --- a/mmdnn/conversion/caffe/README.md +++ b/mmdnn/conversion/caffe/README.md @@ -57,6 +57,7 @@ CNTK model file is saved as [caffe_resnet152.dnn], generated by [069867aa7f674b7 Then you get the CNTK original model *caffe_resnet152.dnn* converted from Caffe. Temporal files are removed automatically. +If you want to assume a fixed inputshape, you can use "--inputShape 224,224,3" --- ## Step-by-step conversion (for debugging) diff --git a/mmdnn/conversion/tensorflow/README.md b/mmdnn/conversion/tensorflow/README.md index 9bf745f..4854215 100644 --- a/mmdnn/conversion/tensorflow/README.md +++ b/mmdnn/conversion/tensorflow/README.md @@ -89,26 +89,48 @@ TensorBoard 0.4.0rc3 at http://kit-station:6006 (Press CTRL+C to quit) ## One-step conversion -Above MMdnn@0.1.4, we provide one command to achieve the conversion +Above MMdnn@0.1.4, we provide one command to achieve the conversion. +For checkpoint format: ```bash -$ mmconvert -sf tensorflow -in imagenet_inception_v1.ckpt.meta -iw inception_v1.ckpt -df cntk -om tf_inception_v1.dnn --inputShape 3 224 224 --dstNodeName MMdnn_Output +$ mmconvert -sf tensorflow -in ./model.ckpt.meta -iw ./model.ckpt -df caffe --inputShape 224,224,3 --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inNodeName Preprocessor/sub -om mobilenet_v1_caffe . . . -Parse file [imagenet_inception_v1.ckpt.meta] with binary format successfully. -Tensorflow model file [imagenet_inception_v1.ckpt.meta] loaded successfully. -Tensorflow checkpoint file [inception_v1.ckpt] loaded successfully. [231] variables loaded. -IR network structure is saved as [b136dfb62a964486bb2bc2e27cece6e8.json]. -IR network structure is saved as [b136dfb62a964486bb2bc2e27cece6e8.pb]. -IR weights are saved as [b136dfb62a964486bb2bc2e27cece6e8.npy]. -Parse file [b136dfb62a964486bb2bc2e27cece6e8.pb] with binary format successfully. -Target network code snippet is saved as [b136dfb62a964486bb2bc2e27cece6e8.py]. -CNTK model file is saved as [tf_inception_v1.dnn], generated by [b136dfb62a964486bb2bc2e27cece6e8.py] and [b136dfb62a964486bb2bc2e27cece6e8.npy]. +Parse file [./model.ckpt.meta] with binary format successfully. +Tensorflow model file [./model.ckpt.meta] loaded successfully. +Tensorflow checkpoint file [./model.ckpt] loaded successfully. [200] variables loaded. +IR network structure is saved as [61ecc03803a747429a9d4ff6dc346c21.json]. +IR network structure is saved as [61ecc03803a747429a9d4ff6dc346c21.pb]. +IR weights are saved as [61ecc03803a747429a9d4ff6dc346c21.npy]. +Parse file [61ecc03803a747429a9d4ff6dc346c21.pb] with binary format successfully. +Target network code snippet is saved as [61ecc03803a747429a9d4ff6dc346c21.py]. +Target weights are saved as [61ecc03803a747429a9d4ff6dc346c21.npy]. +Caffe model files are saved as [mobilenet_v1_caffe.prototxt] and [mobilenet_v1_caffe.caffemodel], generated by [61ecc03803a747429a9d4ff6dc346c21.py] and [61ecc03803a747429a9d4ff6dc346c21.npy]. ``` -Then you get the CNTK original model *tf_inception_v1.dnn* converted from Tensorflow. Temporal files are removed automatically. +Then you get the Caffe original model *mobilenet_v1_caffe.prototxt* and *mobilenet_v1_caffe.caffemodel* converted from Tensorflow. Temporal files are removed automatically. + +For frozen protobuf format: + + +```bash +$ mmconvert -sf tensorflow --frozen_pb entropy.pb -df caffe --inputShape 108,140,1 --dstNodeName Dense2/fc5/BiasAdd --inNodeName X -om entropy_caffe +. +. +. +IR network structure is saved as [2217c0216dd445cca7e44255d989c6c3.json]. +IR network structure is saved as [2217c0216dd445cca7e44255d989c6c3.pb]. +IR weights are saved as [2217c0216dd445cca7e44255d989c6c3.npy]. +Parse file [2217c0216dd445cca7e44255d989c6c3.pb] with binary format successfully. +Target network code snippet is saved as [2217c0216dd445cca7e44255d989c6c3.py]. +Target weights are saved as [2217c0216dd445cca7e44255d989c6c3.npy]. +Caffe model files are saved as [entropy_caffe.prototxt] and [entropy_caffe.caffemodel], generated by [2217c0216dd445cca7e44255d989c6c3.py] and [2217c0216dd445cca7e44255d989c6c3.npy]. + +``` + +if there are more than one input nodes, you can use space to seperate them, eg.(--inputShape 224,224,3 4 --inNodeName image style) --- @@ -118,15 +140,16 @@ Then you get the CNTK original model *tf_inception_v1.dnn* converted from Tensor You can convert only network structure to IR for visualization or training in other frameworks. -We use resnet_v2_152 model as an example. +We use MobilenetV1 model as an example. ```bash -$ mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta --dstNodeName MMdnn_Output -Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully. -Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully. -IR network structure is saved as [resnet152.json]. -IR network structure is saved as [resnet152.pb]. +$ mmtoir -f tensorflow -n ./model.ckpt.meta --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inputShape 224,224,3 --inNodeName Preprocessor/sub -d mobilenet_v1 + +Parse file [./model.ckpt.meta] with binary format successfully. +Tensorflow model file [./model.ckpt.meta] loaded successfully. +IR network structure is saved as [mobilenet_v1.json]. +IR network structure is saved as [mobilenet_v1.pb]. Warning: weights are not loaded. ``` @@ -135,28 +158,29 @@ Warning: weights are not loaded. You can use following bash command to convert the checkpoint files to IR architecture file [*resnet152.pb*], [*resnet152.json*] and IR weights file [*resnet152.npy*] ```bash -$ mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta -w imagenet_resnet_v2_152.ckpt --dstNodeName MMdnn_Output +$ mmtoir -f tensorflow -d resnet152 -n ./model.ckpt.meta -w ./model.ckpt -d mobilenet_v1_tf --inputShape 224,224,3 --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inNodeName Preprocessor/sub -Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully. -Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully. -Tensorflow checkpoint file [imagenet_resnet_v2_152.ckpt] loaded successfully. [816] variables loaded. -IR network structure is saved as [resnet152.json]. -IR network structure is saved as [resnet152.pb]. -IR weights are saved as [resnet152.npy]. + +Parse file [./model.ckpt.meta] with binary format successfully. +Tensorflow model file [./model.ckpt.meta] loaded successfully. +Tensorflow checkpoint file [./model.ckpt] loaded successfully. [200] variables loaded. +IR network structure is saved as [mobilenet_v1_tf.json]. +IR network structure is saved as [mobilenet_v1_tf.pb]. +IR weights are saved as [mobilenet_v1_tf.npy]. ``` -### Convert frozen protobuf model file from Tensorflow to IR +### Convert frozen protobuf model file(.pb) from Tensorflow to IR You can convert frozen protobuf model file to IR for visualization or training in other frameworks. -We use resnet_v2_152 model as an example. +We use inception_v1 as an example. ```bash -$ mmtoir -f tensorflow --frozen_pb inception_v1_2016_08_28_frozen.pb -d inception_v1 --dstNodeName InceptionV1/Logits/Predictions/Reshape_1 --inputShape 224 224 3 +$ mmtoir -f tensorflow --frozen_pb ./tests/cache/inception_v1_2016_08_28_frozen.pb -d inception_v1_part --dstNodeName InceptionV1/Logits/Predictions/Reshape_1 --inputShape 28 28 192 --inNodeName InceptionV1/InceptionV1/MaxPool_3a_3x3/MaxPool -IR network structure is saved as [inception_v1.json]. -IR network structure is saved as [inception_v1.pb]. -IR weights are saved as [inception_v1.npy]. +IR network structure is saved as [inception_v1_part.json]. +IR network structure is saved as [inception_v1_part.pb]. +IR weights are saved as [inception_v1_part.npy]. ``` ### Convert models from IR to Tensorflow code snippet diff --git a/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py b/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py index a0cc821..506b8ff 100644 --- a/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py +++ b/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py @@ -95,16 +95,19 @@ class TensorflowParser2(Parser): original_gdef = tensorflow.GraphDef() original_gdef.ParseFromString(serialized) - # model = original_gdef + + in_type_list = {} + for n in original_gdef.node: + if n.name in in_nodes: + in_type_list[n.name] = n.attr['dtype'].type + from tensorflow.python.tools import strip_unused_lib from tensorflow.python.framework import dtypes from tensorflow.python.platform import gfile - input_node_names = in_nodes.split(',') - output_node_names = dest_nodes.split(',') original_gdef = strip_unused_lib.strip_unused( input_graph_def = original_gdef, - input_node_names = input_node_names, - output_node_names = output_node_names, + input_node_names = in_nodes, + output_node_names = dest_nodes, placeholder_type_enum = dtypes.float32.as_datatype_enum) # Save it to an output file frozen_model_file = './frozen.pb' @@ -118,14 +121,24 @@ class TensorflowParser2(Parser): output_shape_map = dict() input_shape_map = dict() + with tensorflow.Graph().as_default() as g: - x = tensorflow.placeholder(tensorflow.float32, shape = [None] + inputshape) - tensorflow.import_graph_def(model, name='', input_map={in_nodes + ':0' : x}) + input_map = {} + for i in range(len(inputshape)): + if in_type_list[in_nodes[i]] == 1: + dtype = tensorflow.float32 + elif in_type_list[in_nodes[i]] == 3: + dtype = tensorflow.int32 + x = tensorflow.placeholder(dtype, shape = [None] + inputshape[i]) + input_map[in_nodes[i] + ':0'] = x + + tensorflow.import_graph_def(model, name='', input_map=input_map) with tensorflow.Session(graph = g) as sess: meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta') model = meta_graph_def.graph_def + self.tf_graph = TensorflowGraph(model) self.tf_graph.build() @@ -510,12 +523,11 @@ class TensorflowParser2(Parser): self.set_weight(source_node.name, 'mean', mean) def rename_Placeholder(self, source_node): - # print(source_node.layer.attr["shape"].shape) + # print(source_node.layer) if source_node.layer.attr["shape"].shape.unknown_rank == True: return IR_node = self._convert_identity_operation(source_node, new_op='DataInput') TensorflowParser2._copy_shape(source_node, IR_node) - IR_node.attr['shape'].shape.dim[0].size = -1 IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1 diff --git a/mmdnn/conversion/tensorflow/tensorflow_parser.py b/mmdnn/conversion/tensorflow/tensorflow_parser.py index 5277179..09cc367 100644 --- a/mmdnn/conversion/tensorflow/tensorflow_parser.py +++ b/mmdnn/conversion/tensorflow/tensorflow_parser.py @@ -197,12 +197,10 @@ class TensorflowParser(Parser): from tensorflow.python.tools import strip_unused_lib from tensorflow.python.framework import dtypes from tensorflow.python.platform import gfile - input_node_names = in_nodes.split(',') - output_node_names = dest_nodes.split(',') model = strip_unused_lib.strip_unused( input_graph_def = model, - input_node_names = input_node_names, - output_node_names = output_node_names, + input_node_names = in_nodes, + output_node_names = dest_nodes, placeholder_type_enum = dtypes.float32.as_datatype_enum) input_list = [None] @@ -212,7 +210,7 @@ class TensorflowParser(Parser): # Build network graph self.tf_graph = TensorflowGraph(model) for node in self.tf_graph.model.node: - if node.name in input_node_names: + if node.name in in_nodes: node.attr['shape'].list.shape.extend([tensor_input.as_proto()]) node.attr['_output_shapes'].list.shape.pop() #unknown_rank pop node.attr['_output_shapes'].list.shape.extend([tensor_input.as_proto()]) diff --git a/tests/test_conversion_imagenet.py b/tests/test_conversion_imagenet.py index 836e8ea..badb7e9 100644 --- a/tests/test_conversion_imagenet.py +++ b/tests/test_conversion_imagenet.py @@ -131,7 +131,7 @@ class TestModels(CorrectnessTest): # original to IR IR_file = TestModels.tmpdir + 'tensorflow_frozen_' + architecture_name + "_converted" parser = TensorflowParser2( - TestModels.cachedir + para[0], para[1], para[2].split(':')[0], para[3].split(':')[0]) + TestModels.cachedir + para[0], [para[1]], [para[2].split(':')[0]], [para[3].split(':')[0]]) parser.run(IR_file) del parser del TensorflowParser2 @@ -1020,7 +1020,7 @@ class TestModels(CorrectnessTest): except ImportError: print('Please install Paddlepaddle! Or Paddlepaddle is not supported in your platform.', file=sys.stderr) - + def test_pytorch(self):