From 72eee00c082ccfe64c4b357ec246f4144e36c7fe Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Thu, 4 Nov 2021 10:36:03 +0800 Subject: [PATCH] refine package import --- .../hardware-aware-model-design.md | 74 +++++++++---------- docs/{ => predictor}/usage.md | 0 .../requirements} | 0 nn_meter/__init__.py | 22 ++++-- nn_meter/dataset/bench_dataset.py | 9 +-- nn_meter/dataset/gnn_dataloader.py | 10 ++- .../frozenpb_converter/frozenpb_converter.py | 2 +- .../frozenpb_converter/frozenpb_parser.py | 9 ++- .../frozenpb_converter/shape_fetcher.py | 3 +- .../frozenpb_converter/shape_inference.py | 2 +- .../ir_converter/onnx_converter/converter.py | 6 +- nn_meter/ir_converter/onnx_converter/utils.py | 2 - .../ir_converter/torch_converter/converter.py | 6 +- nn_meter/ir_converter/utils.py | 5 +- nn_meter/kernel_detector/__init__.py | 2 +- .../kernel_detector/detection/__init__.py | 2 - nn_meter/kernel_detector/fusionlib/utils.py | 2 +- .../detector.py => kernel_detector.py} | 9 +-- .../kernel_detector/rulelib/rule_reader.py | 2 +- .../kernel_detector/rulelib/rule_splitter.py | 4 +- .../utils/fusion_aware_graph.py | 4 +- nn_meter/nn_meter_cli.py | 8 +- nn_meter/predictor/__init__.py | 5 +- nn_meter/predictor/nn_meter_predictor.py | 14 ++-- .../{predictors => prediction}/__init__.py | 0 .../extract_feature.py | 2 +- .../kernel_predictor.py | 0 .../predict_by_kernel.py | 0 .../{predictors => prediction}/utils.py | 3 +- nn_meter/predictor/utils.py | 2 +- nn_meter/utils/__init__.py | 3 +- nn_meter/utils/config_manager.py | 9 +-- nn_meter/utils/import_package.py | 2 +- nn_meter/utils/utils.py | 2 +- 34 files changed, 110 insertions(+), 115 deletions(-) rename docs/{ => predictor}/hardware-aware-model-design.md (99%) rename docs/{ => predictor}/usage.md (100%) rename docs/{requirements.txt => requirements/requirements} (100%) delete mode 100644 nn_meter/kernel_detector/detection/__init__.py rename nn_meter/kernel_detector/{detection/detector.py => kernel_detector.py} (92%) rename nn_meter/predictor/{predictors => prediction}/__init__.py (100%) rename nn_meter/predictor/{predictors => prediction}/extract_feature.py (100%) rename nn_meter/predictor/{predictors => prediction}/kernel_predictor.py (100%) rename nn_meter/predictor/{predictors => prediction}/predict_by_kernel.py (100%) rename nn_meter/predictor/{predictors => prediction}/utils.py (99%) diff --git a/docs/hardware-aware-model-design.md b/docs/predictor/hardware-aware-model-design.md similarity index 99% rename from docs/hardware-aware-model-design.md rename to docs/predictor/hardware-aware-model-design.md index e858675..efcfbcd 100644 --- a/docs/hardware-aware-model-design.md +++ b/docs/predictor/hardware-aware-model-design.md @@ -1,37 +1,37 @@ -# Hardware-aware DNN Model Design - -In many DNN model deployment scenarios, there are strict inference efficiency constraints as well as the model accuracy. For example, the **inference latency** and **energy consumption** are the most frequently used criteria of efficiencies to determine whether a DNN model could be deployed on a mobile phone or not. Therefore, DNN model designers have to consider the model efficiency. A typical methodology is to train a big model to meet the accuracy requirements first, and then apply model compression algorithms to get a light-weight model with similar accuracy but much smaller size. Due to many reasons, they use the number of parameters and FLOPs in the compression process. - -However, as pointed out in our work [[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) and many others, ***neither number of parameters nor number of FLOPs is a good metric of the real inference efficiency (e.g., latency or energy consumption)***. Operators with similar FLOPs may have very different inference latency on different hardware platforms (e.g., CPU, GPU, and ASIC) (shown in work [[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) and [[3]](https://proceedings.mlsys.org/paper/2021/file/02522a2b2726fb0a03bb19f2d8d9524d-Paper.pdf)). This makes the effort of designing efficient DNN models for a target hardware bit of games of opening blind boxes. Recently, many hardware-aware NAS works are proposed to solve this challenge. - -Compared with the conventional NAS algorithms, some recent works (i.e. hardware-aware NAS, aka HW-NAS) integrated hardware-awareness into the search loop and achieves a balanced trade-off between accuracy and hardware efficiencies [[4]](http://arxiv.org/abs/2101.09336). - -Next, we introduce our hardware-aware NAS framework[[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf), which combines the nn-Meter, to search high-accuracy DNN models within the latency constraints for target edge devices. - -## Hardware-aware Neural Architecture Search - -drawing - -**Hardware-aware Search Space Generation.** As formulated in many works, the search space is one of the three key aspects of a NAS process (the other two are the search strategy and the evaluation methodology) and matters a lot to the final results. - -Our HW-NAS framework firstly automatically selects the hardware-friendly operators (or blocks) by considering both representation capacity and hardware efficiency. The selected operators could establish a ***hardware-aware search space*** for most of existing NAS algorithms. - -**Latency Prediction in search process by nn-Meter.** Different with other simple predictors (e.g., look-up table for operators/blocks, linear regression models), [nn-Meter](overview.md) conducts kernel-level prediction, which captures the complex model graph optimizations on edge devices. nn-Meter is the first accurate latency prediction tool for DNNs on edge devices. - -Besides the search space specialization, our HW-NAS framework also allows combining nn-Meter with existing NAS algorithms in the optimization objectives and constraints. As described in [[4]](http://arxiv.org/abs/2101.09336), the HW-NAS algorithms often consider hardware efficiency metrics as the constraints of existing NAS formulation or part of the scalarized loss functions (e.g., the loss is weighted sum of both cross entropy loss and hardware-aware penalty). Since the NAS process may sample up to millions of candidate model architectures, the obtaining of hardware metrics must be accurate and efficient. - -nn-Meter is now integrated with [NNI](https://github.com/microsoft/nni), the AutoML framework also published by Microsoft, and could be combined with existing NAS algorithms seamlessly. [This doc](https://nni.readthedocs.io/en/stable/NAS/multi_trial_nas.html) show how to construct a latency constraint filter in [random search algorithm](https://arxiv.org/abs/1902.07638) on [SPOS NAS](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf) search space. Users could use this filter in multiple phases of the NAS process, e.g., the architecture searching phase and the super-net training phase. - -***Note that current nn-Meter project is limited to the latency prediction. For the other hardware metrics, e.g., energy consumption is another important metric in edge computing. Collaborations and contributions together with nn-Meter are highly welcomed!*** - -## Other hardware-aware techniques - -Besides light weighted NAS, which search for an efficient architecture directly, there are also other techniques to achieve light weight DNN models, such as model compression and knowledge distillation (KD). Both methods tries to get a smaller but similar-performed models from a pre-trained big model. The difference is that model compression removes some of the components in the origin model, while knowledge distillation constructs a new student model and lets it learn the behavior of the origin model. Hardware awareness could also be combined with these methods. -For example, nn-Meter could help users to construct suitable student architectures for the target hardware platform in the KD task. - -## References - -1. Li Lyna Zhang, Yuqing Yang, Yuhang Jiang, Wenwu Zhu, Yunxin Liu: ["Fast hardware-aware neural architecture search."](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops. 2020. -2. Li Lyna Zhang, Shihao Han, Jianyu Wei, Ningxin Zheng, Ting Cao, Yuqing Yang, Yunxin Liu: ["nn-Meter: Towards Accurate Latency Prediction of Deep-Learning Model Inference on Diverse Edge Devices."](https://dl.acm.org/doi/10.1145/3458864.3467882) Proceedings of the 19th ACM International Conference on Mobile Systems, Applications, and Services (MobiSys 2021) -3. Xiaohu Tang, Shihao Han, Li Lyna Zhang, Ting Cao, Yunxin Liu: ["To Bridge Neural Network Design and Real-World Performance: A Behaviour Study for Neural Networks"](https://proceedings.mlsys.org/paper/2021/file/02522a2b2726fb0a03bb19f2d8d9524d-Paper.pdf) Proceedings of the 4th MLSys Conference (MLSys 2021) -4. Benmeziane, H., Maghraoui, K. el, Ouarnoughi, H., Niar, S., Wistuba, M., & Wang, N. (2021).[" A Comprehensive Survey on Hardware-Aware Neural Architecture Search."](http://arxiv.org/abs/2101.09336) +# Hardware-aware DNN Model Design + +In many DNN model deployment scenarios, there are strict inference efficiency constraints as well as the model accuracy. For example, the **inference latency** and **energy consumption** are the most frequently used criteria of efficiencies to determine whether a DNN model could be deployed on a mobile phone or not. Therefore, DNN model designers have to consider the model efficiency. A typical methodology is to train a big model to meet the accuracy requirements first, and then apply model compression algorithms to get a light-weight model with similar accuracy but much smaller size. Due to many reasons, they use the number of parameters and FLOPs in the compression process. + +However, as pointed out in our work [[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) and many others, ***neither number of parameters nor number of FLOPs is a good metric of the real inference efficiency (e.g., latency or energy consumption)***. Operators with similar FLOPs may have very different inference latency on different hardware platforms (e.g., CPU, GPU, and ASIC) (shown in work [[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) and [[3]](https://proceedings.mlsys.org/paper/2021/file/02522a2b2726fb0a03bb19f2d8d9524d-Paper.pdf)). This makes the effort of designing efficient DNN models for a target hardware bit of games of opening blind boxes. Recently, many hardware-aware NAS works are proposed to solve this challenge. + +Compared with the conventional NAS algorithms, some recent works (i.e. hardware-aware NAS, aka HW-NAS) integrated hardware-awareness into the search loop and achieves a balanced trade-off between accuracy and hardware efficiencies [[4]](http://arxiv.org/abs/2101.09336). + +Next, we introduce our hardware-aware NAS framework[[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf), which combines the nn-Meter, to search high-accuracy DNN models within the latency constraints for target edge devices. + +## Hardware-aware Neural Architecture Search + +drawing + +**Hardware-aware Search Space Generation.** As formulated in many works, the search space is one of the three key aspects of a NAS process (the other two are the search strategy and the evaluation methodology) and matters a lot to the final results. + +Our HW-NAS framework firstly automatically selects the hardware-friendly operators (or blocks) by considering both representation capacity and hardware efficiency. The selected operators could establish a ***hardware-aware search space*** for most of existing NAS algorithms. + +**Latency Prediction in search process by nn-Meter.** Different with other simple predictors (e.g., look-up table for operators/blocks, linear regression models), [nn-Meter](overview.md) conducts kernel-level prediction, which captures the complex model graph optimizations on edge devices. nn-Meter is the first accurate latency prediction tool for DNNs on edge devices. + +Besides the search space specialization, our HW-NAS framework also allows combining nn-Meter with existing NAS algorithms in the optimization objectives and constraints. As described in [[4]](http://arxiv.org/abs/2101.09336), the HW-NAS algorithms often consider hardware efficiency metrics as the constraints of existing NAS formulation or part of the scalarized loss functions (e.g., the loss is weighted sum of both cross entropy loss and hardware-aware penalty). Since the NAS process may sample up to millions of candidate model architectures, the obtaining of hardware metrics must be accurate and efficient. + +nn-Meter is now integrated with [NNI](https://github.com/microsoft/nni), the AutoML framework also published by Microsoft, and could be combined with existing NAS algorithms seamlessly. [This doc](https://nni.readthedocs.io/en/stable/NAS/multi_trial_nas.html) show how to construct a latency constraint filter in [random search algorithm](https://arxiv.org/abs/1902.07638) on [SPOS NAS](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf) search space. Users could use this filter in multiple phases of the NAS process, e.g., the architecture searching phase and the super-net training phase. + +***Note that current nn-Meter project is limited to the latency prediction. For the other hardware metrics, e.g., energy consumption is another important metric in edge computing. Collaborations and contributions together with nn-Meter are highly welcomed!*** + +## Other hardware-aware techniques + +Besides light weighted NAS, which search for an efficient architecture directly, there are also other techniques to achieve light weight DNN models, such as model compression and knowledge distillation (KD). Both methods tries to get a smaller but similar-performed models from a pre-trained big model. The difference is that model compression removes some of the components in the origin model, while knowledge distillation constructs a new student model and lets it learn the behavior of the origin model. Hardware awareness could also be combined with these methods. +For example, nn-Meter could help users to construct suitable student architectures for the target hardware platform in the KD task. + +## References + +1. Li Lyna Zhang, Yuqing Yang, Yuhang Jiang, Wenwu Zhu, Yunxin Liu: ["Fast hardware-aware neural architecture search."](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops. 2020. +2. Li Lyna Zhang, Shihao Han, Jianyu Wei, Ningxin Zheng, Ting Cao, Yuqing Yang, Yunxin Liu: ["nn-Meter: Towards Accurate Latency Prediction of Deep-Learning Model Inference on Diverse Edge Devices."](https://dl.acm.org/doi/10.1145/3458864.3467882) Proceedings of the 19th ACM International Conference on Mobile Systems, Applications, and Services (MobiSys 2021) +3. Xiaohu Tang, Shihao Han, Li Lyna Zhang, Ting Cao, Yunxin Liu: ["To Bridge Neural Network Design and Real-World Performance: A Behaviour Study for Neural Networks"](https://proceedings.mlsys.org/paper/2021/file/02522a2b2726fb0a03bb19f2d8d9524d-Paper.pdf) Proceedings of the 4th MLSys Conference (MLSys 2021) +4. Benmeziane, H., Maghraoui, K. el, Ouarnoughi, H., Niar, S., Wistuba, M., & Wang, N. (2021).[" A Comprehensive Survey on Hardware-Aware Neural Architecture Search."](http://arxiv.org/abs/2101.09336) diff --git a/docs/usage.md b/docs/predictor/usage.md similarity index 100% rename from docs/usage.md rename to docs/predictor/usage.md diff --git a/docs/requirements.txt b/docs/requirements/requirements similarity index 100% rename from docs/requirements.txt rename to docs/requirements/requirements diff --git a/nn_meter/__init__.py b/nn_meter/__init__.py index 2efbcef..48740fb 100644 --- a/nn_meter/__init__.py +++ b/nn_meter/__init__.py @@ -7,20 +7,26 @@ try: except ModuleNotFoundError: __version__ = 'UNKNOWN' -from .nn_meter import ( - nnMeter, +import logging +from functools import partial, partialmethod + +from .predictor import ( + nnMeterPredictor, load_latency_predictor, list_latency_predictors, + latency_metrics +) +from .ir_converter import ( model_file_to_graph, - model_to_graph, + model_to_graph +) +from .utils import ( create_user_configs, change_user_data_folder ) -from .utils.utils import download_from_url -from .predictor import latency_metrics -from .dataset import bench_dataset # TODO: add GNNDataloader and GNNDataset here @wenxuan -import logging -from functools import partial, partialmethod +from .dataset import bench_dataset +from .utils import download_from_url + logging.KEYINFO = 22 logging.addLevelName(logging.KEYINFO, 'KEYINFO') diff --git a/nn_meter/dataset/bench_dataset.py b/nn_meter/dataset/bench_dataset.py index 7e2b224..428fe78 100644 --- a/nn_meter/dataset/bench_dataset.py +++ b/nn_meter/dataset/bench_dataset.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os, sys -from nn_meter.predictor import latency_metrics +import logging +import jsonlines from glob import glob -from nn_meter.nn_meter import list_latency_predictors, load_latency_predictor, get_user_data_folder -from nn_meter import download_from_url -import jsonlines -import logging +from nn_meter.predictor import latency_metrics, list_latency_predictors, load_latency_predictor +from nn_meter.utils import download_from_url, get_user_data_folder __user_dataset_folder__ = os.path.join(get_user_data_folder(), 'dataset') diff --git a/nn_meter/dataset/gnn_dataloader.py b/nn_meter/dataset/gnn_dataloader.py index 44c76ac..33e5bab 100644 --- a/nn_meter/dataset/gnn_dataloader.py +++ b/nn_meter/dataset/gnn_dataloader.py @@ -1,16 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import torch -import jsonlines import os import random +import torch +import jsonlines from .bench_dataset import bench_dataset -from nn_meter.nn_meter import get_user_data_folder -from nn_meter.utils.utils import try_import_dgl +from nn_meter.utils import get_user_data_folder +from nn_meter.utils.import_package import try_import_dgl + RAW_DATA_URL = "https://github.com/microsoft/nn-Meter/releases/download/v1.0-data/datasets.zip" __user_dataset_folder__ = os.path.join(get_user_data_folder(), 'dataset') + hws = [ "cortexA76cpu_tflite21", "adreno640gpu_tflite21", diff --git a/nn_meter/ir_converter/frozenpb_converter/frozenpb_converter.py b/nn_meter/ir_converter/frozenpb_converter/frozenpb_converter.py index aac61a4..3495049 100644 --- a/nn_meter/ir_converter/frozenpb_converter/frozenpb_converter.py +++ b/nn_meter/ir_converter/frozenpb_converter/frozenpb_converter.py @@ -2,10 +2,10 @@ # Licensed under the MIT license. import numpy as np -from nn_meter.utils.graph_tool import ModelGraph from .frozenpb_parser import FrozenPbParser from .shape_inference import ShapeInference from .shape_fetcher import ShapeFetcher +from nn_meter.utils.graph_tool import ModelGraph class FrozenPbConverter: def __init__(self, file_name): diff --git a/nn_meter/ir_converter/frozenpb_converter/frozenpb_parser.py b/nn_meter/ir_converter/frozenpb_converter/frozenpb_parser.py index 913a051..479522c 100644 --- a/nn_meter/ir_converter/frozenpb_converter/frozenpb_parser.py +++ b/nn_meter/ir_converter/frozenpb_converter/frozenpb_parser.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from nn_meter.utils.utils import try_import_tensorflow -from .protobuf_helper import ProtobufHelper -from .shape_fetcher import ShapeFetcher -import copy import re +import copy import logging +from .protobuf_helper import ProtobufHelper +from nn_meter.utils.import_package import try_import_tensorflow + + logging = logging.getLogger(__name__) diff --git a/nn_meter/ir_converter/frozenpb_converter/shape_fetcher.py b/nn_meter/ir_converter/frozenpb_converter/shape_fetcher.py index 9c2f594..fc8caec 100644 --- a/nn_meter/ir_converter/frozenpb_converter/shape_fetcher.py +++ b/nn_meter/ir_converter/frozenpb_converter/shape_fetcher.py @@ -1,9 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from nn_meter.utils.utils import try_import_tensorflow import numpy as np from typing import List - +from nn_meter.utils.utils import try_import_tensorflow class ShapeFetcher: def __init__(self, input_graph): diff --git a/nn_meter/ir_converter/frozenpb_converter/shape_inference.py b/nn_meter/ir_converter/frozenpb_converter/shape_inference.py index f3b3dae..da05d20 100644 --- a/nn_meter/ir_converter/frozenpb_converter/shape_inference.py +++ b/nn_meter/ir_converter/frozenpb_converter/shape_inference.py @@ -1,10 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .protobuf_helper import ProtobufHelper as ph from functools import reduce import copy import math import logging +from .protobuf_helper import ProtobufHelper as ph logging = logging.getLogger(__name__) diff --git a/nn_meter/ir_converter/onnx_converter/converter.py b/nn_meter/ir_converter/onnx_converter/converter.py index d4d214f..04786f8 100644 --- a/nn_meter/ir_converter/onnx_converter/converter.py +++ b/nn_meter/ir_converter/onnx_converter/converter.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from nn_meter.utils.utils import try_import_onnx +import logging import networkx as nx +from itertools import chain from .utils import get_tensor_shape from .constants import SLICE_TYPE -from itertools import chain -import logging +from nn_meter.utils.import_package import try_import_onnx class OnnxConverter: diff --git a/nn_meter/ir_converter/onnx_converter/utils.py b/nn_meter/ir_converter/onnx_converter/utils.py index 9abe045..172cc82 100644 --- a/nn_meter/ir_converter/onnx_converter/utils.py +++ b/nn_meter/ir_converter/onnx_converter/utils.py @@ -1,7 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - - def get_tensor_shape(tensor): shape = [] for dim in tensor.type.tensor_type.shape.dim: diff --git a/nn_meter/ir_converter/torch_converter/converter.py b/nn_meter/ir_converter/torch_converter/converter.py index 1af2f43..f595985 100644 --- a/nn_meter/ir_converter/torch_converter/converter.py +++ b/nn_meter/ir_converter/torch_converter/converter.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from nn_meter.utils.utils import try_import_onnx, try_import_torch, try_import_onnxsim, try_import_nni import tempfile -from nn_meter.ir_converter.onnx_converter import OnnxConverter - - +from ..onnx_converter import OnnxConverter from .opset_map import nni_attr_map, nni_type_map +from nn_meter.utils.import_package import try_import_onnx, try_import_torch, try_import_onnxsim, try_import_nni def _nchw_to_nhwc(shapes): diff --git a/nn_meter/ir_converter/utils.py b/nn_meter/ir_converter/utils.py index d435835..de6b952 100644 --- a/nn_meter/ir_converter/utils.py +++ b/nn_meter/ir_converter/utils.py @@ -2,11 +2,12 @@ # Licensed under the MIT license. import json import logging -from nn_meter.utils.utils import try_import_onnx, try_import_torch, try_import_torchvision_models +from nn_meter.utils.import_package import try_import_onnx, try_import_torch, try_import_torchvision_models from .onnx_converter import OnnxConverter from .frozenpb_converter import FrozenPbConverter from .torch_converter import NNIBasedTorchConverter, OnnxBasedTorchConverter, NNIIRConverter + def model_file_to_graph(filename: str, model_type: str, input_shape=(1, 3, 224, 224), apply_nni=False): """ read the given file and convert the model in the file content to nn-Meter IR graph object @@ -106,10 +107,12 @@ def onnx_model_to_graph(model): converter = OnnxConverter(model) return converter.convert() + def nni_model_to_graph(model): converter = NNIIRConverter(model) return converter.convert() + def torch_model_to_graph(model, input_shape=(1, 3, 224, 224), apply_nni=False): torch = try_import_torch() args = torch.randn(*input_shape) diff --git a/nn_meter/kernel_detector/__init__.py b/nn_meter/kernel_detector/__init__.py index 2b31f4e..bd89c7e 100644 --- a/nn_meter/kernel_detector/__init__.py +++ b/nn_meter/kernel_detector/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .detection.detector import KernelDetector +from .kernel_detector import KernelDetector diff --git a/nn_meter/kernel_detector/detection/__init__.py b/nn_meter/kernel_detector/detection/__init__.py deleted file mode 100644 index 9a04545..0000000 --- a/nn_meter/kernel_detector/detection/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. diff --git a/nn_meter/kernel_detector/fusionlib/utils.py b/nn_meter/kernel_detector/fusionlib/utils.py index 7450aae..2ba0502 100644 --- a/nn_meter/kernel_detector/fusionlib/utils.py +++ b/nn_meter/kernel_detector/fusionlib/utils.py @@ -3,7 +3,7 @@ import os import json from nn_meter.utils.graph_tool import ModelGraph -from nn_meter.kernel_detector.utils.ir_tools import convert_nodes +from ..utils.ir_tools import convert_nodes BASE_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/nn_meter/kernel_detector/detection/detector.py b/nn_meter/kernel_detector/kernel_detector.py similarity index 92% rename from nn_meter/kernel_detector/detection/detector.py rename to nn_meter/kernel_detector/kernel_detector.py index 8799129..2f14948 100644 --- a/nn_meter/kernel_detector/detection/detector.py +++ b/nn_meter/kernel_detector/kernel_detector.py @@ -1,11 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from nn_meter.kernel_detector.rulelib.rule_reader import RuleReader -from nn_meter.kernel_detector.rulelib.rule_splitter import RuleSplitter from nn_meter.utils.graph_tool import ModelGraph -from nn_meter.kernel_detector.utils.constants import DUMMY_TYPES -from nn_meter.kernel_detector.utils.ir_tools import convert_nodes -# import logging +from .utils.constants import DUMMY_TYPES +from .utils.ir_tools import convert_nodes +from .rulelib.rule_reader import RuleReader +from .rulelib.rule_splitter import RuleSplitter class KernelDetector: diff --git a/nn_meter/kernel_detector/rulelib/rule_reader.py b/nn_meter/kernel_detector/rulelib/rule_reader.py index 8bb54c0..4e2fdfd 100644 --- a/nn_meter/kernel_detector/rulelib/rule_reader.py +++ b/nn_meter/kernel_detector/rulelib/rule_reader.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import json +from ..fusionlib import get_fusion_unit from nn_meter.utils.graph_tool import ModelGraph -from nn_meter.kernel_detector.fusionlib import get_fusion_unit class RuleReader: diff --git a/nn_meter/kernel_detector/rulelib/rule_splitter.py b/nn_meter/kernel_detector/rulelib/rule_splitter.py index dac4446..084e5af 100644 --- a/nn_meter/kernel_detector/rulelib/rule_splitter.py +++ b/nn_meter/kernel_detector/rulelib/rule_splitter.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from .rule_reader import RuleReader +from ..utils.match_helper import MatchHelper +from ..utils.fusion_aware_graph import FusionAwareGraph from nn_meter.utils.graph_tool import ModelGraph -from nn_meter.kernel_detector.utils.match_helper import MatchHelper -from nn_meter.kernel_detector.utils.fusion_aware_graph import FusionAwareGraph class RuleSplitter: diff --git a/nn_meter/kernel_detector/utils/fusion_aware_graph.py b/nn_meter/kernel_detector/utils/fusion_aware_graph.py index fc6dd01..ab75f6e 100644 --- a/nn_meter/kernel_detector/utils/fusion_aware_graph.py +++ b/nn_meter/kernel_detector/utils/fusion_aware_graph.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from nn_meter.utils.graph_tool import ModelGraph -from .union_find import UF import networkx as nx +from .union_find import UF +from nn_meter.utils.graph_tool import ModelGraph class FusionAwareGraph: diff --git a/nn_meter/nn_meter_cli.py b/nn_meter/nn_meter_cli.py index 55dccd2..39ce781 100644 --- a/nn_meter/nn_meter_cli.py +++ b/nn_meter/nn_meter_cli.py @@ -5,12 +5,7 @@ import os import sys import argparse import logging -from nn_meter.nn_meter import * - -__user_config_folder__ = os.path.expanduser('~/.nn_meter/config') -__user_data_folder__ = os.path.expanduser('~/.nn_meter/data') - -__predictors_cfg_filename__ = 'predictors.yaml' +from nn_meter import * def list_latency_predictors_cli(): @@ -62,6 +57,7 @@ def apply_latency_predictor_cli(args): return result + def get_nnmeter_ir_cli(args): """convert pb file or onnx file to nn-Meter IR graph according to the command line interface arguments """ diff --git a/nn_meter/predictor/__init__.py b/nn_meter/predictor/__init__.py index 6bb6217..0f9eeef 100644 --- a/nn_meter/predictor/__init__.py +++ b/nn_meter/predictor/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .predictors.utils import latency_metrics - - +from .prediction.utils import latency_metrics +from .nn_meter_predictor import nnMeterPredictor, list_latency_predictors, load_latency_predictor diff --git a/nn_meter/predictor/nn_meter_predictor.py b/nn_meter/predictor/nn_meter_predictor.py index 0cdc99b..aed99e2 100644 --- a/nn_meter/predictor/nn_meter_predictor.py +++ b/nn_meter/predictor/nn_meter_predictor.py @@ -1,18 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from glob import glob -from nn_meter.predictor.predictors.predict_by_kernel import nn_predict -from nn_meter.kernel_detector import KernelDetector -from nn_meter.ir_converter import model_file_to_graph, model_to_graph -from .utils import loading_to_local - -import yaml import os +import yaml import pkg_resources from shutil import copyfile from packaging import version import logging + +from .utils import loading_to_local +from .prediction.predict_by_kernel import nn_predict +from nn_meter.kernel_detector import KernelDetector from nn_meter.utils import load_config_file, get_user_data_folder +from nn_meter.ir_converter import model_file_to_graph, model_to_graph + __predictors_cfg_filename__ = 'predictors.yaml' diff --git a/nn_meter/predictor/predictors/__init__.py b/nn_meter/predictor/prediction/__init__.py similarity index 100% rename from nn_meter/predictor/predictors/__init__.py rename to nn_meter/predictor/prediction/__init__.py diff --git a/nn_meter/predictor/predictors/extract_feature.py b/nn_meter/predictor/prediction/extract_feature.py similarity index 100% rename from nn_meter/predictor/predictors/extract_feature.py rename to nn_meter/predictor/prediction/extract_feature.py index 1c2a91d..8a37c76 100644 --- a/nn_meter/predictor/predictors/extract_feature.py +++ b/nn_meter/predictor/prediction/extract_feature.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import logging import numpy as np from sklearn.metrics import mean_squared_error -import logging def get_flop(input_channel, output_channel, k, H, W, stride): diff --git a/nn_meter/predictor/predictors/kernel_predictor.py b/nn_meter/predictor/prediction/kernel_predictor.py similarity index 100% rename from nn_meter/predictor/predictors/kernel_predictor.py rename to nn_meter/predictor/prediction/kernel_predictor.py diff --git a/nn_meter/predictor/predictors/predict_by_kernel.py b/nn_meter/predictor/prediction/predict_by_kernel.py similarity index 100% rename from nn_meter/predictor/predictors/predict_by_kernel.py rename to nn_meter/predictor/prediction/predict_by_kernel.py diff --git a/nn_meter/predictor/predictors/utils.py b/nn_meter/predictor/prediction/utils.py similarity index 99% rename from nn_meter/predictor/predictors/utils.py rename to nn_meter/predictor/prediction/utils.py index 2cf40a7..3ffe181 100644 --- a/nn_meter/predictor/predictors/utils.py +++ b/nn_meter/predictor/prediction/utils.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - import numpy as np from sklearn.metrics import mean_squared_error + def get_kernel_name(optype): """ for many similar kernels, we use one kernel predictor since their latency difference is negligible, @@ -34,6 +34,7 @@ def get_kernel_name(optype): return optype + def get_accuracy(y_pred, y_true, threshold=0.01): a = (y_true - y_pred) / y_true b = np.where(abs(a) <= threshold) diff --git a/nn_meter/predictor/utils.py b/nn_meter/predictor/utils.py index add48c2..cf5f580 100644 --- a/nn_meter/predictor/utils.py +++ b/nn_meter/predictor/utils.py @@ -7,7 +7,7 @@ from zipfile import ZipFile from tqdm import tqdm import requests import logging -from nn_meter.utils.utils import download_from_url +from nn_meter.utils import download_from_url def loading_to_local(pred_info, dir="data/predictorzoo"): diff --git a/nn_meter/utils/__init__.py b/nn_meter/utils/__init__.py index b8d4662..a27d1e9 100644 --- a/nn_meter/utils/__init__.py +++ b/nn_meter/utils/__init__.py @@ -5,4 +5,5 @@ from config_manager import ( get_user_data_folder, change_user_data_folder, load_config_file -) \ No newline at end of file +) +from utils import download_from_url \ No newline at end of file diff --git a/nn_meter/utils/config_manager.py b/nn_meter/utils/config_manager.py index d1ca03d..f8ace2a 100644 --- a/nn_meter/utils/config_manager.py +++ b/nn_meter/utils/config_manager.py @@ -1,14 +1,9 @@ -from glob import glob -from nn_meter.predictor.predictors.predict_by_kernel import nn_predict -from nn_meter.kernel_detector import KernelDetector -from nn_meter.ir_converter import model_file_to_graph, model_to_graph -from nn_meter.predictor.load_predictors import loading_to_local - import yaml import os +import logging import pkg_resources from shutil import copyfile -import logging + __user_config_folder__ = os.path.expanduser('~/.nn_meter/config') __default_user_data_folder__ = os.path.expanduser('~/.nn_meter/data') diff --git a/nn_meter/utils/import_package.py b/nn_meter/utils/import_package.py index f41def3..08a5cb7 100644 --- a/nn_meter/utils/import_package.py +++ b/nn_meter/utils/import_package.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from packaging import version import logging +from packaging import version def try_import_onnx(require_version = ["1.9.0"]): diff --git a/nn_meter/utils/utils.py b/nn_meter/utils/utils.py index c7e7649..02cd919 100644 --- a/nn_meter/utils/utils.py +++ b/nn_meter/utils/utils.py @@ -20,7 +20,7 @@ def download_from_url(urladdr, ppath): if not os.path.isdir(ppath): os.makedirs(ppath) - # logging.keyinfo(f'Download from {urladdr}') + logging.keyinfo(f'Download from {urladdr}') response = requests.get(urladdr, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 2048 # 2 Kibibyte