update inputshape readme and multiple input

This commit is contained in:
namizzz 2018-07-02 17:26:44 +08:00
Родитель 0ccc2100c3
Коммит d4c29b1566
6 изменённых файлов: 99 добавлений и 57 удалений

Просмотреть файл

@ -5,12 +5,15 @@ from six import text_type as _text_type
def _convert(args):
if args.inputShape != None:
inputshape = [int(x) for x in args.inputShape]
inputshape = []
for x in args.inputShape:
shape = x.split(',')
inputshape.append([int(x) for x in shape])
else:
inputshape = None
if args.srcFramework == 'caffe':
from mmdnn.conversion.caffe.transformer import CaffeTransformer
transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape, phase = args.caffePhase)
transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape[0], phase = args.caffePhase)
graph = transformer.transform_graph()
data = transformer.transform_data()
@ -48,25 +51,27 @@ def _convert(args):
# assert args.network or args.frozen_pb
if args.frozen_pb:
if args.inNodeName is None:
raise ValueError("Need to provide the input node of Tensorflow model.")
from mmdnn.conversion.tensorflow.tensorflow_frozenparser import TensorflowParser2
parser = TensorflowParser2(args.frozen_pb, inputshape, args.inNodeName, args.dstNodeName)
else:
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
if args.inNodeName and inputshape:
parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape, args.inNodeName)
if args.inNodeName and inputshape[0]:
parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape[0], args.inNodeName)
else:
parser = TensorflowParser(args.network, args.weights, args.dstNodeName)
elif args.srcFramework == 'mxnet':
assert inputshape != None
if args.weights == None:
model = (args.network, inputshape)
model = (args.network, inputshape[0])
else:
import re
if re.search('.', args.weights):
args.weights = args.weights[:-7]
prefix, epoch = args.weights.rsplit('-', 1)
model = (args.network, prefix, epoch, inputshape)
model = (args.network, prefix, epoch, inputshape[0])
from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
parser = MXNetParser(model)
@ -79,13 +84,13 @@ def _convert(args):
elif args.srcFramework == 'pytorch':
assert inputshape != None
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
parser = PytorchParser(args.network, inputshape)
parser = PytorchParser(args.network, inputshape[0])
elif args.srcFramework == 'torch' or args.srcFramework == 'torch7':
from mmdnn.conversion.torch.torch_parser import TorchParser
model = args.network or args.weights
assert model != None
parser = TorchParser(model, inputshape)
parser = TorchParser(model, inputshape[0])
elif args.srcFramework == 'onnx':
from mmdnn.conversion.onnx.onnx_parser import ONNXParser
@ -138,12 +143,14 @@ def _get_parser():
parser.add_argument(
'--inNodeName', '-inode',
nargs='+',
type=_text_type,
default='input',
default=None,
help="[Tensorflow] Input nodes' name of the graph.")
parser.add_argument(
'--dstNodeName', '-node',
nargs='+',
type=_text_type,
default=None,
help="[Tensorflow] Output nodes' name of the graph.")
@ -160,7 +167,7 @@ def _get_parser():
nargs='+',
type=_text_type,
default=None,
help='[MXNet/Caffe2/Torch7] Input shape of model (channel, height, width)')
help='[Tensorflow/MXNet/Caffe2/Torch7] Input shape of model (channel, height, width)')
# Caffe

Просмотреть файл

