remove networkx in onnx converter
This commit is contained in:
Родитель
b8c864394e
Коммит
977e2e1516
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import logging
|
||||
import networkx as nx
|
||||
from itertools import chain
|
||||
from .utils import get_tensor_shape
|
||||
from .constants import SLICE_TYPE
|
||||
|
@ -31,38 +30,6 @@ class OnnxConverter:
|
|||
if output_name in self.tensors:
|
||||
self.tensors[output_name]["inputs"].append(node)
|
||||
|
||||
self.G = self.to_networkx()
|
||||
|
||||
def to_networkx(self):
|
||||
G = nx.DiGraph()
|
||||
|
||||
sliced_tensors = set()
|
||||
selected_slice = set()
|
||||
for node in self.graph.node:
|
||||
if node.op_type == SLICE_TYPE:
|
||||
tensor = node.input[0]
|
||||
if tensor in sliced_tensors:
|
||||
continue
|
||||
else:
|
||||
sliced_tensors.add(tensor)
|
||||
selected_slice.add(node.name)
|
||||
G.add_node(node.name, **self.fetch_attrs(node))
|
||||
|
||||
for node in self.graph.node:
|
||||
if node.op_type == SLICE_TYPE and node.name not in selected_slice:
|
||||
continue
|
||||
for input_name in node.input:
|
||||
if input_name in self.tensors: # remove dummy ops
|
||||
G.add_edge(input_name, node.name)
|
||||
for output_name in node.output:
|
||||
if output_name in self.tensors:
|
||||
G.add_edge(node.name, output_name)
|
||||
if node.op_type == SLICE_TYPE:
|
||||
for tensor_name in self._get_sibling_slice_output_tensors(node):
|
||||
G.add_edge(node.name, tensor_name)
|
||||
|
||||
return G
|
||||
|
||||
def fetch_attrs(self, node):
|
||||
from onnx import AttributeProto
|
||||
attrs = {}
|
||||
|
@ -106,26 +73,40 @@ class OnnxConverter:
|
|||
|
||||
def convert(self):
|
||||
result = {}
|
||||
|
||||
for node in self.G.nodes:
|
||||
node_attrs = self.G.nodes[node]
|
||||
if node in self.tensors or not node_attrs:
|
||||
continue
|
||||
|
||||
|
||||
sliced_tensors = set()
|
||||
selected_slice = set()
|
||||
for node in self.graph.node:
|
||||
outbounds = []
|
||||
inbounds = []
|
||||
for succ in self.G.successors(node):
|
||||
for succ_succ in self.G.successors(succ):
|
||||
outbounds.append(succ_succ)
|
||||
for pred in self.G.predecessors(node):
|
||||
for pred_pred in self.G.predecessors(pred):
|
||||
inbounds.append(pred_pred)
|
||||
|
||||
if node.op_type == SLICE_TYPE:
|
||||
tensor = node.input[0]
|
||||
if tensor in sliced_tensors:
|
||||
continue
|
||||
else:
|
||||
sliced_tensors.add(tensor)
|
||||
selected_slice.add(node.name)
|
||||
|
||||
result[node] = {
|
||||
"attr": node_attrs,
|
||||
"outbounds": outbounds,
|
||||
"inbounds": inbounds,
|
||||
}
|
||||
if node.op_type == SLICE_TYPE and node.name not in selected_slice:
|
||||
continue
|
||||
|
||||
for input_name in node.input:
|
||||
if input_name in self.tensors: # remove dummy ops
|
||||
for pred_pred in self.tensors[input_name]['inputs']:
|
||||
inbounds.append(pred_pred.name)
|
||||
for output_name in node.output:
|
||||
if output_name in self.tensors:
|
||||
for succ_succ in self.tensors[output_name]['outputs']:
|
||||
outbounds.append(succ_succ.name)
|
||||
if node.op_type == SLICE_TYPE:
|
||||
for tensor_name in self._get_sibling_slice_output_tensors(node):
|
||||
outbounds.append(tensor_name)
|
||||
result[node.name] = {
|
||||
"attr": self.fetch_attrs(node),
|
||||
"outbounds": outbounds,
|
||||
"inbounds": inbounds,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче