This commit is contained in:
jiahangxu 2023-02-06 22:15:58 -05:00
Родитель 9d80233cb3
Коммит ca4e92e086
1 изменённых файлов: 29 добавлений и 29 удалений

Просмотреть файл

@ -52,7 +52,7 @@ class ShapeInference:
"Relu6",
"Selu",
"LeakyReLU",
"Elu"
"Elu",
"Softmax",
"NoOp"
@ -68,7 +68,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
input_nodes = node["inbounds"]
@ -170,7 +170,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return [], [node["attr"]["attr"]["tensor_shape"]]
@ -185,7 +185,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return [], [graph[node["inbounds"][0]]["attr"]["output_shape"][0]]
@ -200,7 +200,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
in_shape = [graph[node["inbounds"][0]]["attr"]["output_shape"][0]]
@ -220,7 +220,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Pad_get_shape(graph, node)
@ -235,7 +235,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
logging.info("Propagate through op %s.", node["attr"]["name"])
@ -251,7 +251,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
if len(node["inbounds"]) != 1:
@ -307,7 +307,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Pool_get_shape(graph, node)
@ -321,7 +321,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Pool_get_shape(graph, node)
@ -335,7 +335,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Pool_get_shape(graph, node)
@ -349,7 +349,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Pool_get_shape(graph, node)
@ -363,7 +363,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Pool_get_shape(graph, node)
@ -378,7 +378,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return [], [node["attr"]["attr"]["shape"]]
@ -392,7 +392,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
weight_node = ph.find_weights_root(graph, node)
@ -475,7 +475,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
weight_node = ph.find_weights_root(graph, node)
@ -559,7 +559,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
input_shape = graph[node["inbounds"][0]]["attr"]["output_shape"][0]
@ -591,7 +591,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Reduce_get_shape(graph, node)
@ -605,7 +605,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Reduce_get_shape(graph, node)
@ -619,7 +619,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Reduce_get_shape(graph, node)
@ -633,7 +633,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
weight_node = ph.find_weights_root(graph, node)
@ -694,7 +694,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
if "shape" in node["attr"]["attr"].keys():
@ -751,7 +751,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
input_shape = []
@ -780,7 +780,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Concat_get_shape(graph, node)
@ -794,7 +794,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
return ShapeInference.Concat_get_shape(graph, node)
@ -809,7 +809,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
for in_node in node["inbounds"]:
@ -839,7 +839,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
for in_node in node["inbounds"]:
@ -875,7 +875,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
seq = ph.get_graph_seq(graph, [node["attr"]["name"]])[:5]
@ -898,7 +898,7 @@ class ShapeInference:
----------
graph : dict
The Graph IR in dict format.
node : dict
node : dict
The node in Graph IR in dict format.
"""
seq = ph.get_graph_seq(graph, [node["attr"]["name"]])[:5]