зеркало из https://github.com/microsoft/MMdnn.git
Merge pull request #371 from namizzz/master
tf frozen parser fix and new extractor
This commit is contained in:
Коммит
dde0b4a757
|
@ -7,6 +7,7 @@ from __future__ import division
|
|||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import caffe
|
||||
|
@ -145,13 +146,11 @@ if __name__=='__main__':
|
|||
self.save_weights(self.weights_dict, dstWeightPath)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _shapeToStr(shapes):
|
||||
return [dim.size if dim.size > 0 else 1 for dim in shapes.dim]
|
||||
|
||||
|
||||
|
||||
def _get_symmetric_padding(self, IR_node):
|
||||
stride_h = IR_node.get_attr('strides')[1]
|
||||
stride_w = IR_node.get_attr('strides')[2]
|
||||
|
@ -166,6 +165,7 @@ if __name__=='__main__':
|
|||
pad_w = pads[2] + (0 if pads[2] == pads[6] else stride_w)
|
||||
return pad_h, pad_w
|
||||
|
||||
|
||||
def check_if_need_transpose(self, IR_node):
|
||||
parent = self.IR_graph.get_parent(IR_node.name, [0])
|
||||
while parent.type == 'Flatten' or parent.type == 'Dropout' or parent.type == 'Reshape':
|
||||
|
@ -388,28 +388,21 @@ if __name__=='__main__':
|
|||
self.parent_variable_name(IR_node),
|
||||
IR_node.get_attr('use_bias', False)
|
||||
))
|
||||
|
||||
if self.weight_loaded:
|
||||
try:
|
||||
self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(IR_node.name)
|
||||
except:
|
||||
self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(IR_node.name + "_second")
|
||||
self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(IR_node.name)
|
||||
|
||||
|
||||
def emit_Constant(self, IR_node):
|
||||
IR_node_after = self.IR_graph.get_son(IR_node.name, [0])
|
||||
shape = IR_node_after.get_attr("_output_shapes")[0]
|
||||
shape = shape_to_list(shape)
|
||||
if IR_node_after.type == 'Mul':
|
||||
return
|
||||
else: #Sub
|
||||
self.add_body(1, "n.{:<15} = L.DummyData(shape=[dict(dim=[1,{},{},{}])], data_filler=dict(type='constant', value={}), ntop=1)".format(
|
||||
IR_node.variable_name,
|
||||
shape[-1],
|
||||
shape[1],
|
||||
shape[2],
|
||||
self.weights_dict[IR_node.name]['value'][0]
|
||||
))
|
||||
self.add_body(1, "n.{:<15} = L.DummyData(shape=[dict(dim=[1,{},{},{}])], data_filler=dict(type='constant', value={}), ntop=1)".format(
|
||||
IR_node.variable_name,
|
||||
shape[-1],
|
||||
shape[1],
|
||||
shape[2],
|
||||
self.weights_dict[IR_node.name]['value'][0]
|
||||
))
|
||||
|
||||
|
||||
def emit_LRN(self, IR_node):
|
||||
|
@ -459,8 +452,10 @@ if __name__=='__main__':
|
|||
axis
|
||||
))
|
||||
|
||||
# def emit_Tanh(self, IR_node):
|
||||
# self._emit_activation(IR_node, 'ops.tanh')
|
||||
def emit_Sigmoid(self, IR_node):
|
||||
self.add_body(1, "n.{:<15} = L.Sigmoid(n.{}, ntop=1)".format(
|
||||
IR_node.variable_name,
|
||||
self.parent_variable_name(IR_node)))
|
||||
|
||||
|
||||
def emit_Relu(self, IR_node):
|
||||
|
@ -568,8 +563,32 @@ if __name__=='__main__':
|
|||
input_layers))
|
||||
|
||||
def emit_Mul(self, IR_node):
|
||||
self.emit_Scale(IR_node)
|
||||
if len(IR_node.in_edges) == 2:
|
||||
input_layers = ', '.join(('n.' + self.IR_graph.get_node(edge).real_variable_name) for edge in IR_node.in_edges)
|
||||
self.add_body(1, "n.{:<15} = L.Eltwise({}, operation=0, ntop=1)".format(
|
||||
IR_node.variable_name,
|
||||
input_layers))
|
||||
elif len(IR_node.in_edges) == 1:
|
||||
self.emit_Scale(IR_node)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def emit_UpSampling2D(self, IR_node):
|
||||
scales = IR_node.get_attr('scales')
|
||||
scale = tuple(scales)[0]
|
||||
|
||||
shape = IR_node.get_attr('_output_shapes')[0]
|
||||
shape = shape_to_list(shape)
|
||||
|
||||
self.add_body(1, "n.{:<15} = L.Deconvolution(n.{}, convolution_param=dict(kernel_size={}, stride={}, pad={}, num_output={}, group={}, bias_term={}), param=[dict(lr_mult=0)], ntop=1)".format(
|
||||
IR_node.variable_name,
|
||||
IR_node.in_edges[0],
|
||||
2 * scale - scale % 2,
|
||||
scale,
|
||||
int(math.ceil((scale - 1) / 2)),
|
||||
shape[-1],
|
||||
shape[-1],
|
||||
False))
|
||||
|
||||
# def emit_Square(self, IR_node):
|
||||
# input_layers = ', '.join(('n.' + self.IR_graph.get_node(edge).real_variable_name) for edge in IR_node.in_edges)
|
||||
|
|
|
@ -471,14 +471,12 @@ def KitModel(weight_file = None):
|
|||
|
||||
|
||||
|
||||
def emit_upsample(self, IR_node):
|
||||
# print(IR_node.layer)
|
||||
# assert False
|
||||
def emit_UpSampling2D(self, IR_node):
|
||||
self.used_layers.add(IR_node.type)
|
||||
self.add_body(1, "{:<15} = Upsampling2D({}, stride = {}, name = '{}')".format(
|
||||
IR_node.variable_name,
|
||||
self.parent_variable_name(IR_node),
|
||||
IR_node.get_attr('strides'),
|
||||
IR_node.get_attr('scales')[0],
|
||||
IR_node.name))
|
||||
|
||||
|
||||
|
|
|
@ -408,7 +408,7 @@ class DarknetGraph(Graph):
|
|||
upsample_layer['type'] = 'upsample'
|
||||
upsample_param = OrderedDict()
|
||||
stride = block['stride']
|
||||
upsample_param['strides'] = int(stride)
|
||||
upsample_param['scales'] = [int(stride), int(stride)]
|
||||
upsample_param['_output_shape'] = [input_shape[0]] + [q*int(stride) for q in input_shape[1:3]] + [input_shape[-1]]
|
||||
upsample_layer['attr'] = upsample_param
|
||||
self.layer_map[upsample_layer['name']] = DarknetGraphNode(upsample_layer)
|
||||
|
|
|
@ -294,10 +294,10 @@ class DarknetParser(Parser):
|
|||
|
||||
|
||||
def rename_upsample(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op='upsample')
|
||||
stride = source_node.get_attr('strides')
|
||||
IR_node = self._convert_identity_operation(source_node, new_op='UpSampling2D')
|
||||
scales = source_node.get_attr('scales')
|
||||
kwargs = {}
|
||||
kwargs['strides'] = stride
|
||||
kwargs['scales'] = scales
|
||||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
|
|
@ -99,6 +99,7 @@ class TestKit(object):
|
|||
'resnet_v1_101' : lambda path : TestKit.ZeroCenter(path, 224),
|
||||
'resnet_v1_152' : lambda path : TestKit.ZeroCenter(path, 224),
|
||||
'resnet_v2_50' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v2_101' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v2_152' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v2_200' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet152' : lambda path : TestKit.Standard(path, 299),
|
||||
|
|
|
@ -98,6 +98,14 @@ class tensorflow_extractor(base_extractor):
|
|||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'resnet_v2_101' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz',
|
||||
'filename' : 'resnet_v2_101.ckpt',
|
||||
'builder' : lambda : resnet_v2.resnet_v2_101,
|
||||
'arg_scope' : resnet_v2.resnet_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'resnet_v2_152' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz',
|
||||
'filename' : 'resnet_v2_152.ckpt',
|
||||
|
|
|
@ -66,15 +66,17 @@ class Keras2Parser(Parser):
|
|||
# Load the model weights
|
||||
|
||||
try:
|
||||
from keras.applications.mobilenet import relu6
|
||||
from keras.applications.mobilenet import DepthwiseConv2D
|
||||
loaded_model = model_from_json(loaded_model_json, custom_objects={
|
||||
'relu6': _keras.applications.mobilenet.relu6,
|
||||
'DepthwiseConv2D': _keras.applications.mobilenet.DepthwiseConv2D})
|
||||
'relu6': _keras.applications.mobilenet.relu6,
|
||||
'DepthwiseConv2D': _keras.applications.mobilenet.DepthwiseConv2D})
|
||||
except:
|
||||
from keras_applications import mobilenet_v2
|
||||
import keras.layers as layers
|
||||
loaded_model = model_from_json(loaded_model_json, custom_objects={
|
||||
'relu6': mobilenet_v2.layers.ReLU(6, name='relu6'),
|
||||
'DepthwiseConv2D': layers.DepthwiseConv2D})
|
||||
'relu6': mobilenet_v2.layers.ReLU(6, name='relu6'),
|
||||
'DepthwiseConv2D': layers.DepthwiseConv2D})
|
||||
|
||||
|
||||
if model_weight_path:
|
||||
|
@ -550,10 +552,7 @@ class Keras2Parser(Parser):
|
|||
self.convert_inedge(source_node, IR_node)
|
||||
|
||||
# size
|
||||
IR_node.attr["size"].list.i.extend(source_node.keras_layer.size)
|
||||
|
||||
|
||||
|
||||
IR_node.attr["scales"].list.i.extend(source_node.keras_layer.size)
|
||||
|
||||
|
||||
def rename_Embedding(self, source_node):
|
||||
|
|
|
@ -301,11 +301,11 @@ def KitModel(weight_file = None):
|
|||
|
||||
|
||||
def emit_UpSampling2D(self, IR_node):
|
||||
size = IR_node.get_attr('size')
|
||||
size = tuple(size)
|
||||
scales = IR_node.get_attr('scales')
|
||||
scales = tuple(scales)
|
||||
self.add_body(1, "{:<15} = tf.keras.layers.UpSampling2D(size={})({})".format(
|
||||
IR_node.variable_name,
|
||||
size,
|
||||
scales,
|
||||
self.parent_variable_name(IR_node)))
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,9 @@ class TensorflowParser2(Parser):
|
|||
"Identity",
|
||||
# "Mean",
|
||||
# "Cast"
|
||||
"Pack"
|
||||
"Pack",
|
||||
"CheckNumerics",
|
||||
"Where"
|
||||
])
|
||||
|
||||
|
||||
|
@ -150,12 +152,6 @@ class TensorflowParser2(Parser):
|
|||
|
||||
tensorflow.import_graph_def(model, name='', input_map=input_map)
|
||||
|
||||
# graph_options = tensorflow.GraphOptions(
|
||||
# optimizer_options=tensorflow.OptimizerOptions(
|
||||
# opt_level=tensorflow.OptimizerOptions.L0, do_function_inlining=False))
|
||||
|
||||
# config = tensorflow.ConfigProto(graph_options=graph_options)
|
||||
# with tensorflow.Session(graph = g, config=config) as sess:
|
||||
with tensorflow.Session(graph = g) as sess:
|
||||
|
||||
meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
|
||||
|
@ -210,7 +206,6 @@ class TensorflowParser2(Parser):
|
|||
input_mul_A = self.get_parent(source_node.name, [0, 1])
|
||||
tensor_content = input_mul_A.get_attr('value')
|
||||
A_content = tensor_util.MakeNdarray(tensor_content)
|
||||
# print(A_content)
|
||||
self.set_weight(source_node.name, 'A', A_content)
|
||||
|
||||
# b
|
||||
|
@ -350,10 +345,7 @@ class TensorflowParser2(Parser):
|
|||
if self._skip_node(current_node):
|
||||
continue
|
||||
|
||||
# print(current_node.name)
|
||||
node_type = current_node.type
|
||||
# print(current_node.name)
|
||||
# print(node_type)
|
||||
|
||||
if hasattr(self, "rename_" + node_type):
|
||||
func = getattr(self, "rename_" + node_type)
|
||||
|
@ -390,11 +382,10 @@ class TensorflowParser2(Parser):
|
|||
if source_node.type == 'Enter':
|
||||
IR_node.attr["dtype"].type = TensorflowParser2.dtype_map[6]
|
||||
else:
|
||||
# print(source_node.layer)
|
||||
# print(source_node.layer.attr['T'].type)
|
||||
assert source_node.layer.attr['T'].type in TensorflowParser2.dtype_map, 'type [{}] is unknown.'.format(source_node.layer.attr['dtype'].type)
|
||||
IR_node.attr["dtype"].type = TensorflowParser2.dtype_map[source_node.layer.attr['T'].type]
|
||||
else:
|
||||
# Quantized model type
|
||||
IR_node.attr["dtype"].type = TensorflowParser2.dtype_map[6]
|
||||
|
||||
if '_output_shapes' in source_node.layer.attr:
|
||||
|
@ -448,7 +439,6 @@ class TensorflowParser2(Parser):
|
|||
assert False
|
||||
|
||||
|
||||
|
||||
def _get_bias(self, source_node, IR_node):
|
||||
if not source_node.out_edges:
|
||||
return
|
||||
|
@ -509,7 +499,8 @@ class TensorflowParser2(Parser):
|
|||
|
||||
|
||||
def rename_Merge(self, source_node):
|
||||
# print(source_node.layer)
|
||||
# In facenet or other newtwork using slim.batch_norm,
|
||||
# There are two BN(train, test) skip switch and merge.
|
||||
source_node.real_name = self.src_graph.get_node(source_node.in_edges[0]).real_name
|
||||
|
||||
|
||||
|
@ -561,7 +552,6 @@ class TensorflowParser2(Parser):
|
|||
self.set_weight(source_node.name, 'mean', mean)
|
||||
|
||||
def rename_Placeholder(self, source_node):
|
||||
# print(source_node.layer)
|
||||
if source_node.layer.attr["shape"].shape.unknown_rank == True:
|
||||
return
|
||||
IR_node = self._convert_identity_operation(source_node, new_op='DataInput')
|
||||
|
@ -570,28 +560,24 @@ class TensorflowParser2(Parser):
|
|||
IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1
|
||||
|
||||
|
||||
def rename_CheckNumerics(self, source_node):
|
||||
return
|
||||
|
||||
def rename_Mean(self, source_node):
|
||||
# ReduceMean
|
||||
IR_node = self._convert_identity_operation(source_node, new_op='ReduceMean')
|
||||
IR_node = self._convert_identity_operation(source_node, start_idx = 0, end_idx = 1, new_op='ReduceMean')
|
||||
# keep dims
|
||||
IR_node.attr['keepdims'].b = source_node.layer.attr['keep_dims'].b
|
||||
|
||||
|
||||
# axes
|
||||
axes = self.get_parent(source_node.name, [1]).layer.attr['value'].tensor
|
||||
axes = tensor_util.MakeNdarray(axes)
|
||||
IR_node.attr['axes'].list.i.extend(axes)
|
||||
|
||||
|
||||
|
||||
def rename_Reshape(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx = 1)
|
||||
kwargs = {'shape' : self.tensor_shape_to_list(source_node.get_attr('_output_shapes'))[0]}
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_MirrorPad(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'MirrorPad')
|
||||
input_node = self.src_graph.get_parent(source_node.name, [1])
|
||||
|
@ -603,6 +589,7 @@ class TensorflowParser2(Parser):
|
|||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_Min(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, start_idx=0, end_idx=1, new_op = 'Min')
|
||||
kwargs = {}
|
||||
|
@ -628,21 +615,22 @@ class TensorflowParser2(Parser):
|
|||
def rename_Mul(self, source_node):
|
||||
scopes = self._get_scopes(source_node.name)
|
||||
|
||||
if scopes[-2] == "batchnorm" or scopes[-2].startswith("Assign"):
|
||||
return
|
||||
if len(scopes) >= 2:
|
||||
if scopes[-2] == "batchnorm" or scopes[-2].startswith("Assign"):
|
||||
return
|
||||
|
||||
input_node = self.check_const(self.src_graph.get_parent(source_node.name, [1]))
|
||||
if input_node:
|
||||
tensor_content = input_node.get_attr('value')
|
||||
IR_node = self._convert_identity_operation(source_node, start_idx=0, end_idx=1, new_op='Mul')
|
||||
else:
|
||||
input_node = self.check_const(self.src_graph.get_parent(source_node.name, [1]))
|
||||
if input_node:
|
||||
tensor_content = input_node.get_attr('value')
|
||||
IR_node = self._convert_identity_operation(source_node, start_idx=0, end_idx=1, new_op='Mul')
|
||||
else:
|
||||
input_node = self.check_const(self.src_graph.get_parent(source_node.name, [0]))
|
||||
tensor_content = input_node.get_attr('value')
|
||||
IR_node = self._convert_identity_operation(source_node, start_idx=1, end_idx=2, new_op='Mul')
|
||||
input_node = self.check_const(self.src_graph.get_parent(source_node.name, [0]))
|
||||
tensor_content = input_node.get_attr('value')
|
||||
IR_node = self._convert_identity_operation(source_node, start_idx=1, end_idx=2, new_op='Mul')
|
||||
|
||||
W = tensor_util.MakeNdarray(tensor_content)
|
||||
W = tensor_util.MakeNdarray(tensor_content)
|
||||
|
||||
self.set_weight(source_node.name, 'weights', W)
|
||||
self.set_weight(source_node.name, 'weights', W)
|
||||
|
||||
|
||||
def rename_Add(self, source_node):
|
||||
|
@ -665,8 +653,9 @@ class TensorflowParser2(Parser):
|
|||
|
||||
def rename_Sub(self, source_node):
|
||||
scopes = self._get_scopes(source_node.name)
|
||||
if scopes[-2].startswith('Assign') or scopes[-1].startswith('Assign'):
|
||||
return
|
||||
if len(scopes) > 2:
|
||||
if scopes[-2].startswith('Assign') or scopes[-1].startswith('Assign'):
|
||||
return
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = "Sub")
|
||||
|
||||
|
||||
|
@ -703,6 +692,7 @@ class TensorflowParser2(Parser):
|
|||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_Sigmoid(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node)
|
||||
|
||||
|
@ -729,11 +719,10 @@ class TensorflowParser2(Parser):
|
|||
IR_node = self._convert_identity_operation(source_node, new_op = 'Enter')
|
||||
|
||||
def rename_Switch(self, source_node):
|
||||
# skip the node
|
||||
# Skip the node as merge
|
||||
source_node.real_name = self.src_graph.get_node(source_node.in_edges[0]).real_name
|
||||
|
||||
|
||||
|
||||
def rename_Exp(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Exp')
|
||||
|
||||
|
@ -776,36 +765,16 @@ class TensorflowParser2(Parser):
|
|||
IR_node = self._convert_identity_operation(source_node, new_op = 'Squeeze')
|
||||
|
||||
|
||||
# def rename_Pack(self, source_node):
|
||||
# IR_node = self._convert_identity_operation(source_node, new_op = 'Pack')
|
||||
|
||||
|
||||
def rename_Gather(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Gather')
|
||||
# input_node = self.src_graph.get_parent(source_node.name, [0])
|
||||
|
||||
input_node_range = self.src_graph.get_parent(source_node.name, [1])
|
||||
# print(input_node.layer)
|
||||
# print(input_node_range.layer)
|
||||
kwargs = {}
|
||||
kwargs['shape'] = self.tensor_shape_to_list(input_node_range.get_attr('_output_shapes'))[0]
|
||||
|
||||
# input_node_indices = self.src_graph.get_parent(source_node.name, [1])
|
||||
|
||||
# print(source_node.layer)
|
||||
# print(input_node_indices.layer)
|
||||
# input1 = self.src_graph.get_parent(input_node_indices.name, [1])
|
||||
# print(input1.layer)
|
||||
# indice_value = input_node_indices.get_attr('value')
|
||||
# shapes = tensor_util.MakeNdarray(indice_value)
|
||||
# c = shapes.tolist()
|
||||
# kwargs['gather_indices'] = c
|
||||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
# print(IR_node)
|
||||
|
||||
|
||||
def rename_StridedSlice(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'Slice')
|
||||
kwargs = {}
|
||||
kwargs = {
|
||||
|
@ -826,23 +795,7 @@ class TensorflowParser2(Parser):
|
|||
strides = tensor_util.MakeNdarray(strides).tolist()
|
||||
kwargs['strides'] = strides
|
||||
|
||||
|
||||
# print(kwargs)
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
# assert False
|
||||
|
||||
|
||||
# def rename_ExpandDims(self, source_node):
|
||||
|
||||
# IR_node = self._convert_identity_operation(source_node, new_op = 'ExpandDims')
|
||||
# input_node = self.src_graph.get_parent(source_node.name, [0])
|
||||
# kwargs = {}
|
||||
# kwargs['shape'] = self.tensor_shape_to_list(input_node.get_attr('_output_shapes'))[0]
|
||||
|
||||
# input_node_indices = self.src_graph.get_parent(source_node.name, [1])
|
||||
|
||||
# kwargs['exp_dim'] = input_node_indices.get_attr('value').int_val[0]
|
||||
# assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_ResizeNearestNeighbor(self, source_node):
|
||||
|
@ -856,6 +809,7 @@ class TensorflowParser2(Parser):
|
|||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_Conv2D(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'Conv')
|
||||
kwargs = {}
|
||||
|
@ -887,7 +841,6 @@ class TensorflowParser2(Parser):
|
|||
|
||||
|
||||
def rename_MaxPool(self, source_node):
|
||||
# print(source_node.layer)
|
||||
self._convert_pooling(source_node, b'MAX')
|
||||
|
||||
|
||||
|
@ -953,11 +906,6 @@ class TensorflowParser2(Parser):
|
|||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_Identity(self, source_node):
|
||||
# skip the node
|
||||
source_node.real_name = self.src_graph.get_node(source_node.in_edges[0]).real_name
|
||||
|
||||
|
||||
def rename_BiasAdd(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx = 1, new_op = "Add")
|
||||
|
||||
|
@ -966,23 +914,25 @@ class TensorflowParser2(Parser):
|
|||
IR_node = self._convert_identity_operation(source_node, new_op = 'QuantizeV2')
|
||||
TensorflowParser2._copy_shape(source_node, IR_node)
|
||||
|
||||
|
||||
def rename_QuantizedRelu(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = "QuantizedRelu")
|
||||
kwargs = {'shape' : self.tensor_shape_to_list(source_node.get_attr('_output_shapes'))[0]}
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_QuantizedReshape(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx = 1)
|
||||
kwargs = {'shape' : self.tensor_shape_to_list(source_node.get_attr('_output_shapes'))[0]}
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_QuantizedConv2D(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'QConv')
|
||||
kwargs = {}
|
||||
kwargs['strides'] = source_node.get_attr('strides')
|
||||
kwargs['padding'] = source_node.get_attr('padding')
|
||||
|
||||
|
||||
# weights
|
||||
input_node = self.src_graph.get_parent(source_node.name, [1])
|
||||
tensor_content = input_node.get_attr('value')
|
||||
|
@ -991,34 +941,21 @@ class TensorflowParser2(Parser):
|
|||
|
||||
kwargs['kernel_shape'] = self.tensor_shape_to_list(input_node.get_attr('_output_shapes'))[0]
|
||||
|
||||
|
||||
input_node_minw = self.src_graph.get_parent(source_node.name, [4])
|
||||
min_W = input_node_minw.get_attr('value').float_val[0]
|
||||
|
||||
input_node_maxw = self.src_graph.get_parent(source_node.name, [5])
|
||||
max_W = input_node_maxw.get_attr('value').float_val[0]
|
||||
|
||||
if source_node.get_attr('Tfilter') == tensorflow.quint8:
|
||||
W = ((max_W - min_W)/255.0) * W + min_W
|
||||
|
||||
else:
|
||||
assert False, ('Only uint8 weights handled currently by the converter')
|
||||
|
||||
self.set_weight(source_node.name, 'kernel_weights', W)
|
||||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
# def rename_Dequantize(self, source_node):
|
||||
# IR_node = self._convert_identity_operation(source_node,start_idx=0, end_idx= 1, new_op = 'Dequantize')
|
||||
# kwargs = {}
|
||||
# input_node = self.src_graph.get_parent(source_node.name, [0])
|
||||
# kwargs['shape'] = self.tensor_shape_to_list(input_node.get_attr('_output_shapes'))[0]
|
||||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
def rename_Requantize(self, source_node):
|
||||
# print(source_node.layer)
|
||||
input_node = self.get_parent(source_node.name, [0])
|
||||
son_node = self.get_son(source_node.name, [0])
|
||||
|
||||
|
@ -1030,30 +967,17 @@ class TensorflowParser2(Parser):
|
|||
|
||||
|
||||
def rename_ZerosLike(self, source_node):
|
||||
# print(source_node.layer)
|
||||
# assert False
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'ZerosLike')
|
||||
|
||||
|
||||
def rename_Rank(self, source_node):
|
||||
# print(source_node.layer)
|
||||
# assert False
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Rank')
|
||||
|
||||
|
||||
def rename_Transpose(self, source_node):
|
||||
# print(source_node.layer)
|
||||
# assert False
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Transpose')
|
||||
|
||||
|
||||
def rename_Where(self, source_node):
|
||||
# print(source_node.layer)
|
||||
# assert False
|
||||
# IR_node = self._convert_identity_operation(source_node, new_op = 'Where')
|
||||
return
|
||||
|
||||
|
||||
def rename_GreaterEqual(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'GreaterEqual')
|
||||
|
||||
|
@ -1073,31 +997,30 @@ class TensorflowParser2(Parser):
|
|||
def rename_LogicalAnd(self, source_node):
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Mul')
|
||||
|
||||
|
||||
def rename_Pad(self, source_node):
|
||||
# print(source_node.layer)
|
||||
# assert False
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'Pad')
|
||||
kwargs = {}
|
||||
kwargs['mode'] = 'constant'
|
||||
# kwargs['constant_values'] = 0.0
|
||||
|
||||
# # paddings
|
||||
# paddings
|
||||
padding = self.get_parent(source_node.name, [1]).layer.attr['value'].tensor
|
||||
shapes = tensor_util.MakeNdarray(padding)
|
||||
kwargs['pads'] = convert_tf_pad_to_onnx(shapes)
|
||||
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
def rename_FusedBatchNorm(self, source_node):
|
||||
|
||||
def rename_FusedBatchNorm(self, source_node):
|
||||
scalenode = self.check_const(self.get_parent(source_node.name, [1], True))
|
||||
if scalenode:
|
||||
scale_value = scalenode.get_attr('value')
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'BatchNorm')
|
||||
|
||||
else:
|
||||
# for slim.batch_norm to remove switch
|
||||
# For models built by slim.batch_norm, remove duplicate BN (eg.facenet)
|
||||
return
|
||||
|
||||
scale = tensor_util.MakeNdarray(scale_value)
|
||||
self.set_weight(source_node.name, 'scale', scale)
|
||||
|
||||
|
@ -1125,26 +1048,22 @@ class TensorflowParser2(Parser):
|
|||
|
||||
|
||||
def rename_SpaceToBatchND(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'SpaceToBatchND')
|
||||
# assert False
|
||||
|
||||
|
||||
def rename_BatchToSpaceND(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'BatchToSpaceND')
|
||||
# assert False
|
||||
|
||||
|
||||
def rename_ArgMax(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'ArgMax')
|
||||
# assert False
|
||||
|
||||
|
||||
def rename_Slice(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Slice')
|
||||
|
||||
|
||||
def rename_Split(self, source_node):
|
||||
# print(source_node.layer)
|
||||
if source_node.get_attr('num_split') == 1:
|
||||
source_node.real_name = self.get_parent(source_node.name, [1]).real_name
|
||||
|
||||
|
@ -1156,10 +1075,10 @@ class TensorflowParser2(Parser):
|
|||
}
|
||||
assign_IRnode_values(IR_node, kwargs)
|
||||
|
||||
|
||||
def rename_Tile(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Tile')
|
||||
|
||||
|
||||
def rename_Sqrt(self, source_node):
|
||||
# print(source_node.layer)
|
||||
IR_node = self._convert_identity_operation(source_node, new_op = 'Sqrt')
|
||||
|
|
Загрузка…
Ссылка в новой задаче