remove networkx in onnx converter

This commit is contained in:
jiahangxu 2021-12-02 08:30:24 -05:00
Родитель b8c864394e
Коммит 977e2e1516
1 изменённых файлов: 31 добавлений и 50 удалений

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import networkx as nx
from itertools import chain
from .utils import get_tensor_shape
from .constants import SLICE_TYPE
@ -31,38 +30,6 @@ class OnnxConverter:
if output_name in self.tensors:
self.tensors[output_name]["inputs"].append(node)
self.G = self.to_networkx()
def to_networkx(self):
G = nx.DiGraph()
sliced_tensors = set()
selected_slice = set()
for node in self.graph.node:
if node.op_type == SLICE_TYPE:
tensor = node.input[0]
if tensor in sliced_tensors:
continue
else:
sliced_tensors.add(tensor)
selected_slice.add(node.name)
G.add_node(node.name, **self.fetch_attrs(node))
for node in self.graph.node:
if node.op_type == SLICE_TYPE and node.name not in selected_slice:
continue
for input_name in node.input:
if input_name in self.tensors: # remove dummy ops
G.add_edge(input_name, node.name)
for output_name in node.output:
if output_name in self.tensors:
G.add_edge(node.name, output_name)
if node.op_type == SLICE_TYPE:
for tensor_name in self._get_sibling_slice_output_tensors(node):
G.add_edge(node.name, tensor_name)
return G
def fetch_attrs(self, node):
from onnx import AttributeProto
attrs = {}
@ -106,26 +73,40 @@ class OnnxConverter:
def convert(self):
result = {}
for node in self.G.nodes:
node_attrs = self.G.nodes[node]
if node in self.tensors or not node_attrs:
continue
sliced_tensors = set()
selected_slice = set()
for node in self.graph.node:
outbounds = []
inbounds = []
for succ in self.G.successors(node):
for succ_succ in self.G.successors(succ):
outbounds.append(succ_succ)
for pred in self.G.predecessors(node):
for pred_pred in self.G.predecessors(pred):
inbounds.append(pred_pred)
if node.op_type == SLICE_TYPE:
tensor = node.input[0]
if tensor in sliced_tensors:
continue
else:
sliced_tensors.add(tensor)
selected_slice.add(node.name)
result[node] = {
"attr": node_attrs,
"outbounds": outbounds,
"inbounds": inbounds,
}
if node.op_type == SLICE_TYPE and node.name not in selected_slice:
continue
for input_name in node.input:
if input_name in self.tensors: # remove dummy ops
for pred_pred in self.tensors[input_name]['inputs']:
inbounds.append(pred_pred.name)
for output_name in node.output:
if output_name in self.tensors:
for succ_succ in self.tensors[output_name]['outputs']:
outbounds.append(succ_succ.name)
if node.op_type == SLICE_TYPE:
for tensor_name in self._get_sibling_slice_output_tensors(node):
outbounds.append(tensor_name)
result[node.name] = {
"attr": self.fetch_attrs(node),
"outbounds": outbounds,
"inbounds": inbounds,
}
return result