Merge review to main (fix README.md conflicts)

This commit is contained in:
kalineid 2021-07-11 22:01:49 -04:00
Родитель 5ccce455d3 9e443c0b24
Коммит 14503e1a27
7 изменённых файлов: 179 добавлений и 23 удалений

Просмотреть файл

@ -70,6 +70,21 @@ nn-meter --input_model data/testmodels/alexnet.onnx --backend TFLite-CortexA76
```
Currently we support `ONNX` format (ONNX files of popular CNN models are included in [`data/testmodels`](data/testmodels)) and Tensorflow pb file.
### Hardware-aware NAS by nn-Meter and NNI
Install NNI by following [NNI Doc](https://nni.readthedocs.io/en/stable/Tutorial/InstallationLinux.html#installation).
Install nn-Meter from source code (currently we haven't released this package, so development installation is required).
```bash
python setup.py develop
```
Then run multi-trail SPOS demo:
```bash
python ${NNI_ROOT}/examples/nas/oneshot/spos/multi_trial.py
```
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a

Просмотреть файл

@ -1 +1 @@
from .nn_meter import load_latency_predictors, nnMeter
from .nn_meter import load_latency_predictors, nnMeter, get_default_config

Просмотреть файл

@ -3,18 +3,103 @@ import onnx
import tempfile
from nn_meter.ir_converters.onnx_converter import OnnxConverter
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.converter.graph_gen import GraphConverterWithShape
from nni.retiarii.graph import Model
class TorchConverter(OnnxConverter):
def __init__(self, model, args):
"""
@params
args: model input, refer to
https://pytorch.org/docs/stable/onnx.html#example-end-to-end-alexnet-from-pytorch-to-onnx
for more information.
"""
from .opset_map import nni_attr_map, nni_type_map
class NNIIRConverter:
def __init__(self, ir_model: Model):
self.ir_model = ir_model.fork()
GraphConverterWithShape().flatten(self.ir_model)
def convert(self):
graphe = self._to_graphe_layout()
for _, node in graphe.items():
self._map_opset(node)
self._remove_unshaped_nodes(graphe)
return graphe
def _to_graphe_layout(self):
graphe = {}
for node in self.ir_model.root_graph.hidden_nodes:
node_dict = {
'attr': {
'attr': {
k: v
for k, v in node.operation.parameters.items()
if k not in ['input_shape', 'output_shape']
},
'input_shape': node.operation.parameters['input_shape'],
'output_shape': node.operation.parameters['output_shape'],
'type': node.operation.type,
},
'inbounds': [],
'outbounds': [],
}
incoming_edges = sorted(node.incoming_edges, key=lambda e: e.tail_slot or 0)
for edge in incoming_edges:
node_dict['inbounds'].append(edge.head.name)
outgoing_edges = sorted(node.outgoing_edges, key=lambda e: e.head_slot or 0)
for edge in outgoing_edges:
node_dict['outbounds'].append(edge.tail.name)
graphe[node.name] = node_dict
return graphe
def _map_opset(self, node):
old_type = node['attr']['type']
new_type = nni_type_map.get(old_type, old_type)
new_attr_dict = {}
for attr_name, attr_value in node['attr']['attr'].items():
new_attr_name = attr_name
new_attr_value = attr_value
for type, attr_map in nni_attr_map.items():
if type == '__all__' or type == new_type:
if attr_name in attr_map:
new_attr_name, modifier = attr_map[attr_name]
if modifier is not None:
new_attr_value = modifier(attr_value)
new_attr_dict[new_attr_name] = new_attr_value
node['attr']['type'] = new_type
node['attr']['attr'] = new_attr_dict
def _remove_unshaped_nodes(self, graphe):
for node_name, node_dict in list(graphe.items()):
if not node_dict['attr']['input_shape']:
del graphe[node_name]
class NNIBasedTorchConverter(NNIIRConverter):
def __init__(self, model, example_inputs):
# PyTorch module to NNI IR
script_module = torch.jit.script(model)
converter = GraphConverterWithShape()
ir_model = convert_to_graph(script_module, model, converter, example_inputs=example_inputs)
super().__init__(ir_model)
class OnnxBasedTorchConverter(OnnxConverter):
def __init__(self, model, example_inputs):
with tempfile.TemporaryFile() as fp:
torch.onnx.export(model, args, fp)
torch.onnx.export(model, example_inputs, fp)
fp.seek(0)
model = onnx.load(fp, load_external_data=False)
super().__init__(model)
TorchConverter = NNIBasedTorchConverter

Просмотреть файл

@ -0,0 +1,29 @@
nni_type_map = {
'aten::mul': 'mul',
'aten::floordiv': 'div',
'aten::reshape': 'reshape',
'aten::cat': 'concat',
'__torch__.torch.nn.modules.conv.Conv2d': 'conv',
'__torch__.torch.nn.modules.activation.ReLU': 'relu',
'__torch__.torch.nn.modules.batchnorm.BatchNorm2d': 'bn',
'__torch__.torch.nn.modules.linear.Linea': 'fc',
'__torch__.torch.nn.modules.pooling.AvgPool2d': 'gap'
}
def int_to_list_modifier(attr):
if isinstance(attr, int):
return [attr, attr]
nni_attr_map = {
'__all__': {
'kernel_size': ('ks', int_to_list_modifier),
'padding': ('pads', int_to_list_modifier),
'stride': ('strides', int_to_list_modifier),
},
'concat': {
'dim': ('axis', None)
},
}

Просмотреть файл

@ -6,6 +6,7 @@ import json
from .onnx_converter import OnnxConverter
from .frozenpb_converter import FrozenPbConverter
from .torch_converter import TorchConverter
from .torch_converter.converter import NNIIRConverter
def model_to_graph(model, model_type, input_shape=(1, 3, 224, 224)):
@ -25,6 +26,9 @@ def model_to_graph(model, model_type, input_shape=(1, 3, 224, 224)):
args = args.to("cuda")
converter = TorchConverter(model, args)
result = converter.convert()
elif model_type == 'nni':
converter = NNIIRConverter(model)
result = converter.convert()
else:
raise ValueError(f"Unsupported model type: {model_type}")

Просмотреть файл

@ -5,6 +5,17 @@ from .kerneldetection import KernelDetector
from .ir_converters import model_to_graph, model_file_to_graph
from .prediction.load_predictors import loading_to_local
import yaml
import os
def get_default_config():
config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'configs/devices.yaml')
with open(config_path, 'r') as fp:
config = yaml.load(fp, yaml.FullLoader)['predictors']
hardware = 'cortexA76cpu_tflite21'
return config, hardware
def load_latency_predictors(config, hardware):
kernel_predictors, fusionrule = loading_to_local(config, hardware)

12
setup.py Normal file
Просмотреть файл

@ -0,0 +1,12 @@
from setuptools import setup, find_packages
setup(
name='nn_meter',
version='1.0',
description='',
author='',
author_email='',
url='',
packages=find_packages(),
)