Merge review to main (fix README.md conflicts)
This commit is contained in:
Коммит
14503e1a27
41
README.md
41
README.md
|
@ -1,5 +1,5 @@
|
|||
|
||||
nn-Meter is a novel and efficient system to accurately predict the inference latency of DNN models on diverse edge devices. The key idea is dividing a whole model inference into kernels, i.e., the execution units of fused operators on a device, and conduct kernel-level prediction.
|
||||
nn-Meter is a novel and efficient system to accurately predict the inference latency of DNN models on diverse edge devices. The key idea is dividing a whole model inference into kernels, i.e., the execution units of fused operators on a device, and conduct kernel-level prediction.
|
||||
nn-Meter contains two key techniques: (i) kernel detection to automatically detect the execution unit of model inference via a set of well-designed test cases; (ii) adaptive sampling to efficiently sample the most beneficial configurations from a large space to build accurate kernel-level latency predictors.
|
||||
We currently evaluate four popular platforms on a large dataset of 26k models. It achieves 99.0% (mobile CPU), 99.1% (mobile Adreno 640 GPU), 99.0% (mobile Adreno 630 GPU), and 83.4% (Intel VPU) prediction accuracy.
|
||||
|
||||
|
@ -7,7 +7,7 @@ The current supported hardware and inference frameworks:
|
|||
|
||||
| Abbr. | Device | Framework | Processor | +-10% Accuracy | key in nn-Meter usage |
|
||||
|:----:|:-------------------:|:--------------:|:--------------:|:------------------:|:-------------------:|
|
||||
| CPU | Pixel4 | TFLite v2.1 | CortexA76 CPU | 99.0% | **cortexA76cpu_tflite21** |
|
||||
| CPU | Pixel4 | TFLite v2.1 | CortexA76 CPU | 99.0% | **cortexA76cpu_tflite21** |
|
||||
| GPU | Mi9 | TFLite v2.1 | Adreno 640 GPU | 99.1% | **adreno640gpu_tflite21** |
|
||||
| GPU1 | Pixel3XL | TFLite v2.1 | Adreno 630 GPU | 99.0% | **adreno630gpu_tflite21** |
|
||||
| VPU | Intel Movidius NCS2 | OpenVINO2019R2 | Myriad VPU | 83.4% | **myriadvpu_openvino2019r2** |
|
||||
|
@ -19,9 +19,9 @@ The current supported hardware and inference frameworks:
|
|||
- Those who want to **build latency predictors for their own devices**.
|
||||
## Installation
|
||||
|
||||
To install nn-meter, please first install python3. The test environment uses anaconda python 3.6.10. Install the dependencies via:
|
||||
To install nn-meter, please first install python3. The test environment uses anaconda python 3.6.10. Install the dependencies via:
|
||||
`pip3 install -r requirements.txt`
|
||||
Please also check the versions of numpy, scikit_learn. The different versions may change the prediction accuracy of kernel predictors.
|
||||
Please also check the versions of numpy, scikit_learn. The different versions may change the prediction accuracy of kernel predictors.
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -36,10 +36,10 @@ nn-Meter currently supports prediction on the following four config:
|
|||
|
||||
| hardware |
|
||||
|:-------------------:|
|
||||
| cortexA76cpu_tflite21 |
|
||||
| adreno640gpu_tflite21 |
|
||||
| adreno630gpu_tflite21 |
|
||||
| myriadvpu_openvino2019r2 |
|
||||
| cortexA76cpu_tflite21 |
|
||||
| adreno640gpu_tflite21 |
|
||||
| adreno630gpu_tflite21 |
|
||||
| myriadvpu_openvino2019r2 |
|
||||
|
||||
For the input model file, you can find any example provided under the `data/testmodels`
|
||||
|
||||
|
@ -47,7 +47,7 @@ For the input model file, you can find any example provided under the `data/test
|
|||
|
||||
|
||||
### Predict inference latency (to be implemented)
|
||||
nn-Meter could be seamlessly integrated with existing `PyTorch` codes to predict the inference latency of an `torch.nn.Module` object.
|
||||
nn-Meter could be seamlessly integrated with existing `PyTorch` codes to predict the inference latency of an `torch.nn.Module` object.
|
||||
```python
|
||||
from nn_meter import load_latency_predictor
|
||||
|
||||
|
@ -58,17 +58,32 @@ model = ... # model is instance of torch.nn.Module
|
|||
|
||||
lat = predictor.predict(model)
|
||||
```
|
||||
By calling `load_latency_predictor`, user selects the target backend (`Framework-Hardware`) and loads the corresponding predictor. nn-Meter will try to find the right predictor file in `~/.nn_meter/predictors`. If the predictor file doesn't exist, it will download from the Github repo.
|
||||
By calling `load_latency_predictor`, user selects the target backend (`Framework-Hardware`) and loads the corresponding predictor. nn-Meter will try to find the right predictor file in `~/.nn_meter/predictors`. If the predictor file doesn't exist, it will download from the Github repo.
|
||||
|
||||
Users could view the information all built-in predictors by `list_latency_predictors` or view the config file in `~/.nn_meter/config.json`.
|
||||
|
||||
### Use nn-Meter in commands (to do)
|
||||
To predict the latency for saved models, users could also use the nn-Meter command like
|
||||
To predict the latency for saved models, users could also use the nn-Meter command like
|
||||
|
||||
```bash
|
||||
nn-meter --input_model data/testmodels/alexnet.onnx --backend TFLite-CortexA76
|
||||
```
|
||||
Currently we support `ONNX` format (ONNX files of popular CNN models are included in [`data/testmodels`](data/testmodels)) and Tensorflow pb file.
|
||||
Currently we support `ONNX` format (ONNX files of popular CNN models are included in [`data/testmodels`](data/testmodels)) and Tensorflow pb file.
|
||||
|
||||
### Hardware-aware NAS by nn-Meter and NNI
|
||||
Install NNI by following [NNI Doc](https://nni.readthedocs.io/en/stable/Tutorial/InstallationLinux.html#installation).
|
||||
|
||||
Install nn-Meter from source code (currently we haven't released this package, so development installation is required).
|
||||
|
||||
```bash
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
Then run multi-trail SPOS demo:
|
||||
|
||||
```bash
|
||||
python ${NNI_ROOT}/examples/nas/oneshot/spos/multi_trial.py
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -85,7 +100,7 @@ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.
|
|||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## License
|
||||
The entire codebase is under [MIT license](https://github.com/microsoft/nn-Meter/blob/main/LICENSE)
|
||||
The entire codebase is under [MIT license](https://github.com/microsoft/nn-Meter/blob/main/LICENSE)
|
||||
|
||||
The dataset is under [Open Use of Data Agreement](https://github.com/Community-Data-License-Agreements/Releases/blob/main/O-UDA-1.0.md)
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .nn_meter import load_latency_predictors, nnMeter
|
||||
from .nn_meter import load_latency_predictors, nnMeter, get_default_config
|
||||
|
|
|
@ -3,18 +3,103 @@ import onnx
|
|||
import tempfile
|
||||
from nn_meter.ir_converters.onnx_converter import OnnxConverter
|
||||
|
||||
from nni.retiarii.converter import convert_to_graph
|
||||
from nni.retiarii.converter.graph_gen import GraphConverterWithShape
|
||||
from nni.retiarii.graph import Model
|
||||
|
||||
class TorchConverter(OnnxConverter):
|
||||
def __init__(self, model, args):
|
||||
"""
|
||||
@params
|
||||
args: model input, refer to
|
||||
https://pytorch.org/docs/stable/onnx.html#example-end-to-end-alexnet-from-pytorch-to-onnx
|
||||
for more information.
|
||||
"""
|
||||
from .opset_map import nni_attr_map, nni_type_map
|
||||
|
||||
|
||||
class NNIIRConverter:
|
||||
def __init__(self, ir_model: Model):
|
||||
self.ir_model = ir_model.fork()
|
||||
GraphConverterWithShape().flatten(self.ir_model)
|
||||
|
||||
def convert(self):
|
||||
graphe = self._to_graphe_layout()
|
||||
|
||||
for _, node in graphe.items():
|
||||
self._map_opset(node)
|
||||
|
||||
self._remove_unshaped_nodes(graphe)
|
||||
|
||||
return graphe
|
||||
|
||||
def _to_graphe_layout(self):
|
||||
graphe = {}
|
||||
|
||||
for node in self.ir_model.root_graph.hidden_nodes:
|
||||
node_dict = {
|
||||
'attr': {
|
||||
'attr': {
|
||||
k: v
|
||||
for k, v in node.operation.parameters.items()
|
||||
if k not in ['input_shape', 'output_shape']
|
||||
},
|
||||
'input_shape': node.operation.parameters['input_shape'],
|
||||
'output_shape': node.operation.parameters['output_shape'],
|
||||
'type': node.operation.type,
|
||||
},
|
||||
'inbounds': [],
|
||||
'outbounds': [],
|
||||
}
|
||||
|
||||
incoming_edges = sorted(node.incoming_edges, key=lambda e: e.tail_slot or 0)
|
||||
for edge in incoming_edges:
|
||||
node_dict['inbounds'].append(edge.head.name)
|
||||
|
||||
outgoing_edges = sorted(node.outgoing_edges, key=lambda e: e.head_slot or 0)
|
||||
for edge in outgoing_edges:
|
||||
node_dict['outbounds'].append(edge.tail.name)
|
||||
|
||||
graphe[node.name] = node_dict
|
||||
|
||||
return graphe
|
||||
|
||||
def _map_opset(self, node):
|
||||
old_type = node['attr']['type']
|
||||
new_type = nni_type_map.get(old_type, old_type)
|
||||
|
||||
new_attr_dict = {}
|
||||
for attr_name, attr_value in node['attr']['attr'].items():
|
||||
new_attr_name = attr_name
|
||||
new_attr_value = attr_value
|
||||
for type, attr_map in nni_attr_map.items():
|
||||
if type == '__all__' or type == new_type:
|
||||
if attr_name in attr_map:
|
||||
new_attr_name, modifier = attr_map[attr_name]
|
||||
if modifier is not None:
|
||||
new_attr_value = modifier(attr_value)
|
||||
|
||||
new_attr_dict[new_attr_name] = new_attr_value
|
||||
|
||||
node['attr']['type'] = new_type
|
||||
node['attr']['attr'] = new_attr_dict
|
||||
|
||||
def _remove_unshaped_nodes(self, graphe):
|
||||
for node_name, node_dict in list(graphe.items()):
|
||||
if not node_dict['attr']['input_shape']:
|
||||
del graphe[node_name]
|
||||
|
||||
|
||||
class NNIBasedTorchConverter(NNIIRConverter):
|
||||
def __init__(self, model, example_inputs):
|
||||
# PyTorch module to NNI IR
|
||||
script_module = torch.jit.script(model)
|
||||
converter = GraphConverterWithShape()
|
||||
ir_model = convert_to_graph(script_module, model, converter, example_inputs=example_inputs)
|
||||
|
||||
super().__init__(ir_model)
|
||||
|
||||
|
||||
class OnnxBasedTorchConverter(OnnxConverter):
|
||||
def __init__(self, model, example_inputs):
|
||||
with tempfile.TemporaryFile() as fp:
|
||||
torch.onnx.export(model, args, fp)
|
||||
torch.onnx.export(model, example_inputs, fp)
|
||||
fp.seek(0)
|
||||
model = onnx.load(fp, load_external_data=False)
|
||||
|
||||
super().__init__(model)
|
||||
|
||||
|
||||
TorchConverter = NNIBasedTorchConverter
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
nni_type_map = {
|
||||
'aten::mul': 'mul',
|
||||
'aten::floordiv': 'div',
|
||||
'aten::reshape': 'reshape',
|
||||
'aten::cat': 'concat',
|
||||
'__torch__.torch.nn.modules.conv.Conv2d': 'conv',
|
||||
'__torch__.torch.nn.modules.activation.ReLU': 'relu',
|
||||
'__torch__.torch.nn.modules.batchnorm.BatchNorm2d': 'bn',
|
||||
'__torch__.torch.nn.modules.linear.Linea': 'fc',
|
||||
'__torch__.torch.nn.modules.pooling.AvgPool2d': 'gap'
|
||||
}
|
||||
|
||||
|
||||
def int_to_list_modifier(attr):
|
||||
if isinstance(attr, int):
|
||||
return [attr, attr]
|
||||
|
||||
|
||||
nni_attr_map = {
|
||||
'__all__': {
|
||||
'kernel_size': ('ks', int_to_list_modifier),
|
||||
'padding': ('pads', int_to_list_modifier),
|
||||
'stride': ('strides', int_to_list_modifier),
|
||||
},
|
||||
'concat': {
|
||||
'dim': ('axis', None)
|
||||
},
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import json
|
|||
from .onnx_converter import OnnxConverter
|
||||
from .frozenpb_converter import FrozenPbConverter
|
||||
from .torch_converter import TorchConverter
|
||||
from .torch_converter.converter import NNIIRConverter
|
||||
|
||||
|
||||
def model_to_graph(model, model_type, input_shape=(1, 3, 224, 224)):
|
||||
|
@ -25,6 +26,9 @@ def model_to_graph(model, model_type, input_shape=(1, 3, 224, 224)):
|
|||
args = args.to("cuda")
|
||||
converter = TorchConverter(model, args)
|
||||
result = converter.convert()
|
||||
elif model_type == 'nni':
|
||||
converter = NNIIRConverter(model)
|
||||
result = converter.convert()
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
|
|
|
@ -5,6 +5,17 @@ from .kerneldetection import KernelDetector
|
|||
from .ir_converters import model_to_graph, model_file_to_graph
|
||||
from .prediction.load_predictors import loading_to_local
|
||||
|
||||
import yaml
|
||||
import os
|
||||
|
||||
|
||||
def get_default_config():
|
||||
config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'configs/devices.yaml')
|
||||
with open(config_path, 'r') as fp:
|
||||
config = yaml.load(fp, yaml.FullLoader)['predictors']
|
||||
hardware = 'cortexA76cpu_tflite21'
|
||||
return config, hardware
|
||||
|
||||
|
||||
def load_latency_predictors(config, hardware):
|
||||
kernel_predictors, fusionrule = loading_to_local(config, hardware)
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
setup(
|
||||
name='nn_meter',
|
||||
version='1.0',
|
||||
description='',
|
||||
author='',
|
||||
author_email='',
|
||||
url='',
|
||||
packages=find_packages(),
|
||||
)
|
Загрузка…
Ссылка в новой задаче