diff --git a/nn_meter/ir_converter/frozenpb_converter/shape_inference.py b/nn_meter/ir_converter/frozenpb_converter/shape_inference.py index f2e69b6..921c287 100644 --- a/nn_meter/ir_converter/frozenpb_converter/shape_inference.py +++ b/nn_meter/ir_converter/frozenpb_converter/shape_inference.py @@ -52,7 +52,7 @@ class ShapeInference: "Relu6", "Selu", "LeakyReLU", - "Elu" + "Elu", "Softmax", "NoOp" @@ -68,7 +68,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ input_nodes = node["inbounds"] @@ -170,7 +170,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return [], [node["attr"]["attr"]["tensor_shape"]] @@ -185,7 +185,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return [], [graph[node["inbounds"][0]]["attr"]["output_shape"][0]] @@ -200,7 +200,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ in_shape = [graph[node["inbounds"][0]]["attr"]["output_shape"][0]] @@ -220,7 +220,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Pad_get_shape(graph, node) @@ -235,7 +235,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ logging.info("Propagate through op %s.", node["attr"]["name"]) @@ -251,7 +251,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ if len(node["inbounds"]) != 1: @@ -307,7 +307,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Pool_get_shape(graph, node) @@ -321,7 +321,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Pool_get_shape(graph, node) @@ -335,7 +335,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Pool_get_shape(graph, node) @@ -349,7 +349,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Pool_get_shape(graph, node) @@ -363,7 +363,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Pool_get_shape(graph, node) @@ -378,7 +378,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return [], [node["attr"]["attr"]["shape"]] @@ -392,7 +392,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ weight_node = ph.find_weights_root(graph, node) @@ -475,7 +475,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ weight_node = ph.find_weights_root(graph, node) @@ -559,7 +559,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ input_shape = graph[node["inbounds"][0]]["attr"]["output_shape"][0] @@ -591,7 +591,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Reduce_get_shape(graph, node) @@ -605,7 +605,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Reduce_get_shape(graph, node) @@ -619,7 +619,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Reduce_get_shape(graph, node) @@ -633,7 +633,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ weight_node = ph.find_weights_root(graph, node) @@ -694,7 +694,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ if "shape" in node["attr"]["attr"].keys(): @@ -751,7 +751,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ input_shape = [] @@ -780,7 +780,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Concat_get_shape(graph, node) @@ -794,7 +794,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ return ShapeInference.Concat_get_shape(graph, node) @@ -809,7 +809,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ for in_node in node["inbounds"]: @@ -839,7 +839,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ for in_node in node["inbounds"]: @@ -875,7 +875,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ seq = ph.get_graph_seq(graph, [node["attr"]["name"]])[:5] @@ -898,7 +898,7 @@ class ShapeInference: ---------- graph : dict The Graph IR in dict format. - node : dict + node : dict The node in Graph IR in dict format. """ seq = ph.get_graph_seq(graph, [node["attr"]["name"]])[:5]