Merge branch 'main' into dev/add-error-info

This commit is contained in:
jiahangxu 2023-01-06 04:25:08 -05:00
Родитель aef41ddaa3 a1682dd2d2
Коммит 8623ab7e92
3 изменённых файлов: 10 добавлений и 3 удалений

Просмотреть файл

@ -16,7 +16,7 @@ class OnnxConverter:
self.graph = inferred_model.graph self.graph = inferred_model.graph
self.tensors = {} self.tensors = {}
for tensor in chain(self.graph.input, self.graph.value_info, self.graph.output): for tensor in chain(self.graph.input, self.graph.value_info, self.graph.initializer, self.graph.output):
self.tensors[tensor.name] = { self.tensors[tensor.name] = {
"shape": get_tensor_shape(tensor), "shape": get_tensor_shape(tensor),
"inputs": [], "inputs": [],

Просмотреть файл

@ -2,8 +2,12 @@
# Licensed under the MIT license. # Licensed under the MIT license.
def get_tensor_shape(tensor): def get_tensor_shape(tensor):
shape = [] shape = []
for dim in tensor.type.tensor_type.shape.dim: try:
shape.append(dim.dim_value) for dim in tensor.type.tensor_type.shape.dim:
shape.append(dim.dim_value)
except AttributeError:
# initializer
shape += tensor.dims
if len(shape) == 4: if len(shape) == 4:
shape = [shape[0], shape[2], shape[3], shape[1]] shape = [shape[0], shape[2], shape[3], shape[1]]
return shape return shape

Просмотреть файл

@ -28,6 +28,9 @@ class ModelGraph:
self.graph[node]["outbounds"].append(name) self.graph[node]["outbounds"].append(name)
def refresh(self): def refresh(self):
if len(self.graph) <= 1:
return
last_remove_nodes_cnt = -1 last_remove_nodes_cnt = -1
while True: while True:
for name in self.graph.keys(): for name in self.graph.keys():