@ -57,6 +57,7 @@ CNTK model file is saved as [caffe_resnet152.dnn], generated by [069867aa7f674b7
Then you get the CNTK original model *caffe_resnet152.dnn* converted from Caffe. Temporal files are removed automatically.
If you want to assume a fixed inputshape, you can use "--inputShape 224,224,3"
---
## Step-by-step conversion (for debugging)

Просмотреть файл

@ -89,26 +89,48 @@ TensorBoard 0.4.0rc3 at http://kit-station:6006 (Press CTRL+C to quit)
## One-step conversion
Above MMdnn@0.1.4, we provide one command to achieve the conversion
Above MMdnn@0.1.4, we provide one command to achieve the conversion.
For checkpoint format:
```bash
$ mmconvert -sf tensorflow -in imagenet_inception_v1.ckpt.meta -iw inception_v1.ckpt -df cntk -om tf_inception_v1.dnn --inputShape 3 224 224 --dstNodeName MMdnn_Output
$ mmconvert -sf tensorflow -in ./model.ckpt.meta -iw ./model.ckpt -df caffe --inputShape 224,224,3 --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inNodeName Preprocessor/sub -om mobilenet_v1_caffe
.
.
.
Parse file [imagenet_inception_v1.ckpt.meta] with binary format successfully.
Tensorflow model file [imagenet_inception_v1.ckpt.meta] loaded successfully.
Tensorflow checkpoint file [inception_v1.ckpt] loaded successfully. [231] variables loaded.
IR network structure is saved as [b136dfb62a964486bb2bc2e27cece6e8.json].
IR network structure is saved as [b136dfb62a964486bb2bc2e27cece6e8.pb].
IR weights are saved as [b136dfb62a964486bb2bc2e27cece6e8.npy].
Parse file [b136dfb62a964486bb2bc2e27cece6e8.pb] with binary format successfully.
Target network code snippet is saved as [b136dfb62a964486bb2bc2e27cece6e8.py].
CNTK model file is saved as [tf_inception_v1.dnn], generated by [b136dfb62a964486bb2bc2e27cece6e8.py] and [b136dfb62a964486bb2bc2e27cece6e8.npy].
Parse file [./model.ckpt.meta] with binary format successfully.
Tensorflow model file [./model.ckpt.meta] loaded successfully.
Tensorflow checkpoint file [./model.ckpt] loaded successfully. [200] variables loaded.
IR network structure is saved as [61ecc03803a747429a9d4ff6dc346c21.json].
IR network structure is saved as [61ecc03803a747429a9d4ff6dc346c21.pb].
IR weights are saved as [61ecc03803a747429a9d4ff6dc346c21.npy].
Parse file [61ecc03803a747429a9d4ff6dc346c21.pb] with binary format successfully.
Target network code snippet is saved as [61ecc03803a747429a9d4ff6dc346c21.py].
Target weights are saved as [61ecc03803a747429a9d4ff6dc346c21.npy].
Caffe model files are saved as [mobilenet_v1_caffe.prototxt] and [mobilenet_v1_caffe.caffemodel], generated by [61ecc03803a747429a9d4ff6dc346c21.py] and [61ecc03803a747429a9d4ff6dc346c21.npy].
```
Then you get the CNTK original model *tf_inception_v1.dnn* converted from Tensorflow. Temporal files are removed automatically.
Then you get the Caffe original model *mobilenet_v1_caffe.prototxt* and *mobilenet_v1_caffe.caffemodel* converted from Tensorflow. Temporal files are removed automatically.
For frozen protobuf format:
```bash
$ mmconvert -sf tensorflow --frozen_pb entropy.pb -df caffe --inputShape 108,140,1 --dstNodeName Dense2/fc5/BiasAdd --inNodeName X -om entropy_caffe
.
.
.
IR network structure is saved as [2217c0216dd445cca7e44255d989c6c3.json].
IR network structure is saved as [2217c0216dd445cca7e44255d989c6c3.pb].
IR weights are saved as [2217c0216dd445cca7e44255d989c6c3.npy].
Parse file [2217c0216dd445cca7e44255d989c6c3.pb] with binary format successfully.
Target network code snippet is saved as [2217c0216dd445cca7e44255d989c6c3.py].
Target weights are saved as [2217c0216dd445cca7e44255d989c6c3.npy].
Caffe model files are saved as [entropy_caffe.prototxt] and [entropy_caffe.caffemodel], generated by [2217c0216dd445cca7e44255d989c6c3.py] and [2217c0216dd445cca7e44255d989c6c3.npy].
```
if there are more than one input nodes, you can use space to seperate them, eg.(--inputShape 224,224,3 4 --inNodeName image style)
---
@ -118,15 +140,16 @@ Then you get the CNTK original model *tf_inception_v1.dnn* converted from Tensor
You can convert only network structure to IR for visualization or training in other frameworks.
We use resnet_v2_152 model as an example.
We use MobilenetV1 model as an example.
```bash
$ mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta --dstNodeName MMdnn_Output
Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully.
Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully.
IR network structure is saved as [resnet152.json].
IR network structure is saved as [resnet152.pb].
$ mmtoir -f tensorflow -n ./model.ckpt.meta --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inputShape 224,224,3 --inNodeName Preprocessor/sub -d mobilenet_v1
Parse file [./model.ckpt.meta] with binary format successfully.
Tensorflow model file [./model.ckpt.meta] loaded successfully.
IR network structure is saved as [mobilenet_v1.json].
IR network structure is saved as [mobilenet_v1.pb].
Warning: weights are not loaded.
```
@ -135,28 +158,29 @@ Warning: weights are not loaded.
You can use following bash command to convert the checkpoint files to IR architecture file [*resnet152.pb*], [*resnet152.json*] and IR weights file [*resnet152.npy*]
```bash
$ mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta -w imagenet_resnet_v2_152.ckpt --dstNodeName MMdnn_Output
$ mmtoir -f tensorflow -d resnet152 -n ./model.ckpt.meta -w ./model.ckpt -d mobilenet_v1_tf --inputShape 224,224,3 --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inNodeName Preprocessor/sub
Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully.
Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully.
Tensorflow checkpoint file [imagenet_resnet_v2_152.ckpt] loaded successfully. [816] variables loaded.
IR network structure is saved as [resnet152.json].
IR network structure is saved as [resnet152.pb].
IR weights are saved as [resnet152.npy].
Parse file [./model.ckpt.meta] with binary format successfully.
Tensorflow model file [./model.ckpt.meta] loaded successfully.
Tensorflow checkpoint file [./model.ckpt] loaded successfully. [200] variables loaded.
IR network structure is saved as [mobilenet_v1_tf.json].
IR network structure is saved as [mobilenet_v1_tf.pb].
IR weights are saved as [mobilenet_v1_tf.npy].
```
### Convert frozen protobuf model file from Tensorflow to IR
### Convert frozen protobuf model file(.pb) from Tensorflow to IR
You can convert frozen protobuf model file to IR for visualization or training in other frameworks.
We use resnet_v2_152 model as an example.
We use inception_v1 as an example.
```bash
$ mmtoir -f tensorflow --frozen_pb inception_v1_2016_08_28_frozen.pb -d inception_v1 --dstNodeName InceptionV1/Logits/Predictions/Reshape_1 --inputShape 224 224 3
$ mmtoir -f tensorflow --frozen_pb ./tests/cache/inception_v1_2016_08_28_frozen.pb -d inception_v1_part --dstNodeName InceptionV1/Logits/Predictions/Reshape_1 --inputShape 28 28 192 --inNodeName InceptionV1/InceptionV1/MaxPool_3a_3x3/MaxPool
IR network structure is saved as [inception_v1.json].
IR network structure is saved as [inception_v1.pb].
IR weights are saved as [inception_v1.npy].
IR network structure is saved as [inception_v1_part.json].
IR network structure is saved as [inception_v1_part.pb].
IR weights are saved as [inception_v1_part.npy].
```
### Convert models from IR to Tensorflow code snippet

Просмотреть файл

@ -95,16 +95,19 @@ class TensorflowParser2(Parser):
original_gdef = tensorflow.GraphDef()
original_gdef.ParseFromString(serialized)
# model = original_gdef
in_type_list = {}
for n in original_gdef.node:
if n.name in in_nodes:
in_type_list[n.name] = n.attr['dtype'].type
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
input_node_names = in_nodes.split(',')
output_node_names = dest_nodes.split(',')
original_gdef = strip_unused_lib.strip_unused(
input_graph_def = original_gdef,
input_node_names = input_node_names,
output_node_names = output_node_names,
input_node_names = in_nodes,
output_node_names = dest_nodes,
placeholder_type_enum = dtypes.float32.as_datatype_enum)
# Save it to an output file
frozen_model_file = './frozen.pb'
@ -118,14 +121,24 @@ class TensorflowParser2(Parser):
output_shape_map = dict()
input_shape_map = dict()
with tensorflow.Graph().as_default() as g:
x = tensorflow.placeholder(tensorflow.float32, shape = [None] + inputshape)
tensorflow.import_graph_def(model, name='', input_map={in_nodes + ':0' : x})
input_map = {}
for i in range(len(inputshape)):
if in_type_list[in_nodes[i]] == 1:
dtype = tensorflow.float32
elif in_type_list[in_nodes[i]] == 3:
dtype = tensorflow.int32
x = tensorflow.placeholder(dtype, shape = [None] + inputshape[i])
input_map[in_nodes[i] + ':0'] = x
tensorflow.import_graph_def(model, name='', input_map=input_map)
with tensorflow.Session(graph = g) as sess:
meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
model = meta_graph_def.graph_def
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
@ -510,12 +523,11 @@ class TensorflowParser2(Parser):
self.set_weight(source_node.name, 'mean', mean)
def rename_Placeholder(self, source_node):
# print(source_node.layer.attr["shape"].shape)
# print(source_node.layer)
if source_node.layer.attr["shape"].shape.unknown_rank == True:
return
IR_node = self._convert_identity_operation(source_node, new_op='DataInput')
TensorflowParser2._copy_shape(source_node, IR_node)
IR_node.attr['shape'].shape.dim[0].size = -1
IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1

Просмотреть файл

@ -197,12 +197,10 @@ class TensorflowParser(Parser):
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
input_node_names = in_nodes.split(',')
output_node_names = dest_nodes.split(',')
model = strip_unused_lib.strip_unused(
input_graph_def = model,
input_node_names = input_node_names,
output_node_names = output_node_names,
input_node_names = in_nodes,
output_node_names = dest_nodes,
placeholder_type_enum = dtypes.float32.as_datatype_enum)
input_list = [None]
@ -212,7 +210,7 @@ class TensorflowParser(Parser):
# Build network graph
self.tf_graph = TensorflowGraph(model)
for node in self.tf_graph.model.node:
if node.name in input_node_names:
if node.name in in_nodes:
node.attr['shape'].list.shape.extend([tensor_input.as_proto()])
node.attr['_output_shapes'].list.shape.pop() #unknown_rank pop
node.attr['_output_shapes'].list.shape.extend([tensor_input.as_proto()])

Просмотреть файл

@ -131,7 +131,7 @@ class TestModels(CorrectnessTest):
# original to IR
IR_file = TestModels.tmpdir + 'tensorflow_frozen_' + architecture_name + "_converted"
parser = TensorflowParser2(
TestModels.cachedir + para[0], para[1], para[2].split(':')[0], para[3].split(':')[0])
TestModels.cachedir + para[0], [para[1]], [para[2].split(':')[0]], [para[3].split(':')[0]])
parser.run(IR_file)
del parser
del TensorflowParser2
@ -1020,7 +1020,7 @@ class TestModels(CorrectnessTest):
except ImportError:
print('Please install Paddlepaddle! Or Paddlepaddle is not supported in your platform.', file=sys.stderr)
def test_pytorch(self):