From 448475797cc75fd29715e4c3a4303a39e35e6a00 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 21 Jul 2021 17:34:17 +0800 Subject: [PATCH] applying --verbose to hide dubug info --- demo.py | 67 +++++++++++---- .../frozenpb_converter/frozenpb_parser.py | 6 +- .../frozenpb_converter/protobuf_helper.py | 2 +- .../frozenpb_converter/shape_inference.py | 86 +++++++++---------- .../kerneldetection/detection/detector.py | 3 +- nn_meter/nn_meter.py | 13 +-- nn_meter/prediction/load_predictors.py | 9 +- .../prediction/predictors/extract_feature.py | 4 +- nn_meter/utils/graphe_tool.py | 3 +- nn_meter/utils/utils.py | 17 ++-- 10 files changed, 126 insertions(+), 84 deletions(-) diff --git a/demo.py b/demo.py index 76d8d3f..188af21 100644 --- a/demo.py +++ b/demo.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. from nn_meter.utils.utils import try_import_torchvision_models from nn_meter import load_predictor_config, load_latency_predictors -import yaml import argparse +import os +import logging def test_ir_graphs(args, predictor): @@ -16,7 +17,7 @@ def test_ir_graphs(args, predictor): models = glob("data/testmodels/**.json") for model in models: latency = predictor.predict(model) - print(model.split("/")[-1], latency) + logging.debug(os.path.basename(model), latency) def test_pb_models(args, predictor): @@ -29,7 +30,7 @@ def test_pb_models(args, predictor): models = glob("data/testmodels/**.pb") for model in models: latency = predictor.predict(model) - print(model.split("/")[-1], latency) + logging.debug(os.path.basename(model), latency) def test_onnx_models(args, predictor): @@ -42,7 +43,7 @@ def test_onnx_models(args, predictor): models = glob("data/testmodels/**.onnx") for model in models: latency = predictor.predict(model) - print(model.split("/")[-1], latency) + logging.debug(os.path.basename(model), latency) def test_pytorch_models(args, predictor): @@ -73,27 +74,21 @@ def test_pytorch_models(args, predictor): models.append(resnext50_32x4d) models.append(wide_resnet50_2) models.append(mnasnet) - print("start to test") + logging.debug("start to test") for model in models: latency = predictor.predict( model, model_type="torch", input_shape=(1, 3, 224, 224) ) - print(model.__class__.__name__, latency) + logging.debug(model.__class__.__name__, latency) if __name__ == "__main__": parser = argparse.ArgumentParser("predict model latency on device") - parser.add_argument( - "--input_model", - type=str, - required=True, - help="Path to input model. ONNX, FrozenPB or JSON", - ) parser.add_argument( "--predictor", type=str, required=True, - help="name of target predictor (hardware)" + help="name of target predictor (hardware)", ) parser.add_argument( "--predictor-version", @@ -107,9 +102,51 @@ if __name__ == "__main__": default="nn_meter/configs/predictors.yaml", help="config file to store current supported edge platform", ) + group = parser.add_mutually_exclusive_group() # Jiahang: can't handle model_type == "torch" now. + group.add_argument( + "--tensorflow", + type=str, + # required=True, + help="Path to input Tensorflow model (*.pb)" + ) + group.add_argument( + "--onnx", + type=str, + # required=True, + help="Path to input ONNX model (*.onnx)" + ) + group.add_argument( + "--nn-meter-ir", + type=str, + # required=True, + help="Path to input nn-Meter IR model (*.json)" + ) + group.add_argument( + "--nni-ir", + type=str, + # required=True, + help="Path to input NNI IR model (*.json)" + ) + parser.add_argument( + "-v", "--verbose", + help="increase output verbosity", + action="store_true" + ) args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + pred_info = load_predictor_config(args.config, args.predictor, args.predictor_version) predictor = load_latency_predictors(pred_info) - latency = predictor.predict(args.input_model) - print('predict latency', latency) \ No newline at end of file + if args.tensorflow: + input_model, model_type = args.tensorflow, "pb" + elif args.onnx: + input_model, model_type = args.onnx, "onnx" + elif args.nn_meter_ir: + input_model, model_type = args.nn_meter_ir, "json" + elif args.nni_ir: + input_model, model_type = args.nni_ir, "json" + + latency = predictor.predict(input_model, model_type) + logging.info('predict latency: %f' % latency) diff --git a/nn_meter/ir_converters/frozenpb_converter/frozenpb_parser.py b/nn_meter/ir_converters/frozenpb_converter/frozenpb_parser.py index 59bf4ae..ff00bf4 100644 --- a/nn_meter/ir_converters/frozenpb_converter/frozenpb_parser.py +++ b/nn_meter/ir_converters/frozenpb_converter/frozenpb_parser.py @@ -89,13 +89,13 @@ class FrozenPbParser: graph[graph_node]["attr"]["type"] == "Split" and ":" not in graph_node ): - logging.info("Find split main node %s." % graph_node) + logging.debug("Find split main node %s." % graph_node) split_node_name = graph_node for node_name in graph.keys(): idx = re.findall(r"%s:(\d+)" % split_node_name, node_name) if len(idx) > 0: idx = int(idx[0]) - logging.info("Find split child node %s." % node_name) + logging.debug("Find split child node %s." % node_name) graph[graph_node]["outbounds"] += graph[node_name][ "outbounds" ] @@ -194,7 +194,7 @@ class FrozenPbParser: attr_as_node[node.op]["node_name"](node.name), target_node.name ) if len(node_attr) > 0: - logging.info("Find regex matching node %s" % node.name) + logging.debug("Find regex matching node %s" % node.name) for attr_name in target_node.attr.keys(): if ( attr_name == "value" diff --git a/nn_meter/ir_converters/frozenpb_converter/protobuf_helper.py b/nn_meter/ir_converters/frozenpb_converter/protobuf_helper.py index 105809e..525584f 100644 --- a/nn_meter/ir_converters/frozenpb_converter/protobuf_helper.py +++ b/nn_meter/ir_converters/frozenpb_converter/protobuf_helper.py @@ -82,7 +82,7 @@ class ProtobufHelper: weight_op in graph.keys() and graph[weight_op]["attr"]["type"] != "Identity" ): - logging.info( + logging.debug( "Find node %s with its weight op %s." % (node["attr"]["name"], weight_op) ) diff --git a/nn_meter/ir_converters/frozenpb_converter/shape_inference.py b/nn_meter/ir_converters/frozenpb_converter/shape_inference.py index 4122e84..c7c866c 100644 --- a/nn_meter/ir_converters/frozenpb_converter/shape_inference.py +++ b/nn_meter/ir_converters/frozenpb_converter/shape_inference.py @@ -69,7 +69,7 @@ class ShapeInference: Padding type, now support SAME and VALID. """ - logging.info( + logging.debug( "Calculating padding shape, input shape: %s, kernel size: %s, strides: %s, padding: %s." % (str(input_shape), str(k_size), str(strides), str(padding)) ) @@ -156,7 +156,7 @@ class ShapeInference: node : dict The node in Graph IR in dict format. """ - logging.info("Propogate through op %s.", node["attr"]["name"]) + logging.debug("Propogate through op %s.", node["attr"]["name"]) in_shape = [graphe[node["inbounds"][0]]["attr"]["output_shape"][0]] return in_shape, in_shape @@ -279,13 +279,13 @@ class ShapeInference: """ if len(node["inbounds"]) != 1: logging.warning("Failed to get input node of %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return input_shape = copy.deepcopy( graphe[node["inbounds"][0]]["attr"]["output_shape"][0] ) - logging.info( + logging.debug( "Get input shape of %s from %s, input shape:%s." % (node["attr"]["name"], node["inbounds"][0], input_shape) ) @@ -297,12 +297,12 @@ class ShapeInference: "Invalid strides %s of node %s." % (str(node["attr"]["attr"]["strides"]), node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return strides = node["attr"]["attr"]["strides"] padding = node["attr"]["attr"]["padding"].decode("utf-8") - logging.info( + logging.debug( "Op:%s, stride:%s, padding:%s." % (node["attr"]["name"], str(strides), str(padding)) ) @@ -407,18 +407,18 @@ class ShapeInference: weight_node = ph.find_weights_root(graphe, node) if len(weight_node) != 1: logging.warning("Failed to get shape of node %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return input_node = [x for x in node["inbounds"] if x != weight_node] input_node = [x for x in input_node if graphe[x]["attr"]["type"] != "Identity"] if len(input_node) != 1: logging.warning("Failed to get input node of %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return input_shape = copy.deepcopy(graphe[input_node[0]]["attr"]["output_shape"][0]) - logging.info( + logging.debug( "Get input shape of %s from %s, input shape:%s." % (node["attr"]["name"], input_node[0], input_shape) ) @@ -429,10 +429,10 @@ class ShapeInference: "Failed to parse weight shape %s of node %s." % (str(weight_shape), node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return - logging.info( + logging.debug( "Get weight shape of %s from %s, input shape:%s." % (node["attr"]["name"], weight_node, weight_shape) ) @@ -445,13 +445,13 @@ class ShapeInference: "Invalid strides %s of node %s." % (str(node["attr"]["attr"]["strides"]), node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return strides = node["attr"]["attr"]["strides"] dilation = node["attr"]["attr"]["dilations"] padding = node["attr"]["attr"]["padding"].decode("utf-8") - logging.info( + logging.debug( "Op:%s, stride:%s, dilation:%s, padding:%s." % (node["attr"]["name"], str(strides), str(dilation), str(padding)) ) @@ -490,18 +490,18 @@ class ShapeInference: weight_node = ph.find_weights_root(graphe, node) if len(weight_node) != 1: logging.warning("Failed to get shape of node %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return input_node = [x for x in node["inbounds"] if x != weight_node] input_node = [x for x in input_node if graphe[x]["attr"]["type"] != "Identity"] if len(input_node) != 1: logging.warning("Failed to get input node of %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return input_shape = copy.deepcopy(graphe[input_node[0]]["attr"]["output_shape"][0]) - logging.info( + logging.debug( "Get input shape of %s from %s, input shape:%s." % (node["attr"]["name"], input_node[0], input_shape) ) @@ -512,10 +512,10 @@ class ShapeInference: "Failed to parse weight shape %s of node %s." % (str(weight_shape), node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return - logging.info( + logging.debug( "Get weight shape of %s from %s, input shape:%s." % (node["attr"]["name"], weight_node, weight_shape) ) @@ -528,14 +528,14 @@ class ShapeInference: "Invalid strides %s of node %s." % (str(node["attr"]["attr"]["strides"]), node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return strides = node["attr"]["attr"]["strides"] dilation = node["attr"]["attr"]["dilations"] padding = node["attr"]["attr"]["padding"].decode("utf-8") - logging.info( + logging.debug( "Op:%s, stride:%s, dilation:%s, padding:%s." % (node["attr"]["name"], str(strides), str(dilation), str(padding)) ) @@ -573,7 +573,7 @@ class ShapeInference: """ input_shape = graphe[node["inbounds"][0]]["attr"]["output_shape"][0] output_shape = input_shape - logging.info( + logging.debug( "Get input shape of %s from %s, input shape:%s." % (node["attr"]["name"], node["inbounds"][0], output_shape) ) @@ -582,7 +582,7 @@ class ShapeInference: output_shape[2] = 0 reduction_indices = node["attr"]["attr"]["reduction_indices"] - logging.info("Get Reduction Indices %s.", str(reduction_indices)) + logging.debug("Get Reduction Indices %s.", str(reduction_indices)) reduction_cnt = 0 for reduction in sorted(reduction_indices): @@ -648,7 +648,7 @@ class ShapeInference: weight_node = ph.find_weights_root(graphe, node) if len(weight_node) != 1: logging.warning("Failed to get shape of node %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return weight_shape = graphe[weight_node[0]]["attr"]["attr"]["tensor_shape"] @@ -657,10 +657,10 @@ class ShapeInference: "Failed to parse weight shape %s of node %s." % (str(weight_shape), node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return - logging.info( + logging.debug( "Get weight shape of %s from %s, input shape:%s." % (node["attr"]["name"], weight_node, weight_shape) ) @@ -669,11 +669,11 @@ class ShapeInference: input_node = [x for x in input_node if graphe[x]["attr"]["type"] != "Identity"] if len(input_node) != 1: logging.warning("Failed to get input node of %s." % (node["attr"]["name"])) - logging.info(node) + logging.debug(node) return input_shape = copy.deepcopy(graphe[input_node[0]]["attr"]["output_shape"][0]) - logging.info( + logging.debug( "Get input shape of %s from %s, input shape:%s." % (node["attr"]["name"], input_node[0], input_shape) ) @@ -683,7 +683,7 @@ class ShapeInference: "Weight shape and input shape not matched for %s." % (node["attr"]["name"]) ) - logging.info(node) + logging.debug(node) return output_shape = copy.deepcopy(input_shape) @@ -707,7 +707,7 @@ class ShapeInference: The node in Graph IR in dict format. """ if "shape" in node["attr"]["attr"].keys(): - logging.info( + logging.debug( "Shape attr find in %s op, propogate with normal.", node["attr"]["name"] ) input_shape = copy.deepcopy( @@ -715,7 +715,7 @@ class ShapeInference: ) exp_output_shape = copy.deepcopy(node["attr"]["attr"]["shape"]) else: - logging.info( + logging.debug( "Shape attr not find in %s op, try finding the shape node.", node["attr"]["name"], ) @@ -730,7 +730,7 @@ class ShapeInference: for sl in graphe[in_node]["attr"]["attr"]["constant"] for it in sl ] - logging.info( + logging.debug( "Fetched expected output shape from Pack op %s" % str(exp_output_shape) ) @@ -768,7 +768,7 @@ class ShapeInference: in_shape = graphe[in_node]["attr"]["output_shape"][0] if in_shape != []: input_shape.append(in_shape) - logging.info( + logging.debug( "Get input shape of %s from %s, input shape:%s." % (node["attr"]["name"], in_node, input_shape[-1]) ) @@ -830,7 +830,7 @@ class ShapeInference: input_shape = copy.deepcopy(graphe[in_node]["attr"]["output_shape"][0]) split_dim = node["attr"]["attr"]["split_dim"][0] - logging.info("Fetched Split dim for %s is %s.", node["attr"]["name"], split_dim) + logging.debug("Fetched Split dim for %s is %s.", node["attr"]["name"], split_dim) output_node_cnt = len(node["outbounds"]) output_shape = copy.deepcopy(input_shape) @@ -854,17 +854,17 @@ class ShapeInference: for in_node in node["inbounds"]: if graphe[in_node]["attr"]["type"] == "Const": perm = copy.deepcopy(graphe[in_node]["attr"]["attr"]["constant"]) - logging.info("Fetched perm sequence from Const op %s" % str(perm)) + logging.debug("Fetched perm sequence from Const op %s" % str(perm)) elif graphe[in_node]["attr"]["type"] == "Pack": perm = [1] + [ it for sl in graphe[in_node]["attr"]["attr"]["constant"] for it in sl ] - logging.info("Fetched perm sequence from Pack op %s" % str(perm)) + logging.debug("Fetched perm sequence from Pack op %s" % str(perm)) else: input_shape = copy.deepcopy(graphe[in_node]["attr"]["output_shape"][0]) - logging.info( + logging.debug( "Fetched input shape from %s, %s" % (in_node, str(input_shape)) ) @@ -946,17 +946,17 @@ class ShapeInference: ) if input_shape is not None: graph[node_name]["attr"]["input_shape"] = copy.deepcopy(input_shape) - logging.info( + logging.debug( "Input shape of %s op is %s." % (node_name, str(input_shape)) ) - logging.info( + logging.debug( "Output shape of %s op is %s." % (node_name, str(output_shape)) ) else: logging.error("%s not support yet." % graphe.get_node_type(node_name)) - logging.info("------ node content --------") - logging.info(graph[node_name]) - logging.info("----------------------------") + logging.debug("------ node content --------") + logging.debug(graph[node_name]) + logging.debug("----------------------------") # Pass #2 # This is a patching for back-end, since backend extract shapes from @@ -973,11 +973,11 @@ class ShapeInference: ) if input_shape is not None: graph[node_name]["attr"]["input_shape"] = copy.deepcopy(input_shape) - logging.info( + logging.debug( "Second Pass: Input shape of %s op is %s." % (node_name, str(input_shape)) ) - logging.info( + logging.debug( "Second Pass: Output shape of %s op is %s." % (node_name, str(output_shape)) ) diff --git a/nn_meter/kerneldetection/detection/detector.py b/nn_meter/kerneldetection/detection/detector.py index a140c98..fcb8095 100644 --- a/nn_meter/kerneldetection/detection/detector.py +++ b/nn_meter/kerneldetection/detection/detector.py @@ -5,6 +5,7 @@ from nn_meter.kerneldetection.rulelib.rule_splitter import RuleSplitter from nn_meter.utils.graphe_tool import Graphe from nn_meter.kerneldetection.utils.constants import DUMMY_TYPES from nn_meter.kerneldetection.utils.ir_tools import convert_nodes +# import logging class KernelDetector: @@ -33,7 +34,7 @@ class KernelDetector: def _bb_to_kernel(self, bb): types = [self.graph.get_node_type(node) for node in bb] - # print(types) + # logging.debug(types) types = [t for t in types if t and t not in DUMMY_TYPES] if types: diff --git a/nn_meter/nn_meter.py b/nn_meter/nn_meter.py index 5dbc3a5..f170a07 100644 --- a/nn_meter/nn_meter.py +++ b/nn_meter/nn_meter.py @@ -11,6 +11,7 @@ import argparse import pkg_resources from shutil import copyfile from packaging import version +import logging __user_config_folder__ = os.path.expanduser('~/.nn_meter/config') @@ -40,7 +41,7 @@ def list_latency_predictors(): with open(fn_pred) as fp: return yaml.load(fp, yaml.FullLoader) except FileNotFoundError: - print(f"config file {fn_pred} not found, created") + logging.debug(f"config file {fn_pred} not found, created") create_user_configs() return list_latency_predictors() @@ -60,7 +61,7 @@ def load_predictor_config(config, predictor, predictor_version): if version.parse(preds_info[i]['version']) > latest_version: latest_version = version.parse(preds_info[i]['version']) latest_version_idx = i - print(f'WARNING: There are multiple version for {predictor}, use the latest one ({str(latest_version)})') + logging.warning(f'There are multiple version for {predictor}, use the latest one ({str(latest_version)})') return preds_info[latest_version_idx] else: raise NotImplementedError('No predictor that meet the required version, please try again.') @@ -84,7 +85,7 @@ class nnMeter: graph = model_file_to_graph(model, model_type) else: graph = model_to_graph(model, model_type, input_shape=input_shape) - # print(graph) + # logging.debug(graph) self.kd.load_graph(graph) py = nn_predict(self.kernel_predictors, self.kd.kernels) @@ -120,12 +121,12 @@ def nn_meter_cli(): if args.list_predictors: preds = list_latency_predictors() - print("Supported latency predictors:") + logging.info("Supported latency predictors:") for p in preds: - print(f"{p['name']}: version={p['version']}") + logging.info(f"{p['name']}: version={p['version']}") return pred_info = load_predictor_config(args.config, args.predictor, args.predictor_version) predictor = load_latency_predictors(pred_info) latency = predictor.predict(args.input_model) - print('predict latency', latency) + logging.info('predict latency', latency) diff --git a/nn_meter/prediction/load_predictors.py b/nn_meter/prediction/load_predictors.py index 442eb03..ff37a75 100644 --- a/nn_meter/prediction/load_predictors.py +++ b/nn_meter/prediction/load_predictors.py @@ -4,6 +4,7 @@ from glob import glob from zipfile import ZipFile from tqdm import tqdm import requests +import logging def loading_to_local(pred_info, dir="data/predictorzoo"): @@ -29,11 +30,11 @@ def loading_to_local(pred_info, dir="data/predictorzoo"): for p in ps: pname = os.path.basename(p).replace(".pkl", "") with open(p, "rb") as f: - print("load predictor", p) + logging.debug("load predictor %s" % p) model = pickle.load(f) predictors[pname] = model fusionrule = os.path.join(ppath, "rule_" + hardware + ".json") - print(fusionrule) + logging.debug(fusionrule) if not os.path.isfile(fusionrule): raise ValueError( "check your fusion rule path, file " + fusionrule + " does not exist!" @@ -54,7 +55,7 @@ def download_from_url(urladdr, ppath): if not os.path.isdir(ppath): os.makedirs(ppath) - print("download from " + urladdr) + logging.debug("download from " + urladdr) response = requests.get(urladdr, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 2048 # 2 Kibibyte @@ -76,7 +77,7 @@ def check_predictors(ppath, kernel_predictors): model: a pytorch/onnx/tensorflow model object or a str containing path to the model file """ - print("checking local kernel predictors at " + ppath) + logging.debug("checking local kernel predictors at " + ppath) if os.path.isdir(ppath): filenames = glob(os.path.join(ppath, "**.pkl")) # check if all the pkl files are included diff --git a/nn_meter/prediction/predictors/extract_feature.py b/nn_meter/prediction/predictors/extract_feature.py index 6bf75cb..4338612 100644 --- a/nn_meter/prediction/predictors/extract_feature.py +++ b/nn_meter/prediction/predictors/extract_feature.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import numpy as np from sklearn.metrics import mean_squared_error - +import logging def get_accuracy(y_pred, y_true, threshold=0.01): a = (y_true - y_pred) / y_true @@ -57,7 +57,7 @@ def get_predict_features(config): mdicts = {} layer = 0 for item in config: - print(item) + logging.debug(item) for item in config: op = item["op"] if "conv" in op or "maxpool" in op or "avgpool" in op: diff --git a/nn_meter/utils/graphe_tool.py b/nn_meter/utils/graphe_tool.py index fef56d5..203e802 100644 --- a/nn_meter/utils/graphe_tool.py +++ b/nn_meter/utils/graphe_tool.py @@ -3,6 +3,7 @@ import copy import json import numpy as np +import logging class NumpyEncoder(json.JSONEncoder): @@ -136,7 +137,7 @@ class Graphe: if name in self.graph.keys() and "attr" in self.graph[name].keys(): return self.graph[name]["attr"]["type"] else: - print(name, self.graph[name]) + logging.debug(name, self.graph[name]) return None def get_root_node(self, subgraph): diff --git a/nn_meter/utils/utils.py b/nn_meter/utils/utils.py index de5121a..779308a 100644 --- a/nn_meter/utils/utils.py +++ b/nn_meter/utils/utils.py @@ -5,6 +5,7 @@ from zipfile import ZipFile from tqdm import tqdm import requests from packaging import version +import logging def download_from_url(urladdr, ppath): @@ -20,7 +21,7 @@ def download_from_url(urladdr, ppath): if not os.path.isdir(ppath): os.makedirs(ppath) - print("download from " + urladdr) + logging.info("download from " + urladdr) response = requests.get(urladdr, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 2048 # 2 Kibibyte @@ -39,10 +40,10 @@ def try_import_onnx(require_version = "1.9.0"): try: import onnx if version.parse(onnx.__version__) != version.parse(require_version): - print(f'WARNING: onnx=={onnx.__version__} is not well tested now, well tested version: onnx=={require_version}' ) + logging.warning(f'onnx=={onnx.__version__} is not well tested now, well tested version: onnx=={require_version}' ) return onnx except ImportError: - print(f'You have not install the onnx package, please install onnx=={require_version} and try again.') + logging.error(f'You have not install the onnx package, please install onnx=={require_version} and try again.') exit() @@ -50,10 +51,10 @@ def try_import_torch(require_version = "1.8.1"): try: import torch if version.parse(torch.__version__) != version.parse(require_version): - print(f'WARNING: torch=={torch.__version__} is not well tested now, well tested version: torch=={require_version}' ) + logging.warning(f'torch=={torch.__version__} is not well tested now, well tested version: torch=={require_version}' ) return torch except ImportError: - print(f'You have not install the torch package, please install torch=={require_version} and try again.') + logging.error(f'You have not install the torch package, please install torch=={require_version} and try again.') exit() @@ -61,10 +62,10 @@ def try_import_tensorflow(require_version = "1.9.0"): try: import tensorflow if version.parse(tensorflow.__version__) != version.parse(require_version): - print(f'WARNING: tensorflow=={tensorflow.__version__} is not well tested now, well tested version: tensorflow=={require_version}' ) + logging.warning(f'tensorflow=={tensorflow.__version__} is not well tested now, well tested version: tensorflow=={require_version}' ) return tensorflow except ImportError: - print(f'You have not install the tensorflow package, please install tensorflow=={require_version} and try again.') + logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version} and try again.') exit() @@ -73,6 +74,6 @@ def try_import_torchvision_models(): import torchvision return torchvision.models except ImportError: - print(f'You have not install the torchvision package, please install torchvision and try again.') + logging.error(f'You have not install the torchvision package, please install torchvision and try again.') exit() \ No newline at end of file