This commit is contained in:
namizzz 2018-08-24 19:18:51 +08:00
Родитель ccd2843678
Коммит 9661b21bcf
5 изменённых файлов: 10 добавлений и 15 удалений

Просмотреть файл

@ -471,14 +471,12 @@ def KitModel(weight_file = None):
def emit_upsample(self, IR_node):
# print(IR_node.layer)
# assert False
def emit_UpSampling2D(self, IR_node):
self.used_layers.add(IR_node.type)
self.add_body(1, "{:<15} = Upsampling2D({}, stride = {}, name = '{}')".format(
IR_node.variable_name,
self.parent_variable_name(IR_node),
IR_node.get_attr('strides'),
IR_node.get_attr('scales')[0],
IR_node.name))

Просмотреть файл

@ -408,7 +408,7 @@ class DarknetGraph(Graph):
upsample_layer['type'] = 'upsample'
upsample_param = OrderedDict()
stride = block['stride']
upsample_param['strides'] = int(stride)
upsample_param['scales'] = [int(stride), int(stride)]
upsample_param['_output_shape'] = [input_shape[0]] + [q*int(stride) for q in input_shape[1:3]] + [input_shape[-1]]
upsample_layer['attr'] = upsample_param
self.layer_map[upsample_layer['name']] = DarknetGraphNode(upsample_layer)

Просмотреть файл

@ -294,10 +294,10 @@ class DarknetParser(Parser):
def rename_upsample(self, source_node):
IR_node = self._convert_identity_operation(source_node, new_op='upsample')
stride = source_node.get_attr('strides')
IR_node = self._convert_identity_operation(source_node, new_op='UpSampling2D')
scales = source_node.get_attr('scales')
kwargs = {}
kwargs['strides'] = stride
kwargs['scales'] = scales
assign_IRnode_values(IR_node, kwargs)

Просмотреть файл

@ -552,10 +552,7 @@ class Keras2Parser(Parser):
self.convert_inedge(source_node, IR_node)
# size
IR_node.attr["size"].list.i.extend(source_node.keras_layer.size)
IR_node.attr["scales"].list.i.extend(source_node.keras_layer.size)
def rename_Embedding(self, source_node):

Просмотреть файл

@ -301,11 +301,11 @@ def KitModel(weight_file = None):
def emit_UpSampling2D(self, IR_node):
size = IR_node.get_attr('size')
size = tuple(size)
scales = IR_node.get_attr('scales')
scales = tuple(scales)
self.add_body(1, "{:<15} = tf.keras.layers.UpSampling2D(size={})({})".format(
IR_node.variable_name,
size,
scales,
self.parent_variable_name(IR_node)))