Add output shape check
This commit is contained in:
XiaoXYe 2020-08-13 00:39:52 +08:00 коммит произвёл GitHub
Родитель 9199dcbb4c
Коммит 7cd0670725
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 5 добавлений и 4 удалений

Просмотреть файл

@ -129,11 +129,12 @@ class PytorchGraph(Graph):
for node in nodes:
node_id = PytorchGraph.get_node_id(node)
node_name = self.rename_nodes(node, node_id)
output_shape_str = re.findall(r'[^()!]+', node.__str__())[1]
if '%' in output_shape_str:
out_put_shape = None
output_str = node.__str__().split('=')[0]
output_shape_str = re.findall(r'[^()!]+', output_str)
if len(output_shape_str) > 1:
output_shape = [int(x.replace('!', '')) for x in output_shape_str[1].split(',')]
else:
output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')]
output_shape = None
self.shape_dict[node_name] = output_shape
self.layer_map[node_name] = self.CreateGraphNode(node)
self.layer_name_map[node_name] = node_name