Merge branch 'main' into dev/add-error-info
This commit is contained in:
Коммит
8623ab7e92
|
@ -16,7 +16,7 @@ class OnnxConverter:
|
||||||
self.graph = inferred_model.graph
|
self.graph = inferred_model.graph
|
||||||
|
|
||||||
self.tensors = {}
|
self.tensors = {}
|
||||||
for tensor in chain(self.graph.input, self.graph.value_info, self.graph.output):
|
for tensor in chain(self.graph.input, self.graph.value_info, self.graph.initializer, self.graph.output):
|
||||||
self.tensors[tensor.name] = {
|
self.tensors[tensor.name] = {
|
||||||
"shape": get_tensor_shape(tensor),
|
"shape": get_tensor_shape(tensor),
|
||||||
"inputs": [],
|
"inputs": [],
|
||||||
|
|
|
@ -2,8 +2,12 @@
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
def get_tensor_shape(tensor):
|
def get_tensor_shape(tensor):
|
||||||
shape = []
|
shape = []
|
||||||
for dim in tensor.type.tensor_type.shape.dim:
|
try:
|
||||||
shape.append(dim.dim_value)
|
for dim in tensor.type.tensor_type.shape.dim:
|
||||||
|
shape.append(dim.dim_value)
|
||||||
|
except AttributeError:
|
||||||
|
# initializer
|
||||||
|
shape += tensor.dims
|
||||||
if len(shape) == 4:
|
if len(shape) == 4:
|
||||||
shape = [shape[0], shape[2], shape[3], shape[1]]
|
shape = [shape[0], shape[2], shape[3], shape[1]]
|
||||||
return shape
|
return shape
|
||||||
|
|
|
@ -28,6 +28,9 @@ class ModelGraph:
|
||||||
self.graph[node]["outbounds"].append(name)
|
self.graph[node]["outbounds"].append(name)
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
|
if len(self.graph) <= 1:
|
||||||
|
return
|
||||||
|
|
||||||
last_remove_nodes_cnt = -1
|
last_remove_nodes_cnt = -1
|
||||||
while True:
|
while True:
|
||||||
for name in self.graph.keys():
|
for name in self.graph.keys():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче