Merge branch 'main' into dev/add-error-info
This commit is contained in:
Коммит
8623ab7e92
|
@ -16,7 +16,7 @@ class OnnxConverter:
|
|||
self.graph = inferred_model.graph
|
||||
|
||||
self.tensors = {}
|
||||
for tensor in chain(self.graph.input, self.graph.value_info, self.graph.output):
|
||||
for tensor in chain(self.graph.input, self.graph.value_info, self.graph.initializer, self.graph.output):
|
||||
self.tensors[tensor.name] = {
|
||||
"shape": get_tensor_shape(tensor),
|
||||
"inputs": [],
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
# Licensed under the MIT license.
|
||||
def get_tensor_shape(tensor):
|
||||
shape = []
|
||||
for dim in tensor.type.tensor_type.shape.dim:
|
||||
shape.append(dim.dim_value)
|
||||
try:
|
||||
for dim in tensor.type.tensor_type.shape.dim:
|
||||
shape.append(dim.dim_value)
|
||||
except AttributeError:
|
||||
# initializer
|
||||
shape += tensor.dims
|
||||
if len(shape) == 4:
|
||||
shape = [shape[0], shape[2], shape[3], shape[1]]
|
||||
return shape
|
||||
|
|
|
@ -28,6 +28,9 @@ class ModelGraph:
|
|||
self.graph[node]["outbounds"].append(name)
|
||||
|
||||
def refresh(self):
|
||||
if len(self.graph) <= 1:
|
||||
return
|
||||
|
||||
last_remove_nodes_cnt = -1
|
||||
while True:
|
||||
for name in self.graph.keys():
|
||||
|
|
Загрузка…
Ссылка в новой задаче