зеркало из https://github.com/microsoft/MMdnn.git
update inputshape readme and multiple input
This commit is contained in:
Родитель
0ccc2100c3
Коммит
d4c29b1566
|
@ -5,12 +5,15 @@ from six import text_type as _text_type
|
||||||
|
|
||||||
def _convert(args):
|
def _convert(args):
|
||||||
if args.inputShape != None:
|
if args.inputShape != None:
|
||||||
inputshape = [int(x) for x in args.inputShape]
|
inputshape = []
|
||||||
|
for x in args.inputShape:
|
||||||
|
shape = x.split(',')
|
||||||
|
inputshape.append([int(x) for x in shape])
|
||||||
else:
|
else:
|
||||||
inputshape = None
|
inputshape = None
|
||||||
if args.srcFramework == 'caffe':
|
if args.srcFramework == 'caffe':
|
||||||
from mmdnn.conversion.caffe.transformer import CaffeTransformer
|
from mmdnn.conversion.caffe.transformer import CaffeTransformer
|
||||||
transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape, phase = args.caffePhase)
|
transformer = CaffeTransformer(args.network, args.weights, "tensorflow", inputshape[0], phase = args.caffePhase)
|
||||||
graph = transformer.transform_graph()
|
graph = transformer.transform_graph()
|
||||||
data = transformer.transform_data()
|
data = transformer.transform_data()
|
||||||
|
|
||||||
|
@ -48,25 +51,27 @@ def _convert(args):
|
||||||
|
|
||||||
# assert args.network or args.frozen_pb
|
# assert args.network or args.frozen_pb
|
||||||
if args.frozen_pb:
|
if args.frozen_pb:
|
||||||
|
if args.inNodeName is None:
|
||||||
|
raise ValueError("Need to provide the input node of Tensorflow model.")
|
||||||
from mmdnn.conversion.tensorflow.tensorflow_frozenparser import TensorflowParser2
|
from mmdnn.conversion.tensorflow.tensorflow_frozenparser import TensorflowParser2
|
||||||
parser = TensorflowParser2(args.frozen_pb, inputshape, args.inNodeName, args.dstNodeName)
|
parser = TensorflowParser2(args.frozen_pb, inputshape, args.inNodeName, args.dstNodeName)
|
||||||
else:
|
else:
|
||||||
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
|
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
|
||||||
if args.inNodeName and inputshape:
|
if args.inNodeName and inputshape[0]:
|
||||||
parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape, args.inNodeName)
|
parser = TensorflowParser(args.network, args.weights, args.dstNodeName, inputshape[0], args.inNodeName)
|
||||||
else:
|
else:
|
||||||
parser = TensorflowParser(args.network, args.weights, args.dstNodeName)
|
parser = TensorflowParser(args.network, args.weights, args.dstNodeName)
|
||||||
|
|
||||||
elif args.srcFramework == 'mxnet':
|
elif args.srcFramework == 'mxnet':
|
||||||
assert inputshape != None
|
assert inputshape != None
|
||||||
if args.weights == None:
|
if args.weights == None:
|
||||||
model = (args.network, inputshape)
|
model = (args.network, inputshape[0])
|
||||||
else:
|
else:
|
||||||
import re
|
import re
|
||||||
if re.search('.', args.weights):
|
if re.search('.', args.weights):
|
||||||
args.weights = args.weights[:-7]
|
args.weights = args.weights[:-7]
|
||||||
prefix, epoch = args.weights.rsplit('-', 1)
|
prefix, epoch = args.weights.rsplit('-', 1)
|
||||||
model = (args.network, prefix, epoch, inputshape)
|
model = (args.network, prefix, epoch, inputshape[0])
|
||||||
|
|
||||||
from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
|
from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
|
||||||
parser = MXNetParser(model)
|
parser = MXNetParser(model)
|
||||||
|
@ -79,13 +84,13 @@ def _convert(args):
|
||||||
elif args.srcFramework == 'pytorch':
|
elif args.srcFramework == 'pytorch':
|
||||||
assert inputshape != None
|
assert inputshape != None
|
||||||
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
|
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
|
||||||
parser = PytorchParser(args.network, inputshape)
|
parser = PytorchParser(args.network, inputshape[0])
|
||||||
|
|
||||||
elif args.srcFramework == 'torch' or args.srcFramework == 'torch7':
|
elif args.srcFramework == 'torch' or args.srcFramework == 'torch7':
|
||||||
from mmdnn.conversion.torch.torch_parser import TorchParser
|
from mmdnn.conversion.torch.torch_parser import TorchParser
|
||||||
model = args.network or args.weights
|
model = args.network or args.weights
|
||||||
assert model != None
|
assert model != None
|
||||||
parser = TorchParser(model, inputshape)
|
parser = TorchParser(model, inputshape[0])
|
||||||
|
|
||||||
elif args.srcFramework == 'onnx':
|
elif args.srcFramework == 'onnx':
|
||||||
from mmdnn.conversion.onnx.onnx_parser import ONNXParser
|
from mmdnn.conversion.onnx.onnx_parser import ONNXParser
|
||||||
|
@ -138,12 +143,14 @@ def _get_parser():
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--inNodeName', '-inode',
|
'--inNodeName', '-inode',
|
||||||
|
nargs='+',
|
||||||
type=_text_type,
|
type=_text_type,
|
||||||
default='input',
|
default=None,
|
||||||
help="[Tensorflow] Input nodes' name of the graph.")
|
help="[Tensorflow] Input nodes' name of the graph.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dstNodeName', '-node',
|
'--dstNodeName', '-node',
|
||||||
|
nargs='+',
|
||||||
type=_text_type,
|
type=_text_type,
|
||||||
default=None,
|
default=None,
|
||||||
help="[Tensorflow] Output nodes' name of the graph.")
|
help="[Tensorflow] Output nodes' name of the graph.")
|
||||||
|
@ -160,7 +167,7 @@ def _get_parser():
|
||||||
nargs='+',
|
nargs='+',
|
||||||
type=_text_type,
|
type=_text_type,
|
||||||
default=None,
|
default=None,
|
||||||
help='[MXNet/Caffe2/Torch7] Input shape of model (channel, height, width)')
|
help='[Tensorflow/MXNet/Caffe2/Torch7] Input shape of model (channel, height, width)')
|
||||||
|
|
||||||
|
|
||||||
# Caffe
|
# Caffe
|
||||||
|
|
|
@ -57,6 +57,7 @@ CNTK model file is saved as [caffe_resnet152.dnn], generated by [069867aa7f674b7
|
||||||
|
|
||||||
Then you get the CNTK original model *caffe_resnet152.dnn* converted from Caffe. Temporal files are removed automatically.
|
Then you get the CNTK original model *caffe_resnet152.dnn* converted from Caffe. Temporal files are removed automatically.
|
||||||
|
|
||||||
|
If you want to assume a fixed inputshape, you can use "--inputShape 224,224,3"
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step-by-step conversion (for debugging)
|
## Step-by-step conversion (for debugging)
|
||||||
|
|
|
@ -89,26 +89,48 @@ TensorBoard 0.4.0rc3 at http://kit-station:6006 (Press CTRL+C to quit)
|
||||||
|
|
||||||
## One-step conversion
|
## One-step conversion
|
||||||
|
|
||||||
Above MMdnn@0.1.4, we provide one command to achieve the conversion
|
Above MMdnn@0.1.4, we provide one command to achieve the conversion.
|
||||||
|
For checkpoint format:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ mmconvert -sf tensorflow -in imagenet_inception_v1.ckpt.meta -iw inception_v1.ckpt -df cntk -om tf_inception_v1.dnn --inputShape 3 224 224 --dstNodeName MMdnn_Output
|
$ mmconvert -sf tensorflow -in ./model.ckpt.meta -iw ./model.ckpt -df caffe --inputShape 224,224,3 --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inNodeName Preprocessor/sub -om mobilenet_v1_caffe
|
||||||
.
|
.
|
||||||
.
|
.
|
||||||
.
|
.
|
||||||
Parse file [imagenet_inception_v1.ckpt.meta] with binary format successfully.
|
Parse file [./model.ckpt.meta] with binary format successfully.
|
||||||
Tensorflow model file [imagenet_inception_v1.ckpt.meta] loaded successfully.
|
Tensorflow model file [./model.ckpt.meta] loaded successfully.
|
||||||
Tensorflow checkpoint file [inception_v1.ckpt] loaded successfully. [231] variables loaded.
|
Tensorflow checkpoint file [./model.ckpt] loaded successfully. [200] variables loaded.
|
||||||
IR network structure is saved as [b136dfb62a964486bb2bc2e27cece6e8.json].
|
IR network structure is saved as [61ecc03803a747429a9d4ff6dc346c21.json].
|
||||||
IR network structure is saved as [b136dfb62a964486bb2bc2e27cece6e8.pb].
|
IR network structure is saved as [61ecc03803a747429a9d4ff6dc346c21.pb].
|
||||||
IR weights are saved as [b136dfb62a964486bb2bc2e27cece6e8.npy].
|
IR weights are saved as [61ecc03803a747429a9d4ff6dc346c21.npy].
|
||||||
Parse file [b136dfb62a964486bb2bc2e27cece6e8.pb] with binary format successfully.
|
Parse file [61ecc03803a747429a9d4ff6dc346c21.pb] with binary format successfully.
|
||||||
Target network code snippet is saved as [b136dfb62a964486bb2bc2e27cece6e8.py].
|
Target network code snippet is saved as [61ecc03803a747429a9d4ff6dc346c21.py].
|
||||||
CNTK model file is saved as [tf_inception_v1.dnn], generated by [b136dfb62a964486bb2bc2e27cece6e8.py] and [b136dfb62a964486bb2bc2e27cece6e8.npy].
|
Target weights are saved as [61ecc03803a747429a9d4ff6dc346c21.npy].
|
||||||
|
Caffe model files are saved as [mobilenet_v1_caffe.prototxt] and [mobilenet_v1_caffe.caffemodel], generated by [61ecc03803a747429a9d4ff6dc346c21.py] and [61ecc03803a747429a9d4ff6dc346c21.npy].
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then you get the CNTK original model *tf_inception_v1.dnn* converted from Tensorflow. Temporal files are removed automatically.
|
Then you get the Caffe original model *mobilenet_v1_caffe.prototxt* and *mobilenet_v1_caffe.caffemodel* converted from Tensorflow. Temporal files are removed automatically.
|
||||||
|
|
||||||
|
For frozen protobuf format:
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ mmconvert -sf tensorflow --frozen_pb entropy.pb -df caffe --inputShape 108,140,1 --dstNodeName Dense2/fc5/BiasAdd --inNodeName X -om entropy_caffe
|
||||||
|
.
|
||||||
|
.
|
||||||
|
.
|
||||||
|
IR network structure is saved as [2217c0216dd445cca7e44255d989c6c3.json].
|
||||||
|
IR network structure is saved as [2217c0216dd445cca7e44255d989c6c3.pb].
|
||||||
|
IR weights are saved as [2217c0216dd445cca7e44255d989c6c3.npy].
|
||||||
|
Parse file [2217c0216dd445cca7e44255d989c6c3.pb] with binary format successfully.
|
||||||
|
Target network code snippet is saved as [2217c0216dd445cca7e44255d989c6c3.py].
|
||||||
|
Target weights are saved as [2217c0216dd445cca7e44255d989c6c3.npy].
|
||||||
|
Caffe model files are saved as [entropy_caffe.prototxt] and [entropy_caffe.caffemodel], generated by [2217c0216dd445cca7e44255d989c6c3.py] and [2217c0216dd445cca7e44255d989c6c3.npy].
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
if there are more than one input nodes, you can use space to seperate them, eg.(--inputShape 224,224,3 4 --inNodeName image style)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -118,15 +140,16 @@ Then you get the CNTK original model *tf_inception_v1.dnn* converted from Tensor
|
||||||
|
|
||||||
You can convert only network structure to IR for visualization or training in other frameworks.
|
You can convert only network structure to IR for visualization or training in other frameworks.
|
||||||
|
|
||||||
We use resnet_v2_152 model as an example.
|
We use MobilenetV1 model as an example.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta --dstNodeName MMdnn_Output
|
|
||||||
|
|
||||||
Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully.
|
$ mmtoir -f tensorflow -n ./model.ckpt.meta --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inputShape 224,224,3 --inNodeName Preprocessor/sub -d mobilenet_v1
|
||||||
Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully.
|
|
||||||
IR network structure is saved as [resnet152.json].
|
Parse file [./model.ckpt.meta] with binary format successfully.
|
||||||
IR network structure is saved as [resnet152.pb].
|
Tensorflow model file [./model.ckpt.meta] loaded successfully.
|
||||||
|
IR network structure is saved as [mobilenet_v1.json].
|
||||||
|
IR network structure is saved as [mobilenet_v1.pb].
|
||||||
Warning: weights are not loaded.
|
Warning: weights are not loaded.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -135,28 +158,29 @@ Warning: weights are not loaded.
|
||||||
You can use following bash command to convert the checkpoint files to IR architecture file [*resnet152.pb*], [*resnet152.json*] and IR weights file [*resnet152.npy*]
|
You can use following bash command to convert the checkpoint files to IR architecture file [*resnet152.pb*], [*resnet152.json*] and IR weights file [*resnet152.npy*]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta -w imagenet_resnet_v2_152.ckpt --dstNodeName MMdnn_Output
|
$ mmtoir -f tensorflow -d resnet152 -n ./model.ckpt.meta -w ./model.ckpt -d mobilenet_v1_tf --inputShape 224,224,3 --dstNodeName FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6 --inNodeName Preprocessor/sub
|
||||||
|
|
||||||
Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully.
|
|
||||||
Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully.
|
Parse file [./model.ckpt.meta] with binary format successfully.
|
||||||
Tensorflow checkpoint file [imagenet_resnet_v2_152.ckpt] loaded successfully. [816] variables loaded.
|
Tensorflow model file [./model.ckpt.meta] loaded successfully.
|
||||||
IR network structure is saved as [resnet152.json].
|
Tensorflow checkpoint file [./model.ckpt] loaded successfully. [200] variables loaded.
|
||||||
IR network structure is saved as [resnet152.pb].
|
IR network structure is saved as [mobilenet_v1_tf.json].
|
||||||
IR weights are saved as [resnet152.npy].
|
IR network structure is saved as [mobilenet_v1_tf.pb].
|
||||||
|
IR weights are saved as [mobilenet_v1_tf.npy].
|
||||||
```
|
```
|
||||||
|
|
||||||
### Convert frozen protobuf model file from Tensorflow to IR
|
### Convert frozen protobuf model file(.pb) from Tensorflow to IR
|
||||||
|
|
||||||
You can convert frozen protobuf model file to IR for visualization or training in other frameworks.
|
You can convert frozen protobuf model file to IR for visualization or training in other frameworks.
|
||||||
|
|
||||||
We use resnet_v2_152 model as an example.
|
We use inception_v1 as an example.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ mmtoir -f tensorflow --frozen_pb inception_v1_2016_08_28_frozen.pb -d inception_v1 --dstNodeName InceptionV1/Logits/Predictions/Reshape_1 --inputShape 224 224 3
|
$ mmtoir -f tensorflow --frozen_pb ./tests/cache/inception_v1_2016_08_28_frozen.pb -d inception_v1_part --dstNodeName InceptionV1/Logits/Predictions/Reshape_1 --inputShape 28 28 192 --inNodeName InceptionV1/InceptionV1/MaxPool_3a_3x3/MaxPool
|
||||||
|
|
||||||
IR network structure is saved as [inception_v1.json].
|
IR network structure is saved as [inception_v1_part.json].
|
||||||
IR network structure is saved as [inception_v1.pb].
|
IR network structure is saved as [inception_v1_part.pb].
|
||||||
IR weights are saved as [inception_v1.npy].
|
IR weights are saved as [inception_v1_part.npy].
|
||||||
```
|
```
|
||||||
|
|
||||||
### Convert models from IR to Tensorflow code snippet
|
### Convert models from IR to Tensorflow code snippet
|
||||||
|
|
|
@ -95,16 +95,19 @@ class TensorflowParser2(Parser):
|
||||||
original_gdef = tensorflow.GraphDef()
|
original_gdef = tensorflow.GraphDef()
|
||||||
|
|
||||||
original_gdef.ParseFromString(serialized)
|
original_gdef.ParseFromString(serialized)
|
||||||
# model = original_gdef
|
|
||||||
|
in_type_list = {}
|
||||||
|
for n in original_gdef.node:
|
||||||
|
if n.name in in_nodes:
|
||||||
|
in_type_list[n.name] = n.attr['dtype'].type
|
||||||
|
|
||||||
from tensorflow.python.tools import strip_unused_lib
|
from tensorflow.python.tools import strip_unused_lib
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
input_node_names = in_nodes.split(',')
|
|
||||||
output_node_names = dest_nodes.split(',')
|
|
||||||
original_gdef = strip_unused_lib.strip_unused(
|
original_gdef = strip_unused_lib.strip_unused(
|
||||||
input_graph_def = original_gdef,
|
input_graph_def = original_gdef,
|
||||||
input_node_names = input_node_names,
|
input_node_names = in_nodes,
|
||||||
output_node_names = output_node_names,
|
output_node_names = dest_nodes,
|
||||||
placeholder_type_enum = dtypes.float32.as_datatype_enum)
|
placeholder_type_enum = dtypes.float32.as_datatype_enum)
|
||||||
# Save it to an output file
|
# Save it to an output file
|
||||||
frozen_model_file = './frozen.pb'
|
frozen_model_file = './frozen.pb'
|
||||||
|
@ -118,14 +121,24 @@ class TensorflowParser2(Parser):
|
||||||
|
|
||||||
output_shape_map = dict()
|
output_shape_map = dict()
|
||||||
input_shape_map = dict()
|
input_shape_map = dict()
|
||||||
|
|
||||||
with tensorflow.Graph().as_default() as g:
|
with tensorflow.Graph().as_default() as g:
|
||||||
x = tensorflow.placeholder(tensorflow.float32, shape = [None] + inputshape)
|
input_map = {}
|
||||||
tensorflow.import_graph_def(model, name='', input_map={in_nodes + ':0' : x})
|
for i in range(len(inputshape)):
|
||||||
|
if in_type_list[in_nodes[i]] == 1:
|
||||||
|
dtype = tensorflow.float32
|
||||||
|
elif in_type_list[in_nodes[i]] == 3:
|
||||||
|
dtype = tensorflow.int32
|
||||||
|
x = tensorflow.placeholder(dtype, shape = [None] + inputshape[i])
|
||||||
|
input_map[in_nodes[i] + ':0'] = x
|
||||||
|
|
||||||
|
tensorflow.import_graph_def(model, name='', input_map=input_map)
|
||||||
|
|
||||||
with tensorflow.Session(graph = g) as sess:
|
with tensorflow.Session(graph = g) as sess:
|
||||||
meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
|
meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
|
||||||
model = meta_graph_def.graph_def
|
model = meta_graph_def.graph_def
|
||||||
|
|
||||||
|
|
||||||
self.tf_graph = TensorflowGraph(model)
|
self.tf_graph = TensorflowGraph(model)
|
||||||
self.tf_graph.build()
|
self.tf_graph.build()
|
||||||
|
|
||||||
|
@ -510,12 +523,11 @@ class TensorflowParser2(Parser):
|
||||||
self.set_weight(source_node.name, 'mean', mean)
|
self.set_weight(source_node.name, 'mean', mean)
|
||||||
|
|
||||||
def rename_Placeholder(self, source_node):
|
def rename_Placeholder(self, source_node):
|
||||||
# print(source_node.layer.attr["shape"].shape)
|
# print(source_node.layer)
|
||||||
if source_node.layer.attr["shape"].shape.unknown_rank == True:
|
if source_node.layer.attr["shape"].shape.unknown_rank == True:
|
||||||
return
|
return
|
||||||
IR_node = self._convert_identity_operation(source_node, new_op='DataInput')
|
IR_node = self._convert_identity_operation(source_node, new_op='DataInput')
|
||||||
TensorflowParser2._copy_shape(source_node, IR_node)
|
TensorflowParser2._copy_shape(source_node, IR_node)
|
||||||
|
|
||||||
IR_node.attr['shape'].shape.dim[0].size = -1
|
IR_node.attr['shape'].shape.dim[0].size = -1
|
||||||
IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1
|
IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1
|
||||||
|
|
||||||
|
|
|
@ -197,12 +197,10 @@ class TensorflowParser(Parser):
|
||||||
from tensorflow.python.tools import strip_unused_lib
|
from tensorflow.python.tools import strip_unused_lib
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
input_node_names = in_nodes.split(',')
|
|
||||||
output_node_names = dest_nodes.split(',')
|
|
||||||
model = strip_unused_lib.strip_unused(
|
model = strip_unused_lib.strip_unused(
|
||||||
input_graph_def = model,
|
input_graph_def = model,
|
||||||
input_node_names = input_node_names,
|
input_node_names = in_nodes,
|
||||||
output_node_names = output_node_names,
|
output_node_names = dest_nodes,
|
||||||
placeholder_type_enum = dtypes.float32.as_datatype_enum)
|
placeholder_type_enum = dtypes.float32.as_datatype_enum)
|
||||||
|
|
||||||
input_list = [None]
|
input_list = [None]
|
||||||
|
@ -212,7 +210,7 @@ class TensorflowParser(Parser):
|
||||||
# Build network graph
|
# Build network graph
|
||||||
self.tf_graph = TensorflowGraph(model)
|
self.tf_graph = TensorflowGraph(model)
|
||||||
for node in self.tf_graph.model.node:
|
for node in self.tf_graph.model.node:
|
||||||
if node.name in input_node_names:
|
if node.name in in_nodes:
|
||||||
node.attr['shape'].list.shape.extend([tensor_input.as_proto()])
|
node.attr['shape'].list.shape.extend([tensor_input.as_proto()])
|
||||||
node.attr['_output_shapes'].list.shape.pop() #unknown_rank pop
|
node.attr['_output_shapes'].list.shape.pop() #unknown_rank pop
|
||||||
node.attr['_output_shapes'].list.shape.extend([tensor_input.as_proto()])
|
node.attr['_output_shapes'].list.shape.extend([tensor_input.as_proto()])
|
||||||
|
|
|
@ -131,7 +131,7 @@ class TestModels(CorrectnessTest):
|
||||||
# original to IR
|
# original to IR
|
||||||
IR_file = TestModels.tmpdir + 'tensorflow_frozen_' + architecture_name + "_converted"
|
IR_file = TestModels.tmpdir + 'tensorflow_frozen_' + architecture_name + "_converted"
|
||||||
parser = TensorflowParser2(
|
parser = TensorflowParser2(
|
||||||
TestModels.cachedir + para[0], para[1], para[2].split(':')[0], para[3].split(':')[0])
|
TestModels.cachedir + para[0], [para[1]], [para[2].split(':')[0]], [para[3].split(':')[0]])
|
||||||
parser.run(IR_file)
|
parser.run(IR_file)
|
||||||
del parser
|
del parser
|
||||||
del TensorflowParser2
|
del TensorflowParser2
|
||||||
|
|
Загрузка…
Ссылка в новой задаче