Merge branch 'main' into dev/torch-converter

This commit is contained in:
Jiahang Xu 2021-08-17 14:45:11 +08:00 коммит произвёл GitHub
Родитель 2095b0d299 e8cc2e8482
Коммит 3bbbe7c832
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 178 добавлений и 137 удалений

Просмотреть файл

@ -0,0 +1,37 @@
# Hardware-aware DNN Model Design
In many DNN model deployment scenarios, there are strict inference efficiency constraints as well as the model accuracy. For example, the **inference latency** and **energy consumption** are the most frequently used criteria of efficiencies to determine whether a DNN model could be deployed on a mobile phone or not. Therefore, DNN model designers have to consider the model efficiency. A typical methodology is to train a big model to meet the accuracy requirements first, and then apply model compression algorithms to get a light-weight model with similar accuracy but much smaller size. Due to many reasons, they use the number of parameters and FLOPs in the compression process.
However, as pointed out in our work [[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) and many others, ***neither number of parameters nor number of FLOPs is a good metric of the real inference efficiency (e.g., latency or energy consumption)***. Operators with similar FLOPs may have very different inference latency on different hardware platforms (e.g., CPU, GPU, and ASIC) (shown in work [[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) and [[3]](https://proceedings.mlsys.org/paper/2021/file/02522a2b2726fb0a03bb19f2d8d9524d-Paper.pdf)). This makes the effort of designing efficient DNN models for a target hardware bit of games of opening blind boxes. Recently, many hardware-aware NAS works are proposed to solve this challenge.
Compared with the conventional NAS algorithms, some recent works (i.e. hardware-aware NAS, aka HW-NAS) integrated hardware-awareness into the search loop and achieves a balanced trade-off between accuracy and hardware efficiencies [[4]](http://arxiv.org/abs/2101.09336).
Next, we introduce our hardware-aware NAS framework[[1]](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf), which combines the nn-Meter, to search high-accuracy DNN models within the latency constraints for target edge devices.
## Hardware-aware Neural Architecture Search
<img src="imgs/hw-nas.png" alt="drawing" width="800"/>
**Hardware-aware Search Space Generation.** As formulated in many works, the search space is one of the three key aspects of a NAS process (the other two are the search strategy and the evaluation methodology) and matters a lot to the final results.
Our HW-NAS framework firstly automatically selects the hardware-friendly operators (or blocks) by considering both representation capacity and hardware efficiency. The selected operators could establish a ***hardware-aware search space*** for most of existing NAS algorithms.
**Latency Prediction in search process by nn-Meter.** Different with other simple predictors (e.g., look-up table for operators/blocks, linear regression models), [nn-Meter](overview.md) conducts kernel-level prediction, which captures the complex model graph optimizations on edge devices. nn-Meter is the first accurate latency prediction tool for DNNs on edge devices.
Besides the search space specialization, our HW-NAS framework also allows combining nn-Meter with existing NAS algorithms in the optimization objectives and constraints. As described in [[4]](http://arxiv.org/abs/2101.09336), the HW-NAS algorithms often consider hardware efficiency metrics as the constraints of existing NAS formulation or part of the scalarized loss functions (e.g., the loss is weighted sum of both cross entropy loss and hardware-aware penalty). Since the NAS process may sample up to millions of candidate model architectures, the obtaining of hardware metrics must be accurate and efficient.
nn-Meter is now integrated with [NNI](https://github.com/microsoft/nni), the AutoML framework also published by Microsoft, and could be combined with existing NAS algorithms seamlessly. [This doc](https://nni.readthedocs.io/en/stable/NAS/multi_trial_nas.html) show how to construct a latency constraint filter in [random search algorithm](https://arxiv.org/abs/1902.07638) on [SPOS NAS](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf) search space. Users could use this filter in multiple phases of the NAS process, e.g., the architecture searching phase and the super-net training phase.
***Note that current nn-Meter project is limited to the latency prediction. For the other hardware metrics, e.g., energy consumption is another important metric in edge computing. Collaborations and contributions together with nn-Meter are highly welcomed!***
## Other hardware-aware techniques
Besides light weighted NAS, which search for an efficient architecture directly, there are also other techniques to achieve light weight DNN models, such as model compression and knowledge distillation (KD). Both methods tries to get a smaller but similar-performed models from a pre-trained big model. The difference is that model compression removes some of the components in the origin model, while knowledge distillation constructs a new student model and lets it learn the behavior of the origin model. Hardware awareness could also be combined with these methods.
For example, nn-Meter could help users to construct suitable student architectures for the target hardware platform in the KD task.
## References
1. Li Lyna Zhang, Yuqing Yang, Yuhang Jiang, Wenwu Zhu, Yunxin Liu: [&#34;Fast hardware-aware neural architecture search.&#34;](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w40/Zhang_Fast_Hardware-Aware_Neural_Architecture_Search_CVPRW_2020_paper.pdf) Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops. 2020.
2. Li Lyna Zhang, Shihao Han, Jianyu Wei, Ningxin Zheng, Ting Cao, Yuqing Yang, Yunxin Liu: [&#34;nn-Meter: Towards Accurate Latency Prediction of Deep-Learning Model Inference on Diverse Edge Devices.&#34;](https://dl.acm.org/doi/10.1145/3458864.3467882) Proceedings of the 19th ACM International Conference on Mobile Systems, Applications, and Services (MobiSys 2021)
3. Xiaohu Tang, Shihao Han, Li Lyna Zhang, Ting Cao, Yunxin Liu: [&#34;To Bridge Neural Network Design and Real-World Performance: A Behaviour Study for Neural Networks&#34;](https://proceedings.mlsys.org/paper/2021/file/02522a2b2726fb0a03bb19f2d8d9524d-Paper.pdf) Proceedings of the 4th MLSys Conference (MLSys 2021)
4. Benmeziane, H., Maghraoui, K. el, Ouarnoughi, H., Niar, S., Wistuba, M., & Wang, N. (2021).[&#34; A Comprehensive Survey on Hardware-Aware Neural Architecture Search.&#34;](http://arxiv.org/abs/2101.09336)

Двоичные данные
docs/imgs/hw-nas.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 283 KiB

Просмотреть файл

@ -19,4 +19,6 @@ If you have a new hardware to predict DNN latency, a re-run of nn-Meter is requ
## Learn More
- [Get started](quick_start.md)
- [How to use nn-Meter](usage.md)
- [How to use nn-Meter](usage.md)
- [nn-meter in hardware-aware NAS](hardware-aware-model-design.md)

Просмотреть файл

@ -93,7 +93,7 @@ Users could get a nn-Meter IR graph by applying `model_file_to_graph` and `model
## Hardware-aware NAS by nn-Meter and NNI
To empower affordable DNN on the edge and mobile devices, hardware-aware NAS searches both high accuracy and low latency models. In particular, the search algorithm only considers the models within the target latency constraints during the search process.
To empower affordable DNN on the edge and mobile devices, hardware-aware NAS searches both high accuracy and low latency models. In particular, the search algorithm only considers the models within the target latency constraints during the search process. For more theoretical details, please refer to [this doc](hardware-aware-model-design.md).
Currently we provides example of end-to-end [multi-trial NAS](https://nni.readthedocs.io/en/stable/NAS/multi_trial_nas.html), which is a [random search algorithm](https://arxiv.org/abs/1902.07638) on [SPOS NAS](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf) search space. More examples of more hardware-aware NAS and model compression algorithms are coming soon.

Просмотреть файл

@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .predictors.utils import latency_metrics

Просмотреть файл

@ -1,135 +1,135 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
from sklearn.metrics import mean_squared_error
import logging
def get_flop(input_channel, output_channel, k, H, W, stride):
paras = output_channel * (k * k * input_channel + 1)
flops = 2 * H / stride * W / stride * paras
return flops, paras
def get_conv_mem(input_channel, output_channel, k, H, W, stride):
paras = output_channel * (k * k * input_channel + 1)
mem = paras + output_channel * H / stride * W / stride + input_channel * H * W
return mem
def get_depthwise_flop(input_channel, output_channel, k, H, W, stride):
paras = output_channel * (k * k + 1)
flops = 2 * H / stride * W / stride * paras
return flops, paras
def get_flops_params(blocktype, hw, cin, cout, kernelsize, stride):
if "dwconv" in blocktype:
return get_depthwise_flop(cin, cout, kernelsize, hw, hw, stride)
elif "conv" in blocktype:
return get_flop(cin, cout, kernelsize, hw, hw, stride)
elif "fc" in blocktype:
flop = (2 * cin + 1) * cout
return flop, flop
def get_predict_features(config):
"""
get prediction features
"""
mdicts = {}
layer = 0
for item in config:
logging.info(item)
for item in config:
op = item["op"]
if "conv" in op or "maxpool" in op or "avgpool" in op:
cout = item["cout"]
cin = item["cin"]
ks = item["ks"][1]
s = item["strides"][1] if "strides" in item else 1
inputh = item["inputh"]
if op in ["channelshuffle", "split"]:
[b, inputh, inputw, cin] = item["input_tensors"][0]
if "conv" in op:
flops, params = get_flops_params(op, inputh, cin, cout, ks, s)
features = [inputh, cin, cout, ks, s, flops / 2e6, params / 1e6]
elif "fc" in op or "fc-relu" in op:
cout = item["cout"]
cin = item["cin"]
flop = (2 * cin + 1) * cout
features = [cin, cout, flop / 2e6, flop / 1e6]
elif "pool" in op and "global" not in op:
features = [inputh, cin, cout, ks, s]
elif "global-pool" in op or "global-avgpool" in op or "gap" in op:
inputh = 1
cin = item["cin"]
features = [inputh, cin]
elif "channelshuffle" in op:
features = [inputh, cin]
elif "split" in op:
features = [inputh, cin]
elif "se" in op or "SE" in op:
inputh = item["input_tensors"][-1][-2]
cin = item["input_tensors"][-1][-1]
features = [inputh, cin]
elif "concat" in op: # maximum 4 branches
itensors = item["input_tensors"]
inputh = itensors[0][1]
features = [inputh, len(itensors)]
for it in itensors:
co = it[-1]
features.append(co)
if len(features) < 6:
features = features + [0] * (6 - len(features))
elif len(features) > 6:
nf = features[0:6]
features = nf
features[1] = 6
elif op in ["hswish"]:
if "inputh" in item:
inputh = item["inputh"]
else:
inputh = item["input_tensors"][0][1]
cin = item["cin"]
features = [inputh, cin]
elif op in ["bn", "relu", "bn-relu"]:
itensors = item["input_tensors"]
if len(itensors[0]) == 4:
inputh = itensors[0][1]
cin = itensors[0][3]
else:
inputh = itensors[0][0]
cin = itensors[0][1]
features = [inputh, cin]
elif op in ["add-relu", "add"]:
itensors = item["input_tensors"]
inputh = itensors[0][1]
cin1 = itensors[0][3]
cin2 = itensors[1][3]
features = [inputh, cin1, cin2]
else: # indicates that there is no matching predictor for this op
# logging.warning(f'There is no matching predictor for op {op}.')
continue
mdicts[layer] = {}
mdicts[layer][op] = features
layer += 1
return mdicts
def read_model_latency(latency_file):
"""
read model latency csv files. It can provide the benchmarked latency, and compare with the predicted latency
"""
f = open(latency_file, "r")
dicts = {}
while True:
line = f.readline()
if not line:
break
content = line.strip().split(",")
model = content[1]
latency = float(content[2])
dicts[model] = latency
return dicts
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
from sklearn.metrics import mean_squared_error
import logging
def get_flop(input_channel, output_channel, k, H, W, stride):
paras = output_channel * (k * k * input_channel + 1)
flops = 2 * H / stride * W / stride * paras
return flops, paras
def get_conv_mem(input_channel, output_channel, k, H, W, stride):
paras = output_channel * (k * k * input_channel + 1)
mem = paras + output_channel * H / stride * W / stride + input_channel * H * W
return mem
def get_depthwise_flop(input_channel, output_channel, k, H, W, stride):
paras = output_channel * (k * k + 1)
flops = 2 * H / stride * W / stride * paras
return flops, paras
def get_flops_params(blocktype, hw, cin, cout, kernelsize, stride):
if "dwconv" in blocktype:
return get_depthwise_flop(cin, cout, kernelsize, hw, hw, stride)
elif "conv" in blocktype:
return get_flop(cin, cout, kernelsize, hw, hw, stride)
elif "fc" in blocktype:
flop = (2 * cin + 1) * cout
return flop, flop
def get_predict_features(config):
"""
get prediction features
"""
mdicts = {}
layer = 0
for item in config:
logging.info(item)
for item in config:
op = item["op"]
if "conv" in op or "maxpool" in op or "avgpool" in op:
cout = item["cout"]
cin = item["cin"]
ks = item["ks"][1]
s = item["strides"][1] if "strides" in item else 1
inputh = item["inputh"]
if op in ["channelshuffle", "split"]:
[b, inputh, inputw, cin] = item["input_tensors"][0]
if "conv" in op:
flops, params = get_flops_params(op, inputh, cin, cout, ks, s)
features = [inputh, cin, cout, ks, s, flops / 2e6, params / 1e6]
elif "fc" in op or "fc-relu" in op:
cout = item["cout"]
cin = item["cin"]
flop = (2 * cin + 1) * cout
features = [cin, cout, flop / 2e6, flop / 1e6]
elif "pool" in op and "global" not in op:
features = [inputh, cin, cout, ks, s]
elif "global-pool" in op or "global-avgpool" in op or "gap" in op:
inputh = 1
cin = item["cin"]
features = [inputh, cin]
elif "channelshuffle" in op:
features = [inputh, cin]
elif "split" in op:
features = [inputh, cin]
elif "se" in op or "SE" in op:
inputh = item["input_tensors"][-1][-2]
cin = item["input_tensors"][-1][-1]
features = [inputh, cin]
elif "concat" in op: # maximum 4 branches
itensors = item["input_tensors"]
inputh = itensors[0][1]
features = [inputh, len(itensors)]
for it in itensors:
co = it[-1]
features.append(co)
if len(features) < 6:
features = features + [0] * (6 - len(features))
elif len(features) > 6:
nf = features[0:6]
features = nf
features[1] = 6
elif op in ["hswish"]:
if "inputh" in item:
inputh = item["inputh"]
else:
inputh = item["input_tensors"][0][1]
cin = item["cin"]
features = [inputh, cin]
elif op in ["bn", "relu", "bn-relu"]:
itensors = item["input_tensors"]
if len(itensors[0]) == 4:
inputh = itensors[0][1]
cin = itensors[0][3]
else:
inputh = itensors[0][0]
cin = itensors[0][1]
features = [inputh, cin]
elif op in ["add-relu", "add"]:
itensors = item["input_tensors"]
inputh = itensors[0][1]
cin1 = itensors[0][3]
cin2 = itensors[1][3]
features = [inputh, cin1, cin2]
else: # indicates that there is no matching predictor for this op
# logging.warning(f'There is no matching predictor for op {op}.')
continue
mdicts[layer] = {}
mdicts[layer][op] = features
layer += 1
return mdicts
def read_model_latency(latency_file):
"""
read model latency csv files. It can provide the benchmarked latency, and compare with the predicted latency
"""
f = open(latency_file, "r")
dicts = {}
while True:
line = f.readline()
if not line:
break
content = line.strip().split(",")
model = content[1]
latency = float(content[2])
dicts[model] = latency
return dicts