зеркало из https://github.com/microsoft/MMdnn.git
Родитель
9199dcbb4c
Коммит
7cd0670725
|
@ -129,11 +129,12 @@ class PytorchGraph(Graph):
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_id = PytorchGraph.get_node_id(node)
|
node_id = PytorchGraph.get_node_id(node)
|
||||||
node_name = self.rename_nodes(node, node_id)
|
node_name = self.rename_nodes(node, node_id)
|
||||||
output_shape_str = re.findall(r'[^()!]+', node.__str__())[1]
|
output_str = node.__str__().split('=')[0]
|
||||||
if '%' in output_shape_str:
|
output_shape_str = re.findall(r'[^()!]+', output_str)
|
||||||
out_put_shape = None
|
if len(output_shape_str) > 1:
|
||||||
|
output_shape = [int(x.replace('!', '')) for x in output_shape_str[1].split(',')]
|
||||||
else:
|
else:
|
||||||
output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')]
|
output_shape = None
|
||||||
self.shape_dict[node_name] = output_shape
|
self.shape_dict[node_name] = output_shape
|
||||||
self.layer_map[node_name] = self.CreateGraphNode(node)
|
self.layer_map[node_name] = self.CreateGraphNode(node)
|
||||||
self.layer_name_map[node_name] = node_name
|
self.layer_name_map[node_name] = node_name
|
||||||
|
|
Загрузка…
Ссылка в новой задаче