This commit is contained in:
namizzz 2018-08-24 19:18:51 +08:00
Родитель ccd2843678
Коммит 9661b21bcf
5 изменённых файлов: 10 добавлений и 15 удалений

Просмотреть файл

@ -471,14 +471,12 @@ def KitModel(weight_file = None):
def emit_upsample(self, IR_node): def emit_UpSampling2D(self, IR_node):
# print(IR_node.layer)
# assert False
self.used_layers.add(IR_node.type) self.used_layers.add(IR_node.type)
self.add_body(1, "{:<15} = Upsampling2D({}, stride = {}, name = '{}')".format( self.add_body(1, "{:<15} = Upsampling2D({}, stride = {}, name = '{}')".format(
IR_node.variable_name, IR_node.variable_name,
self.parent_variable_name(IR_node), self.parent_variable_name(IR_node),
IR_node.get_attr('strides'), IR_node.get_attr('scales')[0],
IR_node.name)) IR_node.name))

Просмотреть файл

@ -408,7 +408,7 @@ class DarknetGraph(Graph):
upsample_layer['type'] = 'upsample' upsample_layer['type'] = 'upsample'
upsample_param = OrderedDict() upsample_param = OrderedDict()
stride = block['stride'] stride = block['stride']
upsample_param['strides'] = int(stride) upsample_param['scales'] = [int(stride), int(stride)]
upsample_param['_output_shape'] = [input_shape[0]] + [q*int(stride) for q in input_shape[1:3]] + [input_shape[-1]] upsample_param['_output_shape'] = [input_shape[0]] + [q*int(stride) for q in input_shape[1:3]] + [input_shape[-1]]
upsample_layer['attr'] = upsample_param upsample_layer['attr'] = upsample_param
self.layer_map[upsample_layer['name']] = DarknetGraphNode(upsample_layer) self.layer_map[upsample_layer['name']] = DarknetGraphNode(upsample_layer)

Просмотреть файл

@ -294,10 +294,10 @@ class DarknetParser(Parser):
def rename_upsample(self, source_node): def rename_upsample(self, source_node):
IR_node = self._convert_identity_operation(source_node, new_op='upsample') IR_node = self._convert_identity_operation(source_node, new_op='UpSampling2D')
stride = source_node.get_attr('strides') scales = source_node.get_attr('scales')
kwargs = {} kwargs = {}
kwargs['strides'] = stride kwargs['scales'] = scales
assign_IRnode_values(IR_node, kwargs) assign_IRnode_values(IR_node, kwargs)

Просмотреть файл

@ -552,10 +552,7 @@ class Keras2Parser(Parser):
self.convert_inedge(source_node, IR_node) self.convert_inedge(source_node, IR_node)
# size # size
IR_node.attr["size"].list.i.extend(source_node.keras_layer.size) IR_node.attr["scales"].list.i.extend(source_node.keras_layer.size)
def rename_Embedding(self, source_node): def rename_Embedding(self, source_node):

Просмотреть файл

@ -301,11 +301,11 @@ def KitModel(weight_file = None):
def emit_UpSampling2D(self, IR_node): def emit_UpSampling2D(self, IR_node):
size = IR_node.get_attr('size') scales = IR_node.get_attr('scales')
size = tuple(size) scales = tuple(scales)
self.add_body(1, "{:<15} = tf.keras.layers.UpSampling2D(size={})({})".format( self.add_body(1, "{:<15} = tf.keras.layers.UpSampling2D(size={})({})".format(
IR_node.variable_name, IR_node.variable_name,
size, scales,
self.parent_variable_name(IR_node))) self.parent_variable_name(IR_node)))