applying --verbose to hide dubug info
This commit is contained in:
Родитель
3d9a505d55
Коммит
448475797c
67
demo.py
67
demo.py
|
@ -2,8 +2,9 @@
|
|||
# Licensed under the MIT license.
|
||||
from nn_meter.utils.utils import try_import_torchvision_models
|
||||
from nn_meter import load_predictor_config, load_latency_predictors
|
||||
import yaml
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
def test_ir_graphs(args, predictor):
|
||||
|
@ -16,7 +17,7 @@ def test_ir_graphs(args, predictor):
|
|||
models = glob("data/testmodels/**.json")
|
||||
for model in models:
|
||||
latency = predictor.predict(model)
|
||||
print(model.split("/")[-1], latency)
|
||||
logging.debug(os.path.basename(model), latency)
|
||||
|
||||
|
||||
def test_pb_models(args, predictor):
|
||||
|
@ -29,7 +30,7 @@ def test_pb_models(args, predictor):
|
|||
models = glob("data/testmodels/**.pb")
|
||||
for model in models:
|
||||
latency = predictor.predict(model)
|
||||
print(model.split("/")[-1], latency)
|
||||
logging.debug(os.path.basename(model), latency)
|
||||
|
||||
|
||||
def test_onnx_models(args, predictor):
|
||||
|
@ -42,7 +43,7 @@ def test_onnx_models(args, predictor):
|
|||
models = glob("data/testmodels/**.onnx")
|
||||
for model in models:
|
||||
latency = predictor.predict(model)
|
||||
print(model.split("/")[-1], latency)
|
||||
logging.debug(os.path.basename(model), latency)
|
||||
|
||||
|
||||
def test_pytorch_models(args, predictor):
|
||||
|
@ -73,27 +74,21 @@ def test_pytorch_models(args, predictor):
|
|||
models.append(resnext50_32x4d)
|
||||
models.append(wide_resnet50_2)
|
||||
models.append(mnasnet)
|
||||
print("start to test")
|
||||
logging.debug("start to test")
|
||||
for model in models:
|
||||
latency = predictor.predict(
|
||||
model, model_type="torch", input_shape=(1, 3, 224, 224)
|
||||
)
|
||||
print(model.__class__.__name__, latency)
|
||||
logging.debug(model.__class__.__name__, latency)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("predict model latency on device")
|
||||
parser.add_argument(
|
||||
"--input_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to input model. ONNX, FrozenPB or JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictor",
|
||||
type=str,
|
||||
required=True,
|
||||
help="name of target predictor (hardware)"
|
||||
help="name of target predictor (hardware)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictor-version",
|
||||
|
@ -107,9 +102,51 @@ if __name__ == "__main__":
|
|||
default="nn_meter/configs/predictors.yaml",
|
||||
help="config file to store current supported edge platform",
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group() # Jiahang: can't handle model_type == "torch" now.
|
||||
group.add_argument(
|
||||
"--tensorflow",
|
||||
type=str,
|
||||
# required=True,
|
||||
help="Path to input Tensorflow model (*.pb)"
|
||||
)
|
||||
group.add_argument(
|
||||
"--onnx",
|
||||
type=str,
|
||||
# required=True,
|
||||
help="Path to input ONNX model (*.onnx)"
|
||||
)
|
||||
group.add_argument(
|
||||
"--nn-meter-ir",
|
||||
type=str,
|
||||
# required=True,
|
||||
help="Path to input nn-Meter IR model (*.json)"
|
||||
)
|
||||
group.add_argument(
|
||||
"--nni-ir",
|
||||
type=str,
|
||||
# required=True,
|
||||
help="Path to input NNI IR model (*.json)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
help="increase output verbosity",
|
||||
action="store_true"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
pred_info = load_predictor_config(args.config, args.predictor, args.predictor_version)
|
||||
predictor = load_latency_predictors(pred_info)
|
||||
latency = predictor.predict(args.input_model)
|
||||
print('predict latency', latency)
|
||||
if args.tensorflow:
|
||||
input_model, model_type = args.tensorflow, "pb"
|
||||
elif args.onnx:
|
||||
input_model, model_type = args.onnx, "onnx"
|
||||
elif args.nn_meter_ir:
|
||||
input_model, model_type = args.nn_meter_ir, "json"
|
||||
elif args.nni_ir:
|
||||
input_model, model_type = args.nni_ir, "json"
|
||||
|
||||
latency = predictor.predict(input_model, model_type)
|
||||
logging.info('predict latency: %f' % latency)
|
||||
|
|
|
@ -89,13 +89,13 @@ class FrozenPbParser:
|
|||
graph[graph_node]["attr"]["type"] == "Split"
|
||||
and ":" not in graph_node
|
||||
):
|
||||
logging.info("Find split main node %s." % graph_node)
|
||||
logging.debug("Find split main node %s." % graph_node)
|
||||
split_node_name = graph_node
|
||||
for node_name in graph.keys():
|
||||
idx = re.findall(r"%s:(\d+)" % split_node_name, node_name)
|
||||
if len(idx) > 0:
|
||||
idx = int(idx[0])
|
||||
logging.info("Find split child node %s." % node_name)
|
||||
logging.debug("Find split child node %s." % node_name)
|
||||
graph[graph_node]["outbounds"] += graph[node_name][
|
||||
"outbounds"
|
||||
]
|
||||
|
@ -194,7 +194,7 @@ class FrozenPbParser:
|
|||
attr_as_node[node.op]["node_name"](node.name), target_node.name
|
||||
)
|
||||
if len(node_attr) > 0:
|
||||
logging.info("Find regex matching node %s" % node.name)
|
||||
logging.debug("Find regex matching node %s" % node.name)
|
||||
for attr_name in target_node.attr.keys():
|
||||
if (
|
||||
attr_name == "value"
|
||||
|
|
|
@ -82,7 +82,7 @@ class ProtobufHelper:
|
|||
weight_op in graph.keys()
|
||||
and graph[weight_op]["attr"]["type"] != "Identity"
|
||||
):
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Find node %s with its weight op %s."
|
||||
% (node["attr"]["name"], weight_op)
|
||||
)
|
||||
|
|
|
@ -69,7 +69,7 @@ class ShapeInference:
|
|||
Padding type, now support SAME and VALID.
|
||||
"""
|
||||
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Calculating padding shape, input shape: %s, kernel size: %s, strides: %s, padding: %s."
|
||||
% (str(input_shape), str(k_size), str(strides), str(padding))
|
||||
)
|
||||
|
@ -156,7 +156,7 @@ class ShapeInference:
|
|||
node : dict
|
||||
The node in Graph IR in dict format.
|
||||
"""
|
||||
logging.info("Propogate through op %s.", node["attr"]["name"])
|
||||
logging.debug("Propogate through op %s.", node["attr"]["name"])
|
||||
in_shape = [graphe[node["inbounds"][0]]["attr"]["output_shape"][0]]
|
||||
return in_shape, in_shape
|
||||
|
||||
|
@ -279,13 +279,13 @@ class ShapeInference:
|
|||
"""
|
||||
if len(node["inbounds"]) != 1:
|
||||
logging.warning("Failed to get input node of %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
input_shape = copy.deepcopy(
|
||||
graphe[node["inbounds"][0]]["attr"]["output_shape"][0]
|
||||
)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get input shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], node["inbounds"][0], input_shape)
|
||||
)
|
||||
|
@ -297,12 +297,12 @@ class ShapeInference:
|
|||
"Invalid strides %s of node %s."
|
||||
% (str(node["attr"]["attr"]["strides"]), node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
strides = node["attr"]["attr"]["strides"]
|
||||
padding = node["attr"]["attr"]["padding"].decode("utf-8")
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Op:%s, stride:%s, padding:%s."
|
||||
% (node["attr"]["name"], str(strides), str(padding))
|
||||
)
|
||||
|
@ -407,18 +407,18 @@ class ShapeInference:
|
|||
weight_node = ph.find_weights_root(graphe, node)
|
||||
if len(weight_node) != 1:
|
||||
logging.warning("Failed to get shape of node %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
input_node = [x for x in node["inbounds"] if x != weight_node]
|
||||
input_node = [x for x in input_node if graphe[x]["attr"]["type"] != "Identity"]
|
||||
if len(input_node) != 1:
|
||||
logging.warning("Failed to get input node of %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
input_shape = copy.deepcopy(graphe[input_node[0]]["attr"]["output_shape"][0])
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get input shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], input_node[0], input_shape)
|
||||
)
|
||||
|
@ -429,10 +429,10 @@ class ShapeInference:
|
|||
"Failed to parse weight shape %s of node %s."
|
||||
% (str(weight_shape), node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get weight shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], weight_node, weight_shape)
|
||||
)
|
||||
|
@ -445,13 +445,13 @@ class ShapeInference:
|
|||
"Invalid strides %s of node %s."
|
||||
% (str(node["attr"]["attr"]["strides"]), node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
strides = node["attr"]["attr"]["strides"]
|
||||
dilation = node["attr"]["attr"]["dilations"]
|
||||
padding = node["attr"]["attr"]["padding"].decode("utf-8")
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Op:%s, stride:%s, dilation:%s, padding:%s."
|
||||
% (node["attr"]["name"], str(strides), str(dilation), str(padding))
|
||||
)
|
||||
|
@ -490,18 +490,18 @@ class ShapeInference:
|
|||
weight_node = ph.find_weights_root(graphe, node)
|
||||
if len(weight_node) != 1:
|
||||
logging.warning("Failed to get shape of node %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
input_node = [x for x in node["inbounds"] if x != weight_node]
|
||||
input_node = [x for x in input_node if graphe[x]["attr"]["type"] != "Identity"]
|
||||
if len(input_node) != 1:
|
||||
logging.warning("Failed to get input node of %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
input_shape = copy.deepcopy(graphe[input_node[0]]["attr"]["output_shape"][0])
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get input shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], input_node[0], input_shape)
|
||||
)
|
||||
|
@ -512,10 +512,10 @@ class ShapeInference:
|
|||
"Failed to parse weight shape %s of node %s."
|
||||
% (str(weight_shape), node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get weight shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], weight_node, weight_shape)
|
||||
)
|
||||
|
@ -528,14 +528,14 @@ class ShapeInference:
|
|||
"Invalid strides %s of node %s."
|
||||
% (str(node["attr"]["attr"]["strides"]), node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
strides = node["attr"]["attr"]["strides"]
|
||||
dilation = node["attr"]["attr"]["dilations"]
|
||||
padding = node["attr"]["attr"]["padding"].decode("utf-8")
|
||||
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Op:%s, stride:%s, dilation:%s, padding:%s."
|
||||
% (node["attr"]["name"], str(strides), str(dilation), str(padding))
|
||||
)
|
||||
|
@ -573,7 +573,7 @@ class ShapeInference:
|
|||
"""
|
||||
input_shape = graphe[node["inbounds"][0]]["attr"]["output_shape"][0]
|
||||
output_shape = input_shape
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get input shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], node["inbounds"][0], output_shape)
|
||||
)
|
||||
|
@ -582,7 +582,7 @@ class ShapeInference:
|
|||
output_shape[2] = 0
|
||||
|
||||
reduction_indices = node["attr"]["attr"]["reduction_indices"]
|
||||
logging.info("Get Reduction Indices %s.", str(reduction_indices))
|
||||
logging.debug("Get Reduction Indices %s.", str(reduction_indices))
|
||||
|
||||
reduction_cnt = 0
|
||||
for reduction in sorted(reduction_indices):
|
||||
|
@ -648,7 +648,7 @@ class ShapeInference:
|
|||
weight_node = ph.find_weights_root(graphe, node)
|
||||
if len(weight_node) != 1:
|
||||
logging.warning("Failed to get shape of node %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
weight_shape = graphe[weight_node[0]]["attr"]["attr"]["tensor_shape"]
|
||||
|
@ -657,10 +657,10 @@ class ShapeInference:
|
|||
"Failed to parse weight shape %s of node %s."
|
||||
% (str(weight_shape), node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get weight shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], weight_node, weight_shape)
|
||||
)
|
||||
|
@ -669,11 +669,11 @@ class ShapeInference:
|
|||
input_node = [x for x in input_node if graphe[x]["attr"]["type"] != "Identity"]
|
||||
if len(input_node) != 1:
|
||||
logging.warning("Failed to get input node of %s." % (node["attr"]["name"]))
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
input_shape = copy.deepcopy(graphe[input_node[0]]["attr"]["output_shape"][0])
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get input shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], input_node[0], input_shape)
|
||||
)
|
||||
|
@ -683,7 +683,7 @@ class ShapeInference:
|
|||
"Weight shape and input shape not matched for %s."
|
||||
% (node["attr"]["name"])
|
||||
)
|
||||
logging.info(node)
|
||||
logging.debug(node)
|
||||
return
|
||||
|
||||
output_shape = copy.deepcopy(input_shape)
|
||||
|
@ -707,7 +707,7 @@ class ShapeInference:
|
|||
The node in Graph IR in dict format.
|
||||
"""
|
||||
if "shape" in node["attr"]["attr"].keys():
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Shape attr find in %s op, propogate with normal.", node["attr"]["name"]
|
||||
)
|
||||
input_shape = copy.deepcopy(
|
||||
|
@ -715,7 +715,7 @@ class ShapeInference:
|
|||
)
|
||||
exp_output_shape = copy.deepcopy(node["attr"]["attr"]["shape"])
|
||||
else:
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Shape attr not find in %s op, try finding the shape node.",
|
||||
node["attr"]["name"],
|
||||
)
|
||||
|
@ -730,7 +730,7 @@ class ShapeInference:
|
|||
for sl in graphe[in_node]["attr"]["attr"]["constant"]
|
||||
for it in sl
|
||||
]
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Fetched expected output shape from Pack op %s"
|
||||
% str(exp_output_shape)
|
||||
)
|
||||
|
@ -768,7 +768,7 @@ class ShapeInference:
|
|||
in_shape = graphe[in_node]["attr"]["output_shape"][0]
|
||||
if in_shape != []:
|
||||
input_shape.append(in_shape)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Get input shape of %s from %s, input shape:%s."
|
||||
% (node["attr"]["name"], in_node, input_shape[-1])
|
||||
)
|
||||
|
@ -830,7 +830,7 @@ class ShapeInference:
|
|||
input_shape = copy.deepcopy(graphe[in_node]["attr"]["output_shape"][0])
|
||||
|
||||
split_dim = node["attr"]["attr"]["split_dim"][0]
|
||||
logging.info("Fetched Split dim for %s is %s.", node["attr"]["name"], split_dim)
|
||||
logging.debug("Fetched Split dim for %s is %s.", node["attr"]["name"], split_dim)
|
||||
output_node_cnt = len(node["outbounds"])
|
||||
|
||||
output_shape = copy.deepcopy(input_shape)
|
||||
|
@ -854,17 +854,17 @@ class ShapeInference:
|
|||
for in_node in node["inbounds"]:
|
||||
if graphe[in_node]["attr"]["type"] == "Const":
|
||||
perm = copy.deepcopy(graphe[in_node]["attr"]["attr"]["constant"])
|
||||
logging.info("Fetched perm sequence from Const op %s" % str(perm))
|
||||
logging.debug("Fetched perm sequence from Const op %s" % str(perm))
|
||||
elif graphe[in_node]["attr"]["type"] == "Pack":
|
||||
perm = [1] + [
|
||||
it
|
||||
for sl in graphe[in_node]["attr"]["attr"]["constant"]
|
||||
for it in sl
|
||||
]
|
||||
logging.info("Fetched perm sequence from Pack op %s" % str(perm))
|
||||
logging.debug("Fetched perm sequence from Pack op %s" % str(perm))
|
||||
else:
|
||||
input_shape = copy.deepcopy(graphe[in_node]["attr"]["output_shape"][0])
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Fetched input shape from %s, %s" % (in_node, str(input_shape))
|
||||
)
|
||||
|
||||
|
@ -946,17 +946,17 @@ class ShapeInference:
|
|||
)
|
||||
if input_shape is not None:
|
||||
graph[node_name]["attr"]["input_shape"] = copy.deepcopy(input_shape)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Input shape of %s op is %s." % (node_name, str(input_shape))
|
||||
)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Output shape of %s op is %s." % (node_name, str(output_shape))
|
||||
)
|
||||
else:
|
||||
logging.error("%s not support yet." % graphe.get_node_type(node_name))
|
||||
logging.info("------ node content --------")
|
||||
logging.info(graph[node_name])
|
||||
logging.info("----------------------------")
|
||||
logging.debug("------ node content --------")
|
||||
logging.debug(graph[node_name])
|
||||
logging.debug("----------------------------")
|
||||
|
||||
# Pass #2
|
||||
# This is a patching for back-end, since backend extract shapes from
|
||||
|
@ -973,11 +973,11 @@ class ShapeInference:
|
|||
)
|
||||
if input_shape is not None:
|
||||
graph[node_name]["attr"]["input_shape"] = copy.deepcopy(input_shape)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Second Pass: Input shape of %s op is %s."
|
||||
% (node_name, str(input_shape))
|
||||
)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Second Pass: Output shape of %s op is %s."
|
||||
% (node_name, str(output_shape))
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ from nn_meter.kerneldetection.rulelib.rule_splitter import RuleSplitter
|
|||
from nn_meter.utils.graphe_tool import Graphe
|
||||
from nn_meter.kerneldetection.utils.constants import DUMMY_TYPES
|
||||
from nn_meter.kerneldetection.utils.ir_tools import convert_nodes
|
||||
# import logging
|
||||
|
||||
|
||||
class KernelDetector:
|
||||
|
@ -33,7 +34,7 @@ class KernelDetector:
|
|||
|
||||
def _bb_to_kernel(self, bb):
|
||||
types = [self.graph.get_node_type(node) for node in bb]
|
||||
# print(types)
|
||||
# logging.debug(types)
|
||||
types = [t for t in types if t and t not in DUMMY_TYPES]
|
||||
|
||||
if types:
|
||||
|
|
|
@ -11,6 +11,7 @@ import argparse
|
|||
import pkg_resources
|
||||
from shutil import copyfile
|
||||
from packaging import version
|
||||
import logging
|
||||
|
||||
|
||||
__user_config_folder__ = os.path.expanduser('~/.nn_meter/config')
|
||||
|
@ -40,7 +41,7 @@ def list_latency_predictors():
|
|||
with open(fn_pred) as fp:
|
||||
return yaml.load(fp, yaml.FullLoader)
|
||||
except FileNotFoundError:
|
||||
print(f"config file {fn_pred} not found, created")
|
||||
logging.debug(f"config file {fn_pred} not found, created")
|
||||
create_user_configs()
|
||||
return list_latency_predictors()
|
||||
|
||||
|
@ -60,7 +61,7 @@ def load_predictor_config(config, predictor, predictor_version):
|
|||
if version.parse(preds_info[i]['version']) > latest_version:
|
||||
latest_version = version.parse(preds_info[i]['version'])
|
||||
latest_version_idx = i
|
||||
print(f'WARNING: There are multiple version for {predictor}, use the latest one ({str(latest_version)})')
|
||||
logging.warning(f'There are multiple version for {predictor}, use the latest one ({str(latest_version)})')
|
||||
return preds_info[latest_version_idx]
|
||||
else:
|
||||
raise NotImplementedError('No predictor that meet the required version, please try again.')
|
||||
|
@ -84,7 +85,7 @@ class nnMeter:
|
|||
graph = model_file_to_graph(model, model_type)
|
||||
else:
|
||||
graph = model_to_graph(model, model_type, input_shape=input_shape)
|
||||
# print(graph)
|
||||
# logging.debug(graph)
|
||||
self.kd.load_graph(graph)
|
||||
|
||||
py = nn_predict(self.kernel_predictors, self.kd.kernels)
|
||||
|
@ -120,12 +121,12 @@ def nn_meter_cli():
|
|||
|
||||
if args.list_predictors:
|
||||
preds = list_latency_predictors()
|
||||
print("Supported latency predictors:")
|
||||
logging.info("Supported latency predictors:")
|
||||
for p in preds:
|
||||
print(f"{p['name']}: version={p['version']}")
|
||||
logging.info(f"{p['name']}: version={p['version']}")
|
||||
return
|
||||
|
||||
pred_info = load_predictor_config(args.config, args.predictor, args.predictor_version)
|
||||
predictor = load_latency_predictors(pred_info)
|
||||
latency = predictor.predict(args.input_model)
|
||||
print('predict latency', latency)
|
||||
logging.info('predict latency', latency)
|
||||
|
|
|
@ -4,6 +4,7 @@ from glob import glob
|
|||
from zipfile import ZipFile
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import logging
|
||||
|
||||
|
||||
def loading_to_local(pred_info, dir="data/predictorzoo"):
|
||||
|
@ -29,11 +30,11 @@ def loading_to_local(pred_info, dir="data/predictorzoo"):
|
|||
for p in ps:
|
||||
pname = os.path.basename(p).replace(".pkl", "")
|
||||
with open(p, "rb") as f:
|
||||
print("load predictor", p)
|
||||
logging.debug("load predictor %s" % p)
|
||||
model = pickle.load(f)
|
||||
predictors[pname] = model
|
||||
fusionrule = os.path.join(ppath, "rule_" + hardware + ".json")
|
||||
print(fusionrule)
|
||||
logging.debug(fusionrule)
|
||||
if not os.path.isfile(fusionrule):
|
||||
raise ValueError(
|
||||
"check your fusion rule path, file " + fusionrule + " does not exist!"
|
||||
|
@ -54,7 +55,7 @@ def download_from_url(urladdr, ppath):
|
|||
if not os.path.isdir(ppath):
|
||||
os.makedirs(ppath)
|
||||
|
||||
print("download from " + urladdr)
|
||||
logging.debug("download from " + urladdr)
|
||||
response = requests.get(urladdr, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 2048 # 2 Kibibyte
|
||||
|
@ -76,7 +77,7 @@ def check_predictors(ppath, kernel_predictors):
|
|||
|
||||
model: a pytorch/onnx/tensorflow model object or a str containing path to the model file
|
||||
"""
|
||||
print("checking local kernel predictors at " + ppath)
|
||||
logging.debug("checking local kernel predictors at " + ppath)
|
||||
if os.path.isdir(ppath):
|
||||
filenames = glob(os.path.join(ppath, "**.pkl"))
|
||||
# check if all the pkl files are included
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
import logging
|
||||
|
||||
def get_accuracy(y_pred, y_true, threshold=0.01):
|
||||
a = (y_true - y_pred) / y_true
|
||||
|
@ -57,7 +57,7 @@ def get_predict_features(config):
|
|||
mdicts = {}
|
||||
layer = 0
|
||||
for item in config:
|
||||
print(item)
|
||||
logging.debug(item)
|
||||
for item in config:
|
||||
op = item["op"]
|
||||
if "conv" in op or "maxpool" in op or "avgpool" in op:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import copy
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
|
@ -136,7 +137,7 @@ class Graphe:
|
|||
if name in self.graph.keys() and "attr" in self.graph[name].keys():
|
||||
return self.graph[name]["attr"]["type"]
|
||||
else:
|
||||
print(name, self.graph[name])
|
||||
logging.debug(name, self.graph[name])
|
||||
return None
|
||||
|
||||
def get_root_node(self, subgraph):
|
||||
|
|
|
@ -5,6 +5,7 @@ from zipfile import ZipFile
|
|||
from tqdm import tqdm
|
||||
import requests
|
||||
from packaging import version
|
||||
import logging
|
||||
|
||||
|
||||
def download_from_url(urladdr, ppath):
|
||||
|
@ -20,7 +21,7 @@ def download_from_url(urladdr, ppath):
|
|||
if not os.path.isdir(ppath):
|
||||
os.makedirs(ppath)
|
||||
|
||||
print("download from " + urladdr)
|
||||
logging.info("download from " + urladdr)
|
||||
response = requests.get(urladdr, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 2048 # 2 Kibibyte
|
||||
|
@ -39,10 +40,10 @@ def try_import_onnx(require_version = "1.9.0"):
|
|||
try:
|
||||
import onnx
|
||||
if version.parse(onnx.__version__) != version.parse(require_version):
|
||||
print(f'WARNING: onnx=={onnx.__version__} is not well tested now, well tested version: onnx=={require_version}' )
|
||||
logging.warning(f'onnx=={onnx.__version__} is not well tested now, well tested version: onnx=={require_version}' )
|
||||
return onnx
|
||||
except ImportError:
|
||||
print(f'You have not install the onnx package, please install onnx=={require_version} and try again.')
|
||||
logging.error(f'You have not install the onnx package, please install onnx=={require_version} and try again.')
|
||||
exit()
|
||||
|
||||
|
||||
|
@ -50,10 +51,10 @@ def try_import_torch(require_version = "1.8.1"):
|
|||
try:
|
||||
import torch
|
||||
if version.parse(torch.__version__) != version.parse(require_version):
|
||||
print(f'WARNING: torch=={torch.__version__} is not well tested now, well tested version: torch=={require_version}' )
|
||||
logging.warning(f'torch=={torch.__version__} is not well tested now, well tested version: torch=={require_version}' )
|
||||
return torch
|
||||
except ImportError:
|
||||
print(f'You have not install the torch package, please install torch=={require_version} and try again.')
|
||||
logging.error(f'You have not install the torch package, please install torch=={require_version} and try again.')
|
||||
exit()
|
||||
|
||||
|
||||
|
@ -61,10 +62,10 @@ def try_import_tensorflow(require_version = "1.9.0"):
|
|||
try:
|
||||
import tensorflow
|
||||
if version.parse(tensorflow.__version__) != version.parse(require_version):
|
||||
print(f'WARNING: tensorflow=={tensorflow.__version__} is not well tested now, well tested version: tensorflow=={require_version}' )
|
||||
logging.warning(f'tensorflow=={tensorflow.__version__} is not well tested now, well tested version: tensorflow=={require_version}' )
|
||||
return tensorflow
|
||||
except ImportError:
|
||||
print(f'You have not install the tensorflow package, please install tensorflow=={require_version} and try again.')
|
||||
logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version} and try again.')
|
||||
exit()
|
||||
|
||||
|
||||
|
@ -73,6 +74,6 @@ def try_import_torchvision_models():
|
|||
import torchvision
|
||||
return torchvision.models
|
||||
except ImportError:
|
||||
print(f'You have not install the torchvision package, please install torchvision and try again.')
|
||||
logging.error(f'You have not install the torchvision package, please install torchvision and try again.')
|
||||
exit()
|
||||
|
Загрузка…
Ссылка в новой задаче